iota_data_ingestion_core/
worker_pool.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{BTreeSet, HashMap, VecDeque},
7    fmt::Debug,
8    sync::Arc,
9    time::Instant,
10};
11
12use backoff::{ExponentialBackoff, backoff::Backoff};
13use futures::StreamExt;
14use iota_metrics::spawn_monitored_task;
15use iota_rest_api::CheckpointData;
16use iota_types::messages_checkpoint::CheckpointSequenceNumber;
17use tokio::{sync::mpsc, task::JoinHandle};
18use tokio_stream::wrappers::ReceiverStream;
19use tokio_util::sync::CancellationToken;
20use tracing::{info, warn};
21
22use crate::{
23    IngestionError, IngestionResult, Reducer, Worker, executor::MAX_CHECKPOINTS_IN_PROGRESS,
24    reducer::reduce, util::reset_backoff,
25};
26
27type TaskName = String;
28type WorkerID = usize;
29
30/// Represents the possible message types a [`WorkerPool`] can communicate with
31/// external components.
32#[derive(Debug, Clone)]
33pub enum WorkerPoolStatus {
34    /// Message with information (e.g. `(<task-name>,
35    /// checkpoint_sequence_number)`) about the ingestion progress.
36    Running((TaskName, CheckpointSequenceNumber)),
37    /// Message with information (e.g. `<task-name>`) about shutdown status.
38    Shutdown(String),
39}
40
41/// Represents the possible message types a [`Worker`] can communicate with
42/// external components
43#[derive(Debug, Clone, Copy)]
44enum WorkerStatus<M> {
45    /// Message with information (e.g. `(<worker-id>`,
46    /// `checkpoint_sequence_number`, [`Worker::Message`]) about the ingestion
47    /// progress.
48    Running((WorkerID, CheckpointSequenceNumber, M)),
49    /// Message with information (e.g. `<worker-id>`) about shutdown status.
50    Shutdown(WorkerID),
51}
52
53/// A pool of [`Worker`]'s that process checkpoints concurrently.
54///
55/// This struct manages a collection of workers that process checkpoints in
56/// parallel. It handles checkpoint distribution, progress tracking, and
57/// graceful shutdown. It can optionally use a [`Reducer`] to aggregate and
58/// process worker [`Messages`](Worker::Message).
59///
60/// # Examples
61/// ## Direct Processing (Without Batching)
62/// ```rust,no_run
63/// use std::sync::Arc;
64///
65/// use async_trait::async_trait;
66/// use iota_data_ingestion_core::{Worker, WorkerPool};
67/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
68/// #
69/// # struct DatabaseClient;
70/// #
71/// # impl DatabaseClient {
72/// #     pub fn new() -> Self {
73/// #         Self
74/// #     }
75/// #
76/// #     pub async fn store_transaction(&self,
77/// #         _transactions: &CheckpointTransaction,
78/// #     ) -> Result<(), DatabaseError> {
79/// #         Ok(())
80/// #     }
81/// # }
82/// #
83/// # #[derive(Debug, Clone)]
84/// # struct DatabaseError;
85/// #
86/// # impl std::fmt::Display for DatabaseError {
87/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
88/// #         write!(f, "database error")
89/// #     }
90/// # }
91/// #
92/// # fn extract_transaction(checkpoint: &CheckpointData) -> CheckpointTransaction {
93/// #     checkpoint.transactions.first().unwrap().clone()
94/// # }
95///
96/// struct DirectProcessor {
97///     // generic Database client.
98///     client: Arc<DatabaseClient>,
99/// }
100///
101/// #[async_trait]
102/// impl Worker for DirectProcessor {
103///     type Message = ();
104///     type Error = DatabaseError;
105///
106///     async fn process_checkpoint(
107///         &self,
108///         checkpoint: Arc<CheckpointData>,
109///     ) -> Result<Self::Message, Self::Error> {
110///         // extract a particulat transaction we care about.
111///         let tx: CheckpointTransaction = extract_transaction(checkpoint.as_ref());
112///         // store the transaction in our database of choice.
113///         self.client.store_transaction(&tx).await?;
114///         Ok(())
115///     }
116/// }
117///
118/// // configure worker pool for direct processing.
119/// let processor = DirectProcessor {
120///     client: Arc::new(DatabaseClient::new()),
121/// };
122/// let pool = WorkerPool::new(processor, "direct_processing".into(), 5, Default::default());
123/// ```
124///
125/// ## Batch Processing (With Reducer)
126/// ```rust,no_run
127/// use std::sync::Arc;
128///
129/// use async_trait::async_trait;
130/// use iota_data_ingestion_core::{Reducer, Worker, WorkerPool};
131/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
132/// # struct DatabaseClient;
133/// #
134/// # impl DatabaseClient {
135/// #     pub fn new() -> Self {
136/// #         Self
137/// #     }
138/// #
139/// #     pub async fn store_transactions_batch(&self,
140/// #         _transactions: &Vec<CheckpointTransaction>,
141/// #     ) -> Result<(), DatabaseError> {
142/// #         Ok(())
143/// #     }
144/// # }
145/// #
146/// # #[derive(Debug, Clone)]
147/// # struct DatabaseError;
148/// #
149/// # impl std::fmt::Display for DatabaseError {
150/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
151/// #         write!(f, "database error")
152/// #     }
153/// # }
154///
155/// // worker that accumulates transactions for batch processing.
156/// struct BatchProcessor;
157///
158/// #[async_trait]
159/// impl Worker for BatchProcessor {
160///     type Message = Vec<CheckpointTransaction>;
161///     type Error = DatabaseError;
162///
163///     async fn process_checkpoint(
164///         &self,
165///         checkpoint: Arc<CheckpointData>,
166///     ) -> Result<Self::Message, Self::Error> {
167///         // collect all checkpoint transactions for batch processing.
168///         Ok(checkpoint.transactions.clone())
169///     }
170/// }
171///
172/// // batch reducer for efficient storage.
173/// struct TransactionBatchReducer {
174///     batch_size: usize,
175///     // generic Database client.
176///     client: Arc<DatabaseClient>,
177/// }
178///
179/// #[async_trait]
180/// impl Reducer<BatchProcessor> for TransactionBatchReducer {
181///     async fn commit(&self, batch: &[Vec<CheckpointTransaction>]) -> Result<(), DatabaseError> {
182///         let flattened: Vec<CheckpointTransaction> = batch.iter().flatten().cloned().collect();
183///         // store the transaction batch in the database of choice.
184///         self.client.store_transactions_batch(&flattened).await?;
185///         Ok(())
186///     }
187///
188///     fn should_close_batch(
189///         &self,
190///         batch: &[Vec<CheckpointTransaction>],
191///         _: Option<&Vec<CheckpointTransaction>>,
192///     ) -> bool {
193///         batch.iter().map(|b| b.len()).sum::<usize>() >= self.batch_size
194///     }
195/// }
196///
197/// // configure worker pool with batch processing.
198/// let processor = BatchProcessor;
199/// let reducer = TransactionBatchReducer {
200///     batch_size: 1000,
201///     client: Arc::new(DatabaseClient::new()),
202/// };
203/// let pool = WorkerPool::new_with_reducer(
204///     processor,
205///     "batch_processing".into(),
206///     5,
207///     Default::default(),
208///     reducer,
209/// );
210/// ```
211pub struct WorkerPool<W: Worker> {
212    /// An unique name of the WorkerPool task.
213    pub task_name: String,
214    /// How many instances of the current [`Worker`] to create, more workers are
215    /// created more checkpoints they can process in parallel.
216    concurrency: usize,
217    /// The actual [`Worker`] instance itself.
218    worker: Arc<W>,
219    /// The reducer instance, responsible for batch processing.
220    reducer: Option<Box<dyn Reducer<W>>>,
221    backoff: Arc<ExponentialBackoff>,
222}
223
224impl<W: Worker + 'static> WorkerPool<W> {
225    /// Creates a new `WorkerPool` without a reducer.
226    pub fn new(
227        worker: W,
228        task_name: String,
229        concurrency: usize,
230        backoff: ExponentialBackoff,
231    ) -> Self {
232        Self {
233            task_name,
234            concurrency,
235            worker: Arc::new(worker),
236            reducer: None,
237            backoff: Arc::new(backoff),
238        }
239    }
240
241    /// Creates a new `WorkerPool` with a reducer.
242    pub fn new_with_reducer<R>(
243        worker: W,
244        task_name: String,
245        concurrency: usize,
246        backoff: ExponentialBackoff,
247        reducer: R,
248    ) -> Self
249    where
250        R: Reducer<W> + 'static,
251    {
252        Self {
253            task_name,
254            concurrency,
255            worker: Arc::new(worker),
256            reducer: Some(Box::new(reducer)),
257            backoff: Arc::new(backoff),
258        }
259    }
260
261    /// Runs the worker pool main logic.
262    pub async fn run(
263        mut self,
264        watermark: CheckpointSequenceNumber,
265        mut checkpoint_receiver: mpsc::Receiver<Arc<CheckpointData>>,
266        pool_status_sender: mpsc::Sender<WorkerPoolStatus>,
267        token: CancellationToken,
268    ) {
269        info!(
270            "Starting indexing pipeline {} with concurrency {}. Current watermark is {watermark}.",
271            self.task_name, self.concurrency
272        );
273        // This channel will be used to send progress data from Workers to WorkerPool
274        // mian loop.
275        let (progress_sender, mut progress_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
276        // This channel will be used to send Workers progress data from WorkerPool to
277        // watermark tracking task.
278        let (watermark_sender, watermark_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
279        let mut idle: BTreeSet<_> = (0..self.concurrency).collect();
280        let mut checkpoints = VecDeque::new();
281        let mut workers_shutdown_signals = vec![];
282        let (workers, workers_join_handles) = self.spawn_workers(progress_sender, token.clone());
283        // Spawn a task that tracks checkpoint processing progress. The task:
284        // - Receives (checkpoint_number, message) pairs from workers.
285        // - Maintains checkpoint sequence order.
286        // - Reports progress either:
287        //   * After processing each chunk (simple tracking).
288        //   * After committing batches (with reducer).
289        let watermark_handle = self.spawn_watermark_tracking(
290            watermark,
291            watermark_receiver,
292            pool_status_sender.clone(),
293            token.clone(),
294        );
295        // main worker pool loop.
296        loop {
297            tokio::select! {
298                Some(worker_progress_msg) = progress_receiver.recv() => {
299                    match worker_progress_msg {
300                        WorkerStatus::Running((worker_id, checkpoint_number, message)) => {
301                            idle.insert(worker_id);
302                            if watermark_sender.send((checkpoint_number, message)).await.is_err() {
303                                break;
304                            }
305                            // By checking if token was not cancelled we ensure that no
306                            // further checkpoints will be sent to the workers.
307                            while !token.is_cancelled() && !checkpoints.is_empty() && !idle.is_empty() {
308                                let checkpoint = checkpoints.pop_front().unwrap();
309                                let worker_id = idle.pop_first().unwrap();
310                                if workers[worker_id].send(checkpoint).await.is_err() {
311                                    // The worker channel closing is a sign we need to exit this loop.
312                                    break;
313                                }
314                            }
315                        }
316                        WorkerStatus::Shutdown(worker_id) => {
317                            // Track workers that have initiated shutdown.
318                            workers_shutdown_signals.push(worker_id);
319                        }
320                    }
321                }
322                // Adding an if guard to this branch ensure that no checkpoints
323                // will be sent to workers once the token has been cancelled.
324                Some(checkpoint) = checkpoint_receiver.recv(), if !token.is_cancelled() => {
325                    let sequence_number = checkpoint.checkpoint_summary.sequence_number;
326                    if sequence_number < watermark {
327                        continue;
328                    }
329                    self.worker
330                        .preprocess_hook(checkpoint.clone())
331                        .map_err(|err| IngestionError::CheckpointHookProcessing(err.to_string()))
332                        .expect("failed to preprocess task");
333                    if idle.is_empty() {
334                        checkpoints.push_back(checkpoint);
335                    } else {
336                        let worker_id = idle.pop_first().unwrap();
337                        if workers[worker_id].send(checkpoint).await.is_err() {
338                            // The worker channel closing is a sign we need to exit this loop.
339                            break;
340                        };
341                    }
342                }
343            }
344            // Once all workers have signaled completion, start the graceful shutdown
345            // process.
346            if workers_shutdown_signals.len() == self.concurrency {
347                break self
348                    .workers_graceful_shutdown(
349                        workers_join_handles,
350                        watermark_handle,
351                        pool_status_sender,
352                        watermark_sender,
353                    )
354                    .await;
355            }
356        }
357    }
358
359    /// Spawn workers based on `self.concurrency` to process checkpoints
360    /// in parallel.
361    fn spawn_workers(
362        &self,
363        progress_sender: mpsc::Sender<WorkerStatus<W::Message>>,
364        token: CancellationToken,
365    ) -> (Vec<mpsc::Sender<Arc<CheckpointData>>>, Vec<JoinHandle<()>>) {
366        let mut worker_senders = Vec::with_capacity(self.concurrency);
367        let mut workers_join_handles = Vec::with_capacity(self.concurrency);
368
369        for worker_id in 0..self.concurrency {
370            let (worker_sender, mut worker_recv) =
371                mpsc::channel::<Arc<CheckpointData>>(MAX_CHECKPOINTS_IN_PROGRESS);
372            let cloned_progress_sender = progress_sender.clone();
373            let task_name = self.task_name.clone();
374            worker_senders.push(worker_sender);
375
376            let token = token.clone();
377
378            let worker = self.worker.clone();
379            let backoff = self.backoff.clone();
380            let join_handle = spawn_monitored_task!(async move {
381                loop {
382                    tokio::select! {
383                        // Once token is cancelled, notify worker's shutdown to the main loop
384                        _ = token.cancelled() => {
385                            _ = cloned_progress_sender.send(WorkerStatus::Shutdown(worker_id)).await;
386                            break
387                        },
388                        Some(checkpoint) = worker_recv.recv() => {
389                            let sequence_number = checkpoint.checkpoint_summary.sequence_number;
390                            info!("received checkpoint for processing {} for workflow {}", sequence_number, task_name);
391                            let start_time = Instant::now();
392                            let status = Self::process_checkpoint_with_retry(worker_id, &worker, checkpoint, reset_backoff(&backoff), &token).await;
393                            let trigger_shutdown = matches!(status, WorkerStatus::Shutdown(_));
394                            if cloned_progress_sender.send(status).await.is_err() || trigger_shutdown {
395                                break;
396                            }
397                            info!(
398                                "finished checkpoint processing {sequence_number} for workflow {task_name} in {:?}",
399                                start_time.elapsed()
400                            );
401                        }
402                    }
403                }
404            });
405            // Keep all join handles to ensure all workers are terminated before exiting
406            workers_join_handles.push(join_handle);
407        }
408        (worker_senders, workers_join_handles)
409    }
410
411    /// Attempts to process a checkpoint with exponential backoff retries on
412    /// failure.
413    ///
414    /// This function repeatedly calls the
415    /// [`process_checkpoint`](Worker::process_checkpoint) method of the
416    /// provided [`Worker`] until either:
417    /// - The checkpoint processing succeeds, returning `WorkerStatus::Running`
418    ///   with the processed message.
419    /// - A cancellation signal is received via the [`CancellationToken`],
420    ///   returning `WorkerStatus::Shutdown(<worker-id>)`.
421    /// - All retry attempts are exhausted within backoff's maximum elapsed
422    ///   time, causing a panic.
423    ///
424    /// # Retry Mechanism:
425    /// - Uses [`ExponentialBackoff`](backoff::ExponentialBackoff) to introduce
426    ///   increasing delays between retry attempts.
427    /// - Checks for cancellation both before and after each processing attempt.
428    /// - If a cancellation signal is received during a backoff delay, the
429    ///   function exits immediately with `WorkerStatus::Shutdown(<worker-id>)`.
430    ///
431    /// # Panics:
432    /// - If all retry attempts are exhausted within the backoff's maximum
433    ///   elapsed time, indicating a persistent failure.
434    async fn process_checkpoint_with_retry(
435        worker_id: WorkerID,
436        worker: &W,
437        checkpoint: Arc<CheckpointData>,
438        mut backoff: ExponentialBackoff,
439        token: &CancellationToken,
440    ) -> WorkerStatus<W::Message> {
441        let sequence_number = checkpoint.checkpoint_summary.sequence_number;
442        loop {
443            // check for cancellation before attempting processing.
444            if token.is_cancelled() {
445                return WorkerStatus::Shutdown(worker_id);
446            }
447            // attempt to process checkpoint.
448            match worker.process_checkpoint(checkpoint.clone()).await {
449                Ok(message) => return WorkerStatus::Running((worker_id, sequence_number, message)),
450                Err(err) => {
451                    let err = IngestionError::CheckpointProcessing(err.to_string());
452                    warn!(
453                        "transient worker execution error {err:?} for checkpoint {sequence_number}"
454                    );
455                    // check for cancellation after failed processing.
456                    if token.is_cancelled() {
457                        return WorkerStatus::Shutdown(worker_id);
458                    }
459                }
460            }
461            // get next backoff duration or panic if max retries exceeded.
462            let duration = backoff
463                .next_backoff()
464                .expect("max elapsed time exceeded: checkpoint processing failed for checkpoint");
465            // if cancellation occurs during backoff wait, exit early with Shutdown.
466            // Otherwise (if timeout expires), continue with the next retry attempt.
467            if tokio::time::timeout(duration, token.cancelled())
468                .await
469                .is_ok()
470            {
471                return WorkerStatus::Shutdown(worker_id);
472            }
473        }
474    }
475
476    /// Spawns a task that tracks the progress of checkpoint processing,
477    /// optionally with message reduction.
478    ///
479    /// This function spawns one of two types of tracking tasks:
480    ///
481    /// 1. Simple Watermark Tracking (when reducer = None):
482    ///    - Reports watermark after processing each chunk.
483    ///
484    /// 2. Batch Processing (when reducer = Some):
485    ///    - Reports progress only after successful batch commits.
486    ///    - A batch is committed based on
487    ///      [`should_close_batch`](Reducer::should_close_batch) policy.
488    fn spawn_watermark_tracking(
489        &mut self,
490        watermark: CheckpointSequenceNumber,
491        watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
492        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
493        token: CancellationToken,
494    ) -> JoinHandle<Result<(), IngestionError>> {
495        let task_name = self.task_name.clone();
496        let backoff = self.backoff.clone();
497        if let Some(reducer) = self.reducer.take() {
498            return spawn_monitored_task!(reduce::<W>(
499                task_name,
500                watermark,
501                watermark_receiver,
502                executor_progress_sender,
503                reducer,
504                backoff,
505                token
506            ));
507        };
508        spawn_monitored_task!(simple_watermark_tracking::<W>(
509            task_name,
510            watermark,
511            watermark_receiver,
512            executor_progress_sender
513        ))
514    }
515
516    /// Start the workers graceful shutdown.
517    ///
518    /// - Awaits all worker handles.
519    /// - Awaits the reducer handle.
520    /// - Send `WorkerPoolStatus::Shutdown(<task-name>)` message notifying
521    ///   external components that Worker Pool has been shutdown.
522    async fn workers_graceful_shutdown(
523        &self,
524        workers_join_handles: Vec<JoinHandle<()>>,
525        watermark_handle: JoinHandle<Result<(), IngestionError>>,
526        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
527        watermark_sender: mpsc::Sender<(u64, <W as Worker>::Message)>,
528    ) {
529        for worker in workers_join_handles {
530            _ = worker
531                .await
532                .inspect_err(|err| tracing::error!("worker task panicked: {err}"));
533        }
534        // by dropping the sender we make sure that the stream will be closed and the
535        // watermark tracker task will exit its loop.
536        drop(watermark_sender);
537        _ = watermark_handle
538            .await
539            .inspect_err(|err| tracing::error!("watermark task panicked: {err}"));
540        _ = executor_progress_sender
541            .send(WorkerPoolStatus::Shutdown(self.task_name.clone()))
542            .await;
543        tracing::info!("Worker pool `{}` terminated gracefully", self.task_name);
544    }
545}
546
547/// Tracks checkpoint progress without reduction logic.
548///
549/// This function maintains a watermark of processed checkpoints by worker:
550/// 1. Receiving batches of progress status from workers.
551/// 2. Processing them in sequence order.
552/// 3. Reporting progress to the executor after each chunk from the stream.
553async fn simple_watermark_tracking<W: Worker>(
554    task_name: String,
555    mut current_checkpoint_number: CheckpointSequenceNumber,
556    watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
557    executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
558) -> IngestionResult<()> {
559    // convert to a stream of MAX_CHECKPOINTS_IN_PROGRESS size. This way, each
560    // iteration of the loop will process all ready messages.
561    let mut stream =
562        ReceiverStream::new(watermark_receiver).ready_chunks(MAX_CHECKPOINTS_IN_PROGRESS);
563    // store unprocessed progress messages from workers.
564    let mut unprocessed = HashMap::new();
565    // track the next unprocessed checkpoint number for reporting progress
566    // after each chunk of messages is received from the stream.
567    let mut progress_update = None;
568
569    while let Some(update_batch) = stream.next().await {
570        unprocessed.extend(update_batch.into_iter());
571        // Process messages sequentially based on checkpoint sequence number.
572        // This ensures in-order processing and maintains progress integrity.
573        while unprocessed.remove(&current_checkpoint_number).is_some() {
574            current_checkpoint_number += 1;
575            progress_update = Some(current_checkpoint_number);
576        }
577        // report progress update to executor.
578        if let Some(watermark) = progress_update.take() {
579            executor_progress_sender
580                .send(WorkerPoolStatus::Running((task_name.clone(), watermark)))
581                .await
582                .map_err(|_| IngestionError::Channel("unable to send worker pool progress updates to executor, receiver half closed".into()))?;
583        }
584    }
585    Ok(())
586}