Skip to content

Implement get_disjoint_mut #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,33 @@ impl core::fmt::Display for TryReserveError {
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl std::error::Error for TryReserveError {}

// NOTE: This is copied from the slice module in the std lib.
/// The error type returned by [`get_disjoint_indices_mut`][`RingMap::get_disjoint_indices_mut`].
///
/// It indicates one of two possible errors:
/// - An index is out-of-bounds.
/// - The same index appeared multiple times in the array.
// (or different but overlapping indices when ranges are provided)
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GetDisjointMutError {
/// An index provided was out-of-bounds for the slice.
IndexOutOfBounds,
/// Two indices provided were overlapping.
OverlappingIndices,
}

impl core::fmt::Display for GetDisjointMutError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let msg = match self {
GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds",
GetDisjointMutError::OverlappingIndices => "there were overlapping indices",
};

core::fmt::Display::fmt(msg, f)
}
}

#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl std::error::Error for GetDisjointMutError {}
49 changes: 48 additions & 1 deletion src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use std::collections::hash_map::RandomState;

use self::core::RingMapCore;
use crate::util::third;
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};
use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError};

/// A hash table where the iteration order of the key-value pairs is independent
/// of the hash values of the keys.
Expand Down Expand Up @@ -825,6 +825,33 @@ where
}
}

/// Return the values for `N` keys. If any key is duplicated, this function will panic.
///
/// # Examples
///
/// ```
/// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
/// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]);
/// ```
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N]
where
Q: ?Sized + Hash + Equivalent<K>,
{
let indices = keys.map(|key| self.get_index_of(key));
let (head, tail) = self.as_mut_slices();
match Slice::get_disjoint_opt_mut(head, tail, indices) {
Err(GetDisjointMutError::IndexOutOfBounds) => {
unreachable!(
"Internal error: indices should never be OOB as we got them from get_index_of"
);
}
Err(GetDisjointMutError::OverlappingIndices) => {
panic!("duplicate keys found");
}
Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)),
}
}

/// Remove the key-value pair equivalent to `key` and return its value.
///
/// Like [`VecDeque::remove`], the pair is removed by shifting all of the
Expand Down Expand Up @@ -1286,6 +1313,26 @@ impl<K, V, S> RingMap<K, V, S> {
Some(IndexedEntry::new(&mut self.core, index))
}

/// Get an array of `N` key-value pairs by `N` indices
///
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
///
/// # Examples
///
/// ```
/// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')]));
/// ```
pub fn get_disjoint_indices_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
let indices = indices.map(Some);
let (head, tail) = self.as_mut_slices();
let key_values = Slice::get_disjoint_opt_mut(head, tail, indices)?;
Ok(key_values.map(Option::unwrap))
}

/// Get the first key-value pair
///
/// Computes in **O(1)** time.
Expand Down
56 changes: 56 additions & 0 deletions src/map/slice.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{Bucket, IntoIter, IntoKeys, IntoValues, Iter, IterMut, Keys, Values, ValuesMut};
use crate::util::{slice_eq, try_simplify_range};
use crate::GetDisjointMutError;

use alloc::boxed::Box;
use alloc::collections::VecDeque;
Expand Down Expand Up @@ -264,6 +265,61 @@ impl<K, V> Slice<K, V> {
self.entries
.partition_point(move |a| pred(&a.key, &a.value))
}

/// Get an array of `N` key-value pairs by `N` indices
///
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
pub fn get_disjoint_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
let indices = indices.map(Some);
let empty_tail = Self::new_mut();
let key_values = Self::get_disjoint_opt_mut(self, empty_tail, indices)?;
Ok(key_values.map(Option::unwrap))
}

#[allow(unsafe_code)]
pub(crate) fn get_disjoint_opt_mut<'a, const N: usize>(
head: &mut Self,
tail: &mut Self,
indices: [Option<usize>; N],
) -> Result<[Option<(&'a K, &'a mut V)>; N], GetDisjointMutError> {
let mid = head.len();
let len = mid + tail.len();

// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
for i in 0..N {
if let Some(idx) = indices[i] {
if idx >= len {
return Err(GetDisjointMutError::IndexOutOfBounds);
} else if indices[..i].contains(&Some(idx)) {
return Err(GetDisjointMutError::OverlappingIndices);
}
}
}

let head_ptr = head.entries.as_mut_ptr();
let tail_ptr = tail.entries.as_mut_ptr();
let out = indices.map(|idx_opt| {
match idx_opt {
Some(idx) => {
// SAFETY: The base pointers are valid as they come from slices and the reference is always
// in-bounds & unique as we've already checked the indices above.
unsafe {
let ptr = match idx.checked_sub(mid) {
None => head_ptr.add(idx),
Some(tidx) => tail_ptr.add(tidx),
};
Some((*ptr).ref_mut())
}
}
None => None,
}
});

Ok(out)
}
}

impl<'a, K, V> IntoIterator for &'a Slice<K, V> {
Expand Down
Loading