iota_data_ingestion/
progress_store.rs1use std::{str::FromStr, time::Duration};
6
7use anyhow::Result;
8use async_trait::async_trait;
9use aws_config::{BehaviorVersion, timeout::TimeoutConfig};
10use aws_sdk_dynamodb::{
11 Client,
12 config::{Credentials, Region},
13 error::SdkError,
14 types::AttributeValue,
15};
16use iota_data_ingestion_core::ProgressStore;
17use iota_types::messages_checkpoint::CheckpointSequenceNumber;
18
19pub struct DynamoDBProgressStore {
20 client: Client,
21 table_name: String,
22}
23
24impl DynamoDBProgressStore {
25 pub async fn new(
26 aws_access_key_id: &str,
27 aws_secret_access_key: &str,
28 aws_region: String,
29 table_name: String,
30 ) -> Self {
31 let credentials = Credentials::new(
32 aws_access_key_id,
33 aws_secret_access_key,
34 None,
35 None,
36 "dynamodb",
37 );
38 let timeout_config = TimeoutConfig::builder()
39 .operation_timeout(Duration::from_secs(3))
40 .operation_attempt_timeout(Duration::from_secs(10))
41 .connect_timeout(Duration::from_secs(3))
42 .build();
43 let aws_config = aws_config::defaults(BehaviorVersion::latest())
44 .credentials_provider(credentials)
45 .region(Region::new(aws_region))
46 .timeout_config(timeout_config)
47 .load()
48 .await;
49 let client = Client::new(&aws_config);
50 Self { client, table_name }
51 }
52}
53
54#[async_trait]
55impl ProgressStore for DynamoDBProgressStore {
56 type Error = anyhow::Error;
57
58 async fn load(&mut self, task_name: String) -> Result<CheckpointSequenceNumber, Self::Error> {
59 let item = self
60 .client
61 .get_item()
62 .table_name(self.table_name.clone())
63 .key("task_name", AttributeValue::S(task_name))
64 .send()
65 .await?;
66 if let Some(output) = item.item() {
67 if let AttributeValue::N(checkpoint_number) = &output["nstate"] {
68 return Ok(CheckpointSequenceNumber::from_str(checkpoint_number)?);
69 }
70 }
71 Ok(0)
72 }
73 async fn save(
74 &mut self,
75 task_name: String,
76 checkpoint_number: CheckpointSequenceNumber,
77 ) -> Result<(), Self::Error> {
78 let backoff = backoff::ExponentialBackoff::default();
79 backoff::future::retry(backoff, || async {
80 let result = self
81 .client
82 .update_item()
83 .table_name(self.table_name.clone())
84 .key("task_name", AttributeValue::S(task_name.clone()))
85 .update_expression("SET #nstate = :newState")
86 .condition_expression("#nstate < :newState")
87 .expression_attribute_names("#nstate", "nstate")
88 .expression_attribute_values(
89 ":newState",
90 AttributeValue::N(checkpoint_number.to_string()),
91 )
92 .send()
93 .await;
94 match result {
95 Ok(_) => Ok(()),
96 Err(SdkError::ServiceError(err))
97 if err.err().is_conditional_check_failed_exception() =>
98 {
99 Ok(())
100 }
101 Err(err) => Err(backoff::Error::transient(err)),
102 }
103 })
104 .await?;
105 Ok(())
106 }
107}