Skip to content

Commit 02d3be6

Browse files
committed
Add parallel_stream::any
1 parent 8182694 commit 02d3be6

File tree

3 files changed

+115
-1
lines changed

3 files changed

+115
-1
lines changed

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ mod par_stream;
5252
pub use from_parallel_stream::FromParallelStream;
5353
pub use from_stream::{from_stream, FromStream};
5454
pub use into_parallel_stream::IntoParallelStream;
55-
pub use par_stream::{ForEach, Map, NextFuture, ParallelStream, Take};
55+
pub use par_stream::{ForEach, Map, NextFuture, ParallelStream, Take, Any};
5656

5757
pub mod prelude;
5858
pub mod vec;

src/par_stream/any.rs

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use async_std::prelude::*;
2+
use async_std::future::Future;
3+
use pin_project_lite::pin_project;
4+
use async_std::task::{self, Context, Poll};
5+
use async_std::sync::{self, Receiver, Sender};
6+
7+
use std::pin::Pin;
8+
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9+
use std::sync::Arc;
10+
11+
use crate::ParallelStream;
12+
13+
pin_project! {
14+
/// Calls a closure on each element until true or exhausted.
15+
#[derive(Debug)]
16+
pub struct Any {
17+
#[pin]
18+
receiver: Receiver<()>,
19+
// Track whether the input stream has been exhausted.
20+
exhausted: Arc<AtomicBool>,
21+
// Count how many tasks are executing.
22+
ref_count: Arc<AtomicU64>,
23+
// Track the boolean value as executed.
24+
value: Arc<AtomicBool>,
25+
}
26+
}
27+
28+
impl Any {
29+
/// Creates a new instance of `Any`.
30+
pub fn new<S, F, Fut>(mut stream: S, mut f: F) -> Self
31+
where
32+
S: ParallelStream,
33+
F: FnMut(S::Item) -> Fut + Send + Sync + Copy + 'static,
34+
Fut: Future<Output = bool> + Send,
35+
{
36+
let exhausted = Arc::new(AtomicBool::new(false));
37+
let value = Arc::new(AtomicBool::new(false));
38+
let ref_count = Arc::new(AtomicU64::new(0));
39+
let (sender, receiver): (Sender<()>, Receiver<()>) = sync::channel(1);
40+
let _limit = stream.get_limit();
41+
42+
// Initialize the return type here to prevent borrowing issues.
43+
let this = Self {
44+
receiver,
45+
exhausted: exhausted.clone(),
46+
ref_count: ref_count.clone(),
47+
value: value.clone(),
48+
};
49+
50+
task::spawn(async move {
51+
while let Some(item) = stream.next().await {
52+
let sender = sender.clone();
53+
let exhausted = exhausted.clone();
54+
let ref_count = ref_count.clone();
55+
let value = value.clone();
56+
57+
ref_count.fetch_add(1, Ordering::SeqCst);
58+
59+
task::spawn(async move {
60+
// Execute the closure.
61+
let res = f(item).await;
62+
63+
// Wake up the receiver if we know we're done.
64+
ref_count.fetch_sub(1, Ordering::SeqCst);
65+
if res {
66+
value.fetch_or(true, Ordering::SeqCst);
67+
sender.send(()).await;
68+
} else if exhausted.load(Ordering::SeqCst) && ref_count.load(Ordering::SeqCst) == 0 {
69+
sender.send(()).await;
70+
}
71+
});
72+
}
73+
74+
// The input stream will no longer yield items.
75+
exhausted.store(true, Ordering::SeqCst);
76+
});
77+
78+
this
79+
}
80+
}
81+
82+
impl Future for Any {
83+
type Output = bool;
84+
85+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86+
let this = self.project();
87+
task::ready!(this.receiver.poll_next(cx));
88+
Poll::Ready(this.value.load(Ordering::SeqCst))
89+
}
90+
}
91+
92+
#[async_std::test]
93+
async fn smoke() {
94+
let s = async_std::stream::repeat(5usize);
95+
let result = crate::from_stream(s)
96+
.take(3)
97+
.any(|n| async move {
98+
n * 2 < 9
99+
})
100+
.await;
101+
102+
assert_eq!(result, false);
103+
}

src/par_stream/mod.rs

+11
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ pub use for_each::ForEach;
99
pub use map::Map;
1010
pub use next::NextFuture;
1111
pub use take::Take;
12+
pub use any::Any;
1213

1314
mod for_each;
1415
mod map;
1516
mod next;
1617
mod take;
18+
mod any;
1719

1820
/// Parallel version of the standard `Stream` trait.
1921
pub trait ParallelStream: Sized + Send + Sync + Unpin + 'static {
@@ -63,6 +65,15 @@ pub trait ParallelStream: Sized + Send + Sync + Unpin + 'static {
6365
ForEach::new(self, f)
6466
}
6567

68+
/// Applies `f` to each item of this stream in parallel and returns true if at least one element satisfies `f`.
69+
fn any<F, Fut>(self, f: F) -> Any
70+
where
71+
F: FnMut(Self::Item) -> Fut + Send + Sync + Copy + 'static,
72+
Fut: Future<Output = bool> + Send,
73+
{
74+
Any::new(self, f)
75+
}
76+
6677
/// Transforms a stream into a collection.
6778
///
6879
///`collect()` can take anything streamable, and turn it into a relevant

0 commit comments

Comments
 (0)