@@ -130,9 +130,81 @@ impl Drop for SessionWrapper<'_> {
130
130
}
131
131
}
132
132
133
+ // mbedtls_ssl_write() has some weird semantics w.r.t non-blocking I/O:
134
+ //
135
+ // > When this function returns MBEDTLS_ERR_SSL_WANT_WRITE/READ, it must be
136
+ // > called later **with the same arguments**, until it returns a value greater
137
+ // > than or equal to 0. When the function returns MBEDTLS_ERR_SSL_WANT_WRITE
138
+ // > there may be some partial data in the output buffer, however this is not
139
+ // > yet sent.
140
+ //
141
+ // WriteTracker is used to ensure we pass the same data in that scenario.
142
+ //
143
+ // Reference:
144
+ // https://tls.mbed.org/api/ssl_8h.html#a5bbda87d484de82df730758b475f32e5
145
+ struct WriteTracker {
146
+ pending : Option < Box < DigestAndLen > > ,
147
+ }
148
+
149
+ struct DigestAndLen {
150
+ digest : [ u8 ; 20 ] , // SHA-1
151
+ len : usize ,
152
+ }
153
+
154
+ impl WriteTracker {
155
+ fn new ( ) -> Self {
156
+ WriteTracker {
157
+ pending : None ,
158
+ }
159
+ }
160
+
161
+ fn digest ( buf : & [ u8 ] ) -> [ u8 ; 20 ] {
162
+ use crate :: hash:: { Md , Type } ;
163
+ let mut out = [ 0u8 ; 20 ] ;
164
+ let res = Md :: hash ( Type :: Sha1 , buf, & mut out[ ..] ) ;
165
+ assert_eq ! ( res, Ok ( out. len( ) ) ) ;
166
+ out
167
+ }
168
+
169
+ fn adjust_buf < ' a > ( & self , buf : & ' a [ u8 ] ) -> io:: Result < & ' a [ u8 ] > {
170
+ match self . pending . as_ref ( ) {
171
+ None => Ok ( buf) ,
172
+ Some ( pending) => {
173
+ if pending. len <= buf. len ( ) {
174
+ let buf = & buf[ ..pending. len ] ;
175
+ if Self :: digest ( buf) == pending. digest {
176
+ return Ok ( buf) ;
177
+ }
178
+ }
179
+ Err ( io:: Error :: new (
180
+ io:: ErrorKind :: Other ,
181
+ "mbedtls expects the same data if the previous call to poll_write() returned Poll::Pending"
182
+ ) )
183
+ } ,
184
+ }
185
+ }
186
+
187
+ fn post_write ( & mut self , buf : & [ u8 ] , res : & Poll < io:: Result < usize > > ) {
188
+ match res {
189
+ & Poll :: Pending => {
190
+ if self . pending . is_none ( ) {
191
+ self . pending = Some ( Box :: new ( DigestAndLen {
192
+ digest : Self :: digest ( buf) ,
193
+ len : buf. len ( ) ,
194
+ } ) ) ;
195
+ }
196
+ } ,
197
+ _ => {
198
+ self . pending = None ;
199
+ }
200
+ }
201
+ }
202
+ }
203
+
133
204
pub struct AsyncSession < ' ctx > {
134
205
session : Option < SessionWrapper < ' ctx > > ,
135
206
ecx : ErasedContext ,
207
+ write_tracker : WriteTracker ,
136
208
}
137
209
138
210
unsafe impl < ' c > Send for AsyncSession < ' c > { }
@@ -232,7 +304,13 @@ impl AsyncWrite for AsyncSession<'_> {
232
304
cx : & mut TaskContext < ' _ > ,
233
305
buf : & [ u8 ] ,
234
306
) -> Poll < io:: Result < usize > > {
235
- self . with_context ( cx, |s| s. write ( buf) )
307
+ let buf = match self . write_tracker . adjust_buf ( buf) {
308
+ Ok ( buf) => buf,
309
+ Err ( e) => return Poll :: Ready ( Err ( e) ) ,
310
+ } ;
311
+ let res = self . with_context ( cx, |s| s. write ( buf) ) ;
312
+ self . write_tracker . post_write ( buf, & res) ;
313
+ res
236
314
}
237
315
238
316
fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut TaskContext < ' _ > ) -> Poll < io:: Result < ( ) > > {
@@ -269,6 +347,7 @@ impl<'ctx, S: AsyncRead + AsyncWrite + Unpin> Future for StartHandshake<'_, 'ctx
269
347
Ok ( session) => Ok ( HandshakeState :: Ready ( AsyncSession {
270
348
session : Some ( session. into ( ) ) ,
271
349
ecx,
350
+ write_tracker : WriteTracker :: new ( ) ,
272
351
} ) ) ,
273
352
Err ( HandshakeError :: WouldBlock ( mid, _err) ) => Ok ( HandshakeState :: InProgress (
274
353
MidHandshakeFuture ( Some ( MidHandshakeFutureInner { mid, ecx } ) ) ,
@@ -299,6 +378,7 @@ impl<'c> Future for MidHandshakeFuture<'c> {
299
378
Ok ( AsyncSession {
300
379
session : Some ( session. into ( ) ) ,
301
380
ecx : inner. ecx ,
381
+ write_tracker : WriteTracker :: new ( ) ,
302
382
} ) . into ( )
303
383
}
304
384
Err ( HandshakeError :: WouldBlock ( mid, _err) ) => {
0 commit comments