iota_data_ingestion_core/progress_store/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    fmt::{Debug, Display},
8};
9
10use async_trait::async_trait;
11use iota_types::messages_checkpoint::CheckpointSequenceNumber;
12mod file;
13pub use file::FileProgressStore;
14
15use crate::{IngestionError, IngestionResult};
16
17pub type ExecutorProgress = HashMap<String, CheckpointSequenceNumber>;
18
19/// A trait defining the interface for persistent storage of checkpoint
20/// progress.
21///
22/// This trait allows for loading and saving the progress of a task, represented
23/// by a `task_name` & `CheckpointSequenceNumber` as key value pairs.
24/// Implementations of this trait are responsible for persisting this progress
25/// across restarts or failures.
26#[async_trait]
27pub trait ProgressStore: Send {
28    type Error: Debug + Display;
29
30    /// Loads the last saved checkpoint sequence number for a given task.
31    async fn load(&mut self, task_name: String) -> Result<CheckpointSequenceNumber, Self::Error>;
32
33    /// Saves the current checkpoint sequence number for a given task.
34    async fn save(
35        &mut self,
36        task_name: String,
37        checkpoint_number: CheckpointSequenceNumber,
38    ) -> Result<(), Self::Error>;
39}
40
41pub struct ProgressStoreWrapper<P> {
42    progress_store: P,
43    pending_state: ExecutorProgress,
44}
45
46#[async_trait]
47impl<P: ProgressStore> ProgressStore for ProgressStoreWrapper<P> {
48    type Error = IngestionError;
49
50    async fn load(&mut self, task_name: String) -> Result<CheckpointSequenceNumber, Self::Error> {
51        let watermark = self
52            .progress_store
53            .load(task_name.clone())
54            .await
55            .map_err(|err| IngestionError::ProgressStore(err.to_string()))?;
56        self.pending_state.insert(task_name, watermark);
57        Ok(watermark)
58    }
59
60    async fn save(
61        &mut self,
62        task_name: String,
63        checkpoint_number: CheckpointSequenceNumber,
64    ) -> Result<(), Self::Error> {
65        self.progress_store
66            .save(task_name.clone(), checkpoint_number)
67            .await
68            .map_err(|err| IngestionError::ProgressStore(err.to_string()))?;
69        self.pending_state.insert(task_name, checkpoint_number);
70        Ok(())
71    }
72}
73
74impl<P: ProgressStore> ProgressStoreWrapper<P> {
75    pub fn new(progress_store: P) -> Self {
76        Self {
77            progress_store,
78            pending_state: HashMap::new(),
79        }
80    }
81
82    pub fn min_watermark(&self) -> IngestionResult<CheckpointSequenceNumber> {
83        self.pending_state
84            .values()
85            .min()
86            .cloned()
87            .ok_or(IngestionError::EmptyWorkerPool)
88    }
89
90    pub fn stats(&self) -> ExecutorProgress {
91        self.pending_state.clone()
92    }
93}
94
95/// A simple, in-memory progress store primarily used for unit testing.
96///
97/// # Note
98///
99/// Provides `save` and `load`, but the `save` is not persistent.
100///
101/// # Example
102/// ```rust
103/// use iota_data_ingestion_core::{ProgressStore, ShimProgressStore};
104///
105/// #[tokio::main]
106/// async fn main() {
107///     let mut store = ShimProgressStore(10);
108///     // will not save the data.
109///     store.save("task1".into(), 42).await.unwrap();
110///     // ignores the task_name argument.
111///     let checkpoint = store.load("task1".into()).await.unwrap();
112///     assert_eq!(checkpoint, 10);
113///     let checkpoint = store.load("task2".into()).await.unwrap();
114///     assert_eq!(checkpoint, 10);
115/// }
116/// ```
117pub struct ShimProgressStore(pub u64);
118
119#[async_trait]
120impl ProgressStore for ShimProgressStore {
121    type Error = IngestionError;
122
123    async fn load(&mut self, _: String) -> Result<CheckpointSequenceNumber, Self::Error> {
124        Ok(self.0)
125    }
126    async fn save(&mut self, _: String, _: CheckpointSequenceNumber) -> Result<(), Self::Error> {
127        Ok(())
128    }
129}