From 1a993e96efa4f17f3382b51197abc297e911d3f2 Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Mon, 3 Mar 2025 17:42:48 -0800 Subject: [PATCH] Make AsyncStream Sync even if the inner future is not The AsyncStream has a few methods that that take &self so we don't necessarily want to unconditionally impl Send for it. However, we only access the inner future by calling Future::poll. We can then make AsyncStream Send in a way that is more obviously safe by wrapping the inner stream in SyncWrapper from [0]. This gets the same effect (AsyncStream is Sync) but uses only unsafe blocks that are more obviously safe. [0]: https://internals.rust-lang.org/t/what-shall-sync-mean-across-an-await/12020/2 --- async-stream/src/async_stream.rs | 5 ++-- async-stream/src/lib.rs | 1 + async-stream/src/sync_wrapper.rs | 51 ++++++++++++++++++++++++++++++++ async-stream/tests/stream.rs | 14 +++++++++ 4 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 async-stream/src/sync_wrapper.rs diff --git a/async-stream/src/async_stream.rs b/async-stream/src/async_stream.rs index ff408ab..090a1bb 100644 --- a/async-stream/src/async_stream.rs +++ b/async-stream/src/async_stream.rs @@ -1,3 +1,4 @@ +use crate::sync_wrapper::SyncWrapper; use crate::yielder::Receiver; use futures_core::{FusedStream, Stream}; @@ -13,7 +14,7 @@ pin_project! { rx: Receiver, done: bool, #[pin] - generator: U, + generator: SyncWrapper, } } @@ -23,7 +24,7 @@ impl AsyncStream { AsyncStream { rx, done: false, - generator, + generator: SyncWrapper::new(generator), } } } diff --git a/async-stream/src/lib.rs b/async-stream/src/lib.rs index 318e404..c1401ef 100644 --- a/async-stream/src/lib.rs +++ b/async-stream/src/lib.rs @@ -158,6 +158,7 @@ mod async_stream; mod next; +mod sync_wrapper; mod yielder; /// Asynchronous stream diff --git a/async-stream/src/sync_wrapper.rs b/async-stream/src/sync_wrapper.rs new file mode 100644 index 0000000..ebef215 --- /dev/null +++ b/async-stream/src/sync_wrapper.rs @@ -0,0 +1,51 @@ +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A wrapper around `T` that only allows mutable access. +/// +/// This allows it to unconditionally implement `Sync`, since there is nothing +/// you can do with an `&SyncWrapper`. +pub(crate) struct SyncWrapper { + inner: T, +} + +impl SyncWrapper { + pub(crate) fn new(value: T) -> Self { + Self { inner: value } + } + + pub(crate) fn get_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + // We can't use pin_project! for this because it generates a project_ref + // method which would allow accessing the inner element + // + // SAFETY: this.inner is guaranteed not to move as long as this lives. + unsafe { self.map_unchecked_mut(|this| &mut this.inner) } + } +} + +// SAFETY: It is not possible to do anything with an &SyncWrapper so it is +// safe for it to be shared between threads. +// +// See [0] for more details. +// +// [0]: https://internals.rust-lang.org/t/what-shall-sync-mean-across-an-await/12020/2 +unsafe impl Sync for SyncWrapper {} + +impl Future for SyncWrapper { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.get_pinned_mut().poll(cx) + } +} + +impl fmt::Debug for SyncWrapper { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // We can't format the inner value (since that would create an &T reference) + // so we just print a placeholder string. + + f.write_str("") + } +} diff --git a/async-stream/tests/stream.rs b/async-stream/tests/stream.rs index abfd1fc..23195e5 100644 --- a/async-stream/tests/stream.rs +++ b/async-stream/tests/stream.rs @@ -1,3 +1,5 @@ +use std::cell::Cell; + use async_stream::stream; use futures_core::stream::{FusedStream, Stream}; @@ -229,6 +231,18 @@ fn inner_try_stream() { }; } +#[test] +fn stream_is_sync() { + fn assert_sync(_: T) {} + + // The stream should be sync even if it contains a non-sync value. + assert_sync(stream! { + let cell = Cell::new(true); + yield 5; + drop(cell); + }); +} + #[rustversion::attr(not(stable), ignore)] #[test] fn test() {