Skip to main content

iota_data_ingestion_core/
worker_pool.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{BTreeSet, HashMap, VecDeque},
7    fmt::Debug,
8    sync::Arc,
9    time::Instant,
10};
11
12use backoff::{ExponentialBackoff, backoff::Backoff};
13use futures::StreamExt;
14use iota_metrics::spawn_monitored_task;
15use iota_types::{
16    full_checkpoint_content::CheckpointData, messages_checkpoint::CheckpointSequenceNumber,
17};
18use tokio::{sync::mpsc, task::JoinHandle};
19use tokio_stream::wrappers::ReceiverStream;
20use tokio_util::sync::CancellationToken;
21use tracing::{info, warn};
22
23use crate::{
24    IngestionError, IngestionResult, Reducer, Worker, executor::MAX_CHECKPOINTS_IN_PROGRESS,
25    reducer::reduce, util::reset_backoff,
26};
27
28type TaskName = String;
29type WorkerID = usize;
30
31/// Represents the possible message types a [`WorkerPool`] can communicate with
32/// external components.
33#[derive(Debug, Clone)]
34pub enum WorkerPoolStatus {
35    /// Message with information (e.g. `(<task-name>,
36    /// checkpoint_sequence_number)`) about the ingestion progress.
37    Running((TaskName, CheckpointSequenceNumber)),
38    /// Message with information (e.g. `<task-name>`) about shutdown status.
39    Shutdown(String),
40}
41
42/// Represents the possible message types a [`Worker`] can communicate with
43/// external components
44#[derive(Debug, Clone, Copy)]
45enum WorkerStatus<M> {
46    /// Message with information (e.g. `(<worker-id>`,
47    /// `checkpoint_sequence_number`, Option<[`Worker::Message`]>) about the
48    /// ingestion progress.
49    ///
50    /// The `Option<M>` is used to indicate that the worker skipped
51    /// processing the checkpoint. Useful for filtered checkpoints where non
52    /// matching checkpoints should not be forwarded to worker. In this case the
53    /// `checkpoint_sequence_number` is needed to track the progress.
54    Running((WorkerID, CheckpointSequenceNumber, Option<M>)),
55    /// Message with information (e.g. `<worker-id>`) about shutdown status.
56    Shutdown(WorkerID),
57}
58
59/// A pool of [`Worker`]'s that process checkpoints concurrently.
60///
61/// This struct manages a collection of workers that process checkpoints in
62/// parallel. It handles checkpoint distribution, progress tracking, and
63/// graceful shutdown. It can optionally use a [`Reducer`] to aggregate and
64/// process worker [`Messages`](Worker::Message).
65///
66/// # Examples
67/// ## Direct Processing (Without Batching)
68/// ```rust,no_run
69/// use std::sync::Arc;
70///
71/// use async_trait::async_trait;
72/// use iota_data_ingestion_core::{Worker, WorkerPool};
73/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
74/// #
75/// # struct DatabaseClient;
76/// #
77/// # impl DatabaseClient {
78/// #     pub fn new() -> Self {
79/// #         Self
80/// #     }
81/// #
82/// #     pub async fn store_transaction(&self,
83/// #         _transactions: &CheckpointTransaction,
84/// #     ) -> Result<(), DatabaseError> {
85/// #         Ok(())
86/// #     }
87/// # }
88/// #
89/// # #[derive(Debug, Clone)]
90/// # struct DatabaseError;
91/// #
92/// # impl std::fmt::Display for DatabaseError {
93/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
94/// #         write!(f, "database error")
95/// #     }
96/// # }
97/// #
98/// # fn extract_transaction(checkpoint: &CheckpointData) -> CheckpointTransaction {
99/// #     checkpoint.transactions.first().unwrap().clone()
100/// # }
101///
102/// struct DirectProcessor {
103///     // generic Database client.
104///     client: Arc<DatabaseClient>,
105/// }
106///
107/// #[async_trait]
108/// impl Worker for DirectProcessor {
109///     type Message = ();
110///     type Error = DatabaseError;
111///
112///     async fn process_checkpoint(
113///         &self,
114///         checkpoint: Arc<CheckpointData>,
115///     ) -> Result<Self::Message, Self::Error> {
116///         // extract a particulat transaction we care about.
117///         let tx: CheckpointTransaction = extract_transaction(checkpoint.as_ref());
118///         // store the transaction in our database of choice.
119///         self.client.store_transaction(&tx).await?;
120///         Ok(())
121///     }
122/// }
123///
124/// // configure worker pool for direct processing.
125/// let processor = DirectProcessor {
126///     client: Arc::new(DatabaseClient::new()),
127/// };
128/// let pool = WorkerPool::new(processor, "direct_processing".into(), 5, Default::default());
129/// ```
130///
131/// ## Batch Processing (With Reducer)
132/// ```rust,no_run
133/// use std::sync::Arc;
134///
135/// use async_trait::async_trait;
136/// use iota_data_ingestion_core::{Reducer, Worker, WorkerPool};
137/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
138/// # struct DatabaseClient;
139/// #
140/// # impl DatabaseClient {
141/// #     pub fn new() -> Self {
142/// #         Self
143/// #     }
144/// #
145/// #     pub async fn store_transactions_batch(&self,
146/// #         _transactions: &Vec<CheckpointTransaction>,
147/// #     ) -> Result<(), DatabaseError> {
148/// #         Ok(())
149/// #     }
150/// # }
151/// #
152/// # #[derive(Debug, Clone)]
153/// # struct DatabaseError;
154/// #
155/// # impl std::fmt::Display for DatabaseError {
156/// #     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
157/// #         write!(f, "database error")
158/// #     }
159/// # }
160///
161/// // worker that accumulates transactions for batch processing.
162/// struct BatchProcessor;
163///
164/// #[async_trait]
165/// impl Worker for BatchProcessor {
166///     type Message = Vec<CheckpointTransaction>;
167///     type Error = DatabaseError;
168///
169///     async fn process_checkpoint(
170///         &self,
171///         checkpoint: Arc<CheckpointData>,
172///     ) -> Result<Self::Message, Self::Error> {
173///         // collect all checkpoint transactions for batch processing.
174///         Ok(checkpoint.transactions.clone())
175///     }
176/// }
177///
178/// // batch reducer for efficient storage.
179/// struct TransactionBatchReducer {
180///     batch_size: usize,
181///     // generic Database client.
182///     client: Arc<DatabaseClient>,
183/// }
184///
185/// #[async_trait]
186/// impl Reducer<BatchProcessor> for TransactionBatchReducer {
187///     async fn commit(&self, batch: &[Vec<CheckpointTransaction>]) -> Result<(), DatabaseError> {
188///         let flattened: Vec<CheckpointTransaction> = batch.iter().flatten().cloned().collect();
189///         // store the transaction batch in the database of choice.
190///         self.client.store_transactions_batch(&flattened).await?;
191///         Ok(())
192///     }
193///
194///     fn should_close_batch(
195///         &self,
196///         batch: &[Vec<CheckpointTransaction>],
197///         _: Option<&Vec<CheckpointTransaction>>,
198///     ) -> bool {
199///         batch.iter().map(|b| b.len()).sum::<usize>() >= self.batch_size
200///     }
201/// }
202///
203/// // configure worker pool with batch processing.
204/// let processor = BatchProcessor;
205/// let reducer = TransactionBatchReducer {
206///     batch_size: 1000,
207///     client: Arc::new(DatabaseClient::new()),
208/// };
209/// let pool = WorkerPool::new_with_reducer(
210///     processor,
211///     "batch_processing".into(),
212///     5,
213///     Default::default(),
214///     reducer,
215/// );
216/// ```
217pub struct WorkerPool<W: Worker> {
218    /// An unique name of the WorkerPool task.
219    pub task_name: String,
220    /// How many instances of the current [`Worker`] to create, more workers are
221    /// created more checkpoints they can process in parallel.
222    concurrency: usize,
223    /// The actual [`Worker`] instance itself.
224    worker: Arc<W>,
225    /// The reducer instance, responsible for batch processing.
226    reducer: Option<Box<dyn Reducer<W>>>,
227    backoff: Arc<ExponentialBackoff>,
228}
229
230impl<W: Worker + 'static> WorkerPool<W> {
231    /// Creates a new `WorkerPool` without a reducer.
232    pub fn new(
233        worker: W,
234        task_name: String,
235        concurrency: usize,
236        backoff: ExponentialBackoff,
237    ) -> Self {
238        Self {
239            task_name,
240            concurrency,
241            worker: Arc::new(worker),
242            reducer: None,
243            backoff: Arc::new(backoff),
244        }
245    }
246
247    /// Creates a new `WorkerPool` with a reducer.
248    pub fn new_with_reducer<R>(
249        worker: W,
250        task_name: String,
251        concurrency: usize,
252        backoff: ExponentialBackoff,
253        reducer: R,
254    ) -> Self
255    where
256        R: Reducer<W> + 'static,
257    {
258        Self {
259            task_name,
260            concurrency,
261            worker: Arc::new(worker),
262            reducer: Some(Box::new(reducer)),
263            backoff: Arc::new(backoff),
264        }
265    }
266
267    /// Runs the worker pool main logic.
268    pub async fn run(
269        mut self,
270        watermark: CheckpointSequenceNumber,
271        mut checkpoint_receiver: mpsc::Receiver<Arc<CheckpointData>>,
272        pool_status_sender: mpsc::Sender<WorkerPoolStatus>,
273        token: CancellationToken,
274    ) {
275        info!(
276            "Starting indexing pipeline {} with concurrency {}. Current watermark is {watermark}.",
277            self.task_name, self.concurrency
278        );
279        // This channel will be used to send progress data from Workers to WorkerPool
280        // main loop.
281        let (progress_sender, mut progress_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
282        // This channel will be used to send Workers progress data from WorkerPool to
283        // watermark tracking task.
284        let (watermark_sender, watermark_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
285        let mut idle: BTreeSet<_> = (0..self.concurrency).collect();
286        let mut checkpoints = VecDeque::new();
287        let mut workers_shutdown_signals = Vec::with_capacity(self.concurrency);
288        let (workers, workers_join_handles) = self.spawn_workers(progress_sender, token.clone());
289        // Spawn a task that tracks checkpoint processing progress. The task:
290        // - Receives (checkpoint_number, message) pairs from workers.
291        // - Maintains checkpoint sequence order.
292        // - Reports progress either:
293        //   * After processing each chunk (simple tracking).
294        //   * After committing batches (with reducer).
295        let watermark_handle = self.spawn_watermark_tracking(
296            watermark,
297            watermark_receiver,
298            pool_status_sender.clone(),
299            token.clone(),
300        );
301        // main worker pool loop.
302        loop {
303            tokio::select! {
304                Some(worker_progress_msg) = progress_receiver.recv() => {
305                    match worker_progress_msg {
306                        WorkerStatus::Running((worker_id, checkpoint_number, message)) => {
307                            idle.insert(worker_id);
308                            // Try to send progress to reducer. If it fails (reducer has exited),
309                            // we just continue - we still need to wait for all workers to shutdown.
310                            let _ = watermark_sender.send((checkpoint_number, message)).await;
311
312                            // By checking if token was not cancelled we ensure that no
313                            // further checkpoints will be sent to the workers.
314                            while !token.is_cancelled() && !checkpoints.is_empty() && !idle.is_empty() {
315                                let checkpoint = checkpoints.pop_front().unwrap();
316                                let worker_id = idle.pop_first().unwrap();
317                                if workers[worker_id].send(checkpoint).await.is_err() {
318                                    // The worker channel closing is a sign we need to exit this inner loop.
319                                    break;
320                                }
321                            }
322                        }
323                        WorkerStatus::Shutdown(worker_id) => {
324                            // Track workers that have initiated shutdown.
325                            workers_shutdown_signals.push(worker_id);
326                        }
327                    }
328                }
329                // Adding an if guard to this branch ensure that no checkpoints
330                // will be sent to workers once the token has been cancelled.
331                Some(checkpoint) = checkpoint_receiver.recv(), if !token.is_cancelled() => {
332                    let sequence_number = checkpoint.checkpoint_summary.sequence_number;
333                    if sequence_number < watermark {
334                        continue;
335                    }
336
337                    if !Self::should_skip_filtered_checkpoint(&checkpoint) {
338                        self.worker
339                            .preprocess_hook(checkpoint.clone())
340                            .map_err(|err| IngestionError::CheckpointHookProcessing(err.to_string()))
341                            .expect("failed to preprocess task");
342                    }
343
344                    if idle.is_empty() {
345                        checkpoints.push_back(checkpoint);
346                    } else {
347                        let worker_id = idle.pop_first().unwrap();
348                        // If worker channel is closed, put the checkpoint back in queue
349                        // and continue - we still need to wait for all worker shutdown signals.
350                        if let Err(send_error) = workers[worker_id].send(checkpoint).await {
351                            checkpoints.push_front(send_error.0);
352                        };
353                    }
354                }
355            }
356            // Once all workers have signaled completion, start the graceful shutdown
357            // process.
358            if workers_shutdown_signals.len() == self.concurrency {
359                break self
360                    .workers_graceful_shutdown(
361                        workers_join_handles,
362                        watermark_handle,
363                        pool_status_sender,
364                        watermark_sender,
365                    )
366                    .await;
367            }
368        }
369    }
370
371    /// Spawn workers based on `self.concurrency` to process checkpoints
372    /// in parallel.
373    fn spawn_workers(
374        &self,
375        progress_sender: mpsc::Sender<WorkerStatus<W::Message>>,
376        token: CancellationToken,
377    ) -> (Vec<mpsc::Sender<Arc<CheckpointData>>>, Vec<JoinHandle<()>>) {
378        let mut worker_senders = Vec::with_capacity(self.concurrency);
379        let mut workers_join_handles = Vec::with_capacity(self.concurrency);
380
381        for worker_id in 0..self.concurrency {
382            let (worker_sender, mut worker_recv) =
383                mpsc::channel::<Arc<CheckpointData>>(MAX_CHECKPOINTS_IN_PROGRESS);
384            let cloned_progress_sender = progress_sender.clone();
385            let task_name = self.task_name.clone();
386            worker_senders.push(worker_sender);
387
388            let token = token.clone();
389
390            let worker = self.worker.clone();
391            let backoff = self.backoff.clone();
392            let join_handle = spawn_monitored_task!(async move {
393                loop {
394                    tokio::select! {
395                        // Once token is cancelled, notify worker's shutdown to the main loop
396                        _ = token.cancelled() => {
397                            _ = cloned_progress_sender.send(WorkerStatus::Shutdown(worker_id)).await;
398                            break
399                        },
400                        Some(checkpoint) = worker_recv.recv() => {
401                            let sequence_number = checkpoint.checkpoint_summary.sequence_number;
402                            info!("received checkpoint for processing {sequence_number} for workflow {task_name}", );
403                            let start_time = Instant::now();
404                            let status = Self::process_checkpoint_with_retry(worker_id, &worker, checkpoint, reset_backoff(&backoff), &token).await;
405                            if matches!(status, WorkerStatus::Running((_,_, None))) {
406                                info!("checkpoint {sequence_number} for workflow {task_name} filtered out");
407                            }
408                            let trigger_shutdown = matches!(status, WorkerStatus::Shutdown(_));
409                            if cloned_progress_sender.send(status).await.is_err() || trigger_shutdown {
410                                break;
411                            }
412                            info!(
413                                "finished checkpoint processing {sequence_number} for workflow {task_name} in {:?}",
414                                start_time.elapsed()
415                            );
416                        }
417                    }
418                }
419            });
420            // Keep all join handles to ensure all workers are terminated before exiting
421            workers_join_handles.push(join_handle);
422        }
423        (worker_senders, workers_join_handles)
424    }
425
426    /// Returns `true` if the checkpoint was entirely stripped of its
427    /// transactions by a server-side filter, indicating a filtered-out
428    /// checkpoint.
429    ///
430    /// The fullnode's gRPC `stream_checkpoints` endpoint applies configured
431    /// `TransactionFilter`s to the expanded `transactions` payload but
432    /// leaves `checkpoint_contents` (the list of all transaction digests in
433    /// the original checkpoint) completely untouched.
434    fn should_skip_filtered_checkpoint(checkpoint: &CheckpointData) -> bool {
435        !checkpoint.checkpoint_contents.inner().is_empty() && checkpoint.transactions.is_empty()
436    }
437
438    /// Attempts to process a checkpoint with exponential backoff retries on
439    /// failure.
440    ///
441    /// This function repeatedly calls the
442    /// [`process_checkpoint`](Worker::process_checkpoint) method of the
443    /// provided [`Worker`] until either:
444    /// - The checkpoint processing succeeds, returning `WorkerStatus::Running`
445    ///   with the processed message.
446    /// - A cancellation signal is received via the [`CancellationToken`],
447    ///   returning `WorkerStatus::Shutdown(<worker-id>)`.
448    /// - All retry attempts are exhausted within backoff's maximum elapsed
449    ///   time, causing a panic.
450    ///
451    /// # Retry Mechanism:
452    /// - Uses [`ExponentialBackoff`](backoff::ExponentialBackoff) to introduce
453    ///   increasing delays between retry attempts.
454    /// - Checks for cancellation both before and after each processing attempt.
455    /// - If a cancellation signal is received during a backoff delay, the
456    ///   function exits immediately with `WorkerStatus::Shutdown(<worker-id>)`.
457    ///
458    /// # Panics:
459    /// - If all retry attempts are exhausted within the backoff's maximum
460    ///   elapsed time, indicating a persistent failure.
461    async fn process_checkpoint_with_retry(
462        worker_id: WorkerID,
463        worker: &W,
464        checkpoint: Arc<CheckpointData>,
465        mut backoff: ExponentialBackoff,
466        token: &CancellationToken,
467    ) -> WorkerStatus<W::Message> {
468        let sequence_number = checkpoint.checkpoint_summary.sequence_number;
469
470        if Self::should_skip_filtered_checkpoint(&checkpoint) {
471            return if token.is_cancelled() {
472                WorkerStatus::Shutdown(worker_id)
473            } else {
474                WorkerStatus::Running((worker_id, sequence_number, None))
475            };
476        }
477
478        loop {
479            // check for cancellation before attempting processing.
480            if token.is_cancelled() {
481                return WorkerStatus::Shutdown(worker_id);
482            }
483
484            // attempt to process checkpoint.
485            match worker.process_checkpoint(checkpoint.clone()).await {
486                Ok(message) => {
487                    return WorkerStatus::Running((worker_id, sequence_number, Some(message)));
488                }
489                Err(err) => {
490                    let err = IngestionError::CheckpointProcessing(err.to_string());
491                    warn!(
492                        "transient worker execution error {err:?} for checkpoint {sequence_number}"
493                    );
494                    // check for cancellation after failed processing.
495                    if token.is_cancelled() {
496                        return WorkerStatus::Shutdown(worker_id);
497                    }
498                }
499            }
500            // get next backoff duration or panic if max retries exceeded.
501            let duration = backoff
502                .next_backoff()
503                .expect("max elapsed time exceeded: checkpoint processing failed for checkpoint {sequence_number}");
504            // if cancellation occurs during backoff wait, exit early with Shutdown.
505            // Otherwise (if timeout expires), continue with the next retry attempt.
506            if tokio::time::timeout(duration, token.cancelled())
507                .await
508                .is_ok()
509            {
510                return WorkerStatus::Shutdown(worker_id);
511            }
512        }
513    }
514
515    /// Spawns a task that tracks the progress of checkpoint processing,
516    /// optionally with message reduction.
517    ///
518    /// This function spawns one of two types of tracking tasks:
519    ///
520    /// 1. Simple Watermark Tracking (when reducer = None):
521    ///    - Reports watermark after processing each chunk.
522    ///
523    /// 2. Batch Processing (when reducer = Some):
524    ///    - Reports progress only after successful batch commits.
525    ///    - A batch is committed based on
526    ///      [`should_close_batch`](Reducer::should_close_batch) policy.
527    fn spawn_watermark_tracking(
528        &mut self,
529        watermark: CheckpointSequenceNumber,
530        watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, Option<W::Message>)>,
531        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
532        token: CancellationToken,
533    ) -> JoinHandle<Result<(), IngestionError>> {
534        let task_name = self.task_name.clone();
535        let backoff = self.backoff.clone();
536        if let Some(reducer) = self.reducer.take() {
537            return spawn_monitored_task!(reduce::<W>(
538                task_name,
539                watermark,
540                watermark_receiver,
541                executor_progress_sender,
542                reducer,
543                backoff,
544                token
545            ));
546        };
547        spawn_monitored_task!(simple_watermark_tracking::<W>(
548            task_name,
549            watermark,
550            watermark_receiver,
551            executor_progress_sender
552        ))
553    }
554
555    /// Start the workers graceful shutdown.
556    ///
557    /// - Awaits all worker handles.
558    /// - Awaits the reducer handle.
559    /// - Send `WorkerPoolStatus::Shutdown(<task-name>)` message notifying
560    ///   external components that Worker Pool has been shutdown.
561    async fn workers_graceful_shutdown(
562        &self,
563        workers_join_handles: Vec<JoinHandle<()>>,
564        watermark_handle: JoinHandle<Result<(), IngestionError>>,
565        executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
566        watermark_sender: mpsc::Sender<(CheckpointSequenceNumber, Option<<W as Worker>::Message>)>,
567    ) {
568        for worker in workers_join_handles {
569            _ = worker
570                .await
571                .inspect_err(|err| tracing::error!("worker task panicked: {err}"));
572        }
573        // by dropping the sender we make sure that the stream will be closed and the
574        // watermark tracker task will exit its loop.
575        drop(watermark_sender);
576        _ = watermark_handle
577            .await
578            .inspect_err(|err| tracing::error!("watermark task panicked: {err}"));
579        _ = executor_progress_sender
580            .send(WorkerPoolStatus::Shutdown(self.task_name.clone()))
581            .await;
582        tracing::info!("Worker pool `{}` terminated gracefully", self.task_name);
583    }
584}
585
586/// Tracks checkpoint progress without reduction logic.
587///
588/// This function maintains a watermark of processed checkpoints by worker:
589/// 1. Receiving batches of progress status from workers.
590/// 2. Processing them in sequence order.
591/// 3. Reporting progress to the executor after each chunk from the stream.
592async fn simple_watermark_tracking<W: Worker>(
593    task_name: String,
594    mut current_checkpoint_number: CheckpointSequenceNumber,
595    watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, Option<W::Message>)>,
596    executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
597) -> IngestionResult<()> {
598    // convert to a stream of MAX_CHECKPOINTS_IN_PROGRESS size. This way, each
599    // iteration of the loop will process all ready messages.
600    let mut stream =
601        ReceiverStream::new(watermark_receiver).ready_chunks(MAX_CHECKPOINTS_IN_PROGRESS);
602    // store unprocessed progress messages from workers.
603    let mut unprocessed = HashMap::new();
604    // track the next unprocessed checkpoint number for reporting progress
605    // after each chunk of messages is received from the stream.
606    let mut progress_update = None;
607
608    while let Some(update_batch) = stream.next().await {
609        unprocessed.extend(update_batch);
610        // Process messages sequentially based on checkpoint sequence number.
611        // This ensures in-order processing and maintains progress integrity.
612        while unprocessed.remove(&current_checkpoint_number).is_some() {
613            current_checkpoint_number += 1;
614            progress_update = Some(current_checkpoint_number);
615        }
616        // report progress update to executor.
617        if let Some(watermark) = progress_update.take() {
618            executor_progress_sender
619                .send(WorkerPoolStatus::Running((task_name.clone(), watermark)))
620                .await
621                .map_err(|_| IngestionError::Channel("unable to send worker pool progress updates to executor, receiver half closed".into()))?;
622        }
623    }
624    Ok(())
625}