iota_storage/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5#![allow(dead_code)]
6
7use std::{
8    fs,
9    fs::File,
10    io,
11    io::{BufReader, Read, Write},
12    ops::Range,
13    path::{Path, PathBuf},
14    sync::{
15        Arc,
16        atomic::{AtomicU64, Ordering},
17    },
18};
19
20use anyhow::{Result, anyhow};
21use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
22use bytes::{Buf, Bytes};
23use fastcrypto::hash::{HashFunction, Sha3_256};
24use futures::StreamExt;
25use iota_types::{
26    committee::Committee,
27    messages_checkpoint::{
28        CertifiedCheckpointSummary, CheckpointSequenceNumber, VerifiedCheckpoint,
29    },
30    storage::WriteStore,
31};
32use itertools::Itertools;
33use num_enum::{IntoPrimitive, TryFromPrimitive};
34use serde::{Deserialize, Serialize, de::DeserializeOwned};
35use tracing::debug;
36
37use crate::blob::BlobIter;
38
39pub mod blob;
40pub mod http_key_value_store;
41pub mod key_value_store;
42pub mod key_value_store_metrics;
43pub mod mutex_table;
44pub mod object_store;
45pub mod package_object_cache;
46pub mod sharded_lru;
47pub mod write_path_pending_tx_log;
48
49pub const SHA3_BYTES: usize = 32;
50
51#[derive(
52    Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, TryFromPrimitive, IntoPrimitive,
53)]
54#[repr(u8)]
55pub enum StorageFormat {
56    Blob = 0,
57}
58
59#[derive(
60    Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, TryFromPrimitive, IntoPrimitive,
61)]
62#[repr(u8)]
63pub enum FileCompression {
64    None = 0,
65    Zstd,
66}
67
68impl FileCompression {
69    pub fn zstd_compress<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> io::Result<()> {
70        // TODO: Add zstd compression level as function argument
71        let mut encoder = zstd::Encoder::new(writer, 1)?;
72        io::copy(reader, &mut encoder)?;
73        encoder.finish()?;
74        Ok(())
75    }
76    pub fn compress(&self, source: &std::path::Path) -> io::Result<()> {
77        match self {
78            FileCompression::Zstd => {
79                let mut input = File::open(source)?;
80                let tmp_file_name = source.with_extension("tmp");
81                let mut output = File::create(&tmp_file_name)?;
82                Self::zstd_compress(&mut input, &mut output)?;
83                fs::rename(tmp_file_name, source)?;
84            }
85            FileCompression::None => {}
86        }
87        Ok(())
88    }
89    pub fn decompress(&self, source: &PathBuf) -> Result<Box<dyn Read>> {
90        let file = File::open(source)?;
91        let res: Box<dyn Read> = match self {
92            FileCompression::Zstd => Box::new(zstd::stream::Decoder::new(file)?),
93            FileCompression::None => Box::new(BufReader::new(file)),
94        };
95        Ok(res)
96    }
97    pub fn bytes_decompress(&self, bytes: Bytes) -> Result<Box<dyn Read>> {
98        let res: Box<dyn Read> = match self {
99            FileCompression::Zstd => Box::new(zstd::stream::Decoder::new(bytes.reader())?),
100            FileCompression::None => Box::new(BufReader::new(bytes.reader())),
101        };
102        Ok(res)
103    }
104}
105
106pub fn compute_sha3_checksum_for_bytes(bytes: Bytes) -> Result<[u8; 32]> {
107    let mut hasher = Sha3_256::default();
108    io::copy(&mut bytes.reader(), &mut hasher)?;
109    Ok(hasher.finalize().digest)
110}
111
112pub fn compute_sha3_checksum_for_file(file: &mut File) -> Result<[u8; 32]> {
113    let mut hasher = Sha3_256::default();
114    io::copy(file, &mut hasher)?;
115    Ok(hasher.finalize().digest)
116}
117
118pub fn compute_sha3_checksum(source: &std::path::Path) -> Result<[u8; 32]> {
119    let mut file = fs::File::open(source)?;
120    compute_sha3_checksum_for_file(&mut file)
121}
122
123pub fn compress<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> Result<()> {
124    let magic = reader.read_u32::<BigEndian>()?;
125    writer.write_u32::<BigEndian>(magic)?;
126    let storage_format = reader.read_u8()?;
127    writer.write_u8(storage_format)?;
128    let file_compression = FileCompression::try_from(reader.read_u8()?)?;
129    writer.write_u8(file_compression.into())?;
130    match file_compression {
131        FileCompression::Zstd => {
132            FileCompression::zstd_compress(reader, writer)?;
133        }
134        FileCompression::None => {}
135    }
136    Ok(())
137}
138
139pub fn read<R: Read + 'static>(
140    expected_magic: u32,
141    mut reader: R,
142) -> Result<(Box<dyn Read>, StorageFormat)> {
143    let magic = reader.read_u32::<BigEndian>()?;
144    if magic != expected_magic {
145        Err(anyhow!(
146            "Unexpected magic string in file: {:?}, expected: {:?}",
147            magic,
148            expected_magic
149        ))
150    } else {
151        let storage_format = StorageFormat::try_from(reader.read_u8()?)?;
152        let file_compression = FileCompression::try_from(reader.read_u8()?)?;
153        let reader: Box<dyn Read> = match file_compression {
154            FileCompression::Zstd => Box::new(zstd::stream::Decoder::new(reader)?),
155            FileCompression::None => Box::new(BufReader::new(reader)),
156        };
157        Ok((reader, storage_format))
158    }
159}
160
161pub fn make_iterator<T: DeserializeOwned, R: Read + 'static>(
162    expected_magic: u32,
163    reader: R,
164) -> Result<impl Iterator<Item = T>> {
165    let (reader, storage_format) = read(expected_magic, reader)?;
166    match storage_format {
167        StorageFormat::Blob => Ok(BlobIter::new(reader)),
168    }
169}
170
171#[expect(clippy::result_large_err)]
172pub fn verify_checkpoint_with_committee(
173    committee: Arc<Committee>,
174    current: &VerifiedCheckpoint,
175    checkpoint: CertifiedCheckpointSummary,
176) -> Result<VerifiedCheckpoint, CertifiedCheckpointSummary> {
177    assert_eq!(
178        *checkpoint.sequence_number(),
179        current.sequence_number().checked_add(1).unwrap()
180    );
181
182    if Some(*current.digest()) != checkpoint.previous_digest {
183        debug!(
184            current_checkpoint_seq = current.sequence_number(),
185            current_digest =% current.digest(),
186            checkpoint_seq = checkpoint.sequence_number(),
187            checkpoint_digest =% checkpoint.digest(),
188            checkpoint_previous_digest =? checkpoint.previous_digest,
189            "checkpoint not on same chain"
190        );
191        return Err(checkpoint);
192    }
193
194    let current_epoch = current.epoch();
195    if checkpoint.epoch() != current_epoch
196        && checkpoint.epoch() != current_epoch.checked_add(1).unwrap()
197    {
198        debug!(
199            checkpoint_seq = checkpoint.sequence_number(),
200            checkpoint_epoch = checkpoint.epoch(),
201            current_checkpoint_seq = current.sequence_number(),
202            current_epoch = current_epoch,
203            "cannot verify checkpoint with too high of an epoch",
204        );
205        return Err(checkpoint);
206    }
207
208    if checkpoint.epoch() == current_epoch.checked_add(1).unwrap()
209        && current.next_epoch_committee().is_none()
210    {
211        debug!(
212            checkpoint_seq = checkpoint.sequence_number(),
213            checkpoint_epoch = checkpoint.epoch(),
214            current_checkpoint_seq = current.sequence_number(),
215            current_epoch = current_epoch,
216            "next checkpoint claims to be from the next epoch but the latest verified \
217            checkpoint does not indicate that it is the last checkpoint of an epoch"
218        );
219        return Err(checkpoint);
220    }
221
222    checkpoint
223        .verify_authority_signatures(&committee)
224        .map_err(|e| {
225            debug!("error verifying checkpoint: {e}");
226            checkpoint.clone()
227        })?;
228    Ok(VerifiedCheckpoint::new_unchecked(checkpoint))
229}
230
231#[expect(clippy::result_large_err)]
232pub fn verify_checkpoint<S>(
233    current: &VerifiedCheckpoint,
234    store: S,
235    checkpoint: CertifiedCheckpointSummary,
236) -> Result<VerifiedCheckpoint, CertifiedCheckpointSummary>
237where
238    S: WriteStore,
239{
240    let committee = store
241        .get_committee(checkpoint.epoch())
242        .expect("store operation should not fail")
243        .unwrap_or_else(|| {
244            panic!(
245                "BUG: should have committee for epoch {} before we try to verify checkpoint {}",
246                checkpoint.epoch(),
247                checkpoint.sequence_number()
248            )
249        });
250
251    verify_checkpoint_with_committee(committee, current, checkpoint)
252}
253
254pub async fn verify_checkpoint_range<S>(
255    checkpoint_range: Range<CheckpointSequenceNumber>,
256    store: S,
257    checkpoint_counter: Arc<AtomicU64>,
258    max_concurrency: usize,
259) where
260    S: WriteStore + Clone,
261{
262    let range_clone = checkpoint_range.clone();
263    futures::stream::iter(range_clone.into_iter().tuple_windows())
264        .map(|(a, b)| {
265            let current = store
266                .get_checkpoint_by_sequence_number(a)
267                .expect("store operation should not fail")
268                .unwrap_or_else(|| {
269                    panic!(
270                        "Checkpoint {a} should exist in store after summary sync but does not"
271                    );
272                });
273            let next = store
274                .get_checkpoint_by_sequence_number(b)
275                .expect("store operation should not fail")
276                .unwrap_or_else(|| {
277                    panic!(
278                        "Checkpoint {a} should exist in store after summary sync but does not"
279                    );
280                });
281            let committee = store
282                .get_committee(next.epoch())
283                .expect("store operation should not fail")
284                .unwrap_or_else(|| {
285                    panic!(
286                        "BUG: should have committee for epoch {} before we try to verify checkpoint {}",
287                        next.epoch(),
288                        next.sequence_number()
289                    )
290                });
291            tokio::spawn(async move {
292                verify_checkpoint_with_committee(committee, &current, next.clone().into())
293                    .expect("Checkpoint verification failed");
294            })
295        })
296        .buffer_unordered(max_concurrency)
297        .for_each(|result| {
298            result.expect("Checkpoint verification task failed");
299            checkpoint_counter.fetch_add(1, Ordering::Relaxed);
300            futures::future::ready(())
301        })
302        .await;
303    let last = checkpoint_range
304        .last()
305        .expect("Received empty checkpoint range");
306    let final_checkpoint = store
307        .get_checkpoint_by_sequence_number(last)
308        .expect("Failed to fetch checkpoint")
309        .expect("Expected end of checkpoint range to exist in store");
310    store
311        .update_highest_verified_checkpoint(&final_checkpoint)
312        .expect("Failed to update highest verified checkpoint");
313}
314
315fn hard_link(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> io::Result<()> {
316    fs::create_dir_all(&dst)?;
317    for entry in fs::read_dir(src)? {
318        let entry = entry?;
319        let ty = entry.file_type()?;
320        if ty.is_dir() {
321            hard_link(entry.path(), dst.as_ref().join(entry.file_name()))?;
322        } else {
323            fs::hard_link(entry.path(), dst.as_ref().join(entry.file_name()))?;
324        }
325    }
326    Ok(())
327}
328
329#[cfg(test)]
330mod tests {
331    use tempfile::TempDir;
332    use typed_store::{
333        Map, reopen,
334        rocks::{DBMap, MetricConf, ReadWriteOptions, open_cf},
335    };
336
337    use crate::hard_link;
338
339    #[tokio::test]
340    pub async fn test_db_hard_link() -> anyhow::Result<()> {
341        let input = TempDir::new()?;
342        let input_path = input.path();
343
344        let output = TempDir::new()?;
345        let output_path = output.path();
346
347        const FIRST_CF: &str = "First_CF";
348        const SECOND_CF: &str = "Second_CF";
349
350        let db_a = open_cf(
351            input_path,
352            None,
353            MetricConf::new("test_db_hard_link_1"),
354            &[FIRST_CF, SECOND_CF],
355        )
356        .unwrap();
357
358        let (db_map_1, db_map_2) = reopen!(&db_a, FIRST_CF;<i32, String>, SECOND_CF;<i32, String>);
359
360        let keys_vals_cf1 = (1..100).map(|i| (i, i.to_string()));
361        let keys_vals_cf2 = (1..100).map(|i| (i, i.to_string()));
362
363        assert!(db_map_1.multi_insert(keys_vals_cf1).is_ok());
364        assert!(db_map_2.multi_insert(keys_vals_cf2).is_ok());
365
366        // set up db hard link
367        hard_link(input_path, output_path)?;
368        let db_b = open_cf(
369            output_path,
370            None,
371            MetricConf::new("test_db_hard_link_2"),
372            &[FIRST_CF, SECOND_CF],
373        )
374        .unwrap();
375
376        let (db_map_1, db_map_2) = reopen!(&db_b, FIRST_CF;<i32, String>, SECOND_CF;<i32, String>);
377        for i in 1..100 {
378            assert!(
379                db_map_1
380                    .contains_key(&i)
381                    .expect("Failed to call contains key")
382            );
383            assert!(
384                db_map_2
385                    .contains_key(&i)
386                    .expect("Failed to call contains key")
387            );
388        }
389
390        Ok(())
391    }
392}