iota_common/
stream_ext.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{future::Future, panic, pin::pin};
6
7use futures::stream::{Stream, StreamExt};
8use tokio::task::JoinSet;
9
10/// Extension trait introducing `try_for_each_spawned` to all streams.
11pub trait TrySpawnStreamExt: Stream {
12    /// Attempts to run this stream to completion, executing the provided
13    /// asynchronous closure on each element from the stream as elements
14    /// become available.
15    ///
16    /// This is similar to [StreamExt::for_each_concurrent], but it may take
17    /// advantage of any parallelism available in the underlying runtime,
18    /// because each unit of work is spawned as its own tokio task.
19    ///
20    /// The first argument is an optional limit on the number of tasks to spawn
21    /// concurrently. Values of `0` and `None` are interpreted as no limit,
22    /// and any other value will result in no more than that many tasks
23    /// being spawned at one time.
24    ///
25    /// ## Safety
26    ///
27    /// This function will panic if any of its futures panics, will return early
28    /// with success if the runtime it is running on is cancelled, and will
29    /// return early with an error propagated from any worker that produces
30    /// an error.
31    fn try_for_each_spawned<Fut, F, E>(
32        self,
33        limit: impl Into<Option<usize>>,
34        f: F,
35    ) -> impl Future<Output = Result<(), E>>
36    where
37        Fut: Future<Output = Result<(), E>> + Send + 'static,
38        F: FnMut(Self::Item) -> Fut,
39        E: Send + 'static;
40}
41
42impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
43    async fn try_for_each_spawned<Fut, F, E>(
44        self,
45        limit: impl Into<Option<usize>>,
46        mut f: F,
47    ) -> Result<(), E>
48    where
49        Fut: Future<Output = Result<(), E>> + Send + 'static,
50        F: FnMut(Self::Item) -> Fut,
51        E: Send + 'static,
52    {
53        // Maximum number of tasks to spawn concurrently.
54        let limit = match limit.into() {
55            Some(0) | None => usize::MAX,
56            Some(n) => n,
57        };
58
59        // Number of permits to spawn tasks left.
60        let mut permits = limit;
61        // Handles for already spawned tasks.
62        let mut join_set = JoinSet::new();
63        // Whether the worker pool has stopped accepting new items and is draining.
64        let mut draining = false;
65        // Error that occurred in one of the workers, to be propagated to the called on
66        // exit.
67        let mut error = None;
68
69        let mut self_ = pin!(self);
70
71        loop {
72            tokio::select! {
73                next = self_.next(), if !draining && permits > 0 => {
74                    if let Some(item) = next {
75                        permits -= 1;
76                        join_set.spawn(f(item));
77                    } else {
78                        // If the stream is empty, signal that the worker pool is going to
79                        // start draining now, so that once we get all our permits back, we
80                        // know we can wind down the pool.
81                        draining = true;
82                    }
83                }
84
85                Some(res) = join_set.join_next() => {
86                    match res {
87                        Ok(Err(e)) if error.is_none() => {
88                            error = Some(e);
89                            permits += 1;
90                            draining = true;
91                        }
92
93                        Ok(_) => permits += 1,
94
95                        // Worker panicked, propagate the panic.
96                        Err(e) if e.is_panic() => {
97                            panic::resume_unwind(e.into_panic())
98                        }
99
100                        // Worker was cancelled -- this can only happen if its join handle was
101                        // cancelled (not possible because that was created in this function),
102                        // or the runtime it was running in was wound down, in which case,
103                        // prepare the worker pool to drain.
104                        Err(e) => {
105                            assert!(e.is_cancelled());
106                            permits += 1;
107                            draining = true;
108                        }
109                    }
110                }
111
112                else => {
113                    // Not accepting any more items from the stream, and all our workers are
114                    // idle, so we stop.
115                    if permits == limit && draining {
116                        break;
117                    }
118                }
119            }
120        }
121
122        if let Some(e) = error { Err(e) } else { Ok(()) }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use std::{
129        sync::{
130            Arc, Mutex,
131            atomic::{AtomicUsize, Ordering},
132        },
133        time::Duration,
134    };
135
136    use futures::stream;
137
138    use super::*;
139
140    #[tokio::test]
141    async fn explicit_sequential_iteration() {
142        let actual = Arc::new(Mutex::new(vec![]));
143        let result = stream::iter(0..20)
144            .try_for_each_spawned(1, |i| {
145                let actual = actual.clone();
146                async move {
147                    tokio::time::sleep(Duration::from_millis(20 - i)).await;
148                    actual.lock().unwrap().push(i);
149                    Ok::<(), ()>(())
150                }
151            })
152            .await;
153
154        assert!(result.is_ok());
155
156        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
157        let expect: Vec<_> = (0..20).collect();
158        assert_eq!(expect, actual);
159    }
160
161    #[tokio::test]
162    async fn concurrent_iteration() {
163        let actual = Arc::new(AtomicUsize::new(0));
164        let result = stream::iter(0..100)
165            .try_for_each_spawned(16, |i| {
166                let actual = actual.clone();
167                async move {
168                    actual.fetch_add(i, Ordering::Relaxed);
169                    Ok::<(), ()>(())
170                }
171            })
172            .await;
173
174        assert!(result.is_ok());
175
176        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
177        let expect = 99 * 100 / 2;
178        assert_eq!(expect, actual);
179    }
180
181    #[tokio::test]
182    async fn implicit_unlimited_iteration() {
183        let actual = Arc::new(AtomicUsize::new(0));
184        let result = stream::iter(0..100)
185            .try_for_each_spawned(None, |i| {
186                let actual = actual.clone();
187                async move {
188                    actual.fetch_add(i, Ordering::Relaxed);
189                    Ok::<(), ()>(())
190                }
191            })
192            .await;
193
194        assert!(result.is_ok());
195
196        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
197        let expect = 99 * 100 / 2;
198        assert_eq!(expect, actual);
199    }
200
201    #[tokio::test]
202    async fn explicit_unlimited_iteration() {
203        let actual = Arc::new(AtomicUsize::new(0));
204        let result = stream::iter(0..100)
205            .try_for_each_spawned(0, |i| {
206                let actual = actual.clone();
207                async move {
208                    actual.fetch_add(i, Ordering::Relaxed);
209                    Ok::<(), ()>(())
210                }
211            })
212            .await;
213
214        assert!(result.is_ok());
215
216        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
217        let expect = 99 * 100 / 2;
218        assert_eq!(expect, actual);
219    }
220
221    #[tokio::test]
222    async fn max_concurrency() {
223        #[derive(Default, Debug)]
224        struct Jobs {
225            max: AtomicUsize,
226            curr: AtomicUsize,
227        }
228
229        let jobs = Arc::new(Jobs::default());
230
231        let result = stream::iter(0..32)
232            .try_for_each_spawned(4, |_| {
233                let jobs = jobs.clone();
234                async move {
235                    jobs.curr.fetch_add(1, Ordering::Relaxed);
236                    tokio::time::sleep(Duration::from_millis(100)).await;
237                    let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
238                    jobs.max.fetch_max(prev, Ordering::Relaxed);
239                    Ok::<(), ()>(())
240                }
241            })
242            .await;
243
244        assert!(result.is_ok());
245
246        let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
247        assert_eq!(curr.into_inner(), 0);
248        assert!(max.into_inner() <= 4);
249    }
250
251    #[tokio::test]
252    async fn error_propagation() {
253        let actual = Arc::new(Mutex::new(vec![]));
254        let result = stream::iter(0..100)
255            .try_for_each_spawned(None, |i| {
256                let actual = actual.clone();
257                async move {
258                    if i < 42 {
259                        actual.lock().unwrap().push(i);
260                        Ok(())
261                    } else {
262                        Err(())
263                    }
264                }
265            })
266            .await;
267
268        assert!(result.is_err());
269
270        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
271        let expect: Vec<_> = (0..42).collect();
272        assert_eq!(expect, actual);
273    }
274
275    #[tokio::test]
276    #[should_panic]
277    async fn panic_propagation() {
278        let _ = stream::iter(0..100)
279            .try_for_each_spawned(None, |i| async move {
280                assert!(i < 42);
281                Ok::<(), ()>(())
282            })
283            .await;
284    }
285}