Skip to content

Commit 334cc34

Browse files
wedsonafmatthewtgilbride
authored andcommitted
rust: rbtree: add mutable iterator
Add mutable Iterator implementation for `RBTree`, allowing iteration over (key, value) pairs in key order. Only values are mutable, as mutating keys implies modifying a node's position in the tree. Mutable iteration is used by the binder driver during shutdown to clean up the tree maintained by the "range allocator" [1]. Link: https://lore.kernel.org/rust-for-linux/[email protected]/ [1] Signed-off-by: Wedson Almeida Filho <[email protected]> Signed-off-by: Matt Gilbride <[email protected]> Reviewed-by: Alice Ryhl <[email protected]> Tested-by: Alice Ryhl <[email protected]>
1 parent e09ec5b commit 334cc34

File tree

1 file changed

+86
-12
lines changed

1 file changed

+86
-12
lines changed

rust/kernel/rbtree.rs

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,26 @@ impl<K, V> RBTree<K, V> {
197197
// INVARIANT: `bindings::rb_first` returns a valid pointer to a tree node given a valid pointer to a tree root.
198198
Iter {
199199
_tree: PhantomData,
200-
// SAFETY: `self.root` is a valid pointer to the tree root.
201-
next: unsafe { bindings::rb_first(&self.root) },
200+
iter_raw: IterRaw {
201+
// SAFETY: by the invariants, all pointers are valid.
202+
next: unsafe { bindings::rb_first(&self.root) },
203+
_phantom: PhantomData,
204+
},
205+
}
206+
}
207+
208+
/// Returns a mutable iterator over the tree nodes, sorted by key.
209+
pub fn iter_mut(&mut self) -> IterMut<'_, K, V> {
210+
IterMut {
211+
_tree: PhantomData,
212+
// INVARIANT:
213+
// - `self.root` is a valid pointer to a tree root.
214+
// - `bindings::rb_first` produces a valid pointer to a node given `root` is valid.
215+
iter_raw: IterRaw {
216+
// SAFETY: by the invariants, all pointers are valid.
217+
next: unsafe { bindings::rb_first(&self.root) },
218+
_phantom: PhantomData,
219+
},
202220
}
203221
}
204222

@@ -211,6 +229,11 @@ impl<K, V> RBTree<K, V> {
211229
pub fn values(&self) -> impl Iterator<Item = &'_ V> {
212230
self.iter().map(|(_, v)| v)
213231
}
232+
233+
/// Returns a mutable iterator over the values of the nodes in the tree, sorted by key.
234+
pub fn values_mut(&mut self) -> impl Iterator<Item = &'_ mut V> {
235+
self.iter_mut().map(|(_, v)| v)
236+
}
214237
}
215238

216239
impl<K, V> RBTree<K, V>
@@ -414,13 +437,9 @@ impl<'a, K, V> IntoIterator for &'a RBTree<K, V> {
414437
/// An iterator over the nodes of a [`RBTree`].
415438
///
416439
/// Instances are created by calling [`RBTree::iter`].
417-
///
418-
/// # Invariants
419-
/// - `self.next` is a valid pointer.
420-
/// - `self.next` points to a node stored inside of a valid `RBTree`.
421440
pub struct Iter<'a, K, V> {
422441
_tree: PhantomData<&'a RBTree<K, V>>,
423-
next: *mut bindings::rb_node,
442+
iter_raw: IterRaw<K, V>,
424443
}
425444

426445
// SAFETY: The [`Iter`] gives out immutable references to K and V, so it has the same
@@ -434,21 +453,76 @@ unsafe impl<'a, K: Sync, V: Sync> Sync for Iter<'a, K, V> {}
434453
impl<'a, K, V> Iterator for Iter<'a, K, V> {
435454
type Item = (&'a K, &'a V);
436455

456+
fn next(&mut self) -> Option<Self::Item> {
457+
self.iter_raw.next().map(|(k, v)|
458+
// SAFETY: Due to `self._tree`, `k` and `v` are valid for the lifetime of `'a`.
459+
unsafe { (&*k, &*v) })
460+
}
461+
}
462+
463+
impl<'a, K, V> IntoIterator for &'a mut RBTree<K, V> {
464+
type Item = (&'a K, &'a mut V);
465+
type IntoIter = IterMut<'a, K, V>;
466+
467+
fn into_iter(self) -> Self::IntoIter {
468+
self.iter_mut()
469+
}
470+
}
471+
472+
/// A mutable iterator over the nodes of a [`RBTree`].
473+
///
474+
/// Instances are created by calling [`RBTree::iter_mut`].
475+
pub struct IterMut<'a, K, V> {
476+
_tree: PhantomData<&'a mut RBTree<K, V>>,
477+
iter_raw: IterRaw<K, V>,
478+
}
479+
480+
// SAFETY: The [`RBTreeIterator`] gives out mutable references to K and V, so it has the same
481+
// thread safety requirements as mutable references.
482+
unsafe impl<'a, K: Send, V: Send> Send for IterMut<'a, K, V> {}
483+
484+
// SAFETY: The [`RBTreeIterator`] gives out mutable references to K and V, so it has the same
485+
// thread safety requirements as mutable references.
486+
unsafe impl<'a, K: Sync, V: Sync> Sync for IterMut<'a, K, V> {}
487+
488+
impl<'a, K, V> Iterator for IterMut<'a, K, V> {
489+
type Item = (&'a K, &'a mut V);
490+
491+
fn next(&mut self) -> Option<Self::Item> {
492+
self.iter_raw.next().map(|(k, v)|
493+
// SAFETY: Due to `&mut self`, we have exclusive access to `k` and `v`, for the lifetime of `'a`.
494+
unsafe { (&*k, &mut *v) })
495+
}
496+
}
497+
498+
/// A raw iterator over the nodes of a [`RBTree`].
499+
///
500+
/// # Invariants
501+
/// - `self.next` is a valid pointer.
502+
/// - `self.next` points to a node stored inside of a valid `RBTree`.
503+
struct IterRaw<K, V> {
504+
next: *mut bindings::rb_node,
505+
_phantom: PhantomData<fn() -> (K, V)>,
506+
}
507+
508+
impl<K, V> Iterator for IterRaw<K, V> {
509+
type Item = (*mut K, *mut V);
510+
437511
fn next(&mut self) -> Option<Self::Item> {
438512
if self.next.is_null() {
439513
return None;
440514
}
441515

442-
// SAFETY: By the type invariant of `Iter`, `self.next` is a valid node in an `RBTree`,
516+
// SAFETY: By the type invariant of `IterRaw`, `self.next` is a valid node in an `RBTree`,
443517
// and by the type invariant of `RBTree`, all nodes point to the links field of `Node<K, V>` objects.
444-
let cur = unsafe { container_of!(self.next, Node<K, V>, links) };
518+
let cur: *mut Node<K, V> =
519+
unsafe { container_of!(self.next, Node<K, V>, links) }.cast_mut();
445520

446521
// SAFETY: `self.next` is a valid tree node by the type invariants.
447522
self.next = unsafe { bindings::rb_next(self.next) };
448523

449-
// SAFETY: By the same reasoning above, it is safe to dereference the node. Additionally,
450-
// it is ok to return a reference to members because the iterator must outlive it.
451-
Some(unsafe { (&(*cur).key, &(*cur).value) })
524+
// SAFETY: By the same reasoning above, it is safe to dereference the node.
525+
Some(unsafe { (addr_of_mut!((*cur).key), addr_of_mut!((*cur).value)) })
452526
}
453527
}
454528

0 commit comments

Comments
 (0)