Skip to content

Commit 8f55ee3

Browse files
committed
Add Unalign::update method
Adds the `Unalign::update` method, which allows updating an `Unalign` in-place via a callback. This works by temporarily moving the `Unalign` into the local stack frame in order to call the callback. Closes #206
1 parent b5b30d0 commit 8f55ee3

File tree

1 file changed

+82
-2
lines changed

1 file changed

+82
-2
lines changed

src/lib.rs

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,10 +1163,13 @@ mod simd {
11631163
// NOTE: This type is sound to use with types that need to be dropped. The
11641164
// reason is that the compiler-generated drop code automatically moves all
11651165
// values to aligned memory slots before dropping them in-place. This is not
1166-
// well-documented, but it's hinted at in places like [1] and [2].
1166+
// well-documented, but it's hinted at in places like [1] and [2]. However, this
1167+
// also means that `T` must be `Sized`; unless something changes, we can never
1168+
// support unsized `T`. [3]
11671169
//
11681170
// [1] https://github.com/rust-lang/rust/issues/54148#issuecomment-420529646
11691171
// [2] https://github.com/google/zerocopy/pull/126#discussion_r1018512323
1172+
// [3] https://github.com/google/zerocopy/issues/209
11701173
#[allow(missing_debug_implementations)]
11711174
#[derive(FromZeroes, FromBytes, Unaligned, Default, Copy)]
11721175
#[repr(C, packed)]
@@ -1329,6 +1332,61 @@ impl<T> Unalign<T> {
13291332
pub fn set(&mut self, t: T) {
13301333
*self = Unalign::new(t);
13311334
}
1335+
1336+
/// Updates the inner `T` by calling a function on it.
1337+
///
1338+
/// For large types, this method may be expensive, as it requires copying `2
1339+
/// * size_of::<T>()` bytes. \[1\]
1340+
///
1341+
/// \[1\] Since the inner `T` may not be aligned, it would not be sound to
1342+
/// invoke `f` on it directly. Instead, `update` moves it into a
1343+
/// properly-aligned location in the local stack frame, calls `f` on it, and
1344+
/// then moves it back to its original location in `self.
1345+
pub fn update<O, F: FnOnce(&mut T) -> O>(&mut self, f: F) -> O {
1346+
// On drop, this moves `copy` out of itself and uses `ptr::write` to
1347+
// overwrite `slf`.
1348+
struct WriteBackOnDrop<T> {
1349+
copy: ManuallyDrop<T>,
1350+
slf: *mut Unalign<T>,
1351+
}
1352+
1353+
impl<T> Drop for WriteBackOnDrop<T> {
1354+
fn drop(&mut self) {
1355+
// SAFETY: See inline comments.
1356+
unsafe {
1357+
// SAFETY: We never use `copy` again as required by
1358+
// `ManuallyDrop::take`.
1359+
let copy = ManuallyDrop::take(&mut self.copy);
1360+
// SAFETY: `slf` is the raw pointer value of `self`. We know
1361+
// it is valid for writes and properly aligned because
1362+
// `self` is a mutable reference, which guarantees both of
1363+
// these properties.
1364+
ptr::write(self.slf, Unalign::new(copy));
1365+
}
1366+
}
1367+
}
1368+
1369+
// SAFETY: We know that `self` is valid for reads, properly aligned, and
1370+
// points to an initialized `Unalign<T>` because it is a mutable
1371+
// reference, which guarantees all of these properties.
1372+
//
1373+
// Since `T: !Copy`, it would be unsound in the general case to allow
1374+
// both the original `Unalign<T>` and the copy to be used by safe code.
1375+
// We guarantee that the copy is used to overwrite the original in the
1376+
// `Drop::drop` impl of `WriteBackOnDrop`. So long as this `drop` is
1377+
// called before any other safe code executes, soundness is upheld.
1378+
// While this method can terminate in two ways (by returning normally or
1379+
// by unwinding due to a panic in `f`), in both cases, `write_back` is
1380+
// dropped - and its `drop` called - before any other safe code can
1381+
// execute.
1382+
let copy = unsafe { ptr::read(self) }.into_inner();
1383+
let mut write_back = WriteBackOnDrop { copy: ManuallyDrop::new(copy), slf: self };
1384+
1385+
let ret = f(&mut write_back.copy);
1386+
1387+
drop(write_back);
1388+
ret
1389+
}
13321390
}
13331391

13341392
impl<T: Copy> Unalign<T> {
@@ -2949,7 +3007,7 @@ pub use alloc_support::*;
29493007
mod tests {
29503008
#![allow(clippy::unreadable_literal)]
29513009

2952-
use core::ops::Deref;
3010+
use core::{ops::Deref, panic::AssertUnwindSafe};
29533011

29543012
use static_assertions::assert_impl_all;
29553013

@@ -3129,6 +3187,28 @@ mod tests {
31293187
};
31303188
}
31313189

3190+
#[test]
3191+
fn test_unalign_update() {
3192+
let mut u = Unalign::new(AU64(123));
3193+
u.update(|a| a.0 += 1);
3194+
assert_eq!(u.get(), AU64(124));
3195+
3196+
// Test that, even if the callback panics, the original is still
3197+
// correctly overwritten. Use a `Box` so that Miri is more likely to
3198+
// catch any unsoundness (which would likely result in two `Box`es for
3199+
// the same heap object, which is the sort of thing that Miri would
3200+
// probably catch).
3201+
let mut u = Unalign::new(Box::new(AU64(123)));
3202+
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
3203+
u.update(|a| {
3204+
a.0 += 1;
3205+
panic!();
3206+
})
3207+
}));
3208+
assert!(res.is_err());
3209+
assert_eq!(u.into_inner(), Box::new(AU64(124)));
3210+
}
3211+
31323212
#[test]
31333213
fn test_read_write() {
31343214
const VAL: u64 = 0x12345678;

0 commit comments

Comments
 (0)