iota_common/
stream_ext.rs1use std::{future::Future, panic, pin::pin};
6
7use futures::stream::{Stream, StreamExt};
8use tokio::task::JoinSet;
9
10pub trait TrySpawnStreamExt: Stream {
12 fn try_for_each_spawned<Fut, F, E>(
32 self,
33 limit: impl Into<Option<usize>>,
34 f: F,
35 ) -> impl Future<Output = Result<(), E>>
36 where
37 Fut: Future<Output = Result<(), E>> + Send + 'static,
38 F: FnMut(Self::Item) -> Fut,
39 E: Send + 'static;
40}
41
42impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
43 async fn try_for_each_spawned<Fut, F, E>(
44 self,
45 limit: impl Into<Option<usize>>,
46 mut f: F,
47 ) -> Result<(), E>
48 where
49 Fut: Future<Output = Result<(), E>> + Send + 'static,
50 F: FnMut(Self::Item) -> Fut,
51 E: Send + 'static,
52 {
53 let limit = match limit.into() {
55 Some(0) | None => usize::MAX,
56 Some(n) => n,
57 };
58
59 let mut permits = limit;
61 let mut join_set = JoinSet::new();
63 let mut draining = false;
65 let mut error = None;
68
69 let mut self_ = pin!(self);
70
71 loop {
72 tokio::select! {
73 next = self_.next(), if !draining && permits > 0 => {
74 if let Some(item) = next {
75 permits -= 1;
76 join_set.spawn(f(item));
77 } else {
78 draining = true;
82 }
83 }
84
85 Some(res) = join_set.join_next() => {
86 match res {
87 Ok(Err(e)) if error.is_none() => {
88 error = Some(e);
89 permits += 1;
90 draining = true;
91 }
92
93 Ok(_) => permits += 1,
94
95 Err(e) if e.is_panic() => {
97 panic::resume_unwind(e.into_panic())
98 }
99
100 Err(e) => {
105 assert!(e.is_cancelled());
106 permits += 1;
107 draining = true;
108 }
109 }
110 }
111
112 else => {
113 if permits == limit && draining {
116 break;
117 }
118 }
119 }
120 }
121
122 if let Some(e) = error { Err(e) } else { Ok(()) }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use std::{
129 sync::{
130 Arc, Mutex,
131 atomic::{AtomicUsize, Ordering},
132 },
133 time::Duration,
134 };
135
136 use futures::stream;
137
138 use super::*;
139
140 #[tokio::test]
141 async fn explicit_sequential_iteration() {
142 let actual = Arc::new(Mutex::new(vec![]));
143 let result = stream::iter(0..20)
144 .try_for_each_spawned(1, |i| {
145 let actual = actual.clone();
146 async move {
147 tokio::time::sleep(Duration::from_millis(20 - i)).await;
148 actual.lock().unwrap().push(i);
149 Ok::<(), ()>(())
150 }
151 })
152 .await;
153
154 assert!(result.is_ok());
155
156 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
157 let expect: Vec<_> = (0..20).collect();
158 assert_eq!(expect, actual);
159 }
160
161 #[tokio::test]
162 async fn concurrent_iteration() {
163 let actual = Arc::new(AtomicUsize::new(0));
164 let result = stream::iter(0..100)
165 .try_for_each_spawned(16, |i| {
166 let actual = actual.clone();
167 async move {
168 actual.fetch_add(i, Ordering::Relaxed);
169 Ok::<(), ()>(())
170 }
171 })
172 .await;
173
174 assert!(result.is_ok());
175
176 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
177 let expect = 99 * 100 / 2;
178 assert_eq!(expect, actual);
179 }
180
181 #[tokio::test]
182 async fn implicit_unlimited_iteration() {
183 let actual = Arc::new(AtomicUsize::new(0));
184 let result = stream::iter(0..100)
185 .try_for_each_spawned(None, |i| {
186 let actual = actual.clone();
187 async move {
188 actual.fetch_add(i, Ordering::Relaxed);
189 Ok::<(), ()>(())
190 }
191 })
192 .await;
193
194 assert!(result.is_ok());
195
196 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
197 let expect = 99 * 100 / 2;
198 assert_eq!(expect, actual);
199 }
200
201 #[tokio::test]
202 async fn explicit_unlimited_iteration() {
203 let actual = Arc::new(AtomicUsize::new(0));
204 let result = stream::iter(0..100)
205 .try_for_each_spawned(0, |i| {
206 let actual = actual.clone();
207 async move {
208 actual.fetch_add(i, Ordering::Relaxed);
209 Ok::<(), ()>(())
210 }
211 })
212 .await;
213
214 assert!(result.is_ok());
215
216 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
217 let expect = 99 * 100 / 2;
218 assert_eq!(expect, actual);
219 }
220
221 #[tokio::test]
222 async fn max_concurrency() {
223 #[derive(Default, Debug)]
224 struct Jobs {
225 max: AtomicUsize,
226 curr: AtomicUsize,
227 }
228
229 let jobs = Arc::new(Jobs::default());
230
231 let result = stream::iter(0..32)
232 .try_for_each_spawned(4, |_| {
233 let jobs = jobs.clone();
234 async move {
235 jobs.curr.fetch_add(1, Ordering::Relaxed);
236 tokio::time::sleep(Duration::from_millis(100)).await;
237 let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
238 jobs.max.fetch_max(prev, Ordering::Relaxed);
239 Ok::<(), ()>(())
240 }
241 })
242 .await;
243
244 assert!(result.is_ok());
245
246 let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
247 assert_eq!(curr.into_inner(), 0);
248 assert!(max.into_inner() <= 4);
249 }
250
251 #[tokio::test]
252 async fn error_propagation() {
253 let actual = Arc::new(Mutex::new(vec![]));
254 let result = stream::iter(0..100)
255 .try_for_each_spawned(None, |i| {
256 let actual = actual.clone();
257 async move {
258 if i < 42 {
259 actual.lock().unwrap().push(i);
260 Ok(())
261 } else {
262 Err(())
263 }
264 }
265 })
266 .await;
267
268 assert!(result.is_err());
269
270 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
271 let expect: Vec<_> = (0..42).collect();
272 assert_eq!(expect, actual);
273 }
274
275 #[tokio::test]
276 #[should_panic]
277 async fn panic_propagation() {
278 let _ = stream::iter(0..100)
279 .try_for_each_spawned(None, |i| async move {
280 assert!(i < 42);
281 Ok::<(), ()>(())
282 })
283 .await;
284 }
285}