Skip to content

Commit d3f1968

Browse files
committed
Implement get_many_mut
1 parent 9a5b1fa commit d3f1968

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

src/map.rs

+46
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,52 @@ where
477477
}
478478
}
479479

480+
pub fn get_many_mut<'a, 'b, Q: ?Sized, const N: usize>(
481+
&'a mut self,
482+
keys: [&'b Q; N],
483+
) -> Option<[&'a mut V; N]>
484+
where
485+
Q: Hash + Equivalent<K>,
486+
{
487+
let indices = keys.map(|key| self.get_index_of(key));
488+
if indices.iter().any(Option::is_none) {
489+
return None;
490+
}
491+
let indices = indices.map(Option::unwrap);
492+
493+
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data
494+
for i in 0..N {
495+
let idx = indices[i];
496+
if indices[i + 1..N].contains(&idx) {
497+
return None;
498+
}
499+
}
500+
501+
// Replace with MaybeUninit::uninit_array when that is stable
502+
// SAFETY: Creating MaybeUninit from uninit is always safe
503+
#[allow(unsafe_code)]
504+
let mut out: [std::mem::MaybeUninit<&'a mut V>; N] =
505+
unsafe { std::mem::MaybeUninit::uninit().assume_init() };
506+
507+
let entries = self.as_entries_mut();
508+
for (elem, idx) in out.iter_mut().zip(indices) {
509+
let v: &mut V = &mut entries[idx].value;
510+
// SAFETY: As we know that each index is unique, it is OK to discard the mutable
511+
// borrow lifetime of v, we will never mutably borrow an element twice.
512+
// The pointer is valid and aligned as we get it from MaybeUninit.
513+
#[allow(unsafe_code)]
514+
unsafe { std::ptr::write(elem.as_mut_ptr(), &mut *(v as *mut V)) };
515+
}
516+
517+
// Can't transmute a const-sized array:
518+
// https://github.com/rust-lang/rust/issues/61956
519+
// This is the workaround.
520+
// SAFETY: This is fine as the references all are from unique entries that we own and all of
521+
// them have been properly initialized by the above loop.
522+
#[allow(unsafe_code)]
523+
Some(unsafe { std::mem::transmute_copy::<_, [&'a mut V; N]>(&out) })
524+
}
525+
480526
/// Remove the key-value pair equivalent to `key` and return
481527
/// its value.
482528
///

src/map/tests.rs

+61
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,64 @@ fn from_array() {
418418

419419
assert_eq!(map, expected)
420420
}
421+
422+
#[test]
423+
fn many_mut_empty() {
424+
let mut map: IndexMap<u32, u32> = IndexMap::default();
425+
assert!(map.get_many_mut([&0, &1, &2, &3]).is_none());
426+
}
427+
428+
#[test]
429+
fn many_mut_single_fail() {
430+
let mut map: IndexMap<u32, u32> = IndexMap::default();
431+
map.insert(1, 10);
432+
assert!(map.get_many_mut([&0]).is_none());
433+
}
434+
435+
#[test]
436+
fn many_mut_single_success() {
437+
let mut map: IndexMap<u32, u32> = IndexMap::default();
438+
map.insert(1, 10);
439+
assert_eq!(map.get_many_mut([&1]), Some([&mut 10]));
440+
}
441+
442+
#[test]
443+
fn many_mut_multi_success() {
444+
let mut map: IndexMap<u32, u32> = IndexMap::default();
445+
map.insert(1, 10);
446+
map.insert(1123, 100);
447+
map.insert(321, 20);
448+
map.insert(1337, 30);
449+
assert_eq!(map.get_many_mut([&1, &1123]), Some([&mut 10, &mut 100]));
450+
assert_eq!(map.get_many_mut([&1, &1337]), Some([&mut 10, &mut 30]));
451+
assert_eq!(
452+
map.get_many_mut([&1337, &321, &1, &1123]),
453+
Some([&mut 30, &mut 20, &mut 10, &mut 100])
454+
);
455+
}
456+
457+
#[test]
458+
fn many_mut_multi_fail_missing() {
459+
let mut map: IndexMap<u32, u32> = IndexMap::default();
460+
map.insert(1, 10);
461+
map.insert(1123, 100);
462+
map.insert(321, 20);
463+
map.insert(1337, 30);
464+
assert_eq!(map.get_many_mut([&121, &1123]), None);
465+
assert_eq!(map.get_many_mut([&1, &1337, &56]), None);
466+
assert_eq!(map.get_many_mut([&1337, &123, &321, &1, &1123]), None);
467+
}
468+
469+
#[test]
470+
fn many_mut_multi_fail_duplicate() {
471+
let mut map: IndexMap<u32, u32> = IndexMap::default();
472+
map.insert(1, 10);
473+
map.insert(1123, 100);
474+
map.insert(321, 20);
475+
map.insert(1337, 30);
476+
assert_eq!(map.get_many_mut([&1, &1]), None);
477+
assert_eq!(
478+
map.get_many_mut([&1337, &123, &321, &1337, &1, &1123]),
479+
None
480+
);
481+
}

0 commit comments

Comments
 (0)