iota_data_ingestion_core/
worker_pool.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{BTreeSet, HashMap, VecDeque},
7    fmt::Debug,
8    sync::Arc,
9    time::Instant,
10};
11
12use backoff::{ExponentialBackoff, backoff::Backoff};
13use futures::StreamExt;
14use iota_metrics::spawn_monitored_task;
15use iota_types::{
16    full_checkpoint_content::CheckpointData, messages_checkpoint::CheckpointSequenceNumber,
17};
18use tokio::{sync::mpsc, task::JoinHandle};
19use tokio_stream::wrappers::ReceiverStream;
20use tokio_util::sync::CancellationToken;
21use tracing::{info, warn};
22
23use crate::{
24    IngestionError, IngestionResult, Reducer, Worker, executor::MAX_CHECKPOINTS_IN_PROGRESS,
25    reducer::reduce, util::reset_backoff,
26};
27
28type TaskName = String;
29type WorkerID = usize;
30
31/// Represents the possible message types a [`WorkerPool`] can communicate with
32/// external components.
33#[derive(Debug, Clone)]
34pub enum WorkerPoolStatus {
35    /// Message with information (e.g. `(<task-name>,
36    /// checkpoint_sequence_number)`) about the ingestion progress.
37    Running((TaskName, CheckpointSequenceNumber)),
38    /// Message with information (e.g. `<task-name>`) about shutdown status.
39    Shutdown(String),
40}
41
42/// Represents the possible message types a [`Worker`] can communicate with
43/// external components
44#[derive(Debug, Clone, Copy)]
45enum WorkerStatus<M> {
46    /// Message with information (e.g. `(<worker-id>`,
47    /// `checkpoint_sequence_number`, [`Worker::Message`]) about the ingestion
48    /// progress.
49    Running((WorkerID, CheckpointSequenceNumber, M)),
50    /// Message with information (e.g. `<worker-id>`) about shutdown status.
51    Shutdown(WorkerID),
52}
53
54/// A pool of [`Worker`]'s that process checkpoints concurrently.
55///
56/// This struct manages a collection of workers that process checkpoints in
57/// parallel. It handles checkpoint distribution, progress tracking, and
58/// graceful shutdown. It can optionally use a [`Reducer`] to aggregate and
59/// process worker [`Messages`](Worker::Message).
60///
61/// # Examples
62/// ## Direct Processing (Without Batching)
63/// ```rust,no_run
64/// use std::sync::Arc;
65///
66/// use async_trait::async_trait;
67/// use iota_data_ingestion_core::{Worker, WorkerPool};
68/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
69/// #
70/// # struct DatabaseClient;
71/// #
72/// # impl DatabaseClient {
73/// #     pub fn new() -> Self {
74/// #         Self
75/// #     }
76/// #
77/// #     pub async fn store_transaction(&self,
78/// #         _transactions: &CheckpointTransaction,
79/// #     ) -> Result<(), DatabaseError> {
80/// #         Ok(())
81/// #     }
82/// # }
83/// #
84/// # #[derive(Debug, Clone)]
85/// # struct DatabaseError;
86/// #
87/// # impl std::fmt::Display for DatabaseError {
88/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
89/// #         write!(f, "database error")
90/// #     }
91/// # }
92/// #
93/// # fn extract_transaction(checkpoint: &CheckpointData) -> CheckpointTransaction {
94/// #     checkpoint.transactions.first().unwrap().clone()
95/// # }
96///
97/// struct DirectProcessor {
98///     // generic Database client.
99///     client: Arc<DatabaseClient>,
100/// }
101///
102/// #[async_trait]
103/// impl Worker for DirectProcessor {
104///     type Message = ();
105///     type Error = DatabaseError;
106///
107///     async fn process_checkpoint(
108///         &self,
109///         checkpoint: Arc<CheckpointData>,
110///     ) -> Result<Self::Message, Self::Error> {
111///         // extract a particulat transaction we care about.
112///         let tx: CheckpointTransaction = extract_transaction(checkpoint.as_ref());
113///         // store the transaction in our database of choice.
114///         self.client.store_transaction(&tx).await?;
115///         Ok(())
116///     }
117/// }
118///
119/// // configure worker pool for direct processing.
120/// let processor = DirectProcessor {
121///     client: Arc::new(DatabaseClient::new()),
122/// };
123/// let pool = WorkerPool::new(processor, "direct_processing".into(), 5, Default::default());
124/// ```
125///
126/// ## Batch Processing (With Reducer)
127/// ```rust,no_run
128/// use std::sync::Arc;
129///
130/// use async_trait::async_trait;
131/// use iota_data_ingestion_core::{Reducer, Worker, WorkerPool};
132/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
133/// # struct DatabaseClient;
134/// #
135/// # impl DatabaseClient {
136/// #     pub fn new() -> Self {
137/// #         Self
138/// #     }
139/// #
140/// #     pub async fn store_transactions_batch(&self,
141/// #         _transactions: &Vec<CheckpointTransaction>,
142/// #     ) -> Result<(), DatabaseError> {
143/// #         Ok(())
144/// #     }
145/// # }
146/// #
147/// # #[derive(Debug, Clone)]
148/// # struct DatabaseError;
149/// #
150/// # impl std::fmt::Display for DatabaseError {
151/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
152/// #         write!(f, "database error")
153/// #     }
154/// # }
155///
156/// // worker that accumulates transactions for batch processing.
157/// struct BatchProcessor;
158///
159/// #[async_trait]
160/// impl Worker for BatchProcessor {
161///     type Message = Vec<CheckpointTransaction>;
162///     type Error = DatabaseError;
163///
164///     async fn process_checkpoint(
165///         &self,
166///         checkpoint: Arc<CheckpointData>,
167///     ) -> Result<Self::Message, Self::Error> {
168///         // collect all checkpoint transactions for batch processing.
169///         Ok(checkpoint.transactions.clone())
170///     }
171/// }
172///
173/// // batch reducer for efficient storage.
174/// struct TransactionBatchReducer {
175///     batch_size: usize,
176///     // generic Database client.
177///     client: Arc<DatabaseClient>,
178/// }
179///
180/// #[async_trait]
181/// impl Reducer<BatchProcessor> for TransactionBatchReducer {
182///     async fn commit(&self, batch: &[Vec<CheckpointTransaction>]) -> Result<(), DatabaseError> {
183///         let flattened: Vec<CheckpointTransaction> = batch.iter().flatten().cloned().collect();
184///         // store the transaction batch in the database of choice.
185///         self.client.store_transactions_batch(&flattened).await?;
186///         Ok(())
187///     }
188///
189///     fn should_close_batch(
190///         &self,
191///         batch: &[Vec<CheckpointTransaction>],
192///         _: Option<&Vec<CheckpointTransaction>>,
193///     ) -> bool {
194///         batch.iter().map(|b| b.len()).sum::<usize>() >= self.batch_size
195///     }
196/// }
197///
198/// // configure worker pool with batch processing.
199/// let processor = BatchProcessor;
200/// let reducer = TransactionBatchReducer {
201///     batch_size: 1000,
202///     client: Arc::new(DatabaseClient::new()),
203/// };
204/// let pool = WorkerPool::new_with_reducer(
205///     processor,
206///     "batch_processing".into(),
207///     5,
208///     Default::default(),
209///     reducer,
210/// );
211/// ```
212pub struct WorkerPool<W: Worker> {
213    /// An unique name of the WorkerPool task.
214    pub task_name: String,
215    /// How many instances of the current [`Worker`] to create, more workers are
216    /// created more checkpoints they can process in parallel.
217    concurrency: usize,
218    /// The actual [`Worker`] instance itself.
219    worker: Arc<W>,
220    /// The reducer instance, responsible for batch processing.
221    reducer: Option<Box<dyn Reducer<W>>>,
222    backoff: Arc<ExponentialBackoff>,
223}
224
225impl<W: Worker + 'static> WorkerPool<W> {
226    /// Creates a new `WorkerPool` without a reducer.
227    pub fn new(
228        worker: W,
229        task_name: String,
230        concurrency: usize,
231        backoff: ExponentialBackoff,
232    ) -> Self {
233        Self {
234            task_name,
235            concurrency,
236            worker: Arc::new(worker),
237            reducer: None,
238            backoff: Arc::new(backoff),
239        }
240    }
241
242    /// Creates a new `WorkerPool` with a reducer.
243    pub fn new_with_reducer<R>(
244        worker: W,
245        task_name: String,
246        concurrency: usize,
247        backoff: ExponentialBackoff,
248        reducer: R,
249    ) -> Self
250    where
251        R: Reducer<W> + 'static,
252    {
253        Self {
254            task_name,
255            concurrency,
256            worker: Arc::new(worker),
257            reducer: Some(Box::new(reducer)),
258            backoff: Arc::new(backoff),
259        }
260    }
261
262    /// Runs the worker pool main logic.
263    pub async fn run(
264        mut self,
265        watermark: CheckpointSequenceNumber,
266        mut checkpoint_receiver: mpsc::Receiver<Arc<CheckpointData>>,
267        pool_status_sender: mpsc::Sender<WorkerPoolStatus>,
268        token: CancellationToken,
269    ) {
270        info!(
271            "Starting indexing pipeline {} with concurrency {}. Current watermark is {watermark}.",
272            self.task_name, self.concurrency
273        );
274        // This channel will be used to send progress data from Workers to WorkerPool
275        // main loop.
276        let (progress_sender, mut progress_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
277        // This channel will be used to send Workers progress data from WorkerPool to
278        // watermark tracking task.
279        let (watermark_sender, watermark_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
280        let mut idle: BTreeSet<_> = (0..self.concurrency).collect();
281        let mut checkpoints = VecDeque::new();
282        let mut workers_shutdown_signals = vec![];
283        let (workers, workers_join_handles) = self.spawn_workers(progress_sender, token.clone());
284        // Spawn a task that tracks checkpoint processing progress. The task:
285        // - Receives (checkpoint_number, message) pairs from workers.
286        // - Maintains checkpoint sequence order.
287        // - Reports progress either:
288        //   * After processing each chunk (simple tracking).
289        //   * After committing batches (with reducer).
290        let watermark_handle = self.spawn_watermark_tracking(
291            watermark,
292            watermark_receiver,
293            pool_status_sender.clone(),
294            token.clone(),
295        );
296        // main worker pool loop.
297        loop {
298            tokio::select! {
299                Some(worker_progress_msg) = progress_receiver.recv() => {
300                    match worker_progress_msg {
301                        WorkerStatus::Running((worker_id, checkpoint_number, message)) => {
302                            idle.insert(worker_id);
303                            // Try to send progress to reducer. If it fails (reducer has exited),
304                            // we just continue - we still need to wait for all workers to shutdown.
305                            let _ = watermark_sender.send((checkpoint_number, message)).await;
306
307                            // By checking if token was not cancelled we ensure that no
308                            // further checkpoints will be sent to the workers.
309                            while !token.is_cancelled() && !checkpoints.is_empty() && !idle.is_empty() {
310                                let checkpoint = checkpoints.pop_front().unwrap();
311                                let worker_id = idle.pop_first().unwrap();
312                                if workers[worker_id].send(checkpoint).await.is_err() {
313                                    // The worker channel closing is a sign we need to exit this inner loop.
314                                    break;
315                                }
316                            }
317                        }
318                        WorkerStatus::Shutdown(worker_id) => {
319                            // Track workers that have initiated shutdown.
320                            workers_shutdown_signals.push(worker_id);
321                        }
322                    }
323                }
324                // Adding an if guard to this branch ensure that no checkpoints
325                // will be sent to workers once the token has been cancelled.
326                Some(checkpoint) = checkpoint_receiver.recv(), if !token.is_cancelled() => {
327                    let sequence_number = checkpoint.checkpoint_summary.sequence_number;
328                    if sequence_number < watermark {
329                        continue;
330                    }
331                    self.worker
332                        .preprocess_hook(checkpoint.clone())
333                        .map_err(|err| IngestionError::CheckpointHookProcessing(err.to_string()))
334                        .expect("failed to preprocess task");
335                    if idle.is_empty() {
336                        checkpoints.push_back(checkpoint);
337                    } else {
338                        let worker_id = idle.pop_first().unwrap();
339                        // If worker channel is closed, put the checkpoint back in queue
340                        // and continue - we still need to wait for all worker shutdown signals.
341                        if let Err(send_error) = workers[worker_id].send(checkpoint).await {
342                            checkpoints.push_front(send_error.0);
343                        };
344                    }
345                }
346            }
347            // Once all workers have signaled completion, start the graceful shutdown
348            // process.
349            if workers_shutdown_signals.len() == self.concurrency {
350                break self
351                    .workers_graceful_shutdown(
352                        workers_join_handles,
353                        watermark_handle,
354                        pool_status_sender,
355                        watermark_sender,
356                    )
357                    .await;
358            }
359        }
360    }
361
362    /// Spawn workers based on `self.concurrency` to process checkpoints
363    /// in parallel.
364    fn spawn_workers(
365        &self,
366        progress_sender: mpsc::Sender<WorkerStatus<W::Message>>,
367        token: CancellationToken,
368    ) -> (Vec<mpsc::Sender<Arc<CheckpointData>>>, Vec<JoinHandle<()>>) {
369        let mut worker_senders = Vec::with_capacity(self.concurrency);
370        let mut workers_join_handles = Vec::with_capacity(self.concurrency);
371
372        for worker_id in 0..self.concurrency {
373            let (worker_sender, mut worker_recv) =
374                mpsc::channel::<Arc<CheckpointData>>(MAX_CHECKPOINTS_IN_PROGRESS);
375            let cloned_progress_sender = progress_sender.clone();
376            let task_name = self.task_name.clone();
377            worker_senders.push(worker_sender);
378
379            let token = token.clone();
380
381            let worker = self.worker.clone();
382            let backoff = self.backoff.clone();
383            let join_handle = spawn_monitored_task!(async move {
384                loop {
385                    tokio::select! {
386                        // Once token is cancelled, notify worker's shutdown to the main loop
387                        _ = token.cancelled() => {
388                            _ = cloned_progress_sender.send(WorkerStatus::Shutdown(worker_id)).await;
389                            break
390                        },
391                        Some(checkpoint) = worker_recv.recv() => {
392                            let sequence_number = checkpoint.checkpoint_summary.sequence_number;
393                            info!("received checkpoint for processing {} for workflow {}", sequence_number, task_name);
394                            let start_time = Instant::now();
395                            let status = Self::process_checkpoint_with_retry(worker_id, &worker, checkpoint, reset_backoff(&backoff), &token).await;
396                            let trigger_shutdown = matches!(status, WorkerStatus::Shutdown(_));
397                            if cloned_progress_sender.send(status).await.is_err() || trigger_shutdown {
398                                break;
399                            }
400                            info!(
401                                "finished checkpoint processing {sequence_number} for workflow {task_name} in {:?}",
402                                start_time.elapsed()
403                            );
404                        }
405                    }
406                }
407            });
408            // Keep all join handles to ensure all workers are terminated before exiting
409            workers_join_handles.push(join_handle);
410        }
411        (worker_senders, workers_join_handles)
412    }
413
414    /// Attempts to process a checkpoint with exponential backoff retries on
415    /// failure.
416    ///
417    /// This function repeatedly calls the
418    /// [`process_checkpoint`](Worker::process_checkpoint) method of the
419    /// provided [`Worker`] until either:
420    /// - The checkpoint processing succeeds, returning `WorkerStatus::Running`
421    ///   with the processed message.
422    /// - A cancellation signal is received via the [`CancellationToken`],
423    ///   returning `WorkerStatus::Shutdown(<worker-id>)`.
424    /// - All retry attempts are exhausted within backoff's maximum elapsed
425    ///   time, causing a panic.
426    ///
427    /// # Retry Mechanism:
428    /// - Uses [`ExponentialBackoff`](backoff::ExponentialBackoff) to introduce
429    ///   increasing delays between retry attempts.
430    /// - Checks for cancellation both before and after each processing attempt.
431    /// - If a cancellation signal is received during a backoff delay, the
432    ///   function exits immediately with `WorkerStatus::Shutdown(<worker-id>)`.
433    ///
434    /// # Panics:
435    /// - If all retry attempts are exhausted within the backoff's maximum
436    ///   elapsed time, indicating a persistent failure.
437    async fn process_checkpoint_with_retry(
438        worker_id: WorkerID,
439        worker: &W,
440        checkpoint: Arc<CheckpointData>,
441        mut backoff: ExponentialBackoff,
442        token: &CancellationToken,
443    ) -> WorkerStatus<W::Message> {
444        let sequence_number = checkpoint.checkpoint_summary.sequence_number;
445        loop {
446            // check for cancellation before attempting processing.
447            if token.is_cancelled() {
448                return WorkerStatus::Shutdown(worker_id);
449            }
450            // attempt to process checkpoint.
451            match worker.process_checkpoint(checkpoint.clone()).await {
452                Ok(message) => return WorkerStatus::Running((worker_id, sequence_number, message)),
453                Err(err) => {
454                    let err = IngestionError::CheckpointProcessing(err.to_string());
455                    warn!(
456                        "transient worker execution error {err:?} for checkpoint {sequence_number}"
457                    );
458                    // check for cancellation after failed processing.
459                    if token.is_cancelled() {
460                        return WorkerStatus::Shutdown(worker_id);
461                    }
462                }
463            }
464            // get next backoff duration or panic if max retries exceeded.
465            let duration = backoff
466                .next_backoff()
467                .expect("max elapsed time exceeded: checkpoint processing failed for checkpoint {sequence_number}");
468            // if cancellation occurs during backoff wait, exit early with Shutdown.
469            // Otherwise (if timeout expires), continue with the next retry attempt.
470            if tokio::time::timeout(duration, token.cancelled())
471                .await
472                .is_ok()
473            {
474                return WorkerStatus::Shutdown(worker_id);
475            }
476        }
477    }
478
479    /// Spawns a task that tracks the progress of checkpoint processing,
480    /// optionally with message reduction.
481    ///
482    /// This function spawns one of two types of tracking tasks:
483    ///
484    /// 1. Simple Watermark Tracking (when reducer = None):
485    ///    - Reports watermark after processing each chunk.
486    ///
487    /// 2. Batch Processing (when reducer = Some):
488    ///    - Reports progress only after successful batch commits.
489    ///    - A batch is committed based on
490    ///      [`should_close_batch`](Reducer::should_close_batch) policy.
491    fn spawn_watermark_tracking(
492        &mut self,
493        watermark: CheckpointSequenceNumber,
494        watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
495        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
496        token: CancellationToken,
497    ) -> JoinHandle<Result<(), IngestionError>> {
498        let task_name = self.task_name.clone();
499        let backoff = self.backoff.clone();
500        if let Some(reducer) = self.reducer.take() {
501            return spawn_monitored_task!(reduce::<W>(
502                task_name,
503                watermark,
504                watermark_receiver,
505                executor_progress_sender,
506                reducer,
507                backoff,
508                token
509            ));
510        };
511        spawn_monitored_task!(simple_watermark_tracking::<W>(
512            task_name,
513            watermark,
514            watermark_receiver,
515            executor_progress_sender
516        ))
517    }
518
519    /// Start the workers graceful shutdown.
520    ///
521    /// - Awaits all worker handles.
522    /// - Awaits the reducer handle.
523    /// - Send `WorkerPoolStatus::Shutdown(<task-name>)` message notifying
524    ///   external components that Worker Pool has been shutdown.
525    async fn workers_graceful_shutdown(
526        &self,
527        workers_join_handles: Vec<JoinHandle<()>>,
528        watermark_handle: JoinHandle<Result<(), IngestionError>>,
529        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
530        watermark_sender: mpsc::Sender<(u64, <W as Worker>::Message)>,
531    ) {
532        for worker in workers_join_handles {
533            _ = worker
534                .await
535                .inspect_err(|err| tracing::error!("worker task panicked: {err}"));
536        }
537        // by dropping the sender we make sure that the stream will be closed and the
538        // watermark tracker task will exit its loop.
539        drop(watermark_sender);
540        _ = watermark_handle
541            .await
542            .inspect_err(|err| tracing::error!("watermark task panicked: {err}"));
543        _ = executor_progress_sender
544            .send(WorkerPoolStatus::Shutdown(self.task_name.clone()))
545            .await;
546        tracing::info!("Worker pool `{}` terminated gracefully", self.task_name);
547    }
548}
549
550/// Tracks checkpoint progress without reduction logic.
551///
552/// This function maintains a watermark of processed checkpoints by worker:
553/// 1. Receiving batches of progress status from workers.
554/// 2. Processing them in sequence order.
555/// 3. Reporting progress to the executor after each chunk from the stream.
556async fn simple_watermark_tracking<W: Worker>(
557    task_name: String,
558    mut current_checkpoint_number: CheckpointSequenceNumber,
559    watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
560    executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
561) -> IngestionResult<()> {
562    // convert to a stream of MAX_CHECKPOINTS_IN_PROGRESS size. This way, each
563    // iteration of the loop will process all ready messages.
564    let mut stream =
565        ReceiverStream::new(watermark_receiver).ready_chunks(MAX_CHECKPOINTS_IN_PROGRESS);
566    // store unprocessed progress messages from workers.
567    let mut unprocessed = HashMap::new();
568    // track the next unprocessed checkpoint number for reporting progress
569    // after each chunk of messages is received from the stream.
570    let mut progress_update = None;
571
572    while let Some(update_batch) = stream.next().await {
573        unprocessed.extend(update_batch.into_iter());
574        // Process messages sequentially based on checkpoint sequence number.
575        // This ensures in-order processing and maintains progress integrity.
576        while unprocessed.remove(&current_checkpoint_number).is_some() {
577            current_checkpoint_number += 1;
578            progress_update = Some(current_checkpoint_number);
579        }
580        // report progress update to executor.
581        if let Some(watermark) = progress_update.take() {
582            executor_progress_sender
583                .send(WorkerPoolStatus::Running((task_name.clone(), watermark)))
584                .await
585                .map_err(|_| IngestionError::Channel("unable to send worker pool progress updates to executor, receiver half closed".into()))?;
586        }
587    }
588    Ok(())
589}