iota_data_ingestion_core/worker_pool.rs
1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6 collections::{BTreeSet, HashMap, VecDeque},
7 fmt::Debug,
8 sync::Arc,
9 time::Instant,
10};
11
12use backoff::{ExponentialBackoff, backoff::Backoff};
13use futures::StreamExt;
14use iota_metrics::spawn_monitored_task;
15use iota_rest_api::CheckpointData;
16use iota_types::messages_checkpoint::CheckpointSequenceNumber;
17use tokio::{sync::mpsc, task::JoinHandle};
18use tokio_stream::wrappers::ReceiverStream;
19use tokio_util::sync::CancellationToken;
20use tracing::{info, warn};
21
22use crate::{
23 IngestionError, IngestionResult, Reducer, Worker, executor::MAX_CHECKPOINTS_IN_PROGRESS,
24 reducer::reduce, util::reset_backoff,
25};
26
27type TaskName = String;
28type WorkerID = usize;
29
30/// Represents the possible message types a [`WorkerPool`] can communicate with
31/// external components.
32#[derive(Debug, Clone)]
33pub enum WorkerPoolStatus {
34 /// Message with information (e.g. `(<task-name>,
35 /// checkpoint_sequence_number)`) about the ingestion progress.
36 Running((TaskName, CheckpointSequenceNumber)),
37 /// Message with information (e.g. `<task-name>`) about shutdown status.
38 Shutdown(String),
39}
40
41/// Represents the possible message types a [`Worker`] can communicate with
42/// external components
43#[derive(Debug, Clone, Copy)]
44enum WorkerStatus<M> {
45 /// Message with information (e.g. `(<worker-id>`,
46 /// `checkpoint_sequence_number`, [`Worker::Message`]) about the ingestion
47 /// progress.
48 Running((WorkerID, CheckpointSequenceNumber, M)),
49 /// Message with information (e.g. `<worker-id>`) about shutdown status.
50 Shutdown(WorkerID),
51}
52
53/// A pool of [`Worker`]'s that process checkpoints concurrently.
54///
55/// This struct manages a collection of workers that process checkpoints in
56/// parallel. It handles checkpoint distribution, progress tracking, and
57/// graceful shutdown. It can optionally use a [`Reducer`] to aggregate and
58/// process worker [`Messages`](Worker::Message).
59///
60/// # Examples
61/// ## Direct Processing (Without Batching)
62/// ```rust,no_run
63/// use std::sync::Arc;
64///
65/// use async_trait::async_trait;
66/// use iota_data_ingestion_core::{Worker, WorkerPool};
67/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
68/// #
69/// # struct DatabaseClient;
70/// #
71/// # impl DatabaseClient {
72/// # pub fn new() -> Self {
73/// # Self
74/// # }
75/// #
76/// # pub async fn store_transaction(&self,
77/// # _transactions: &CheckpointTransaction,
78/// # ) -> Result<(), DatabaseError> {
79/// # Ok(())
80/// # }
81/// # }
82/// #
83/// # #[derive(Debug, Clone)]
84/// # struct DatabaseError;
85/// #
86/// # impl std::fmt::Display for DatabaseError {
87/// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
88/// # write!(f, "database error")
89/// # }
90/// # }
91/// #
92/// # fn extract_transaction(checkpoint: &CheckpointData) -> CheckpointTransaction {
93/// # checkpoint.transactions.first().unwrap().clone()
94/// # }
95///
96/// struct DirectProcessor {
97/// // generic Database client.
98/// client: Arc<DatabaseClient>,
99/// }
100///
101/// #[async_trait]
102/// impl Worker for DirectProcessor {
103/// type Message = ();
104/// type Error = DatabaseError;
105///
106/// async fn process_checkpoint(
107/// &self,
108/// checkpoint: Arc<CheckpointData>,
109/// ) -> Result<Self::Message, Self::Error> {
110/// // extract a particulat transaction we care about.
111/// let tx: CheckpointTransaction = extract_transaction(checkpoint.as_ref());
112/// // store the transaction in our database of choice.
113/// self.client.store_transaction(&tx).await?;
114/// Ok(())
115/// }
116/// }
117///
118/// // configure worker pool for direct processing.
119/// let processor = DirectProcessor {
120/// client: Arc::new(DatabaseClient::new()),
121/// };
122/// let pool = WorkerPool::new(processor, "direct_processing".into(), 5, Default::default());
123/// ```
124///
125/// ## Batch Processing (With Reducer)
126/// ```rust,no_run
127/// use std::sync::Arc;
128///
129/// use async_trait::async_trait;
130/// use iota_data_ingestion_core::{Reducer, Worker, WorkerPool};
131/// use iota_types::full_checkpoint_content::{CheckpointData, CheckpointTransaction};
132/// # struct DatabaseClient;
133/// #
134/// # impl DatabaseClient {
135/// # pub fn new() -> Self {
136/// # Self
137/// # }
138/// #
139/// # pub async fn store_transactions_batch(&self,
140/// # _transactions: &Vec<CheckpointTransaction>,
141/// # ) -> Result<(), DatabaseError> {
142/// # Ok(())
143/// # }
144/// # }
145/// #
146/// # #[derive(Debug, Clone)]
147/// # struct DatabaseError;
148/// #
149/// # impl std::fmt::Display for DatabaseError {
150/// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
151/// # write!(f, "database error")
152/// # }
153/// # }
154///
155/// // worker that accumulates transactions for batch processing.
156/// struct BatchProcessor;
157///
158/// #[async_trait]
159/// impl Worker for BatchProcessor {
160/// type Message = Vec<CheckpointTransaction>;
161/// type Error = DatabaseError;
162///
163/// async fn process_checkpoint(
164/// &self,
165/// checkpoint: Arc<CheckpointData>,
166/// ) -> Result<Self::Message, Self::Error> {
167/// // collect all checkpoint transactions for batch processing.
168/// Ok(checkpoint.transactions.clone())
169/// }
170/// }
171///
172/// // batch reducer for efficient storage.
173/// struct TransactionBatchReducer {
174/// batch_size: usize,
175/// // generic Database client.
176/// client: Arc<DatabaseClient>,
177/// }
178///
179/// #[async_trait]
180/// impl Reducer<BatchProcessor> for TransactionBatchReducer {
181/// async fn commit(&self, batch: &[Vec<CheckpointTransaction>]) -> Result<(), DatabaseError> {
182/// let flattened: Vec<CheckpointTransaction> = batch.iter().flatten().cloned().collect();
183/// // store the transaction batch in the database of choice.
184/// self.client.store_transactions_batch(&flattened).await?;
185/// Ok(())
186/// }
187///
188/// fn should_close_batch(
189/// &self,
190/// batch: &[Vec<CheckpointTransaction>],
191/// _: Option<&Vec<CheckpointTransaction>>,
192/// ) -> bool {
193/// batch.iter().map(|b| b.len()).sum::<usize>() >= self.batch_size
194/// }
195/// }
196///
197/// // configure worker pool with batch processing.
198/// let processor = BatchProcessor;
199/// let reducer = TransactionBatchReducer {
200/// batch_size: 1000,
201/// client: Arc::new(DatabaseClient::new()),
202/// };
203/// let pool = WorkerPool::new_with_reducer(
204/// processor,
205/// "batch_processing".into(),
206/// 5,
207/// Default::default(),
208/// reducer,
209/// );
210/// ```
211pub struct WorkerPool<W: Worker> {
212 /// An unique name of the WorkerPool task.
213 pub task_name: String,
214 /// How many instances of the current [`Worker`] to create, more workers are
215 /// created more checkpoints they can process in parallel.
216 concurrency: usize,
217 /// The actual [`Worker`] instance itself.
218 worker: Arc<W>,
219 /// The reducer instance, responsible for batch processing.
220 reducer: Option<Box<dyn Reducer<W>>>,
221 backoff: Arc<ExponentialBackoff>,
222}
223
224impl<W: Worker + 'static> WorkerPool<W> {
225 /// Creates a new `WorkerPool` without a reducer.
226 pub fn new(
227 worker: W,
228 task_name: String,
229 concurrency: usize,
230 backoff: ExponentialBackoff,
231 ) -> Self {
232 Self {
233 task_name,
234 concurrency,
235 worker: Arc::new(worker),
236 reducer: None,
237 backoff: Arc::new(backoff),
238 }
239 }
240
241 /// Creates a new `WorkerPool` with a reducer.
242 pub fn new_with_reducer<R>(
243 worker: W,
244 task_name: String,
245 concurrency: usize,
246 backoff: ExponentialBackoff,
247 reducer: R,
248 ) -> Self
249 where
250 R: Reducer<W> + 'static,
251 {
252 Self {
253 task_name,
254 concurrency,
255 worker: Arc::new(worker),
256 reducer: Some(Box::new(reducer)),
257 backoff: Arc::new(backoff),
258 }
259 }
260
261 /// Runs the worker pool main logic.
262 pub async fn run(
263 mut self,
264 watermark: CheckpointSequenceNumber,
265 mut checkpoint_receiver: mpsc::Receiver<Arc<CheckpointData>>,
266 pool_status_sender: mpsc::Sender<WorkerPoolStatus>,
267 token: CancellationToken,
268 ) {
269 info!(
270 "Starting indexing pipeline {} with concurrency {}. Current watermark is {watermark}.",
271 self.task_name, self.concurrency
272 );
273 // This channel will be used to send progress data from Workers to WorkerPool
274 // main loop.
275 let (progress_sender, mut progress_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
276 // This channel will be used to send Workers progress data from WorkerPool to
277 // watermark tracking task.
278 let (watermark_sender, watermark_receiver) = mpsc::channel(MAX_CHECKPOINTS_IN_PROGRESS);
279 let mut idle: BTreeSet<_> = (0..self.concurrency).collect();
280 let mut checkpoints = VecDeque::new();
281 let mut workers_shutdown_signals = vec![];
282 let (workers, workers_join_handles) = self.spawn_workers(progress_sender, token.clone());
283 // Spawn a task that tracks checkpoint processing progress. The task:
284 // - Receives (checkpoint_number, message) pairs from workers.
285 // - Maintains checkpoint sequence order.
286 // - Reports progress either:
287 // * After processing each chunk (simple tracking).
288 // * After committing batches (with reducer).
289 let watermark_handle = self.spawn_watermark_tracking(
290 watermark,
291 watermark_receiver,
292 pool_status_sender.clone(),
293 token.clone(),
294 );
295 // main worker pool loop.
296 loop {
297 tokio::select! {
298 Some(worker_progress_msg) = progress_receiver.recv() => {
299 match worker_progress_msg {
300 WorkerStatus::Running((worker_id, checkpoint_number, message)) => {
301 idle.insert(worker_id);
302 // Try to send progress to reducer. If it fails (reducer has exited),
303 // we just continue - we still need to wait for all workers to shutdown.
304 let _ = watermark_sender.send((checkpoint_number, message)).await;
305
306 // By checking if token was not cancelled we ensure that no
307 // further checkpoints will be sent to the workers.
308 while !token.is_cancelled() && !checkpoints.is_empty() && !idle.is_empty() {
309 let checkpoint = checkpoints.pop_front().unwrap();
310 let worker_id = idle.pop_first().unwrap();
311 if workers[worker_id].send(checkpoint).await.is_err() {
312 // The worker channel closing is a sign we need to exit this inner loop.
313 break;
314 }
315 }
316 }
317 WorkerStatus::Shutdown(worker_id) => {
318 // Track workers that have initiated shutdown.
319 workers_shutdown_signals.push(worker_id);
320 }
321 }
322 }
323 // Adding an if guard to this branch ensure that no checkpoints
324 // will be sent to workers once the token has been cancelled.
325 Some(checkpoint) = checkpoint_receiver.recv(), if !token.is_cancelled() => {
326 let sequence_number = checkpoint.checkpoint_summary.sequence_number;
327 if sequence_number < watermark {
328 continue;
329 }
330 self.worker
331 .preprocess_hook(checkpoint.clone())
332 .map_err(|err| IngestionError::CheckpointHookProcessing(err.to_string()))
333 .expect("failed to preprocess task");
334 if idle.is_empty() {
335 checkpoints.push_back(checkpoint);
336 } else {
337 let worker_id = idle.pop_first().unwrap();
338 // If worker channel is closed, put the checkpoint back in queue
339 // and continue - we still need to wait for all worker shutdown signals.
340 if let Err(send_error) = workers[worker_id].send(checkpoint).await {
341 checkpoints.push_front(send_error.0);
342 };
343 }
344 }
345 }
346 // Once all workers have signaled completion, start the graceful shutdown
347 // process.
348 if workers_shutdown_signals.len() == self.concurrency {
349 break self
350 .workers_graceful_shutdown(
351 workers_join_handles,
352 watermark_handle,
353 pool_status_sender,
354 watermark_sender,
355 )
356 .await;
357 }
358 }
359 }
360
361 /// Spawn workers based on `self.concurrency` to process checkpoints
362 /// in parallel.
363 fn spawn_workers(
364 &self,
365 progress_sender: mpsc::Sender<WorkerStatus<W::Message>>,
366 token: CancellationToken,
367 ) -> (Vec<mpsc::Sender<Arc<CheckpointData>>>, Vec<JoinHandle<()>>) {
368 let mut worker_senders = Vec::with_capacity(self.concurrency);
369 let mut workers_join_handles = Vec::with_capacity(self.concurrency);
370
371 for worker_id in 0..self.concurrency {
372 let (worker_sender, mut worker_recv) =
373 mpsc::channel::<Arc<CheckpointData>>(MAX_CHECKPOINTS_IN_PROGRESS);
374 let cloned_progress_sender = progress_sender.clone();
375 let task_name = self.task_name.clone();
376 worker_senders.push(worker_sender);
377
378 let token = token.clone();
379
380 let worker = self.worker.clone();
381 let backoff = self.backoff.clone();
382 let join_handle = spawn_monitored_task!(async move {
383 loop {
384 tokio::select! {
385 // Once token is cancelled, notify worker's shutdown to the main loop
386 _ = token.cancelled() => {
387 _ = cloned_progress_sender.send(WorkerStatus::Shutdown(worker_id)).await;
388 break
389 },
390 Some(checkpoint) = worker_recv.recv() => {
391 let sequence_number = checkpoint.checkpoint_summary.sequence_number;
392 info!("received checkpoint for processing {} for workflow {}", sequence_number, task_name);
393 let start_time = Instant::now();
394 let status = Self::process_checkpoint_with_retry(worker_id, &worker, checkpoint, reset_backoff(&backoff), &token).await;
395 let trigger_shutdown = matches!(status, WorkerStatus::Shutdown(_));
396 if cloned_progress_sender.send(status).await.is_err() || trigger_shutdown {
397 break;
398 }
399 info!(
400 "finished checkpoint processing {sequence_number} for workflow {task_name} in {:?}",
401 start_time.elapsed()
402 );
403 }
404 }
405 }
406 });
407 // Keep all join handles to ensure all workers are terminated before exiting
408 workers_join_handles.push(join_handle);
409 }
410 (worker_senders, workers_join_handles)
411 }
412
413 /// Attempts to process a checkpoint with exponential backoff retries on
414 /// failure.
415 ///
416 /// This function repeatedly calls the
417 /// [`process_checkpoint`](Worker::process_checkpoint) method of the
418 /// provided [`Worker`] until either:
419 /// - The checkpoint processing succeeds, returning `WorkerStatus::Running`
420 /// with the processed message.
421 /// - A cancellation signal is received via the [`CancellationToken`],
422 /// returning `WorkerStatus::Shutdown(<worker-id>)`.
423 /// - All retry attempts are exhausted within backoff's maximum elapsed
424 /// time, causing a panic.
425 ///
426 /// # Retry Mechanism:
427 /// - Uses [`ExponentialBackoff`](backoff::ExponentialBackoff) to introduce
428 /// increasing delays between retry attempts.
429 /// - Checks for cancellation both before and after each processing attempt.
430 /// - If a cancellation signal is received during a backoff delay, the
431 /// function exits immediately with `WorkerStatus::Shutdown(<worker-id>)`.
432 ///
433 /// # Panics:
434 /// - If all retry attempts are exhausted within the backoff's maximum
435 /// elapsed time, indicating a persistent failure.
436 async fn process_checkpoint_with_retry(
437 worker_id: WorkerID,
438 worker: &W,
439 checkpoint: Arc<CheckpointData>,
440 mut backoff: ExponentialBackoff,
441 token: &CancellationToken,
442 ) -> WorkerStatus<W::Message> {
443 let sequence_number = checkpoint.checkpoint_summary.sequence_number;
444 loop {
445 // check for cancellation before attempting processing.
446 if token.is_cancelled() {
447 return WorkerStatus::Shutdown(worker_id);
448 }
449 // attempt to process checkpoint.
450 match worker.process_checkpoint(checkpoint.clone()).await {
451 Ok(message) => return WorkerStatus::Running((worker_id, sequence_number, message)),
452 Err(err) => {
453 let err = IngestionError::CheckpointProcessing(err.to_string());
454 warn!(
455 "transient worker execution error {err:?} for checkpoint {sequence_number}"
456 );
457 // check for cancellation after failed processing.
458 if token.is_cancelled() {
459 return WorkerStatus::Shutdown(worker_id);
460 }
461 }
462 }
463 // get next backoff duration or panic if max retries exceeded.
464 let duration = backoff
465 .next_backoff()
466 .expect("max elapsed time exceeded: checkpoint processing failed for checkpoint {sequence_number}");
467 // if cancellation occurs during backoff wait, exit early with Shutdown.
468 // Otherwise (if timeout expires), continue with the next retry attempt.
469 if tokio::time::timeout(duration, token.cancelled())
470 .await
471 .is_ok()
472 {
473 return WorkerStatus::Shutdown(worker_id);
474 }
475 }
476 }
477
478 /// Spawns a task that tracks the progress of checkpoint processing,
479 /// optionally with message reduction.
480 ///
481 /// This function spawns one of two types of tracking tasks:
482 ///
483 /// 1. Simple Watermark Tracking (when reducer = None):
484 /// - Reports watermark after processing each chunk.
485 ///
486 /// 2. Batch Processing (when reducer = Some):
487 /// - Reports progress only after successful batch commits.
488 /// - A batch is committed based on
489 /// [`should_close_batch`](Reducer::should_close_batch) policy.
490 fn spawn_watermark_tracking(
491 &mut self,
492 watermark: CheckpointSequenceNumber,
493 watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
494 executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
495 token: CancellationToken,
496 ) -> JoinHandle<Result<(), IngestionError>> {
497 let task_name = self.task_name.clone();
498 let backoff = self.backoff.clone();
499 if let Some(reducer) = self.reducer.take() {
500 return spawn_monitored_task!(reduce::<W>(
501 task_name,
502 watermark,
503 watermark_receiver,
504 executor_progress_sender,
505 reducer,
506 backoff,
507 token
508 ));
509 };
510 spawn_monitored_task!(simple_watermark_tracking::<W>(
511 task_name,
512 watermark,
513 watermark_receiver,
514 executor_progress_sender
515 ))
516 }
517
518 /// Start the workers graceful shutdown.
519 ///
520 /// - Awaits all worker handles.
521 /// - Awaits the reducer handle.
522 /// - Send `WorkerPoolStatus::Shutdown(<task-name>)` message notifying
523 /// external components that Worker Pool has been shutdown.
524 async fn workers_graceful_shutdown(
525 &self,
526 workers_join_handles: Vec<JoinHandle<()>>,
527 watermark_handle: JoinHandle<Result<(), IngestionError>>,
528 executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
529 watermark_sender: mpsc::Sender<(u64, <W as Worker>::Message)>,
530 ) {
531 for worker in workers_join_handles {
532 _ = worker
533 .await
534 .inspect_err(|err| tracing::error!("worker task panicked: {err}"));
535 }
536 // by dropping the sender we make sure that the stream will be closed and the
537 // watermark tracker task will exit its loop.
538 drop(watermark_sender);
539 _ = watermark_handle
540 .await
541 .inspect_err(|err| tracing::error!("watermark task panicked: {err}"));
542 _ = executor_progress_sender
543 .send(WorkerPoolStatus::Shutdown(self.task_name.clone()))
544 .await;
545 tracing::info!("Worker pool `{}` terminated gracefully", self.task_name);
546 }
547}
548
549/// Tracks checkpoint progress without reduction logic.
550///
551/// This function maintains a watermark of processed checkpoints by worker:
552/// 1. Receiving batches of progress status from workers.
553/// 2. Processing them in sequence order.
554/// 3. Reporting progress to the executor after each chunk from the stream.
555async fn simple_watermark_tracking<W: Worker>(
556 task_name: String,
557 mut current_checkpoint_number: CheckpointSequenceNumber,
558 watermark_receiver: mpsc::Receiver<(CheckpointSequenceNumber, W::Message)>,
559 executor_progress_sender: mpsc::Sender<WorkerPoolStatus>,
560) -> IngestionResult<()> {
561 // convert to a stream of MAX_CHECKPOINTS_IN_PROGRESS size. This way, each
562 // iteration of the loop will process all ready messages.
563 let mut stream =
564 ReceiverStream::new(watermark_receiver).ready_chunks(MAX_CHECKPOINTS_IN_PROGRESS);
565 // store unprocessed progress messages from workers.
566 let mut unprocessed = HashMap::new();
567 // track the next unprocessed checkpoint number for reporting progress
568 // after each chunk of messages is received from the stream.
569 let mut progress_update = None;
570
571 while let Some(update_batch) = stream.next().await {
572 unprocessed.extend(update_batch.into_iter());
573 // Process messages sequentially based on checkpoint sequence number.
574 // This ensures in-order processing and maintains progress integrity.
575 while unprocessed.remove(¤t_checkpoint_number).is_some() {
576 current_checkpoint_number += 1;
577 progress_update = Some(current_checkpoint_number);
578 }
579 // report progress update to executor.
580 if let Some(watermark) = progress_update.take() {
581 executor_progress_sender
582 .send(WorkerPoolStatus::Running((task_name.clone(), watermark)))
583 .await
584 .map_err(|_| IngestionError::Channel("unable to send worker pool progress updates to executor, receiver half closed".into()))?;
585 }
586 }
587 Ok(())
588}