@@ -1036,8 +1036,9 @@ mod simd {
1036
1036
/// guarantees.
1037
1037
///
1038
1038
/// Since `Unalign` has no alignment requirement, the inner `T` may not be
1039
- /// properly aligned in memory. There are four ways to access the inner `T`:
1039
+ /// properly aligned in memory. There are five ways to access the inner `T`:
1040
1040
/// - by value, using [`get`] or [`into_inner`]
1041
+ /// - by reference inside of a callback, using [`update`]
1041
1042
/// - fallibly by reference, using [`try_deref`] or [`try_deref_mut`]; these can
1042
1043
/// fail if the `Unalign` does not satisfy `T`'s alignment requirement at
1043
1044
/// runtime
@@ -1050,17 +1051,21 @@ mod simd {
1050
1051
/// [or ABI]: https://github.com/google/zerocopy/issues/164
1051
1052
/// [`get`]: Unalign::get
1052
1053
/// [`into_inner`]: Unalign::into_inner
1054
+ /// [`update`]: Unalign::update
1053
1055
/// [`try_deref`]: Unalign::try_deref
1054
1056
/// [`try_deref_mut`]: Unalign::try_deref_mut
1055
1057
/// [`deref_unchecked`]: Unalign::deref_unchecked
1056
1058
/// [`deref_mut_unchecked`]: Unalign::deref_mut_unchecked
1057
1059
// NOTE: This type is sound to use with types that need to be dropped. The
1058
1060
// reason is that the compiler-generated drop code automatically moves all
1059
1061
// values to aligned memory slots before dropping them in-place. This is not
1060
- // well-documented, but it's hinted at in places like [1] and [2].
1062
+ // well-documented, but it's hinted at in places like [1] and [2]. However, this
1063
+ // also means that `T` must be `Sized`; unless something changes, we can never
1064
+ // support unsized `T`. [3]
1061
1065
//
1062
1066
// [1] https://github.com/rust-lang/rust/issues/54148#issuecomment-420529646
1063
1067
// [2] https://github.com/google/zerocopy/pull/126#discussion_r1018512323
1068
+ // [3] https://github.com/google/zerocopy/issues/209
1064
1069
#[ allow( missing_debug_implementations) ]
1065
1070
#[ derive( FromBytes , Unaligned , Default , Copy ) ]
1066
1071
#[ repr( C , packed) ]
@@ -1223,6 +1228,61 @@ impl<T> Unalign<T> {
1223
1228
pub fn set ( & mut self , t : T ) {
1224
1229
* self = Unalign :: new ( t) ;
1225
1230
}
1231
+
1232
+ /// Updates the inner `T` by calling a function on it.
1233
+ ///
1234
+ /// For large types, this method may be expensive, as it requires copying
1235
+ /// `2 * size_of::<T>()` bytes. \[1\]
1236
+ ///
1237
+ /// \[1\] Since the inner `T` may not be aligned, it would not be sound to
1238
+ /// invoke `f` on it directly. Instead, `update` moves it into a
1239
+ /// properly-aligned location in the local stack frame, calls `f` on it, and
1240
+ /// then moves it back to its original location in `self`.
1241
+ pub fn update < O , F : FnOnce ( & mut T ) -> O > ( & mut self , f : F ) -> O {
1242
+ // On drop, this moves `copy` out of itself and uses `ptr::write` to
1243
+ // overwrite `slf`.
1244
+ struct WriteBackOnDrop < T > {
1245
+ copy : ManuallyDrop < T > ,
1246
+ slf : * mut Unalign < T > ,
1247
+ }
1248
+
1249
+ impl < T > Drop for WriteBackOnDrop < T > {
1250
+ fn drop ( & mut self ) {
1251
+ // SAFETY: See inline comments.
1252
+ unsafe {
1253
+ // SAFETY: We never use `copy` again as required by
1254
+ // `ManuallyDrop::take`.
1255
+ let copy = ManuallyDrop :: take ( & mut self . copy ) ;
1256
+ // SAFETY: `slf` is the raw pointer value of `self`. We know
1257
+ // it is valid for writes and properly aligned because
1258
+ // `self` is a mutable reference, which guarantees both of
1259
+ // these properties.
1260
+ ptr:: write ( self . slf , Unalign :: new ( copy) ) ;
1261
+ }
1262
+ }
1263
+ }
1264
+
1265
+ // SAFETY: We know that `self` is valid for reads, properly aligned, and
1266
+ // points to an initialized `Unalign<T>` because it is a mutable
1267
+ // reference, which guarantees all of these properties.
1268
+ //
1269
+ // Since `T: !Copy`, it would be unsound in the general case to allow
1270
+ // both the original `Unalign<T>` and the copy to be used by safe code.
1271
+ // We guarantee that the copy is used to overwrite the original in the
1272
+ // `Drop::drop` impl of `WriteBackOnDrop`. So long as this `drop` is
1273
+ // called before any other safe code executes, soundness is upheld.
1274
+ // While this method can terminate in two ways (by returning normally or
1275
+ // by unwinding due to a panic in `f`), in both cases, `write_back` is
1276
+ // dropped - and its `drop` called - before any other safe code can
1277
+ // execute.
1278
+ let copy = unsafe { ptr:: read ( self ) } . into_inner ( ) ;
1279
+ let mut write_back = WriteBackOnDrop { copy : ManuallyDrop :: new ( copy) , slf : self } ;
1280
+
1281
+ let ret = f ( & mut write_back. copy ) ;
1282
+
1283
+ drop ( write_back) ;
1284
+ ret
1285
+ }
1226
1286
}
1227
1287
1228
1288
impl < T : Copy > Unalign < T > {
@@ -2874,7 +2934,7 @@ pub use alloc_support::*;
2874
2934
mod tests {
2875
2935
#![ allow( clippy:: unreadable_literal) ]
2876
2936
2877
- use core:: ops:: Deref ;
2937
+ use core:: { ops:: Deref , panic :: AssertUnwindSafe } ;
2878
2938
2879
2939
use static_assertions:: assert_impl_all;
2880
2940
@@ -3025,6 +3085,28 @@ mod tests {
3025
3085
} ;
3026
3086
}
3027
3087
3088
+ #[ test]
3089
+ fn test_unalign_update ( ) {
3090
+ let mut u = Unalign :: new ( AU64 ( 123 ) ) ;
3091
+ u. update ( |a| a. 0 += 1 ) ;
3092
+ assert_eq ! ( u. get( ) , AU64 ( 124 ) ) ;
3093
+
3094
+ // Test that, even if the callback panics, the original is still
3095
+ // correctly overwritten. Use a `Box` so that Miri is more likely to
3096
+ // catch any unsoundness (which would likely result in two `Box`es for
3097
+ // the same heap object, which is the sort of thing that Miri would
3098
+ // probably catch).
3099
+ let mut u = Unalign :: new ( Box :: new ( AU64 ( 123 ) ) ) ;
3100
+ let res = std:: panic:: catch_unwind ( AssertUnwindSafe ( || {
3101
+ u. update ( |a| {
3102
+ a. 0 += 1 ;
3103
+ panic ! ( ) ;
3104
+ } )
3105
+ } ) ) ;
3106
+ assert ! ( res. is_err( ) ) ;
3107
+ assert_eq ! ( u. into_inner( ) , Box :: new( AU64 ( 124 ) ) ) ;
3108
+ }
3109
+
3028
3110
#[ test]
3029
3111
fn test_read_write ( ) {
3030
3112
const VAL : u64 = 0x12345678 ;
0 commit comments