diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index ee30f8da6f..084095318b 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -156,6 +156,10 @@ pub use self::ready_chunks::ReadyChunks; mod scan; pub use self::scan::Scan; +mod try_scan; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::try_scan::TryScan; + #[cfg_attr(target_os = "none", cfg(target_has_atomic = "ptr"))] #[cfg(feature = "alloc")] mod buffer_unordered; @@ -1038,6 +1042,52 @@ pub trait StreamExt: Stream { assert_stream::(TakeUntil::new(self, fut)) } + /// Combinator similar to [`fold`](StreamExt::fold) that holds internal state + /// and produces a new stream. + /// + /// Accepts initial state and closure which will be applied to each element + /// of the stream until provided closure returns `None`. Once `None` is + /// returned, stream will be terminated. + /// + /// This method is similar to [`scan`](StreamExt::scan), but will + /// exit early if an error is encountered in either the stream or the + /// provided closure. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::future; + /// use futures::stream::{self, StreamExt, TryStreamExt}; + /// + /// let stream = stream::iter(1..=10); + /// + /// let stream = stream.try_scan(0, |mut state, x| { + /// state += x; + /// future::ready(if state < 10 { Ok::<_, ()>(Some((state, x))) } else { Ok(None) }) + /// }); + /// + /// assert_eq!(Ok(vec![1, 2, 3]), stream.try_collect::>().await); + /// + /// let stream = stream::iter(1..=10); + /// + /// let stream = stream.try_scan(0, |mut state, x| { + /// state += x; + /// future::ready(if state < 10 { Ok(Some((state, x))) } else { Err(()) }) + /// }); + /// + /// assert_eq!(Err(()), stream.try_collect::>().await); + /// # }); + /// ``` + fn try_scan(self, initial_state: S, f: F) -> TryScan + where + F: FnMut(S, Self::Item) -> Fut, + Fut: TryFuture>, + Self: Sized, + { + assert_stream::, _>(TryScan::new(self, initial_state, f)) + } + /// Runs this stream to completion, executing the provided asynchronous /// closure for each element on the stream. /// diff --git a/futures-util/src/stream/stream/try_scan.rs b/futures-util/src/stream/stream/try_scan.rs new file mode 100644 index 0000000000..5a3aa93bcb --- /dev/null +++ b/futures-util/src/stream/stream/try_scan.rs @@ -0,0 +1,127 @@ +use crate::unfold_state::UnfoldState; +use core::fmt; +use core::pin::Pin; +use futures_core::stream::{FusedStream, Stream}; +use futures_core::task::{Context, Poll}; +use futures_core::{ready, TryFuture}; +#[cfg(feature = "sink")] +use futures_sink::Sink; +use pin_project_lite::pin_project; + +pin_project! { + /// TryStream for the [`try_scan`](super::StreamExt::try_scan) method. + #[must_use = "streams do nothing unless polled"] + pub struct TryScan { + #[pin] + stream: St, + f: F, + #[pin] + state: UnfoldState, + } +} + +impl fmt::Debug for TryScan +where + St: Stream + fmt::Debug, + St::Item: fmt::Debug, + S: fmt::Debug, + Fut: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TryScan") + .field("stream", &self.stream) + .field("state", &self.state) + .field("done_taking", &self.is_done_taking()) + .finish() + } +} + +impl TryScan { + /// Checks if internal state is `None`. + fn is_done_taking(&self) -> bool { + matches!(self.state, UnfoldState::Empty) + } +} + +impl TryScan +where + St: Stream, + F: FnMut(S, St::Item) -> Fut, + Fut: TryFuture>, +{ + pub(super) fn new(stream: St, initial_state: S, f: F) -> Self { + Self { stream, f, state: UnfoldState::Value { value: initial_state } } + } +} + +impl Stream for TryScan +where + St: Stream, + F: FnMut(S, St::Item) -> Fut, + Fut: TryFuture>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_done_taking() { + return Poll::Ready(None); + } + + let mut this = self.project(); + + Poll::Ready(loop { + if let Some(fut) = this.state.as_mut().project_future() { + match ready!(fut.try_poll(cx)) { + Ok(None) => { + this.state.set(UnfoldState::Empty); + break None; + } + Ok(Some((state, item))) => { + this.state.set(UnfoldState::Value { value: state }); + break Some(Ok(item)); + } + Err(e) => { + this.state.set(UnfoldState::Empty); + break Some(Err(e)); + } + } + } else if let Some(item) = ready!(this.stream.as_mut().poll_next(cx)) { + let state = this.state.as_mut().take_value().unwrap(); + this.state.set(UnfoldState::Future { future: (this.f)(state, item) }) + } else { + break None; + } + }) + } + + fn size_hint(&self) -> (usize, Option) { + if self.is_done_taking() { + (0, Some(0)) + } else { + self.stream.size_hint() // can't know a lower bound, due to the predicate + } + } +} + +impl FusedStream for TryScan +where + St: FusedStream, + F: FnMut(S, St::Item) -> Fut, + Fut: TryFuture>, +{ + fn is_terminated(&self) -> bool { + self.is_done_taking() + || !matches!(self.state, UnfoldState::Future { .. }) && self.stream.is_terminated() + } +} + +// Forwarding impl of Sink from the underlying stream +#[cfg(feature = "sink")] +impl Sink for TryScan +where + St: Stream + Sink, +{ + type Error = St::Error; + + delegate_sink!(stream, Item); +}