typed_store/
database.rs

1// Copyright (c) 2026 IOTA Stiftung
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    borrow::Borrow,
6    marker::PhantomData,
7    ops::{Bound, Deref, RangeBounds},
8    path::Path,
9    sync::Arc,
10    time::Duration,
11};
12
13use fastcrypto::hash::{Digest, HashFunction};
14use iota_common::debug_fatal;
15use iota_macros::fail_point;
16use prometheus::{Histogram, HistogramTimer};
17use rocksdb::{DBPinnableSlice, Error, LiveFile, ReadOptions, WriteBatch, checkpoint::Checkpoint};
18use serde::{Serialize, de::DeserializeOwned};
19use tokio::sync::oneshot;
20use tracing::{debug, error, instrument, warn};
21use typed_store_error::TypedStoreError;
22
23use crate::{
24    DbIterator,
25    memstore::{InMemoryBatch, InMemoryDB},
26    metrics::{DBMetrics, RocksDBPerfContext, SamplingInterval},
27    rocks::{
28        RocksDB,
29        errors::{typed_store_err_from_bcs_err, typed_store_err_from_rocks_err},
30        options::ReadWriteOptions,
31        rocks_cf, rocks_util,
32        safe_iter::{SafeIter, SafeRevIter},
33    },
34    traits::{Map, TableSummary},
35    util::{be_fix_int_ser, iterator_bounds, iterator_bounds_with_range},
36};
37
38#[derive(Clone)]
39pub(crate) enum ColumnFamily {
40    Rocks(String),
41    #[allow(dead_code)]
42    InMemory(String),
43}
44
45impl std::fmt::Debug for ColumnFamily {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            ColumnFamily::Rocks(name) => write!(f, "RocksDB cf: {}", name),
49            ColumnFamily::InMemory(name) => write!(f, "InMemory cf: {}", name),
50        }
51    }
52}
53
54impl ColumnFamily {
55    pub(crate) fn name(&self) -> &str {
56        match self {
57            ColumnFamily::Rocks(name) => name,
58            ColumnFamily::InMemory(name) => name,
59        }
60    }
61
62    pub(crate) fn rocks_cf<'a>(
63        &self,
64        rocks_db: &'a RocksDB,
65    ) -> Arc<rocksdb::BoundColumnFamily<'a>> {
66        match &self {
67            ColumnFamily::Rocks(name) => rocks_db
68                .underlying
69                .cf_handle(name)
70                .expect("Map-keying column family should have been checked at DB creation"),
71            _ => unreachable!("invariant is checked by the caller"),
72        }
73    }
74}
75
76pub(crate) enum Storage {
77    Rocks(RocksDB),
78    #[allow(dead_code)]
79    InMemory(InMemoryDB),
80}
81
82impl std::fmt::Debug for Storage {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        match self {
85            Storage::Rocks(db) => write!(f, "RocksDB Storage {:?}", db),
86            Storage::InMemory(db) => write!(f, "InMemoryDB Storage {:?}", db),
87        }
88    }
89}
90
91pub(crate) enum GetResult<'a> {
92    Rocks(DBPinnableSlice<'a>),
93    InMemory(Vec<u8>),
94}
95
96impl Deref for GetResult<'_> {
97    type Target = [u8];
98    fn deref(&self) -> &[u8] {
99        match self {
100            GetResult::Rocks(d) => d.deref(),
101            GetResult::InMemory(d) => d.deref(),
102        }
103    }
104}
105
106pub enum StorageWriteBatch {
107    Rocks(rocksdb::WriteBatch),
108    InMemory(InMemoryBatch),
109}
110
111#[derive(Debug, Default)]
112pub struct MetricConf {
113    pub db_name: String,
114    pub read_sample_interval: SamplingInterval,
115    pub write_sample_interval: SamplingInterval,
116    pub iter_sample_interval: SamplingInterval,
117}
118
119impl MetricConf {
120    pub fn new(db_name: &str) -> Self {
121        if db_name.is_empty() {
122            error!("A meaningful db name should be used for metrics reporting.")
123        }
124        Self {
125            db_name: db_name.to_string(),
126            read_sample_interval: SamplingInterval::default(),
127            write_sample_interval: SamplingInterval::default(),
128            iter_sample_interval: SamplingInterval::default(),
129        }
130    }
131
132    pub fn with_sampling(self, read_interval: SamplingInterval) -> Self {
133        Self {
134            db_name: self.db_name,
135            read_sample_interval: read_interval,
136            write_sample_interval: SamplingInterval::default(),
137            iter_sample_interval: SamplingInterval::default(),
138        }
139    }
140}
141
142const CF_METRICS_REPORT_PERIOD_SECS: u64 = 30;
143
144#[derive(Debug)]
145pub struct Database {
146    pub(crate) storage: Storage,
147    pub(crate) metric_conf: MetricConf,
148}
149
150impl Drop for Database {
151    fn drop(&mut self) {
152        DBMetrics::get().decrement_num_active_dbs(&self.metric_conf.db_name);
153    }
154}
155
156impl Database {
157    pub(crate) fn new(storage: Storage, metric_conf: MetricConf) -> Self {
158        DBMetrics::get().increment_num_active_dbs(&metric_conf.db_name);
159        Self {
160            storage,
161            metric_conf,
162        }
163    }
164
165    pub(crate) fn get<K: AsRef<[u8]>>(
166        &self,
167        cf: &ColumnFamily,
168        key: K,
169        readopts: &ReadOptions,
170    ) -> Result<Option<GetResult<'_>>, TypedStoreError> {
171        match (&self.storage, cf) {
172            (Storage::Rocks(db), ColumnFamily::Rocks(_)) => Ok(db
173                .underlying
174                .get_pinned_cf_opt(&cf.rocks_cf(db), key, readopts)
175                .map_err(typed_store_err_from_rocks_err)?
176                .map(GetResult::Rocks)),
177            (Storage::InMemory(db), ColumnFamily::InMemory(cf_name)) => {
178                Ok(db.get(cf_name, key).map(GetResult::InMemory))
179            }
180
181            _ => Err(TypedStoreError::RocksDB(
182                "typed store invariant violation".to_string(),
183            )),
184        }
185    }
186
187    pub(crate) fn multi_get<I, K>(
188        &self,
189        cf: &ColumnFamily,
190        keys: I,
191        readopts: &ReadOptions,
192    ) -> Vec<Result<Option<GetResult<'_>>, TypedStoreError>>
193    where
194        I: IntoIterator<Item = K>,
195        K: AsRef<[u8]>,
196    {
197        match (&self.storage, cf) {
198            (Storage::Rocks(db), ColumnFamily::Rocks(_)) => {
199                let keys_vec: Vec<K> = keys.into_iter().collect();
200                let res = db.underlying.batched_multi_get_cf_opt(
201                    &cf.rocks_cf(db),
202                    keys_vec.iter(),
203                    // sorted_input
204                    false,
205                    readopts,
206                );
207                res.into_iter()
208                    .map(|r| {
209                        r.map_err(typed_store_err_from_rocks_err)
210                            .map(|item| item.map(GetResult::Rocks))
211                    })
212                    .collect()
213            }
214            (Storage::InMemory(db), ColumnFamily::InMemory(cf_name)) => db
215                .multi_get(cf_name, keys)
216                .into_iter()
217                .map(|r| Ok(r.map(GetResult::InMemory)))
218                .collect(),
219            _ => unreachable!("typed store invariant violation"),
220        }
221    }
222
223    pub fn cf_handle(&self, name: &str) -> Option<()> {
224        match &self.storage {
225            Storage::Rocks(db) => db.underlying.cf_handle(name).map(|_| ()),
226            Storage::InMemory(db) => db.has_cf(name).then_some(()),
227        }
228    }
229
230    pub fn drop_cf(&self, name: &str) -> Result<(), rocksdb::Error> {
231        match &self.storage {
232            Storage::Rocks(db) => db.underlying.drop_cf(name),
233            Storage::InMemory(db) => {
234                db.drop_cf(name);
235                Ok(())
236            }
237        }
238    }
239
240    pub(crate) fn delete_cf<K: AsRef<[u8]>>(
241        &self,
242        cf: &ColumnFamily,
243        key: K,
244    ) -> Result<(), TypedStoreError> {
245        fail_point!("delete-cf-before");
246        let ret = match (&self.storage, cf) {
247            (Storage::Rocks(db), ColumnFamily::Rocks(_)) => db
248                .underlying
249                .delete_cf(&cf.rocks_cf(db), key)
250                .map_err(typed_store_err_from_rocks_err),
251            (Storage::InMemory(db), ColumnFamily::InMemory(cf_name)) => {
252                db.delete(cf_name, key.as_ref());
253                Ok(())
254            }
255            _ => Err(TypedStoreError::RocksDB(
256                "typed store invariant violation".to_string(),
257            )),
258        };
259        fail_point!("delete-cf-after");
260        #[allow(clippy::let_and_return)]
261        ret
262    }
263
264    pub fn path_for_pruning(&self) -> &Path {
265        match &self.storage {
266            Storage::Rocks(rocks) => rocks.underlying.path(),
267            _ => unimplemented!("method is only supported for rocksdb backend"),
268        }
269    }
270
271    pub(crate) fn put_cf(
272        &self,
273        cf: &ColumnFamily,
274        key: Vec<u8>,
275        value: Vec<u8>,
276    ) -> Result<(), TypedStoreError> {
277        fail_point!("put-cf-before");
278        let ret = match (&self.storage, cf) {
279            (Storage::Rocks(db), ColumnFamily::Rocks(_)) => db
280                .underlying
281                .put_cf(&cf.rocks_cf(db), key, value)
282                .map_err(typed_store_err_from_rocks_err),
283            (Storage::InMemory(db), ColumnFamily::InMemory(cf_name)) => {
284                db.put(cf_name, key, value);
285                Ok(())
286            }
287            _ => Err(TypedStoreError::RocksDB(
288                "typed store invariant violation".to_string(),
289            )),
290        };
291        fail_point!("put-cf-after");
292        #[allow(clippy::let_and_return)]
293        ret
294    }
295
296    pub(crate) fn key_may_exist_cf<K: AsRef<[u8]>>(
297        &self,
298        cf: &ColumnFamily,
299        key: K,
300        readopts: &ReadOptions,
301    ) -> bool {
302        match &self.storage {
303            // [`rocksdb::DBWithThreadMode::key_may_exist_cf`] can have false positives,
304            // but no false negatives. We use it to short-circuit the absent case
305            Storage::Rocks(rocks) => {
306                rocks
307                    .underlying
308                    .key_may_exist_cf_opt(&rocks_cf(rocks, cf.name()), key, readopts)
309            }
310            _ => true,
311        }
312    }
313
314    pub fn write(&self, batch: StorageWriteBatch) -> Result<(), TypedStoreError> {
315        fail_point!("batch-write-before");
316        let ret = match (&self.storage, batch) {
317            (Storage::Rocks(rocks), StorageWriteBatch::Rocks(batch)) => rocks
318                .underlying
319                .write(batch)
320                .map_err(typed_store_err_from_rocks_err),
321            (Storage::InMemory(db), StorageWriteBatch::InMemory(batch)) => {
322                db.write(batch);
323                Ok(())
324            }
325            _ => Err(TypedStoreError::RocksDB(
326                "using invalid batch type for the database".to_string(),
327            )),
328        };
329        fail_point!("batch-write-after");
330
331        #[allow(clippy::let_and_return)]
332        ret
333    }
334
335    pub(crate) fn compact_range_cf<K: AsRef<[u8]>>(
336        &self,
337        cf: &ColumnFamily,
338        start: Option<K>,
339        end: Option<K>,
340    ) {
341        if let Storage::Rocks(rocksdb) = &self.storage {
342            rocksdb
343                .underlying
344                .compact_range_cf(&rocks_cf(rocksdb, cf.name()), start, end);
345        }
346    }
347
348    pub fn checkpoint(&self, path: &Path) -> Result<(), TypedStoreError> {
349        // TODO: implement for other storage types
350        if let Storage::Rocks(rocks) = &self.storage {
351            let checkpoint =
352                Checkpoint::new(&rocks.underlying).map_err(typed_store_err_from_rocks_err)?;
353            checkpoint
354                .create_checkpoint(path)
355                .map_err(|e| TypedStoreError::RocksDB(e.to_string()))?;
356        }
357        Ok(())
358    }
359
360    pub fn get_sampling_interval(&self) -> SamplingInterval {
361        self.metric_conf.read_sample_interval.new_from_self()
362    }
363
364    pub fn multiget_sampling_interval(&self) -> SamplingInterval {
365        self.metric_conf.read_sample_interval.new_from_self()
366    }
367
368    pub fn write_sampling_interval(&self) -> SamplingInterval {
369        self.metric_conf.write_sample_interval.new_from_self()
370    }
371
372    pub fn iter_sampling_interval(&self) -> SamplingInterval {
373        self.metric_conf.iter_sample_interval.new_from_self()
374    }
375
376    pub(crate) fn db_name(&self) -> String {
377        let name = &self.metric_conf.db_name;
378        if name.is_empty() {
379            "default".to_string()
380        } else {
381            name.clone()
382        }
383    }
384
385    pub fn live_files(&self) -> Result<Vec<LiveFile>, Error> {
386        match &self.storage {
387            Storage::Rocks(rocks) => rocks.underlying.live_files(),
388            _ => Ok(vec![]),
389        }
390    }
391
392    pub(crate) fn try_catch_up_with_primary(&self) -> Result<(), TypedStoreError> {
393        if let Storage::Rocks(rocks) = &self.storage {
394            rocks
395                .underlying
396                .try_catch_up_with_primary()
397                .map_err(typed_store_err_from_rocks_err)?;
398        }
399        Ok(())
400    }
401}
402
403fn rocks_cf_from_db<'a>(
404    db: &'a Database,
405    cf_name: &str,
406) -> Result<Arc<rocksdb::BoundColumnFamily<'a>>, TypedStoreError> {
407    match &db.storage {
408        Storage::Rocks(rocksdb) => Ok(rocksdb
409            .underlying
410            .cf_handle(cf_name)
411            .expect("Map-keying column family should have been checked at DB creation")),
412        _ => Err(TypedStoreError::RocksDB(
413            "using invalid batch type for the database".to_string(),
414        )),
415    }
416}
417
418/// An interface to a rocksDB database, keyed by a columnfamily
419#[derive(Clone, Debug)]
420pub struct DBMap<K, V> {
421    pub db: Arc<Database>,
422    _phantom: PhantomData<fn(K) -> V>,
423    column_family: ColumnFamily,
424    pub opts: ReadWriteOptions,
425    db_metrics: Arc<DBMetrics>,
426    get_sample_interval: SamplingInterval,
427    multiget_sample_interval: SamplingInterval,
428    write_sample_interval: SamplingInterval,
429    iter_sample_interval: SamplingInterval,
430    _metrics_task_cancel_handle: Arc<oneshot::Sender<()>>,
431}
432
433unsafe impl<K: Send, V: Send> Send for DBMap<K, V> {}
434
435impl<K, V> DBMap<K, V> {
436    pub(crate) fn new(
437        db: Arc<Database>,
438        opts: &ReadWriteOptions,
439        column_family: ColumnFamily,
440        is_deprecated: bool,
441    ) -> Self {
442        let db_cloned = Arc::downgrade(&db);
443        let db_metrics = DBMetrics::get();
444        let db_metrics_cloned = db_metrics.clone();
445        let cf = column_family.name().to_string();
446
447        let (sender, mut recv) = tokio::sync::oneshot::channel();
448        if !is_deprecated && matches!(db.storage, Storage::Rocks(_)) {
449            tokio::task::spawn(async move {
450                let mut interval =
451                    tokio::time::interval(Duration::from_secs(CF_METRICS_REPORT_PERIOD_SECS));
452                loop {
453                    tokio::select! {
454                        _ = interval.tick() => {
455                            if let Some(db) = db_cloned.upgrade() {
456                                let cf = cf.clone();
457                                let db_metrics = db_metrics.clone();
458                                if let Err(e) = tokio::task::spawn_blocking(move || {
459                                    Self::report_rocksdb_metrics(&db, &cf, &db_metrics);
460                                }).await {
461                                    error!("Failed to log metrics with error: {}", e);
462                                }
463                            } else {
464                                break;
465                            }
466                        }
467                        _ = &mut recv => break,
468                    }
469                }
470                debug!("Returning the cf metric logging task for DBMap: {}", &cf);
471            });
472        }
473        DBMap {
474            db: db.clone(),
475            opts: opts.clone(),
476            _phantom: PhantomData,
477            column_family,
478            db_metrics: db_metrics_cloned,
479            _metrics_task_cancel_handle: Arc::new(sender),
480            get_sample_interval: db.get_sampling_interval(),
481            multiget_sample_interval: db.multiget_sampling_interval(),
482            write_sample_interval: db.write_sampling_interval(),
483            iter_sample_interval: db.iter_sampling_interval(),
484        }
485    }
486
487    /// Reopens an open database as a typed map operating under a specific
488    /// column family. if no column family is passed, the default column
489    /// family is used.
490    #[instrument(level = "debug", skip(db), err)]
491    pub fn reopen(
492        db: &Arc<Database>,
493        opt_cf: Option<&str>,
494        rw_options: &ReadWriteOptions,
495        is_deprecated: bool,
496    ) -> Result<Self, TypedStoreError> {
497        let cf_key = opt_cf
498            .unwrap_or(rocksdb::DEFAULT_COLUMN_FAMILY_NAME)
499            .to_owned();
500
501        let column_family = match &db.storage {
502            Storage::Rocks(_) => ColumnFamily::Rocks(cf_key),
503            Storage::InMemory(_) => ColumnFamily::InMemory(cf_key),
504        };
505        Ok(DBMap::new(
506            db.clone(),
507            rw_options,
508            column_family,
509            is_deprecated,
510        ))
511    }
512
513    pub fn cf_name(&self) -> &str {
514        self.column_family.name()
515    }
516
517    pub fn batch(&self) -> DBBatch {
518        let batch = match &self.db.storage {
519            Storage::Rocks(_) => StorageWriteBatch::Rocks(WriteBatch::default()),
520            Storage::InMemory(_) => StorageWriteBatch::InMemory(InMemoryBatch::default()),
521        };
522        DBBatch::new(
523            &self.db,
524            batch,
525            &self.db_metrics,
526            &self.write_sample_interval,
527        )
528    }
529
530    pub fn compact_range<J: Serialize>(&self, start: &J, end: &J) -> Result<(), TypedStoreError> {
531        let from_buf = be_fix_int_ser(start);
532        let to_buf = be_fix_int_ser(end);
533        self.db
534            .compact_range_cf(&self.column_family, Some(from_buf), Some(to_buf));
535        Ok(())
536    }
537
538    pub fn compact_range_raw(
539        &self,
540        cf_name: &str,
541        start: Vec<u8>,
542        end: Vec<u8>,
543    ) -> Result<(), TypedStoreError> {
544        let cf = match &self.db.storage {
545            Storage::Rocks(_) => ColumnFamily::Rocks(cf_name.to_string()),
546            Storage::InMemory(_) => ColumnFamily::InMemory(cf_name.to_string()),
547        };
548        self.db.compact_range_cf(&cf, Some(start), Some(end));
549        Ok(())
550    }
551
552    /// Returns a vector of raw values corresponding to the keys provided.
553    fn multi_get_pinned<J>(
554        &self,
555        keys: impl IntoIterator<Item = J>,
556    ) -> Result<Vec<Option<GetResult<'_>>>, TypedStoreError>
557    where
558        J: Borrow<K>,
559        K: Serialize,
560    {
561        let _timer = self
562            .db_metrics
563            .op_metrics
564            .rocksdb_multiget_latency_seconds
565            .with_label_values(&[self.cf_name()])
566            .start_timer();
567        let perf_ctx = if self.multiget_sample_interval.sample() {
568            Some(RocksDBPerfContext)
569        } else {
570            None
571        };
572        let keys_bytes = keys.into_iter().map(|k| be_fix_int_ser(k.borrow()));
573        let results: Result<Vec<_>, TypedStoreError> = self
574            .db
575            .multi_get(&self.column_family, keys_bytes, &self.opts.readopts())
576            .into_iter()
577            .collect();
578        let entries = results?;
579        let entry_size = entries
580            .iter()
581            .flatten()
582            .map(|entry| entry.len())
583            .sum::<usize>();
584        self.db_metrics
585            .op_metrics
586            .rocksdb_multiget_bytes
587            .with_label_values(&[self.cf_name()])
588            .observe(entry_size as f64);
589        if perf_ctx.is_some() {
590            self.db_metrics
591                .read_perf_ctx_metrics
592                .report_metrics(self.cf_name());
593        }
594        Ok(entries)
595    }
596
597    pub fn checkpoint_db(&self, path: &Path) -> Result<(), TypedStoreError> {
598        self.db.checkpoint(path)
599    }
600
601    pub fn table_summary(&self) -> eyre::Result<TableSummary>
602    where
603        K: Serialize + DeserializeOwned,
604        V: Serialize + DeserializeOwned,
605    {
606        let mut num_keys = 0;
607        let mut key_bytes_total = 0;
608        let mut value_bytes_total = 0;
609        let mut key_hist = hdrhistogram::Histogram::<u64>::new_with_max(100000, 2).unwrap();
610        let mut value_hist = hdrhistogram::Histogram::<u64>::new_with_max(100000, 2).unwrap();
611        for item in self.safe_iter() {
612            let (key, value) = item?;
613            num_keys += 1;
614            let key_len = be_fix_int_ser(key.borrow()).len();
615            let value_len = bcs::to_bytes(value.borrow())?.len();
616            key_bytes_total += key_len;
617            value_bytes_total += value_len;
618            key_hist.record(key_len as u64)?;
619            value_hist.record(value_len as u64)?;
620        }
621        Ok(TableSummary {
622            num_keys,
623            key_bytes_total,
624            value_bytes_total,
625            key_hist,
626            value_hist,
627        })
628    }
629
630    // Creates metrics and context for tracking an iterator usage and performance.
631    fn create_iter_context(
632        &self,
633    ) -> (
634        Option<HistogramTimer>,
635        Option<Histogram>,
636        Option<Histogram>,
637        Option<RocksDBPerfContext>,
638    ) {
639        let timer = self
640            .db_metrics
641            .op_metrics
642            .rocksdb_iter_latency_seconds
643            .with_label_values(&[self.cf_name()])
644            .start_timer();
645        let bytes_scanned = self
646            .db_metrics
647            .op_metrics
648            .rocksdb_iter_bytes
649            .with_label_values(&[self.cf_name()]);
650        let keys_scanned = self
651            .db_metrics
652            .op_metrics
653            .rocksdb_iter_keys
654            .with_label_values(&[self.cf_name()]);
655        let perf_ctx = if self.iter_sample_interval.sample() {
656            Some(RocksDBPerfContext)
657        } else {
658            None
659        };
660        (
661            Some(timer),
662            Some(bytes_scanned),
663            Some(keys_scanned),
664            perf_ctx,
665        )
666    }
667
668    /// Creates a safe reversed iterator with optional bounds.
669    /// Both upper bound and lower bound are included.
670    #[allow(clippy::complexity)]
671    pub fn reversed_safe_iter_with_bounds(
672        &self,
673        lower_bound: Option<K>,
674        upper_bound: Option<K>,
675    ) -> Result<DbIterator<'_, (K, V)>, TypedStoreError>
676    where
677        K: Serialize + DeserializeOwned,
678        V: Serialize + DeserializeOwned,
679    {
680        let (it_lower_bound, it_upper_bound) = iterator_bounds_with_range::<K>((
681            lower_bound
682                .as_ref()
683                .map(Bound::Included)
684                .unwrap_or(Bound::Unbounded),
685            upper_bound
686                .as_ref()
687                .map(Bound::Included)
688                .unwrap_or(Bound::Unbounded),
689        ));
690        match &self.db.storage {
691            Storage::Rocks(db) => {
692                let readopts = rocks_util::apply_range_bounds(
693                    self.opts.readopts(),
694                    it_lower_bound,
695                    it_upper_bound,
696                );
697                let upper_bound_key = upper_bound.as_ref().map(|k| be_fix_int_ser(&k));
698                let db_iter = db
699                    .underlying
700                    .raw_iterator_cf_opt(&rocks_cf(db, self.column_family.name()), readopts);
701                let (_timer, bytes_scanned, keys_scanned, _perf_ctx) = self.create_iter_context();
702                let iter = SafeIter::new(
703                    self.cf_name().to_string(),
704                    db_iter,
705                    _timer,
706                    _perf_ctx,
707                    bytes_scanned,
708                    keys_scanned,
709                    Some(self.db_metrics.clone()),
710                );
711                Ok(Box::new(SafeRevIter::new(iter, upper_bound_key)))
712            }
713            Storage::InMemory(db) => Ok(db.iterator(
714                self.column_family.name(),
715                it_lower_bound,
716                it_upper_bound,
717                true,
718            )),
719        }
720    }
721}
722
723/// Provides a mutable struct to form a collection of database write operations,
724/// and execute them.
725///
726/// Batching write and delete operations is faster than performing them one by
727/// one and ensures their atomicity,  ie. they are all written or none is.
728/// This is also true of operations across column families in the same database.
729///
730/// Serializations / Deserialization, and naming of column families is performed
731/// by passing a DBMap<K,V> with each operation.
732///
733/// ```
734/// use core::fmt::Error;
735/// use std::sync::Arc;
736///
737/// use prometheus::Registry;
738/// use tempfile::tempdir;
739/// use typed_store::{Map, metrics::DBMetrics, rocks::*};
740///
741/// #[tokio::main]
742/// async fn main() -> Result<(), Error> {
743///     let rocks = open_cf_opts(
744///         tempfile::tempdir().unwrap(),
745///         None,
746///         MetricConf::default(),
747///         &[
748///             ("First_CF", rocksdb::Options::default()),
749///             ("Second_CF", rocksdb::Options::default()),
750///         ],
751///     )
752///     .unwrap();
753///
754///     let db_cf_1 = DBMap::reopen(
755///         &rocks,
756///         Some("First_CF"),
757///         &ReadWriteOptions::default(),
758///         false,
759///     )
760///     .expect("Failed to open storage");
761///     let keys_vals_1 = (1..100).map(|i| (i, i.to_string()));
762///
763///     let db_cf_2 = DBMap::reopen(
764///         &rocks,
765///         Some("Second_CF"),
766///         &ReadWriteOptions::default(),
767///         false,
768///     )
769///     .expect("Failed to open storage");
770///     let keys_vals_2 = (1000..1100).map(|i| (i, i.to_string()));
771///
772///     let mut batch = db_cf_1.batch();
773///     batch
774///         .insert_batch(&db_cf_1, keys_vals_1.clone())
775///         .expect("Failed to batch insert")
776///         .insert_batch(&db_cf_2, keys_vals_2.clone())
777///         .expect("Failed to batch insert");
778///
779///     let _ = batch.write().expect("Failed to execute batch");
780///     for (k, v) in keys_vals_1 {
781///         let val = db_cf_1.get(&k).expect("Failed to get inserted key");
782///         assert_eq!(Some(v), val);
783///     }
784///
785///     for (k, v) in keys_vals_2 {
786///         let val = db_cf_2.get(&k).expect("Failed to get inserted key");
787///         assert_eq!(Some(v), val);
788///     }
789///     Ok(())
790/// }
791/// ```
792pub struct DBBatch {
793    database: Arc<Database>,
794    batch: StorageWriteBatch,
795    db_metrics: Arc<DBMetrics>,
796    write_sample_interval: SamplingInterval,
797}
798
799impl DBBatch {
800    /// Create a new batch associated with a DB reference.
801    ///
802    /// Use `open_cf` to get the DB reference or an existing open database.
803    pub fn new(
804        dbref: &Arc<Database>,
805        batch: StorageWriteBatch,
806        db_metrics: &Arc<DBMetrics>,
807        write_sample_interval: &SamplingInterval,
808    ) -> Self {
809        DBBatch {
810            database: dbref.clone(),
811            batch,
812            db_metrics: db_metrics.clone(),
813            write_sample_interval: write_sample_interval.clone(),
814        }
815    }
816
817    /// Consume the batch and write its operations to the database
818    #[instrument(level = "trace", skip_all, err)]
819    pub fn write(self) -> Result<(), TypedStoreError> {
820        let db_name = self.database.db_name();
821        let timer = self
822            .db_metrics
823            .op_metrics
824            .rocksdb_batch_commit_latency_seconds
825            .with_label_values(&[&db_name])
826            .start_timer();
827        let batch_size = self.size_in_bytes();
828
829        let perf_ctx = if self.write_sample_interval.sample() {
830            Some(RocksDBPerfContext)
831        } else {
832            None
833        };
834        self.database.write(self.batch)?;
835        self.db_metrics
836            .op_metrics
837            .rocksdb_batch_commit_bytes
838            .with_label_values(&[&db_name])
839            .observe(batch_size as f64);
840
841        if perf_ctx.is_some() {
842            self.db_metrics
843                .write_perf_ctx_metrics
844                .report_metrics(&db_name);
845        }
846        let elapsed = timer.stop_and_record();
847        if elapsed > 1.0 {
848            warn!(?elapsed, ?db_name, "very slow batch write");
849            self.db_metrics
850                .op_metrics
851                .rocksdb_very_slow_batch_writes_count
852                .with_label_values(&[&db_name])
853                .inc();
854            self.db_metrics
855                .op_metrics
856                .rocksdb_very_slow_batch_writes_duration_ms
857                .with_label_values(&[&db_name])
858                .inc_by((elapsed * 1000.0) as u64);
859        }
860        Ok(())
861    }
862
863    pub fn size_in_bytes(&self) -> usize {
864        match self.batch {
865            StorageWriteBatch::Rocks(ref b) => b.size_in_bytes(),
866            StorageWriteBatch::InMemory(_) => 0,
867        }
868    }
869
870    pub fn delete_batch<J: Borrow<K>, K: Serialize, V>(
871        &mut self,
872        db: &DBMap<K, V>,
873        purged_vals: impl IntoIterator<Item = J>,
874    ) -> Result<(), TypedStoreError> {
875        if !Arc::ptr_eq(&db.db, &self.database) {
876            return Err(TypedStoreError::CrossDBBatch);
877        }
878
879        purged_vals
880            .into_iter()
881            .try_for_each::<_, Result<_, TypedStoreError>>(|k| {
882                let k_buf = be_fix_int_ser(k.borrow());
883                match (&mut self.batch, &db.column_family) {
884                    (StorageWriteBatch::Rocks(b), ColumnFamily::Rocks(name)) => {
885                        b.delete_cf(&rocks_cf_from_db(&self.database, name)?, k_buf)
886                    }
887                    (StorageWriteBatch::InMemory(b), ColumnFamily::InMemory(name)) => {
888                        b.delete_cf(name, k_buf)
889                    }
890                    _ => Err(TypedStoreError::RocksDB(
891                        "typed store invariant violation".to_string(),
892                    ))?,
893                }
894                Ok(())
895            })?;
896        Ok(())
897    }
898
899    /// Deletes a range of keys between `from` (inclusive) and `to`
900    /// (non-inclusive) by writing a range delete tombstone in the db map.
901    /// The effect of this write is visible immediately, i.e. you won't see
902    /// old values when you do a lookup or scan.
903    pub fn schedule_delete_range<K: Serialize, V>(
904        &mut self,
905        db: &DBMap<K, V>,
906        from: &K,
907        to: &K,
908    ) -> Result<(), TypedStoreError> {
909        if !Arc::ptr_eq(&db.db, &self.database) {
910            return Err(TypedStoreError::CrossDBBatch);
911        }
912
913        let from_buf = be_fix_int_ser(from);
914        let to_buf = be_fix_int_ser(to);
915
916        if let StorageWriteBatch::Rocks(b) = &mut self.batch {
917            b.delete_range_cf(
918                &rocks_cf_from_db(&self.database, db.cf_name())?,
919                from_buf,
920                to_buf,
921            );
922        }
923        Ok(())
924    }
925
926    /// inserts a range of (key, value) pairs given as an iterator
927    pub fn insert_batch<J: Borrow<K>, K: Serialize, U: Borrow<V>, V: Serialize>(
928        &mut self,
929        db: &DBMap<K, V>,
930        new_vals: impl IntoIterator<Item = (J, U)>,
931    ) -> Result<&mut Self, TypedStoreError> {
932        if !Arc::ptr_eq(&db.db, &self.database) {
933            return Err(TypedStoreError::CrossDBBatch);
934        }
935        let mut total = 0usize;
936        new_vals
937            .into_iter()
938            .try_for_each::<_, Result<_, TypedStoreError>>(|(k, v)| {
939                let k_buf = be_fix_int_ser(k.borrow());
940                let v_buf = bcs::to_bytes(v.borrow()).map_err(typed_store_err_from_bcs_err)?;
941                total += k_buf.len() + v_buf.len();
942                if db.opts.log_value_hash {
943                    let key_hash = default_hash(&k_buf);
944                    let value_hash = default_hash(&v_buf);
945                    debug!(
946                        "Insert to DB table: {:?}, key_hash: {:?}, value_hash: {:?}",
947                        db.cf_name(),
948                        key_hash,
949                        value_hash
950                    );
951                }
952                match (&mut self.batch, &db.column_family) {
953                    (StorageWriteBatch::Rocks(b), ColumnFamily::Rocks(name)) => {
954                        b.put_cf(&rocks_cf_from_db(&self.database, name)?, k_buf, v_buf)
955                    }
956                    (StorageWriteBatch::InMemory(b), ColumnFamily::InMemory(name)) => {
957                        b.put_cf(name, k_buf, v_buf)
958                    }
959                    _ => Err(TypedStoreError::RocksDB(
960                        "typed store invariant violation".to_string(),
961                    ))?,
962                }
963                Ok(())
964            })?;
965        self.db_metrics
966            .op_metrics
967            .rocksdb_batch_put_bytes
968            .with_label_values(&[db.cf_name()])
969            .observe(total as f64);
970        Ok(self)
971    }
972}
973
974impl<'a, K, V> Map<'a, K, V> for DBMap<K, V>
975where
976    K: Serialize + DeserializeOwned,
977    V: Serialize + DeserializeOwned,
978{
979    type Error = TypedStoreError;
980
981    #[instrument(level = "trace", skip_all, err)]
982    fn contains_key(&self, key: &K) -> Result<bool, TypedStoreError> {
983        let key_buf = be_fix_int_ser(key);
984        let readopts = self.opts.readopts();
985        Ok(self
986            .db
987            .key_may_exist_cf(&self.column_family, &key_buf, &readopts)
988            && self
989                .db
990                .get(&self.column_family, &key_buf, &readopts)?
991                .is_some())
992    }
993
994    #[instrument(level = "trace", skip_all, err)]
995    fn multi_contains_keys<J>(
996        &self,
997        keys: impl IntoIterator<Item = J>,
998    ) -> Result<Vec<bool>, Self::Error>
999    where
1000        J: Borrow<K>,
1001    {
1002        let values = self.multi_get_pinned(keys)?;
1003        Ok(values.into_iter().map(|v| v.is_some()).collect())
1004    }
1005
1006    #[instrument(level = "trace", skip_all, err)]
1007    fn get(&self, key: &K) -> Result<Option<V>, TypedStoreError> {
1008        let _timer = self
1009            .db_metrics
1010            .op_metrics
1011            .rocksdb_get_latency_seconds
1012            .with_label_values(&[self.cf_name()])
1013            .start_timer();
1014        let perf_ctx = if self.get_sample_interval.sample() {
1015            Some(RocksDBPerfContext)
1016        } else {
1017            None
1018        };
1019        let key_buf = be_fix_int_ser(key);
1020        let res = self
1021            .db
1022            .get(&self.column_family, &key_buf, &self.opts.readopts())?;
1023        self.db_metrics
1024            .op_metrics
1025            .rocksdb_get_bytes
1026            .with_label_values(&[self.cf_name()])
1027            .observe(res.as_ref().map_or(0.0, |v| v.len() as f64));
1028        if perf_ctx.is_some() {
1029            self.db_metrics
1030                .read_perf_ctx_metrics
1031                .report_metrics(self.cf_name());
1032        }
1033        match res {
1034            Some(data) => {
1035                let value = bcs::from_bytes(&data).map_err(typed_store_err_from_bcs_err);
1036                if value.is_err() {
1037                    let key_hash = default_hash(&key_buf);
1038                    let value_hash = default_hash(&data);
1039                    debug_fatal!(
1040                        "Failed to deserialize value from DB table {:?}, key_hash: {:?}, value_hash: {:?}, error: {:?}",
1041                        self.cf_name(),
1042                        key_hash,
1043                        value_hash,
1044                        value.as_ref().err().unwrap()
1045                    );
1046                }
1047                Ok(Some(value?))
1048            }
1049            None => Ok(None),
1050        }
1051    }
1052
1053    #[instrument(level = "trace", skip_all, err)]
1054    fn insert(&self, key: &K, value: &V) -> Result<(), TypedStoreError> {
1055        let timer = self
1056            .db_metrics
1057            .op_metrics
1058            .rocksdb_put_latency_seconds
1059            .with_label_values(&[self.cf_name()])
1060            .start_timer();
1061        let perf_ctx = if self.write_sample_interval.sample() {
1062            Some(RocksDBPerfContext)
1063        } else {
1064            None
1065        };
1066        let key_buf = be_fix_int_ser(key);
1067        let value_buf = bcs::to_bytes(value).map_err(typed_store_err_from_bcs_err)?;
1068        self.db_metrics
1069            .op_metrics
1070            .rocksdb_put_bytes
1071            .with_label_values(&[self.cf_name()])
1072            .observe((key_buf.len() + value_buf.len()) as f64);
1073        if perf_ctx.is_some() {
1074            self.db_metrics
1075                .write_perf_ctx_metrics
1076                .report_metrics(self.cf_name());
1077        }
1078        self.db.put_cf(&self.column_family, key_buf, value_buf)?;
1079
1080        let elapsed = timer.stop_and_record();
1081        if elapsed > 1.0 {
1082            warn!(?elapsed, cf = ?self.cf_name(), "very slow insert");
1083            self.db_metrics
1084                .op_metrics
1085                .rocksdb_very_slow_puts_count
1086                .with_label_values(&[self.cf_name()])
1087                .inc();
1088            self.db_metrics
1089                .op_metrics
1090                .rocksdb_very_slow_puts_duration_ms
1091                .with_label_values(&[self.cf_name()])
1092                .inc_by((elapsed * 1000.0) as u64);
1093        }
1094
1095        Ok(())
1096    }
1097
1098    #[instrument(level = "trace", skip_all, err)]
1099    fn remove(&self, key: &K) -> Result<(), TypedStoreError> {
1100        let _timer = self
1101            .db_metrics
1102            .op_metrics
1103            .rocksdb_delete_latency_seconds
1104            .with_label_values(&[self.cf_name()])
1105            .start_timer();
1106        let perf_ctx = if self.write_sample_interval.sample() {
1107            Some(RocksDBPerfContext)
1108        } else {
1109            None
1110        };
1111        let key_buf = be_fix_int_ser(key);
1112        self.db.delete_cf(&self.column_family, key_buf)?;
1113        self.db_metrics
1114            .op_metrics
1115            .rocksdb_deletes
1116            .with_label_values(&[self.cf_name()])
1117            .inc();
1118        if perf_ctx.is_some() {
1119            self.db_metrics
1120                .write_perf_ctx_metrics
1121                .report_metrics(self.cf_name());
1122        }
1123        Ok(())
1124    }
1125
1126    /// Writes a range delete tombstone to delete all entries in the db map.
1127    /// The effect of this write is visible immediately, i.e. you won't see
1128    /// old values when you do a lookup or scan.
1129    #[instrument(level = "trace", skip_all, err)]
1130    fn schedule_delete_all(&self) -> Result<(), TypedStoreError> {
1131        let first_key = self.safe_iter().next().transpose()?.map(|(k, _v)| k);
1132        let last_key = self
1133            .reversed_safe_iter_with_bounds(None, None)?
1134            .next()
1135            .transpose()?
1136            .map(|(k, _v)| k);
1137        if let Some((first_key, last_key)) = first_key.zip(last_key) {
1138            let mut batch = self.batch();
1139            batch.schedule_delete_range(self, &first_key, &last_key)?;
1140            batch.write()?;
1141        }
1142        Ok(())
1143    }
1144
1145    fn is_empty(&self) -> bool {
1146        self.safe_iter().next().is_none()
1147    }
1148
1149    fn safe_iter(&'a self) -> DbIterator<'a, (K, V)> {
1150        match &self.db.storage {
1151            Storage::Rocks(db) => {
1152                let db_iter = db.underlying.raw_iterator_cf_opt(
1153                    &rocks_cf(db, self.column_family.name()),
1154                    self.opts.readopts(),
1155                );
1156                let (_timer, bytes_scanned, keys_scanned, _perf_ctx) = self.create_iter_context();
1157                Box::new(SafeIter::new(
1158                    self.cf_name().to_string(),
1159                    db_iter,
1160                    _timer,
1161                    _perf_ctx,
1162                    bytes_scanned,
1163                    keys_scanned,
1164                    Some(self.db_metrics.clone()),
1165                ))
1166            }
1167            Storage::InMemory(db) => db.iterator(self.column_family.name(), None, None, false),
1168        }
1169    }
1170
1171    fn safe_iter_with_bounds(
1172        &'a self,
1173        lower_bound: Option<K>,
1174        upper_bound: Option<K>,
1175    ) -> DbIterator<'a, (K, V)> {
1176        let (lower_bound, upper_bound) = iterator_bounds(lower_bound, upper_bound);
1177        match &self.db.storage {
1178            Storage::Rocks(db) => {
1179                let readopts =
1180                    rocks_util::apply_range_bounds(self.opts.readopts(), lower_bound, upper_bound);
1181                let db_iter = db
1182                    .underlying
1183                    .raw_iterator_cf_opt(&rocks_cf(db, self.column_family.name()), readopts);
1184                let (_timer, bytes_scanned, keys_scanned, _perf_ctx) = self.create_iter_context();
1185                Box::new(SafeIter::new(
1186                    self.cf_name().to_string(),
1187                    db_iter,
1188                    _timer,
1189                    _perf_ctx,
1190                    bytes_scanned,
1191                    keys_scanned,
1192                    Some(self.db_metrics.clone()),
1193                ))
1194            }
1195            Storage::InMemory(db) => {
1196                db.iterator(self.column_family.name(), lower_bound, upper_bound, false)
1197            }
1198        }
1199    }
1200
1201    fn safe_range_iter(&'a self, range: impl RangeBounds<K>) -> DbIterator<'a, (K, V)> {
1202        let (lower_bound, upper_bound) = iterator_bounds_with_range(range);
1203        match &self.db.storage {
1204            Storage::Rocks(db) => {
1205                let readopts =
1206                    rocks_util::apply_range_bounds(self.opts.readopts(), lower_bound, upper_bound);
1207                let db_iter = db
1208                    .underlying
1209                    .raw_iterator_cf_opt(&rocks_cf(db, self.column_family.name()), readopts);
1210                let (_timer, bytes_scanned, keys_scanned, _perf_ctx) = self.create_iter_context();
1211                Box::new(SafeIter::new(
1212                    self.cf_name().to_string(),
1213                    db_iter,
1214                    _timer,
1215                    _perf_ctx,
1216                    bytes_scanned,
1217                    keys_scanned,
1218                    Some(self.db_metrics.clone()),
1219                ))
1220            }
1221            Storage::InMemory(db) => {
1222                db.iterator(self.column_family.name(), lower_bound, upper_bound, false)
1223            }
1224        }
1225    }
1226
1227    /// Returns a vector of values corresponding to the keys provided.
1228    #[instrument(level = "trace", skip_all, err)]
1229    fn multi_get<J>(
1230        &self,
1231        keys: impl IntoIterator<Item = J>,
1232    ) -> Result<Vec<Option<V>>, TypedStoreError>
1233    where
1234        J: Borrow<K>,
1235    {
1236        let results = self.multi_get_pinned(keys)?;
1237        let values_parsed: Result<Vec<_>, TypedStoreError> = results
1238            .into_iter()
1239            .map(|value_byte| match value_byte {
1240                Some(data) => Ok(Some(
1241                    bcs::from_bytes(&data).map_err(typed_store_err_from_bcs_err)?,
1242                )),
1243                None => Ok(None),
1244            })
1245            .collect();
1246
1247        values_parsed
1248    }
1249
1250    /// Convenience method for batch insertion
1251    #[instrument(level = "trace", skip_all, err)]
1252    fn multi_insert<J, U>(
1253        &self,
1254        key_val_pairs: impl IntoIterator<Item = (J, U)>,
1255    ) -> Result<(), Self::Error>
1256    where
1257        J: Borrow<K>,
1258        U: Borrow<V>,
1259    {
1260        let mut batch = self.batch();
1261        batch.insert_batch(self, key_val_pairs)?;
1262        batch.write()
1263    }
1264
1265    /// Convenience method for batch removal
1266    #[instrument(level = "trace", skip_all, err)]
1267    fn multi_remove<J>(&self, keys: impl IntoIterator<Item = J>) -> Result<(), Self::Error>
1268    where
1269        J: Borrow<K>,
1270    {
1271        let mut batch = self.batch();
1272        batch.delete_batch(self, keys)?;
1273        batch.write()
1274    }
1275
1276    /// Try to catch up with primary when running as secondary
1277    #[instrument(level = "trace", skip_all, err)]
1278    fn try_catch_up_with_primary(&self) -> Result<(), Self::Error> {
1279        self.db.try_catch_up_with_primary()
1280    }
1281}
1282
1283fn default_hash(value: &[u8]) -> Digest<32> {
1284    let mut hasher = fastcrypto::hash::Blake2b256::default();
1285    hasher.update(value);
1286    hasher.finalize()
1287}