Skip to content

Commit c73b81e

Browse files
committed
Add Ref to the sync module.
1 parent d72964e commit c73b81e

File tree

3 files changed

+198
-0
lines changed

3 files changed

+198
-0
lines changed

rust/kernel/file_operations.rs

+11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use alloc::sync::Arc;
1313
use crate::bindings;
1414
use crate::c_types;
1515
use crate::error::{Error, KernelResult};
16+
use crate::sync::{Ref, RefCounted};
1617
use crate::user_ptr::{UserSlicePtr, UserSlicePtrReader, UserSlicePtrWriter};
1718

1819
/// Wraps the kernel's `struct file`.
@@ -506,6 +507,16 @@ impl<T> PointerWrapper<T> for Box<T> {
506507
}
507508
}
508509

510+
impl<T: RefCounted> PointerWrapper<T> for Ref<T> {
511+
fn into_pointer(self) -> *const T {
512+
Ref::into_raw(self)
513+
}
514+
515+
unsafe fn from_pointer(ptr: *const T) -> Self {
516+
Ref::from_raw(ptr as _)
517+
}
518+
}
519+
509520
impl<T> PointerWrapper<T> for Arc<T> {
510521
fn into_pointer(self) -> *const T {
511522
Arc::into_raw(self)

rust/kernel/sync/arc.rs

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
3+
//! A reference-counted pointer.
4+
//!
5+
//! This module implements a way for users to create reference-counted objects and pointers to
6+
//! them. Such a pointer automatically increments and decrements the count, and drops the
7+
//! underlying object when it reaches zero. It is also safe to use concurrently from multiple
8+
//! threads.
9+
//!
10+
//! It is different from the standard library's [`Arc`] in two ways: it does not support weak
11+
//! references, which allows it to be smaller -- a single pointer-sized integer; it allows users to
12+
//! safely increment the reference count from a single reference to the underlying object.
13+
//!
14+
//! [`Arc`]: https://doc.rust-lang.org/std/sync/struct.Arc.html
15+
16+
use crate::KernelResult;
17+
use alloc::boxed::Box;
18+
use core::{
19+
mem::ManuallyDrop,
20+
ops::Deref,
21+
ptr::NonNull,
22+
sync::atomic::{fence, AtomicUsize, Ordering},
23+
};
24+
25+
/// A reference-counted pointer to an instance of `T`.
26+
///
27+
/// The reference count is incremented when new instances of [`Ref`] are created, and decremented
28+
/// when they are dropped. When the count reaches zero, the underlying `T` is also dropped.
29+
///
30+
/// # Invariants
31+
///
32+
/// The value stored in [`RefCounted::get_count`] corresponds to the number of instances of [`Ref`]
33+
/// that point to that instance of `T`.
34+
pub struct Ref<T: RefCounted + ?Sized> {
35+
ptr: NonNull<T>,
36+
}
37+
38+
// SAFETY: It is safe to send `Ref<T>` to another thread when the underlying `T` is `Sync` because
39+
// it effectively means sharing `&T` (which is safe because `T` is `Sync`); additionally, it needs
40+
// `T` to be `Send` because any thread that has a `Ref<T>` may ultimately access `T` directly, for
41+
// example, when the reference count reaches zero and `T` is dropped.
42+
unsafe impl<T: RefCounted + ?Sized + Sync + Send> Send for Ref<T> {}
43+
44+
// SAFETY: It is safe to send `&Ref<T>` to another thread when the underlying `T` is `Sync` for
45+
// the same reason as above. `T` needs to be `Send` as well because a thread can clone a `&Ref<T>`
46+
// into a `Ref<T>`, which may lead to `T` being accessed by the same reasoning as above.
47+
unsafe impl<T: RefCounted + ?Sized + Sync + Send> Sync for Ref<T> {}
48+
49+
impl<T: RefCounted> Ref<T> {
50+
/// Constructs a new reference counted instance of `T`.
51+
pub fn try_new(contents: T) -> KernelResult<Self> {
52+
let boxed = Box::try_new(contents)?;
53+
boxed.get_count().count.store(1, Ordering::Relaxed);
54+
let ptr = NonNull::from(boxed.deref());
55+
Box::into_raw(boxed);
56+
Ok(Ref { ptr })
57+
}
58+
}
59+
60+
impl<T: RefCounted + ?Sized> Ref<T> {
61+
/// Creates a new reference-counted pointer to the given instance of `T`.
62+
///
63+
/// It works by incrementing the current reference count as part of constructing the new
64+
/// pointer.
65+
pub fn new_from(obj: &T) -> Self {
66+
let ref_count = obj.get_count();
67+
let cur = ref_count.count.fetch_add(1, Ordering::Relaxed);
68+
if cur == usize::MAX {
69+
panic!("Reference count overflowed");
70+
}
71+
Self {
72+
ptr: NonNull::from(obj),
73+
}
74+
}
75+
76+
/// Returns a mutable reference to `T` iff the reference count is one. Otherwise returns
77+
/// [`None`].
78+
pub fn get_mut(&mut self) -> Option<&mut T> {
79+
// Synchronises with the decrement in `drop`.
80+
if self.get_count().count.load(Ordering::Acquire) != 1 {
81+
return None;
82+
}
83+
// SAFETY: Since there is only one reference, we know it isn't possible for another thread
84+
// to concurrently call this.
85+
Some(unsafe { self.ptr.as_mut() })
86+
}
87+
88+
/// Determines if two reference-counted pointers point to the same underlying instance of `T`.
89+
pub fn ptr_eq(a: &Self, b: &Self) -> bool {
90+
core::ptr::eq(a.ptr.as_ptr(), b.ptr.as_ptr())
91+
}
92+
93+
/// Deconstructs a [`Ref`] object into a raw pointer.
94+
///
95+
/// It can be reconstructed once via [`Ref::from_raw`].
96+
pub fn into_raw(obj: Self) -> *const T {
97+
let no_drop = ManuallyDrop::new(obj);
98+
no_drop.ptr.as_ptr()
99+
}
100+
101+
/// Recreates a [`Ref`] instance previously deconstructed via [`Ref::into_raw`].
102+
///
103+
/// # Safety
104+
///
105+
/// `ptr` must have been returned by a previous call to [`Ref::into_raw`]. Additionally, it
106+
/// can only be called once for each previous call to [``Ref::into_raw`].
107+
pub unsafe fn from_raw(ptr: *const T) -> Self {
108+
Ref {
109+
ptr: NonNull::new(ptr as _).unwrap(),
110+
}
111+
}
112+
}
113+
114+
impl<T: RefCounted + ?Sized> Deref for Ref<T> {
115+
type Target = T;
116+
117+
fn deref(&self) -> &Self::Target {
118+
// SAFETY: By the type invariant, there is necessarily a reference to the object, so it is
119+
// safe to dereference it.
120+
unsafe { self.ptr.as_ref() }
121+
}
122+
}
123+
124+
impl<T: RefCounted + ?Sized> Clone for Ref<T> {
125+
fn clone(&self) -> Self {
126+
Self::new_from(self)
127+
}
128+
}
129+
130+
impl<T: RefCounted + ?Sized> Drop for Ref<T> {
131+
fn drop(&mut self) {
132+
{
133+
// SAFETY: By the type invariant, there is necessarily a reference to the object.
134+
let obj = unsafe { self.ptr.as_ref() };
135+
136+
// Synchronises with the acquire below or with the acquire in `get_mut`.
137+
if obj.get_count().count.fetch_sub(1, Ordering::Release) != 1 {
138+
return;
139+
}
140+
}
141+
142+
// Synchronises with the release when decrementing above. This ensures that modifications
143+
// from all previous threads/CPUs are visible to the underlying object's `drop`.
144+
fence(Ordering::Acquire);
145+
146+
// The count reached zero, we must free the memory.
147+
//
148+
// SAFETY: The pointer was initialised from the result of `Box::into_raw`.
149+
unsafe { Box::from_raw(self.ptr.as_ptr()) };
150+
}
151+
}
152+
153+
/// Trait for reference counted objects.
154+
///
155+
/// # Safety
156+
///
157+
/// Implementers of [`RefCounted`] must ensure that all of their constructors call
158+
/// [`Ref::try_new`].
159+
pub unsafe trait RefCounted {
160+
/// Returns a pointer to the object field holds the reference count.
161+
fn get_count(&self) -> &RefCount;
162+
}
163+
164+
/// Holds the reference count of an object.
165+
///
166+
/// It is meant to be embedded in objects to be reference-counted, with [`RefCounted::get_count`]
167+
/// returning a reference to it.
168+
pub struct RefCount {
169+
count: AtomicUsize,
170+
}
171+
172+
impl RefCount {
173+
/// Constructs a new instance of [`RefCount`].
174+
pub fn new() -> Self {
175+
Self {
176+
count: AtomicUsize::new(1),
177+
}
178+
}
179+
}
180+
181+
impl Default for RefCount {
182+
fn default() -> Self {
183+
Self::new()
184+
}
185+
}

rust/kernel/sync/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
use crate::{bindings, CStr};
2121
use core::pin::Pin;
2222

23+
mod arc;
2324
mod condvar;
2425
mod guard;
2526
mod locked_by;
2627
mod mutex;
2728
mod spinlock;
2829

30+
pub use arc::{Ref, RefCount, RefCounted};
2931
pub use condvar::CondVar;
3032
pub use guard::{Guard, Lock};
3133
pub use locked_by::LockedBy;

0 commit comments

Comments
 (0)