Skip to content

Commit e50dd5c

Browse files
committed
Add Unalign::update method (#212)
Adds the `Unalign::update` method, which allows updating an `Unalign` in-place via a callback. This works by temporarily moving the `Unalign` into the local stack frame in order to call the callback. Closes #206
1 parent 82264e5 commit e50dd5c

File tree

1 file changed

+85
-3
lines changed

1 file changed

+85
-3
lines changed

src/lib.rs

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,8 +1036,9 @@ mod simd {
10361036
/// guarantees.
10371037
///
10381038
/// Since `Unalign` has no alignment requirement, the inner `T` may not be
1039-
/// properly aligned in memory. There are four ways to access the inner `T`:
1039+
/// properly aligned in memory. There are five ways to access the inner `T`:
10401040
/// - by value, using [`get`] or [`into_inner`]
1041+
/// - by reference inside of a callback, using [`update`]
10411042
/// - fallibly by reference, using [`try_deref`] or [`try_deref_mut`]; these can
10421043
/// fail if the `Unalign` does not satisfy `T`'s alignment requirement at
10431044
/// runtime
@@ -1050,17 +1051,21 @@ mod simd {
10501051
/// [or ABI]: https://github.com/google/zerocopy/issues/164
10511052
/// [`get`]: Unalign::get
10521053
/// [`into_inner`]: Unalign::into_inner
1054+
/// [`update`]: Unalign::update
10531055
/// [`try_deref`]: Unalign::try_deref
10541056
/// [`try_deref_mut`]: Unalign::try_deref_mut
10551057
/// [`deref_unchecked`]: Unalign::deref_unchecked
10561058
/// [`deref_mut_unchecked`]: Unalign::deref_mut_unchecked
10571059
// NOTE: This type is sound to use with types that need to be dropped. The
10581060
// reason is that the compiler-generated drop code automatically moves all
10591061
// values to aligned memory slots before dropping them in-place. This is not
1060-
// well-documented, but it's hinted at in places like [1] and [2].
1062+
// well-documented, but it's hinted at in places like [1] and [2]. However, this
1063+
// also means that `T` must be `Sized`; unless something changes, we can never
1064+
// support unsized `T`. [3]
10611065
//
10621066
// [1] https://github.com/rust-lang/rust/issues/54148#issuecomment-420529646
10631067
// [2] https://github.com/google/zerocopy/pull/126#discussion_r1018512323
1068+
// [3] https://github.com/google/zerocopy/issues/209
10641069
#[allow(missing_debug_implementations)]
10651070
#[derive(FromBytes, Unaligned, Default, Copy)]
10661071
#[repr(C, packed)]
@@ -1223,6 +1228,61 @@ impl<T> Unalign<T> {
12231228
pub fn set(&mut self, t: T) {
12241229
*self = Unalign::new(t);
12251230
}
1231+
1232+
/// Updates the inner `T` by calling a function on it.
1233+
///
1234+
/// For large types, this method may be expensive, as it requires copying
1235+
/// `2 * size_of::<T>()` bytes. \[1\]
1236+
///
1237+
/// \[1\] Since the inner `T` may not be aligned, it would not be sound to
1238+
/// invoke `f` on it directly. Instead, `update` moves it into a
1239+
/// properly-aligned location in the local stack frame, calls `f` on it, and
1240+
/// then moves it back to its original location in `self`.
1241+
pub fn update<O, F: FnOnce(&mut T) -> O>(&mut self, f: F) -> O {
1242+
// On drop, this moves `copy` out of itself and uses `ptr::write` to
1243+
// overwrite `slf`.
1244+
struct WriteBackOnDrop<T> {
1245+
copy: ManuallyDrop<T>,
1246+
slf: *mut Unalign<T>,
1247+
}
1248+
1249+
impl<T> Drop for WriteBackOnDrop<T> {
1250+
fn drop(&mut self) {
1251+
// SAFETY: See inline comments.
1252+
unsafe {
1253+
// SAFETY: We never use `copy` again as required by
1254+
// `ManuallyDrop::take`.
1255+
let copy = ManuallyDrop::take(&mut self.copy);
1256+
// SAFETY: `slf` is the raw pointer value of `self`. We know
1257+
// it is valid for writes and properly aligned because
1258+
// `self` is a mutable reference, which guarantees both of
1259+
// these properties.
1260+
ptr::write(self.slf, Unalign::new(copy));
1261+
}
1262+
}
1263+
}
1264+
1265+
// SAFETY: We know that `self` is valid for reads, properly aligned, and
1266+
// points to an initialized `Unalign<T>` because it is a mutable
1267+
// reference, which guarantees all of these properties.
1268+
//
1269+
// Since `T: !Copy`, it would be unsound in the general case to allow
1270+
// both the original `Unalign<T>` and the copy to be used by safe code.
1271+
// We guarantee that the copy is used to overwrite the original in the
1272+
// `Drop::drop` impl of `WriteBackOnDrop`. So long as this `drop` is
1273+
// called before any other safe code executes, soundness is upheld.
1274+
// While this method can terminate in two ways (by returning normally or
1275+
// by unwinding due to a panic in `f`), in both cases, `write_back` is
1276+
// dropped - and its `drop` called - before any other safe code can
1277+
// execute.
1278+
let copy = unsafe { ptr::read(self) }.into_inner();
1279+
let mut write_back = WriteBackOnDrop { copy: ManuallyDrop::new(copy), slf: self };
1280+
1281+
let ret = f(&mut write_back.copy);
1282+
1283+
drop(write_back);
1284+
ret
1285+
}
12261286
}
12271287

12281288
impl<T: Copy> Unalign<T> {
@@ -2874,7 +2934,7 @@ pub use alloc_support::*;
28742934
mod tests {
28752935
#![allow(clippy::unreadable_literal)]
28762936

2877-
use core::ops::Deref;
2937+
use core::{ops::Deref, panic::AssertUnwindSafe};
28782938

28792939
use static_assertions::assert_impl_all;
28802940

@@ -3025,6 +3085,28 @@ mod tests {
30253085
};
30263086
}
30273087

3088+
#[test]
3089+
fn test_unalign_update() {
3090+
let mut u = Unalign::new(AU64(123));
3091+
u.update(|a| a.0 += 1);
3092+
assert_eq!(u.get(), AU64(124));
3093+
3094+
// Test that, even if the callback panics, the original is still
3095+
// correctly overwritten. Use a `Box` so that Miri is more likely to
3096+
// catch any unsoundness (which would likely result in two `Box`es for
3097+
// the same heap object, which is the sort of thing that Miri would
3098+
// probably catch).
3099+
let mut u = Unalign::new(Box::new(AU64(123)));
3100+
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
3101+
u.update(|a| {
3102+
a.0 += 1;
3103+
panic!();
3104+
})
3105+
}));
3106+
assert!(res.is_err());
3107+
assert_eq!(u.into_inner(), Box::new(AU64(124)));
3108+
}
3109+
30283110
#[test]
30293111
fn test_read_write() {
30303112
const VAL: u64 = 0x12345678;

0 commit comments

Comments
 (0)