diff --git a/src/lib.rs b/src/lib.rs index 71be0b3..3bf005a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -266,3 +266,33 @@ impl core::fmt::Display for TryReserveError { #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl std::error::Error for TryReserveError {} + +// NOTE: This is copied from the slice module in the std lib. +/// The error type returned by [`get_disjoint_indices_mut`][`RingMap::get_disjoint_indices_mut`]. +/// +/// It indicates one of two possible errors: +/// - An index is out-of-bounds. +/// - The same index appeared multiple times in the array. +// (or different but overlapping indices when ranges are provided) +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum GetDisjointMutError { + /// An index provided was out-of-bounds for the slice. + IndexOutOfBounds, + /// Two indices provided were overlapping. + OverlappingIndices, +} + +impl core::fmt::Display for GetDisjointMutError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let msg = match self { + GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds", + GetDisjointMutError::OverlappingIndices => "there were overlapping indices", + }; + + core::fmt::Display::fmt(msg, f) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl std::error::Error for GetDisjointMutError {} diff --git a/src/map.rs b/src/map.rs index 101f2ae..651efb9 100644 --- a/src/map.rs +++ b/src/map.rs @@ -41,7 +41,7 @@ use std::collections::hash_map::RandomState; use self::core::RingMapCore; use crate::util::third; -use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError}; +use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError}; /// A hash table where the iteration order of the key-value pairs is independent /// of the hash values of the keys. @@ -825,6 +825,33 @@ where } } + /// Return the values for `N` keys. If any key is duplicated, this function will panic. + /// + /// # Examples + /// + /// ``` + /// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]); + /// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]); + /// ``` + pub fn get_disjoint_mut(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N] + where + Q: ?Sized + Hash + Equivalent, + { + let indices = keys.map(|key| self.get_index_of(key)); + let (head, tail) = self.as_mut_slices(); + match Slice::get_disjoint_opt_mut(head, tail, indices) { + Err(GetDisjointMutError::IndexOutOfBounds) => { + unreachable!( + "Internal error: indices should never be OOB as we got them from get_index_of" + ); + } + Err(GetDisjointMutError::OverlappingIndices) => { + panic!("duplicate keys found"); + } + Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)), + } + } + /// Remove the key-value pair equivalent to `key` and return its value. /// /// Like [`VecDeque::remove`], the pair is removed by shifting all of the @@ -1286,6 +1313,26 @@ impl RingMap { Some(IndexedEntry::new(&mut self.core, index)) } + /// Get an array of `N` key-value pairs by `N` indices + /// + /// Valid indices are *0 <= index < self.len()* and each index needs to be unique. + /// + /// # Examples + /// + /// ``` + /// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]); + /// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')])); + /// ``` + pub fn get_disjoint_indices_mut( + &mut self, + indices: [usize; N], + ) -> Result<[(&K, &mut V); N], GetDisjointMutError> { + let indices = indices.map(Some); + let (head, tail) = self.as_mut_slices(); + let key_values = Slice::get_disjoint_opt_mut(head, tail, indices)?; + Ok(key_values.map(Option::unwrap)) + } + /// Get the first key-value pair /// /// Computes in **O(1)** time. diff --git a/src/map/slice.rs b/src/map/slice.rs index 488c559..1654ce2 100644 --- a/src/map/slice.rs +++ b/src/map/slice.rs @@ -1,5 +1,6 @@ use super::{Bucket, IntoIter, IntoKeys, IntoValues, Iter, IterMut, Keys, Values, ValuesMut}; use crate::util::{slice_eq, try_simplify_range}; +use crate::GetDisjointMutError; use alloc::boxed::Box; use alloc::collections::VecDeque; @@ -264,6 +265,61 @@ impl Slice { self.entries .partition_point(move |a| pred(&a.key, &a.value)) } + + /// Get an array of `N` key-value pairs by `N` indices + /// + /// Valid indices are *0 <= index < self.len()* and each index needs to be unique. + pub fn get_disjoint_mut( + &mut self, + indices: [usize; N], + ) -> Result<[(&K, &mut V); N], GetDisjointMutError> { + let indices = indices.map(Some); + let empty_tail = Self::new_mut(); + let key_values = Self::get_disjoint_opt_mut(self, empty_tail, indices)?; + Ok(key_values.map(Option::unwrap)) + } + + #[allow(unsafe_code)] + pub(crate) fn get_disjoint_opt_mut<'a, const N: usize>( + head: &mut Self, + tail: &mut Self, + indices: [Option; N], + ) -> Result<[Option<(&'a K, &'a mut V)>; N], GetDisjointMutError> { + let mid = head.len(); + let len = mid + tail.len(); + + // SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data. + for i in 0..N { + if let Some(idx) = indices[i] { + if idx >= len { + return Err(GetDisjointMutError::IndexOutOfBounds); + } else if indices[..i].contains(&Some(idx)) { + return Err(GetDisjointMutError::OverlappingIndices); + } + } + } + + let head_ptr = head.entries.as_mut_ptr(); + let tail_ptr = tail.entries.as_mut_ptr(); + let out = indices.map(|idx_opt| { + match idx_opt { + Some(idx) => { + // SAFETY: The base pointers are valid as they come from slices and the reference is always + // in-bounds & unique as we've already checked the indices above. + unsafe { + let ptr = match idx.checked_sub(mid) { + None => head_ptr.add(idx), + Some(tidx) => tail_ptr.add(tidx), + }; + Some((*ptr).ref_mut()) + } + } + None => None, + } + }); + + Ok(out) + } } impl<'a, K, V> IntoIterator for &'a Slice { diff --git a/src/map/tests.rs b/src/map/tests.rs index b22f995..352bedc 100644 --- a/src/map/tests.rs +++ b/src/map/tests.rs @@ -923,3 +923,228 @@ move_index_oob!(test_move_index_out_of_bounds_0_10, 0, 10); move_index_oob!(test_move_index_out_of_bounds_0_max, 0, usize::MAX); move_index_oob!(test_move_index_out_of_bounds_10_0, 10, 0); move_index_oob!(test_move_index_out_of_bounds_max_0, usize::MAX, 0); + +#[test] +fn disjoint_mut_empty_map() { + let mut map: RingMap = RingMap::default(); + assert_eq!( + map.get_disjoint_mut([&0, &1, &2, &3]), + [None, None, None, None] + ); +} + +#[test] +fn disjoint_mut_empty_param() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 10); + assert_eq!(map.get_disjoint_mut([] as [&u32; 0]), []); +} + +#[test] +fn disjoint_mut_single_fail() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 10); + assert_eq!(map.get_disjoint_mut([&0]), [None]); +} + +#[test] +fn disjoint_mut_single_success() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 10); + assert_eq!(map.get_disjoint_mut([&1]), [Some(&mut 10)]); +} + +#[test] +fn disjoint_mut_multi_success() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 100); + map.insert(2, 200); + map.insert(3, 300); + map.insert(4, 400); + assert_eq!( + map.get_disjoint_mut([&1, &2]), + [Some(&mut 100), Some(&mut 200)] + ); + assert_eq!( + map.get_disjoint_mut([&1, &3]), + [Some(&mut 100), Some(&mut 300)] + ); + assert_eq!( + map.get_disjoint_mut([&3, &1, &4, &2]), + [ + Some(&mut 300), + Some(&mut 100), + Some(&mut 400), + Some(&mut 200) + ] + ); +} + +#[test] +fn disjoint_mut_multi_success_double_ended() { + let mut map: RingMap = RingMap::default(); + map.push_back(2, 200); + map.push_back(3, 300); + map.push_back(4, 400); + map.push_front(1, 100); + { + let (head, tail) = map.as_mut_slices(); + assert!(!head.is_empty() && !tail.is_empty()); + } + assert_eq!( + map.get_disjoint_mut([&1, &2]), + [Some(&mut 100), Some(&mut 200)] + ); + assert_eq!( + map.get_disjoint_mut([&1, &3]), + [Some(&mut 100), Some(&mut 300)] + ); + assert_eq!( + map.get_disjoint_mut([&3, &1, &4, &2]), + [ + Some(&mut 300), + Some(&mut 100), + Some(&mut 400), + Some(&mut 200) + ] + ); +} + +#[test] +fn disjoint_mut_multi_success_unsized_key() { + let mut map: RingMap<&'static str, u32> = RingMap::default(); + map.insert("1", 100); + map.insert("2", 200); + map.insert("3", 300); + map.insert("4", 400); + + assert_eq!( + map.get_disjoint_mut(["1", "2"]), + [Some(&mut 100), Some(&mut 200)] + ); + assert_eq!( + map.get_disjoint_mut(["1", "3"]), + [Some(&mut 100), Some(&mut 300)] + ); + assert_eq!( + map.get_disjoint_mut(["3", "1", "4", "2"]), + [ + Some(&mut 300), + Some(&mut 100), + Some(&mut 400), + Some(&mut 200) + ] + ); +} + +#[test] +fn disjoint_mut_multi_success_borrow_key() { + let mut map: RingMap = RingMap::default(); + map.insert("1".into(), 100); + map.insert("2".into(), 200); + map.insert("3".into(), 300); + map.insert("4".into(), 400); + + assert_eq!( + map.get_disjoint_mut(["1", "2"]), + [Some(&mut 100), Some(&mut 200)] + ); + assert_eq!( + map.get_disjoint_mut(["1", "3"]), + [Some(&mut 100), Some(&mut 300)] + ); + assert_eq!( + map.get_disjoint_mut(["3", "1", "4", "2"]), + [ + Some(&mut 300), + Some(&mut 100), + Some(&mut 400), + Some(&mut 200) + ] + ); +} + +#[test] +fn disjoint_mut_multi_fail_missing() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 100); + map.insert(2, 200); + map.insert(3, 300); + map.insert(4, 400); + + assert_eq!(map.get_disjoint_mut([&1, &5]), [Some(&mut 100), None]); + assert_eq!(map.get_disjoint_mut([&5, &6]), [None, None]); + assert_eq!( + map.get_disjoint_mut([&1, &5, &4]), + [Some(&mut 100), None, Some(&mut 400)] + ); +} + +#[test] +#[should_panic] +fn disjoint_mut_multi_fail_duplicate_panic() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 100); + map.get_disjoint_mut([&1, &2, &1]); +} + +#[test] +fn disjoint_indices_mut_fail_oob() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!( + map.get_disjoint_indices_mut([1, 3]), + Err(crate::GetDisjointMutError::IndexOutOfBounds) + ); +} + +#[test] +fn disjoint_indices_mut_empty() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!(map.get_disjoint_indices_mut([]), Ok([])); +} + +#[test] +fn disjoint_indices_mut_success() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!(map.get_disjoint_indices_mut([0]), Ok([(&1, &mut 10)])); + + assert_eq!(map.get_disjoint_indices_mut([1]), Ok([(&321, &mut 20)])); + assert_eq!( + map.get_disjoint_indices_mut([0, 1]), + Ok([(&1, &mut 10), (&321, &mut 20)]) + ); +} + +#[test] +fn disjoint_indices_mut_success_double_ended() { + let mut map: RingMap = RingMap::default(); + map.push_back(1, 10); + map.push_back(321, 20); + map.push_front(42, 0o42); + { + let (head, tail) = map.as_mut_slices(); + assert!(!head.is_empty() && !tail.is_empty()); + } + assert_eq!(map.get_disjoint_indices_mut([1]), Ok([(&1, &mut 10)])); + assert_eq!( + map.get_disjoint_indices_mut([0, 2]), + Ok([(&42, &mut 0o42), (&321, &mut 20)]) + ); +} + +#[test] +fn disjoint_indices_mut_fail_duplicate() { + let mut map: RingMap = RingMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!( + map.get_disjoint_indices_mut([1, 0, 1]), + Err(crate::GetDisjointMutError::OverlappingIndices) + ); +}