iota_data_ingestion_core/
worker_pool.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{BTreeSet, HashMap, VecDeque},
7    fmt::Debug,
8    sync::Arc,
9    time::Instant,
10};
11
12use backoff::{ExponentialBackoff, backoff::Backoff};
13use futures::StreamExt;
14use iota_metrics::spawn_monitored_task;
15use iota_rest_api::CheckpointData;
16use iota_types::messages_checkpoint::CheckpointSequenceNumber;
17use tokio::{sync::mpsc, task::JoinHandle};
18use tokio_stream::wrappers::ReceiverStream;
19use tokio_util::sync::CancellationToken;
20use tracing::{info, warn};
21
22use crate::{
23    IngestionError, IngestionResult, Reducer, Worker, executor::MAX_CHECKPOINTS_IN_PROGRESS,
24    reducer::reduce, util::reset_backoff,
25};
26
27type TaskName = String;
28type WorkerID = usize;
29
30/// Represents the possible message types a [`WorkerPool`] can communicate with
31/// external components.
32#[derive(Debug, Clone)]
33pub enum WorkerPoolStatus {
34    /// Message with information (e.g. `(<task-name>,
35    /// checkpoint_sequence_number)`) about the ingestion progress.
36    Running((TaskName, CheckpointSequenceNumber)),
37    /// Message with information (e.g. `<task-name>`) about shutdown status.
38    Shutdown(String),
39}
40
41/// Represents the possible message types a [`Worker`] can communicate with
42/// external components
43#[derive(Debug, Clone, Copy)]
44enum WorkerStatus<M> {
45    /// Message with information (e.g. `(<worker-id>`,
46    /// `checkpoint_sequence_number`, [`Worker::Message`]) about the ingestion
47    /// progress.
48    Running((WorkerID, CheckpointSequenceNumber, M)),
49    /// Message with information (e.g. `<worker-id>`) about shutdown status.
50    Shutdown(WorkerID),
51}
52
53/// A pool of [`Worker`]'s that process checkpoints concurrently.
54///
55/// This struct manages a collection of workers that process checkpoints in
56/// parallel. It handles checkpoint distribution, progress tracking, and
57/// graceful shutdown. It can optionally use a [`Reducer`] to aggregate and
58/// process worker [`Messages`](Worker::Message).
59///
60/// # Examples
61/// ## Direct Processing (Without Batching)
62/// ```rust,no_run
63/// use std::sync::Arc;
64///
65/// use async_trait::async_trait;
66/// use iota_data_ingestion_core::{Worker, WorkerPool};
67/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
68/// #
69/// # struct DatabaseClient;
70/// #
71/// # impl DatabaseClient {
72/// #     pub fn new() -> Self {
73/// #         Self
74/// #     }
75/// #
76/// #     pub async fn store_transaction(&self,
77/// #         _transactions: &CheckpointTransaction,
78/// #     ) -> Result<(), DatabaseError> {
79/// #         Ok(())
80/// #     }
81/// # }
82/// #
83/// # #[derive(Debug, Clone)]
84/// # struct DatabaseError;
85/// #
86/// # impl std::fmt::Display for DatabaseError {
87/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
88/// #         write!(f, "database error")
89/// #     }
90/// # }
91/// #
92/// # fn extract_transaction(checkpoint: &CheckpointData) -> CheckpointTransaction {
93/// #     checkpoint.transactions.first().unwrap().clone()
94/// # }
95///
96/// struct DirectProcessor {
97///     // generic Database client.
98///     client: Arc<DatabaseClient>,
99/// }
100///
101/// #[async_trait]
102/// impl Worker for DirectProcessor {
103///     type Message = ();
104///     type Error = DatabaseError;
105///
106///     async fn process_checkpoint(
107///         &self,
108///         checkpoint: Arc<CheckpointData>,
109///     ) -> Result<Self::Message, Self::Error> {
110///         // extract a particulat transaction we care about.
111///         let tx: CheckpointTransaction = extract_transaction(checkpoint.as_ref());
112///         // store the transaction in our database of choice.
113///         self.client.store_transaction(&tx).await?;
114///         Ok(())
115///     }
116/// }
117///
118/// // configure worker pool for direct processing.
119/// let processor = DirectProcessor {
120///     client: Arc::new(DatabaseClient::new()),
121/// };
122/// let pool = WorkerPool::new(processor, "direct_processing".into(), 5, Default::default());
123/// ```
124///
125/// ## Batch Processing (With Reducer)
126/// ```rust,no_run
127/// use std::sync::Arc;
128///
129/// use async_trait::async_trait;
130/// use iota_data_ingestion_core::{Reducer, Worker, WorkerPool};
131/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
132/// # struct DatabaseClient;
133/// #
134/// # impl DatabaseClient {
135/// #     pub fn new() -> Self {
136/// #         Self
137/// #     }
138/// #
139/// #     pub async fn store_transactions_batch(&self,
140/// #         _transactions: &Vec<CheckpointTransaction>,
141/// #     ) -> Result<(), DatabaseError> {
142/// #         Ok(())
143/// #     }
144/// # }
145/// #
146/// # #[derive(Debug, Clone)]
147/// # struct DatabaseError;
148/// #
149/// # impl std::fmt::Display for DatabaseError {
150/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
151/// #         write!(f, "database error")
152/// #     }
153/// # }
154///
155/// // worker that accumulates transactions for batch processing.
156/// struct BatchProcessor;
157///
158/// #[async_trait]
159/// impl Worker for BatchProcessor {
160///     type Message = Vec<CheckpointTransaction>;
161///     type Error = DatabaseError;
162///
163///     async fn process_checkpoint(
164///         &self,
165///         checkpoint: Arc<CheckpointData>,
166///     ) -> Result<Self::Message, Self::Error> {
167///         // collect all checkpoint transactions for batch processing.
168///         Ok(checkpoint.transactions.clone())
169///     }
170/// }
171///
172/// // batch reducer for efficient storage.
173/// struct TransactionBatchReducer {
174///     batch_size: usize,
175///     // generic Database client.
176///     client: Arc<DatabaseClient>,
177/// }
178///
179/// #[async_trait]
180/// impl Reducer<BatchProcessor> for TransactionBatchReducer {
181///     async fn commit(&self, batch: &[Vec<CheckpointTransaction>]) -> Result<(), DatabaseError> {
182///         let flattened: Vec<CheckpointTransaction> = batch.iter().flatten().cloned().collect();
183///         // store the transaction batch in the database of choice.
184///         self.client.store_transactions_batch(&flattened).await?;
185///         Ok(())
186///     }
187///
188///     fn should_close_batch(
189///         &self,
190///         batch: &[Vec<CheckpointTransaction>],
191///         _: Option<&Vec<CheckpointTransaction>>,
192///     ) -> bool {
193///         batch.iter().map(|b| b.len()).sum::<usize>() >= self.batch_size
194///     }
195/// }
196///
197/// // configure worker pool with batch processing.
198/// let processor = BatchProcessor;
199/// let reducer = TransactionBatchReducer {
200///     batch_size: 1000,
201///     client: Arc::new(DatabaseClient::new()),
202/// };
203/// let pool = WorkerPool::new_with_reducer(
204///     processor,
205///     "batch_processing".into(),
206///     5,
207///     Default::default(),
208///     reducer,
209/// );
210/// ```
211pub struct WorkerPool<W: Worker> {
212    /// An unique name of the WorkerPool task.
213    pub task_name: String,
214    /// How many instances of the current [`Worker`] to create, more workers are
215    /// created more checkpoints they can process in parallel.
216    concurrency: usize,
217    /// The actual [`Worker`] instance itself.
218    worker: Arc<W>,
219    /// The reducer instance, responsible for batch processing.
220    reducer: Option<Box<dyn Reducer<W>>>,
221    backoff: Arc<ExponentialBackoff>,
222}
223
224impl<W: Worker + 'static> WorkerPool<W> {
225    /// Creates a new `WorkerPool` without a reducer.
226    pub fn new(
227        worker: W,
228        task_name: String,
229        concurrency: usize,
230        backoff: ExponentialBackoff,
231    ) -> Self {
232        Self {
233            task_name,
234            concurrency,
235            worker: Arc::new(worker),
236            reducer: None,
237            backoff: Arc::new(backoff),
238        }
239    }
240
241    /// Creates a new `WorkerPool` with a reducer.
242    pub fn new_with_reducer<R>(
243        worker: W,
244        task_name: String,
245        concurrency: usize,
246        backoff: ExponentialBackoff,
247        reducer: R,
248    ) -> Self
249    where
250        R: Reducer<W> + 'static,
251    {
252        Self {
253            task_name,
254            concurrency,
255            worker: Arc::new(worker),
256            reducer: Some(Box::new(reducer)),
257            backoff: Arc::new(backoff),
258        }
259    }
260
261    /// Runs the worker pool main logic.
262    pub async fn run(
263        mut self,
264        watermark: CheckpointSequenceNumber,
265        mut checkpoint_receiver: mpsc::Receiver<Arc<CheckpointData>>,
266        pool_status_sender: mpsc::Sender<WorkerPoolStatus>,
267        token: CancellationToken,
268    ) {
269        info!(
270            "Starting indexing pipeline {} with concurrency {}. Current watermark is {watermark}.",
271            self.task_name, self.concurrency
272        );
273        // This channel will be used to send progress data from Workers to WorkerPool
274        // main loop.
275        let (progress_sender, mut progress_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
276        // This channel will be used to send Workers progress data from WorkerPool to
277        // watermark tracking task.
278        let (watermark_sender, watermark_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
279        let mut idle: BTreeSet<_> = (0..self.concurrency).collect();
280        let mut checkpoints = VecDeque::new();
281        let mut workers_shutdown_signals = vec![];
282        let (workers, workers_join_handles) = self.spawn_workers(progress_sender, token.clone());
283        // Spawn a task that tracks checkpoint processing progress. The task:
284        // - Receives (checkpoint_number, message) pairs from workers.
285        // - Maintains checkpoint sequence order.
286        // - Reports progress either:
287        //   * After processing each chunk (simple tracking).
288        //   * After committing batches (with reducer).
289        let watermark_handle = self.spawn_watermark_tracking(
290            watermark,
291            watermark_receiver,
292            pool_status_sender.clone(),
293            token.clone(),
294        );
295        // main worker pool loop.
296        loop {
297            tokio::select! {
298                Some(worker_progress_msg) = progress_receiver.recv() => {
299                    match worker_progress_msg {
300                        WorkerStatus::Running((worker_id, checkpoint_number, message)) => {
301                            idle.insert(worker_id);
302                            // Try to send progress to reducer. If it fails (reducer has exited),
303                            // we just continue - we still need to wait for all workers to shutdown.
304                            let _ = watermark_sender.send((checkpoint_number, message)).await;
305
306                            // By checking if token was not cancelled we ensure that no
307                            // further checkpoints will be sent to the workers.
308                            while !token.is_cancelled() && !checkpoints.is_empty() && !idle.is_empty() {
309                                let checkpoint = checkpoints.pop_front().unwrap();
310                                let worker_id = idle.pop_first().unwrap();
311                                if workers[worker_id].send(checkpoint).await.is_err() {
312                                    // The worker channel closing is a sign we need to exit this inner loop.
313                                    break;
314                                }
315                            }
316                        }
317                        WorkerStatus::Shutdown(worker_id) => {
318                            // Track workers that have initiated shutdown.
319                            workers_shutdown_signals.push(worker_id);
320                        }
321                    }
322                }
323                // Adding an if guard to this branch ensure that no checkpoints
324                // will be sent to workers once the token has been cancelled.
325                Some(checkpoint) = checkpoint_receiver.recv(), if !token.is_cancelled() => {
326                    let sequence_number = checkpoint.checkpoint_summary.sequence_number;
327                    if sequence_number < watermark {
328                        continue;
329                    }
330                    self.worker
331                        .preprocess_hook(checkpoint.clone())
332                        .map_err(|err| IngestionError::CheckpointHookProcessing(err.to_string()))
333                        .expect("failed to preprocess task");
334                    if idle.is_empty() {
335                        checkpoints.push_back(checkpoint);
336                    } else {
337                        let worker_id = idle.pop_first().unwrap();
338                        // If worker channel is closed, put the checkpoint back in queue
339                        // and continue - we still need to wait for all worker shutdown signals.
340                        if let Err(send_error) = workers[worker_id].send(checkpoint).await {
341                            checkpoints.push_front(send_error.0);
342                        };
343                    }
344                }
345            }
346            // Once all workers have signaled completion, start the graceful shutdown
347            // process.
348            if workers_shutdown_signals.len() == self.concurrency {
349                break self
350                    .workers_graceful_shutdown(
351                        workers_join_handles,
352                        watermark_handle,
353                        pool_status_sender,
354                        watermark_sender,
355                    )
356                    .await;
357            }
358        }
359    }
360
361    /// Spawn workers based on `self.concurrency` to process checkpoints
362    /// in parallel.
363    fn spawn_workers(
364        &self,
365        progress_sender: mpsc::Sender<WorkerStatus<W::Message>>,
366        token: CancellationToken,
367    ) -> (Vec<mpsc::Sender<Arc<CheckpointData>>>, Vec<JoinHandle<()>>) {
368        let mut worker_senders = Vec::with_capacity(self.concurrency);
369        let mut workers_join_handles = Vec::with_capacity(self.concurrency);
370
371        for worker_id in 0..self.concurrency {
372            let (worker_sender, mut worker_recv) =
373                mpsc::channel::<Arc<CheckpointData>>(MAX_CHECKPOINTS_IN_PROGRESS);
374            let cloned_progress_sender = progress_sender.clone();
375            let task_name = self.task_name.clone();
376            worker_senders.push(worker_sender);
377
378            let token = token.clone();
379
380            let worker = self.worker.clone();
381            let backoff = self.backoff.clone();
382            let join_handle = spawn_monitored_task!(async move {
383                loop {
384                    tokio::select! {
385                        // Once token is cancelled, notify worker's shutdown to the main loop
386                        _ = token.cancelled() => {
387                            _ = cloned_progress_sender.send(WorkerStatus::Shutdown(worker_id)).await;
388                            break
389                        },
390                        Some(checkpoint) = worker_recv.recv() => {
391                            let sequence_number = checkpoint.checkpoint_summary.sequence_number;
392                            info!("received checkpoint for processing {} for workflow {}", sequence_number, task_name);
393                            let start_time = Instant::now();
394                            let status = Self::process_checkpoint_with_retry(worker_id, &worker, checkpoint, reset_backoff(&backoff), &token).await;
395                            let trigger_shutdown = matches!(status, WorkerStatus::Shutdown(_));
396                            if cloned_progress_sender.send(status).await.is_err() || trigger_shutdown {
397                                break;
398                            }
399                            info!(
400                                "finished checkpoint processing {sequence_number} for workflow {task_name} in {:?}",
401                                start_time.elapsed()
402                            );
403                        }
404                    }
405                }
406            });
407            // Keep all join handles to ensure all workers are terminated before exiting
408            workers_join_handles.push(join_handle);
409        }
410        (worker_senders, workers_join_handles)
411    }
412
413    /// Attempts to process a checkpoint with exponential backoff retries on
414    /// failure.
415    ///
416    /// This function repeatedly calls the
417    /// [`process_checkpoint`](Worker::process_checkpoint) method of the
418    /// provided [`Worker`] until either:
419    /// - The checkpoint processing succeeds, returning `WorkerStatus::Running`
420    ///   with the processed message.
421    /// - A cancellation signal is received via the [`CancellationToken`],
422    ///   returning `WorkerStatus::Shutdown(<worker-id>)`.
423    /// - All retry attempts are exhausted within backoff's maximum elapsed
424    ///   time, causing a panic.
425    ///
426    /// # Retry Mechanism:
427    /// - Uses [`ExponentialBackoff`](backoff::ExponentialBackoff) to introduce
428    ///   increasing delays between retry attempts.
429    /// - Checks for cancellation both before and after each processing attempt.
430    /// - If a cancellation signal is received during a backoff delay, the
431    ///   function exits immediately with `WorkerStatus::Shutdown(<worker-id>)`.
432    ///
433    /// # Panics:
434    /// - If all retry attempts are exhausted within the backoff's maximum
435    ///   elapsed time, indicating a persistent failure.
436    async fn process_checkpoint_with_retry(
437        worker_id: WorkerID,
438        worker: &W,
439        checkpoint: Arc<CheckpointData>,
440        mut backoff: ExponentialBackoff,
441        token: &CancellationToken,
442    ) -> WorkerStatus<W::Message> {
443        let sequence_number = checkpoint.checkpoint_summary.sequence_number;
444        loop {
445            // check for cancellation before attempting processing.
446            if token.is_cancelled() {
447                return WorkerStatus::Shutdown(worker_id);
448            }
449            // attempt to process checkpoint.
450            match worker.process_checkpoint(checkpoint.clone()).await {
451                Ok(message) => return WorkerStatus::Running((worker_id, sequence_number, message)),
452                Err(err) => {
453                    let err = IngestionError::CheckpointProcessing(err.to_string());
454                    warn!(
455                        "transient worker execution error {err:?} for checkpoint {sequence_number}"
456                    );
457                    // check for cancellation after failed processing.
458                    if token.is_cancelled() {
459                        return WorkerStatus::Shutdown(worker_id);
460                    }
461                }
462            }
463            // get next backoff duration or panic if max retries exceeded.
464            let duration = backoff
465                .next_backoff()
466                .expect("max elapsed time exceeded: checkpoint processing failed for checkpoint {sequence_number}");
467            // if cancellation occurs during backoff wait, exit early with Shutdown.
468            // Otherwise (if timeout expires), continue with the next retry attempt.
469            if tokio::time::timeout(duration, token.cancelled())
470                .await
471                .is_ok()
472            {
473                return WorkerStatus::Shutdown(worker_id);
474            }
475        }
476    }
477
478    /// Spawns a task that tracks the progress of checkpoint processing,
479    /// optionally with message reduction.
480    ///
481    /// This function spawns one of two types of tracking tasks:
482    ///
483    /// 1. Simple Watermark Tracking (when reducer = None):
484    ///    - Reports watermark after processing each chunk.
485    ///
486    /// 2. Batch Processing (when reducer = Some):
487    ///    - Reports progress only after successful batch commits.
488    ///    - A batch is committed based on
489    ///      [`should_close_batch`](Reducer::should_close_batch) policy.
490    fn spawn_watermark_tracking(
491        &mut self,
492        watermark: CheckpointSequenceNumber,
493        watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
494        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
495        token: CancellationToken,
496    ) -> JoinHandle<Result<(), IngestionError>> {
497        let task_name = self.task_name.clone();
498        let backoff = self.backoff.clone();
499        if let Some(reducer) = self.reducer.take() {
500            return spawn_monitored_task!(reduce::<W>(
501                task_name,
502                watermark,
503                watermark_receiver,
504                executor_progress_sender,
505                reducer,
506                backoff,
507                token
508            ));
509        };
510        spawn_monitored_task!(simple_watermark_tracking::<W>(
511            task_name,
512            watermark,
513            watermark_receiver,
514            executor_progress_sender
515        ))
516    }
517
518    /// Start the workers graceful shutdown.
519    ///
520    /// - Awaits all worker handles.
521    /// - Awaits the reducer handle.
522    /// - Send `WorkerPoolStatus::Shutdown(<task-name>)` message notifying
523    ///   external components that Worker Pool has been shutdown.
524    async fn workers_graceful_shutdown(
525        &self,
526        workers_join_handles: Vec<JoinHandle<()>>,
527        watermark_handle: JoinHandle<Result<(), IngestionError>>,
528        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
529        watermark_sender: mpsc::Sender<(u64, <W as Worker>::Message)>,
530    ) {
531        for worker in workers_join_handles {
532            _ = worker
533                .await
534                .inspect_err(|err| tracing::error!("worker task panicked: {err}"));
535        }
536        // by dropping the sender we make sure that the stream will be closed and the
537        // watermark tracker task will exit its loop.
538        drop(watermark_sender);
539        _ = watermark_handle
540            .await
541            .inspect_err(|err| tracing::error!("watermark task panicked: {err}"));
542        _ = executor_progress_sender
543            .send(WorkerPoolStatus::Shutdown(self.task_name.clone()))
544            .await;
545        tracing::info!("Worker pool `{}` terminated gracefully", self.task_name);
546    }
547}
548
549/// Tracks checkpoint progress without reduction logic.
550///
551/// This function maintains a watermark of processed checkpoints by worker:
552/// 1. Receiving batches of progress status from workers.
553/// 2. Processing them in sequence order.
554/// 3. Reporting progress to the executor after each chunk from the stream.
555async fn simple_watermark_tracking<W: Worker>(
556    task_name: String,
557    mut current_checkpoint_number: CheckpointSequenceNumber,
558    watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
559    executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
560) -> IngestionResult<()> {
561    // convert to a stream of MAX_CHECKPOINTS_IN_PROGRESS size. This way, each
562    // iteration of the loop will process all ready messages.
563    let mut stream =
564        ReceiverStream::new(watermark_receiver).ready_chunks(MAX_CHECKPOINTS_IN_PROGRESS);
565    // store unprocessed progress messages from workers.
566    let mut unprocessed = HashMap::new();
567    // track the next unprocessed checkpoint number for reporting progress
568    // after each chunk of messages is received from the stream.
569    let mut progress_update = None;
570
571    while let Some(update_batch) = stream.next().await {
572        unprocessed.extend(update_batch.into_iter());
573        // Process messages sequentially based on checkpoint sequence number.
574        // This ensures in-order processing and maintains progress integrity.
575        while unprocessed.remove(&current_checkpoint_number).is_some() {
576            current_checkpoint_number += 1;
577            progress_update = Some(current_checkpoint_number);
578        }
579        // report progress update to executor.
580        if let Some(watermark) = progress_update.take() {
581            executor_progress_sender
582                .send(WorkerPoolStatus::Running((task_name.clone(), watermark)))
583                .await
584                .map_err(|_| IngestionError::Channel("unable to send worker pool progress updates to executor, receiver half closed".into()))?;
585        }
586    }
587    Ok(())
588}