iota_storage/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5#![allow(dead_code)]
6
7use std::{
8    fs,
9    fs::File,
10    io,
11    io::{BufReader, Read, Write},
12    ops::Range,
13    path::{Path, PathBuf},
14    sync::{
15        Arc,
16        atomic::{AtomicU64, Ordering},
17    },
18};
19
20use anyhow::{Result, anyhow};
21use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
22use bytes::{Buf, Bytes};
23use fastcrypto::hash::{HashFunction, Sha3_256};
24use futures::StreamExt;
25use iota_types::{
26    committee::Committee,
27    messages_checkpoint::{
28        CertifiedCheckpointSummary, CheckpointSequenceNumber, VerifiedCheckpoint,
29    },
30    storage::WriteStore,
31};
32use itertools::Itertools;
33use num_enum::{IntoPrimitive, TryFromPrimitive};
34use serde::{Deserialize, Serialize, de::DeserializeOwned};
35use tracing::debug;
36
37use crate::blob::BlobIter;
38
39pub mod blob;
40pub mod http_key_value_store;
41pub mod key_value_store;
42pub mod key_value_store_metrics;
43pub mod mutex_table;
44pub mod object_store;
45pub mod package_object_cache;
46pub mod sharded_lru;
47pub mod write_path_pending_tx_log;
48
49pub const SHA3_BYTES: usize = 32;
50
51#[derive(
52    Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, TryFromPrimitive, IntoPrimitive,
53)]
54#[repr(u8)]
55pub enum StorageFormat {
56    Blob = 0,
57}
58
59#[derive(
60    Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, TryFromPrimitive, IntoPrimitive,
61)]
62#[repr(u8)]
63pub enum FileCompression {
64    None = 0,
65    Zstd,
66}
67
68impl FileCompression {
69    pub fn zstd_compress<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> io::Result<()> {
70        // TODO: Add zstd compression level as function argument
71        let mut encoder = zstd::Encoder::new(writer, 1)?;
72        io::copy(reader, &mut encoder)?;
73        encoder.finish()?;
74        Ok(())
75    }
76    pub fn compress(&self, source: &std::path::Path) -> io::Result<()> {
77        match self {
78            FileCompression::Zstd => {
79                let mut input = File::open(source)?;
80                let tmp_file_name = source.with_extension("tmp");
81                let mut output = File::create(&tmp_file_name)?;
82                Self::zstd_compress(&mut input, &mut output)?;
83                fs::rename(tmp_file_name, source)?;
84            }
85            FileCompression::None => {}
86        }
87        Ok(())
88    }
89    pub fn decompress(&self, source: &PathBuf) -> Result<Box<dyn Read>> {
90        let file = File::open(source)?;
91        let res: Box<dyn Read> = match self {
92            FileCompression::Zstd => Box::new(zstd::stream::Decoder::new(file)?),
93            FileCompression::None => Box::new(BufReader::new(file)),
94        };
95        Ok(res)
96    }
97    pub fn bytes_decompress(&self, bytes: Bytes) -> Result<Box<dyn Read>> {
98        let res: Box<dyn Read> = match self {
99            FileCompression::Zstd => Box::new(zstd::stream::Decoder::new(bytes.reader())?),
100            FileCompression::None => Box::new(BufReader::new(bytes.reader())),
101        };
102        Ok(res)
103    }
104}
105
106pub fn compute_sha3_checksum_for_bytes(bytes: Bytes) -> Result<[u8; 32]> {
107    let mut hasher = Sha3_256::default();
108    io::copy(&mut bytes.reader(), &mut hasher)?;
109    Ok(hasher.finalize().digest)
110}
111
112pub fn compute_sha3_checksum_for_file(file: &mut File) -> Result<[u8; 32]> {
113    let mut hasher = Sha3_256::default();
114    io::copy(file, &mut hasher)?;
115    Ok(hasher.finalize().digest)
116}
117
118pub fn compute_sha3_checksum(source: &std::path::Path) -> Result<[u8; 32]> {
119    let mut file = fs::File::open(source)?;
120    compute_sha3_checksum_for_file(&mut file)
121}
122
123pub fn compress<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> Result<()> {
124    let magic = reader.read_u32::<BigEndian>()?;
125    writer.write_u32::<BigEndian>(magic)?;
126    let storage_format = reader.read_u8()?;
127    writer.write_u8(storage_format)?;
128    let file_compression = FileCompression::try_from(reader.read_u8()?)?;
129    writer.write_u8(file_compression.into())?;
130    match file_compression {
131        FileCompression::Zstd => {
132            FileCompression::zstd_compress(reader, writer)?;
133        }
134        FileCompression::None => {}
135    }
136    Ok(())
137}
138
139pub fn read<R: Read + 'static>(
140    expected_magic: u32,
141    mut reader: R,
142) -> Result<(Box<dyn Read>, StorageFormat)> {
143    let magic = reader.read_u32::<BigEndian>()?;
144    if magic != expected_magic {
145        Err(anyhow!(
146            "Unexpected magic string in file: {:?}, expected: {:?}",
147            magic,
148            expected_magic
149        ))
150    } else {
151        let storage_format = StorageFormat::try_from(reader.read_u8()?)?;
152        let file_compression = FileCompression::try_from(reader.read_u8()?)?;
153        let reader: Box<dyn Read> = match file_compression {
154            FileCompression::Zstd => Box::new(zstd::stream::Decoder::new(reader)?),
155            FileCompression::None => Box::new(BufReader::new(reader)),
156        };
157        Ok((reader, storage_format))
158    }
159}
160
161pub fn make_iterator<T: DeserializeOwned, R: Read + 'static>(
162    expected_magic: u32,
163    reader: R,
164) -> Result<impl Iterator<Item = T>> {
165    let (reader, storage_format) = read(expected_magic, reader)?;
166    match storage_format {
167        StorageFormat::Blob => Ok(BlobIter::new(reader)),
168    }
169}
170
171#[expect(clippy::result_large_err)]
172pub fn verify_checkpoint_with_committee(
173    committee: Arc<Committee>,
174    current: &VerifiedCheckpoint,
175    checkpoint: CertifiedCheckpointSummary,
176) -> Result<VerifiedCheckpoint, CertifiedCheckpointSummary> {
177    assert_eq!(
178        *checkpoint.sequence_number(),
179        current.sequence_number().checked_add(1).unwrap()
180    );
181
182    if Some(*current.digest()) != checkpoint.previous_digest {
183        debug!(
184            current_checkpoint_seq = current.sequence_number(),
185            current_digest =% current.digest(),
186            checkpoint_seq = checkpoint.sequence_number(),
187            checkpoint_digest =% checkpoint.digest(),
188            checkpoint_previous_digest =? checkpoint.previous_digest,
189            "checkpoint not on same chain"
190        );
191        return Err(checkpoint);
192    }
193
194    let current_epoch = current.epoch();
195    if checkpoint.epoch() != current_epoch
196        && checkpoint.epoch() != current_epoch.checked_add(1).unwrap()
197    {
198        debug!(
199            checkpoint_seq = checkpoint.sequence_number(),
200            checkpoint_epoch = checkpoint.epoch(),
201            current_checkpoint_seq = current.sequence_number(),
202            current_epoch = current_epoch,
203            "cannot verify checkpoint with too high of an epoch",
204        );
205        return Err(checkpoint);
206    }
207
208    if checkpoint.epoch() == current_epoch.checked_add(1).unwrap()
209        && current.next_epoch_committee().is_none()
210    {
211        debug!(
212            checkpoint_seq = checkpoint.sequence_number(),
213            checkpoint_epoch = checkpoint.epoch(),
214            current_checkpoint_seq = current.sequence_number(),
215            current_epoch = current_epoch,
216            "next checkpoint claims to be from the next epoch but the latest verified \
217            checkpoint does not indicate that it is the last checkpoint of an epoch"
218        );
219        return Err(checkpoint);
220    }
221
222    checkpoint
223        .verify_authority_signatures(&committee)
224        .map_err(|e| {
225            debug!("error verifying checkpoint: {e}");
226            checkpoint.clone()
227        })?;
228    Ok(VerifiedCheckpoint::new_unchecked(checkpoint))
229}
230
231#[expect(clippy::result_large_err)]
232pub fn verify_checkpoint<S>(
233    current: &VerifiedCheckpoint,
234    store: S,
235    checkpoint: CertifiedCheckpointSummary,
236) -> Result<VerifiedCheckpoint, CertifiedCheckpointSummary>
237where
238    S: WriteStore,
239{
240    let committee = store.get_committee(checkpoint.epoch()).unwrap_or_else(|| {
241        panic!(
242            "BUG: should have committee for epoch {} before we try to verify checkpoint {}",
243            checkpoint.epoch(),
244            checkpoint.sequence_number()
245        )
246    });
247
248    verify_checkpoint_with_committee(committee, current, checkpoint)
249}
250
251pub async fn verify_checkpoint_range<S>(
252    checkpoint_range: Range<CheckpointSequenceNumber>,
253    store: S,
254    checkpoint_counter: Arc<AtomicU64>,
255    max_concurrency: usize,
256) where
257    S: WriteStore + Clone,
258{
259    let range_clone = checkpoint_range.clone();
260    futures::stream::iter(range_clone.into_iter().tuple_windows())
261        .map(|(a, b)| {
262            let current = store
263                .get_checkpoint_by_sequence_number(a)
264                .unwrap_or_else(|| {
265                    panic!("Checkpoint {a} should exist in store after summary sync but does not");
266                });
267            let next = store
268                .get_checkpoint_by_sequence_number(b)
269                .unwrap_or_else(|| {
270                    panic!("Checkpoint {a} should exist in store after summary sync but does not");
271                });
272
273            let committee = store.get_committee(next.epoch()).unwrap_or_else(|| {
274                panic!(
275                    "BUG: should have committee for epoch {} before we try to verify checkpoint {}",
276                    next.epoch(),
277                    next.sequence_number()
278                )
279            });
280            tokio::spawn(async move {
281                verify_checkpoint_with_committee(committee, &current, next.clone().into())
282                    .expect("Checkpoint verification failed");
283            })
284        })
285        .buffer_unordered(max_concurrency)
286        .for_each(|result| {
287            result.expect("Checkpoint verification task failed");
288            checkpoint_counter.fetch_add(1, Ordering::Relaxed);
289            futures::future::ready(())
290        })
291        .await;
292    let last = checkpoint_range
293        .last()
294        .expect("Received empty checkpoint range");
295    let final_checkpoint = store
296        .get_checkpoint_by_sequence_number(last)
297        .expect("Expected end of checkpoint range to exist in store");
298    store
299        .try_update_highest_verified_checkpoint(&final_checkpoint)
300        .expect("Failed to update highest verified checkpoint");
301}
302
303fn hard_link(src: impl AsRef<Path>, dst: impl AsRef<Path>) -> io::Result<()> {
304    fs::create_dir_all(&dst)?;
305    for entry in fs::read_dir(src)? {
306        let entry = entry?;
307        let ty = entry.file_type()?;
308        if ty.is_dir() {
309            hard_link(entry.path(), dst.as_ref().join(entry.file_name()))?;
310        } else {
311            fs::hard_link(entry.path(), dst.as_ref().join(entry.file_name()))?;
312        }
313    }
314    Ok(())
315}
316
317#[cfg(test)]
318mod tests {
319    use tempfile::TempDir;
320    use typed_store::{
321        Map, reopen,
322        rocks::{DBMap, MetricConf, ReadWriteOptions, open_cf},
323    };
324
325    use crate::hard_link;
326
327    #[tokio::test]
328    pub async fn test_db_hard_link() -> anyhow::Result<()> {
329        let input = TempDir::new()?;
330        let input_path = input.path();
331
332        let output = TempDir::new()?;
333        let output_path = output.path();
334
335        const FIRST_CF: &str = "First_CF";
336        const SECOND_CF: &str = "Second_CF";
337
338        let db_a = open_cf(
339            input_path,
340            None,
341            MetricConf::new("test_db_hard_link_1"),
342            &[FIRST_CF, SECOND_CF],
343        )
344        .unwrap();
345
346        let (db_map_1, db_map_2) = reopen!(&db_a, FIRST_CF;<i32, String>, SECOND_CF;<i32, String>);
347
348        let keys_vals_cf1 = (1..100).map(|i| (i, i.to_string()));
349        let keys_vals_cf2 = (1..100).map(|i| (i, i.to_string()));
350
351        assert!(db_map_1.multi_insert(keys_vals_cf1).is_ok());
352        assert!(db_map_2.multi_insert(keys_vals_cf2).is_ok());
353
354        // set up db hard link
355        hard_link(input_path, output_path)?;
356        let db_b = open_cf(
357            output_path,
358            None,
359            MetricConf::new("test_db_hard_link_2"),
360            &[FIRST_CF, SECOND_CF],
361        )
362        .unwrap();
363
364        let (db_map_1, db_map_2) = reopen!(&db_b, FIRST_CF;<i32, String>, SECOND_CF;<i32, String>);
365        for i in 1..100 {
366            assert!(
367                db_map_1
368                    .contains_key(&i)
369                    .expect("Failed to call contains key")
370            );
371            assert!(
372                db_map_2
373                    .contains_key(&i)
374                    .expect("Failed to call contains key")
375            );
376        }
377
378        Ok(())
379    }
380}