Skip to content

Commit 148b800

Browse files
committed
Write the same data and len after ssl_write() returns would block
1 parent b91d916 commit 148b800

File tree

1 file changed

+81
-1
lines changed

1 file changed

+81
-1
lines changed

mbedtls/src/ssl/async_session.rs

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,81 @@ impl Drop for SessionWrapper<'_> {
130130
}
131131
}
132132

133+
// mbedtls_ssl_write() has some weird semantics w.r.t non-blocking I/O:
134+
//
135+
// > When this function returns MBEDTLS_ERR_SSL_WANT_WRITE/READ, it must be
136+
// > called later **with the same arguments**, until it returns a value greater
137+
// > than or equal to 0. When the function returns MBEDTLS_ERR_SSL_WANT_WRITE
138+
// > there may be some partial data in the output buffer, however this is not
139+
// > yet sent.
140+
//
141+
// WriteTracker is used to ensure we pass the same data in that scenario.
142+
//
143+
// Reference:
144+
// https://tls.mbed.org/api/ssl_8h.html#a5bbda87d484de82df730758b475f32e5
145+
struct WriteTracker {
146+
pending: Option<Box<DigestAndLen>>,
147+
}
148+
149+
struct DigestAndLen {
150+
digest: [u8; 20], // SHA-1
151+
len: usize,
152+
}
153+
154+
impl WriteTracker {
155+
fn new() -> Self {
156+
WriteTracker {
157+
pending: None,
158+
}
159+
}
160+
161+
fn digest(buf: &[u8]) -> [u8; 20] {
162+
use crate::hash::{Md, Type};
163+
let mut out = [0u8; 20];
164+
let res = Md::hash(Type::Sha1, buf, &mut out[..]);
165+
assert_eq!(res, Ok(out.len()));
166+
out
167+
}
168+
169+
fn adjust_buf<'a>(&self, buf: &'a [u8]) -> io::Result<&'a [u8]> {
170+
match self.pending.as_ref() {
171+
None => Ok(buf),
172+
Some(pending) => {
173+
if pending.len <= buf.len() {
174+
let buf = &buf[..pending.len];
175+
if Self::digest(buf) == pending.digest {
176+
return Ok(buf);
177+
}
178+
}
179+
Err(io::Error::new(
180+
io::ErrorKind::Other,
181+
"mbedtls expects the same data if the previous call to poll_write() returned Poll::Pending"
182+
))
183+
},
184+
}
185+
}
186+
187+
fn post_write(&mut self, buf: &[u8], res: &Poll<io::Result<usize>>) {
188+
match res {
189+
&Poll::Pending => {
190+
if self.pending.is_none() {
191+
self.pending = Some(Box::new(DigestAndLen {
192+
digest: Self::digest(buf),
193+
len: buf.len(),
194+
}));
195+
}
196+
},
197+
_ => {
198+
self.pending = None;
199+
}
200+
}
201+
}
202+
}
203+
133204
pub struct AsyncSession<'ctx> {
134205
session: Option<SessionWrapper<'ctx>>,
135206
ecx: ErasedContext,
207+
write_tracker: WriteTracker,
136208
}
137209

138210
unsafe impl<'c> Send for AsyncSession<'c> {}
@@ -232,7 +304,13 @@ impl AsyncWrite for AsyncSession<'_> {
232304
cx: &mut TaskContext<'_>,
233305
buf: &[u8],
234306
) -> Poll<io::Result<usize>> {
235-
self.with_context(cx, |s| s.write(buf))
307+
let buf = match self.write_tracker.adjust_buf(buf) {
308+
Ok(buf) => buf,
309+
Err(e) => return Poll::Ready(Err(e)),
310+
};
311+
let res = self.with_context(cx, |s| s.write(buf));
312+
self.write_tracker.post_write(buf, &res);
313+
res
236314
}
237315

238316
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<io::Result<()>> {
@@ -269,6 +347,7 @@ impl<'ctx, S: AsyncRead + AsyncWrite + Unpin> Future for StartHandshake<'_, 'ctx
269347
Ok(session) => Ok(HandshakeState::Ready(AsyncSession {
270348
session: Some(session.into()),
271349
ecx,
350+
write_tracker: WriteTracker::new(),
272351
})),
273352
Err(HandshakeError::WouldBlock(mid, _err)) => Ok(HandshakeState::InProgress(
274353
MidHandshakeFuture(Some(MidHandshakeFutureInner { mid, ecx })),
@@ -299,6 +378,7 @@ impl<'c> Future for MidHandshakeFuture<'c> {
299378
Ok(AsyncSession {
300379
session: Some(session.into()),
301380
ecx: inner.ecx,
381+
write_tracker: WriteTracker::new(),
302382
}).into()
303383
}
304384
Err(HandshakeError::WouldBlock(mid, _err)) => {

0 commit comments

Comments
 (0)