iota_data_ingestion/
progress_store.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{str::FromStr, time::Duration};
6
7use anyhow::Result;
8use async_trait::async_trait;
9use aws_config::{BehaviorVersion, timeout::TimeoutConfig};
10use aws_sdk_dynamodb::{
11    Client,
12    config::{Credentials, Region},
13    error::SdkError,
14    types::AttributeValue,
15};
16use iota_data_ingestion_core::ProgressStore;
17use iota_types::messages_checkpoint::CheckpointSequenceNumber;
18
19pub struct DynamoDBProgressStore {
20    client: Client,
21    table_name: String,
22}
23
24impl DynamoDBProgressStore {
25    pub async fn new(
26        aws_access_key_id: &str,
27        aws_secret_access_key: &str,
28        aws_region: String,
29        table_name: String,
30    ) -> Self {
31        let credentials = Credentials::new(
32            aws_access_key_id,
33            aws_secret_access_key,
34            None,
35            None,
36            "dynamodb",
37        );
38        let timeout_config = TimeoutConfig::builder()
39            .operation_timeout(Duration::from_secs(3))
40            .operation_attempt_timeout(Duration::from_secs(10))
41            .connect_timeout(Duration::from_secs(3))
42            .build();
43        let aws_config = aws_config::defaults(BehaviorVersion::latest())
44            .credentials_provider(credentials)
45            .region(Region::new(aws_region))
46            .timeout_config(timeout_config)
47            .load()
48            .await;
49        let client = Client::new(&aws_config);
50        Self { client, table_name }
51    }
52}
53
54#[async_trait]
55impl ProgressStore for DynamoDBProgressStore {
56    type Error = anyhow::Error;
57
58    async fn load(&mut self, task_name: String) -> Result<CheckpointSequenceNumber, Self::Error> {
59        let item = self
60            .client
61            .get_item()
62            .table_name(self.table_name.clone())
63            .key("task_name", AttributeValue::S(task_name))
64            .send()
65            .await?;
66        if let Some(output) = item.item() {
67            if let AttributeValue::N(checkpoint_number) = &output["nstate"] {
68                return Ok(CheckpointSequenceNumber::from_str(checkpoint_number)?);
69            }
70        }
71        Ok(0)
72    }
73    async fn save(
74        &mut self,
75        task_name: String,
76        checkpoint_number: CheckpointSequenceNumber,
77    ) -> Result<(), Self::Error> {
78        let backoff = backoff::ExponentialBackoff::default();
79        backoff::future::retry(backoff, || async {
80            let result = self
81                .client
82                .update_item()
83                .table_name(self.table_name.clone())
84                .key("task_name", AttributeValue::S(task_name.clone()))
85                .update_expression("SET #nstate = :newState")
86                .condition_expression("#nstate < :newState")
87                .expression_attribute_names("#nstate", "nstate")
88                .expression_attribute_values(
89                    ":newState",
90                    AttributeValue::N(checkpoint_number.to_string()),
91                )
92                .send()
93                .await;
94            match result {
95                Ok(_) => Ok(()),
96                Err(SdkError::ServiceError(err))
97                    if err.err().is_conditional_check_failed_exception() =>
98                {
99                    Ok(())
100                }
101                Err(err) => Err(backoff::Error::transient(err)),
102            }
103        })
104        .await?;
105        Ok(())
106    }
107}