@@ -68,13 +68,16 @@ where
68
68
failure,
69
69
)
70
70
. map ( drop)
71
- . map_err ( NonNullPtr :: from_inner)
71
+ . map_err ( |value| {
72
+ // SAFETY: `value` cam from a `NonNullPtr::into_inner` call.
73
+ unsafe { NonNullPtr :: from_inner ( value) }
74
+ } )
72
75
}
73
76
74
77
#[ inline]
75
78
fn load ( & self , order : Ordering ) -> Option < NonNullPtr < N > > {
76
- InnerNonZero :: new ( self . inner . load ( order ) ) . map ( |inner| NonNullPtr {
77
- inner,
79
+ Some ( NonNullPtr {
80
+ inner : InnerNonZero :: new ( self . inner . load ( order ) ) ? ,
78
81
_marker : PhantomData ,
79
82
} )
80
83
}
@@ -115,33 +118,41 @@ where
115
118
}
116
119
117
120
#[ inline]
118
- pub fn from_static_mut_ref ( ref_ : & ' static mut N ) -> NonNullPtr < N > {
119
- let non_null = NonNull :: from ( ref_ ) ;
120
- Self :: from_non_null ( non_null )
121
+ pub fn from_static_mut_ref ( reference : & ' static mut N ) -> NonNullPtr < N > {
122
+ // SAFETY: `reference` is a static mutable reference, i.e. a valid pointer.
123
+ unsafe { Self :: new_unchecked ( initial_tag ( ) , NonNull :: from ( reference ) ) }
121
124
}
122
125
123
- fn from_non_null ( ptr : NonNull < N > ) -> Self {
124
- let address = ptr. as_ptr ( ) as Address ;
125
- let tag = initial_tag ( ) . get ( ) ;
126
-
127
- let value = ( Inner :: from ( tag) << Address :: BITS ) | Inner :: from ( address) ;
126
+ /// # Safety
127
+ ///
128
+ /// - `ptr` must be a valid pointer.
129
+ #[ inline]
130
+ unsafe fn new_unchecked ( tag : Tag , ptr : NonNull < N > ) -> Self {
131
+ let value =
132
+ ( Inner :: from ( tag. get ( ) ) << Address :: BITS ) | Inner :: from ( ptr. as_ptr ( ) as Address ) ;
128
133
129
134
Self {
135
+ // SAFETY: `value` is constructed from a `Tag` which is non-zero and half the
136
+ // size of the `InnerNonZero` type, and a `NonNull<N>` pointer.
130
137
inner : unsafe { InnerNonZero :: new_unchecked ( value) } ,
131
138
_marker : PhantomData ,
132
139
}
133
140
}
134
141
142
+ /// # Safety
143
+ ///
144
+ /// - `value` must come from a `Self::into_inner` call.
135
145
#[ inline]
136
- fn from_inner ( value : Inner ) -> Option < Self > {
137
- InnerNonZero :: new ( value ) . map ( |inner| Self {
138
- inner,
146
+ unsafe fn from_inner ( value : Inner ) -> Option < Self > {
147
+ Some ( Self {
148
+ inner : InnerNonZero :: new ( value ) ? ,
139
149
_marker : PhantomData ,
140
150
} )
141
151
}
142
152
143
153
#[ inline]
144
154
fn non_null ( & self ) -> NonNull < N > {
155
+ // SAFETY: `Self` can only be constructed using a `NonNull<N>`.
145
156
unsafe { NonNull :: new_unchecked ( self . as_ptr ( ) ) }
146
157
}
147
158
@@ -152,17 +163,15 @@ where
152
163
153
164
#[ inline]
154
165
fn tag ( & self ) -> Tag {
166
+ // SAFETY: `self.inner` was constructed from a non-zero `Tag`.
155
167
unsafe { Tag :: new_unchecked ( ( self . inner . get ( ) >> Address :: BITS ) as Address ) }
156
168
}
157
169
158
- fn increase_tag ( & mut self ) {
159
- let address = self . as_ptr ( ) as Address ;
160
-
161
- let new_tag = self . tag ( ) . checked_add ( 1 ) . unwrap_or_else ( initial_tag) . get ( ) ;
162
-
163
- let value = ( Inner :: from ( new_tag) << Address :: BITS ) | Inner :: from ( address) ;
170
+ fn increment_tag ( & mut self ) {
171
+ let new_tag = self . tag ( ) . checked_add ( 1 ) . unwrap_or_else ( initial_tag) ;
164
172
165
- self . inner = unsafe { InnerNonZero :: new_unchecked ( value) } ;
173
+ // SAFETY: `self.non_null()` is a valid pointer.
174
+ * self = unsafe { Self :: new_unchecked ( new_tag, self . non_null ( ) ) } ;
166
175
}
167
176
}
168
177
@@ -210,7 +219,40 @@ where
210
219
. compare_and_exchange_weak ( Some ( top) , next, Ordering :: Release , Ordering :: Relaxed )
211
220
. is_ok ( )
212
221
{
213
- top. increase_tag ( ) ;
222
+ // Prevent the ABA problem (https://en.wikipedia.org/wiki/Treiber_stack#Correctness).
223
+ //
224
+ // Without this, the following would be possible:
225
+ //
226
+ // | Thread 1 | Thread 2 | Stack |
227
+ // |-------------------------------|-------------------------|------------------------------|
228
+ // | push((1, 1)) | | (1, 1) |
229
+ // | push((1, 2)) | | (1, 2) -> (1, 1) |
230
+ // | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
231
+ // | | p = try_pop() // (1, 2) | (1, 1) |
232
+ // | | push((1, 3)) | (1, 3) -> (1, 1) |
233
+ // | | push(p) | (1, 2) -> (1, 3) -> (1, 1) |
234
+ // | try_pop()::cas(p, p.next) | | (1, 1) |
235
+ //
236
+ // As can be seen, the `cas` operation succeeds, wrongly removing pointer `3` from the stack.
237
+ //
238
+ // By incrementing the tag before returning the pointer, it cannot be pushed again with the,
239
+ // same tag, preventing the `try_pop()::cas(p, p.next)` operation from succeeding.
240
+ //
241
+ // With this fix, `try_pop()` in thread 2 returns `(2, 2)` and the comparison between
242
+ // `(1, 2)` and `(2, 2)` fails, restarting the loop and correctly removing the new top:
243
+ //
244
+ // | Thread 1 | Thread 2 | Stack |
245
+ // |-------------------------------|-------------------------|------------------------------|
246
+ // | push((1, 1)) | | (1, 1) |
247
+ // | push((1, 2)) | | (1, 2) -> (1, 1) |
248
+ // | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
249
+ // | | p = try_pop() // (2, 2) | (1, 1) |
250
+ // | | push((1, 3)) | (1, 3) -> (1, 1) |
251
+ // | | push(p) | (2, 2) -> (1, 3) -> (1, 1) |
252
+ // | try_pop()::cas(p, p.next) | | (2, 2) -> (1, 3) -> (1, 1) |
253
+ // | p = try_pop()::load // (2, 2) | | (2, 2) -> (1, 3) -> (1, 1) |
254
+ // | try_pop()::cas(p, p.next) | | (1, 3) -> (1, 1) |
255
+ top. increment_tag ( ) ;
214
256
215
257
return Some ( top) ;
216
258
}
0 commit comments