typed_store_derive/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::collections::BTreeMap;
6
7use itertools::Itertools;
8use proc_macro::TokenStream;
9use proc_macro2::Ident;
10use quote::quote;
11use syn::{
12    AngleBracketedGenericArguments, Attribute, Generics, ItemStruct, Lit, Meta, PathArguments,
13    Type::{self},
14    parse_macro_input,
15};
16
17// This is used as default when none is specified
18const DEFAULT_DB_OPTIONS_CUSTOM_FN: &str = "typed_store::rocks::default_db_options";
19// Custom function which returns the option and overrides the defaults for this
20// table
21const DB_OPTIONS_CUSTOM_FUNCTION: &str = "default_options_override_fn";
22// Use a different name for the column than the identifier
23const DB_OPTIONS_RENAME: &str = "rename";
24// Deprecate a column family
25const DB_OPTIONS_DEPRECATE: &str = "deprecated";
26
27/// Options can either be simplified form or
28enum GeneralTableOptions {
29    OverrideFunction(String),
30}
31
32impl Default for GeneralTableOptions {
33    fn default() -> Self {
34        Self::OverrideFunction(DEFAULT_DB_OPTIONS_CUSTOM_FN.to_owned())
35    }
36}
37
38// Extracts the field names, field types, inner types (K,V in {map_type_name}<K,
39// V>), and the options attrs
40fn extract_struct_info(input: ItemStruct) -> ExtractedStructInfo {
41    let mut deprecated_cfs = vec![];
42
43    let info = input.fields.iter().map(|f| {
44        let attrs: BTreeMap<_, _> = f
45            .attrs
46            .iter()
47            .filter(|a| {
48                a.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION)
49                    || a.path.is_ident(DB_OPTIONS_RENAME)
50                    || a.path.is_ident(DB_OPTIONS_DEPRECATE)
51            })
52            .map(|a| (a.path.get_ident().unwrap().to_string(), a))
53            .collect();
54
55        let options = if let Some(options) = attrs.get(DB_OPTIONS_CUSTOM_FUNCTION) {
56            GeneralTableOptions::OverrideFunction(get_options_override_function(options).unwrap())
57        } else {
58            GeneralTableOptions::default()
59        };
60
61        let ty = &f.ty;
62        if let Type::Path(p) = ty {
63            let type_info = &p.path.segments.first().unwrap();
64            let inner_type =
65                if let PathArguments::AngleBracketed(angle_bracket_type) = &type_info.arguments {
66                    angle_bracket_type.clone()
67                } else {
68                    panic!("All struct members must be of type DBMap");
69                };
70
71            let type_str = format!("{}", &type_info.ident);
72            if type_str == "DBMap" {
73                let field_name = f.ident.as_ref().unwrap().clone();
74                let cf_name = if let Some(rename) = attrs.get(DB_OPTIONS_RENAME) {
75                    match rename.parse_meta().expect("Cannot parse meta of attribute") {
76                        Meta::NameValue(val) => {
77                            if let Lit::Str(s) = val.lit {
78                                // convert to ident
79                                s.parse().expect("Rename value must be identifier")
80                            } else {
81                                panic!("Expected string value for rename")
82                            }
83                        }
84                        _ => panic!("Expected string value for rename"),
85                    }
86                } else {
87                    field_name.clone()
88                };
89                if attrs.contains_key(DB_OPTIONS_DEPRECATE) {
90                    deprecated_cfs.push(field_name.clone());
91                }
92
93                return ((field_name, cf_name, type_str), (inner_type, options));
94            } else {
95                panic!("All struct members must be of type DBMap");
96            }
97        }
98        panic!("All struct members must be of type DBMap");
99    });
100
101    let (field_info, inner_types_with_opts): (Vec<_>, Vec<_>) = info.unzip();
102    let (field_names, cf_names, simple_field_type_names): (Vec<_>, Vec<_>, Vec<_>) =
103        field_info.into_iter().multiunzip();
104
105    // Check for homogeneous types
106    if let Some(first) = simple_field_type_names.first() {
107        simple_field_type_names.iter().for_each(|q| {
108            if q != first {
109                panic!("All struct members must be of same type");
110            }
111        })
112    } else {
113        panic!("Cannot derive on empty struct");
114    };
115
116    let (inner_types, options): (Vec<_>, Vec<_>) = inner_types_with_opts.into_iter().unzip();
117
118    ExtractedStructInfo {
119        field_names,
120        cf_names,
121        inner_types,
122        derived_table_options: options,
123        deprecated_cfs,
124    }
125}
126
127/// Extracts the table options override function
128/// The function must take no args and return Options
129fn get_options_override_function(attr: &Attribute) -> syn::Result<String> {
130    let meta = attr.parse_meta()?;
131
132    let val = match meta.clone() {
133        Meta::NameValue(val) => val,
134        _ => {
135            return Err(syn::Error::new_spanned(
136                meta,
137                format!(
138                    "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
139                ),
140            ));
141        }
142    };
143
144    if !val.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION) {
145        return Err(syn::Error::new_spanned(
146            meta,
147            format!(
148                "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
149            ),
150        ));
151    }
152
153    let fn_name = match val.lit {
154        Lit::Str(fn_name) => fn_name,
155        _ => {
156            return Err(syn::Error::new_spanned(
157                meta,
158                format!(
159                    "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
160                ),
161            ));
162        }
163    };
164    Ok(fn_name.value())
165}
166
167fn extract_generics_names(generics: &Generics) -> Vec<Ident> {
168    generics
169        .params
170        .iter()
171        .map(|g| match g {
172            syn::GenericParam::Type(t) => t.ident.clone(),
173            _ => panic!("Unsupported generic type"),
174        })
175        .collect()
176}
177
178struct ExtractedStructInfo {
179    field_names: Vec<Ident>,
180    cf_names: Vec<Ident>,
181    inner_types: Vec<AngleBracketedGenericArguments>,
182    derived_table_options: Vec<GeneralTableOptions>,
183    deprecated_cfs: Vec<Ident>,
184}
185
186#[proc_macro_derive(DBMapUtils, attributes(default_options_override_fn, rename))]
187pub fn derive_dbmap_utils_general(input: TokenStream) -> TokenStream {
188    let input = parse_macro_input!(input as ItemStruct);
189    let name = &input.ident;
190    let generics = &input.generics;
191    let generics_names = extract_generics_names(generics);
192
193    // TODO: use `parse_quote` over `parse()`
194    let ExtractedStructInfo {
195        field_names,
196        cf_names,
197        inner_types,
198        derived_table_options,
199        deprecated_cfs,
200    } = extract_struct_info(input.clone());
201
202    let (key_names, value_names): (Vec<_>, Vec<_>) = inner_types
203        .iter()
204        .map(|q| (q.args.first().unwrap(), q.args.last().unwrap()))
205        .unzip();
206
207    let default_options_override_fn_names: Vec<proc_macro2::TokenStream> = derived_table_options
208        .iter()
209        .map(|q| {
210            let GeneralTableOptions::OverrideFunction(fn_name) = q;
211            fn_name.parse().unwrap()
212        })
213        .collect();
214
215    let generics_bounds =
216        "std::fmt::Debug + serde::Serialize + for<'de> serde::de::Deserialize<'de>";
217    let generics_bounds_token: proc_macro2::TokenStream = generics_bounds.parse().unwrap();
218
219    let config_struct_name_str = format!("{name}Configurator");
220    let config_struct_name: proc_macro2::TokenStream = config_struct_name_str.parse().unwrap();
221
222    let intermediate_db_map_struct_name_str = format!("{name}IntermediateDBMapStructPrimary");
223    let intermediate_db_map_struct_name: proc_macro2::TokenStream =
224        intermediate_db_map_struct_name_str.parse().unwrap();
225
226    let secondary_db_map_struct_name_str = format!("{name}ReadOnly");
227    let secondary_db_map_struct_name: proc_macro2::TokenStream =
228        secondary_db_map_struct_name_str.parse().unwrap();
229
230    TokenStream::from(quote! {
231
232        // <----------- This section generates the configurator struct -------------->
233
234        /// Create config structs for configuring DBMap tables
235        pub struct #config_struct_name {
236            #(
237                pub #field_names : typed_store::rocks::DBOptions,
238            )*
239        }
240
241        impl #config_struct_name {
242            /// Initialize to defaults
243            pub fn init() -> Self {
244                Self {
245                    #(
246                        #field_names : typed_store::rocks::default_db_options(),
247                    )*
248                }
249            }
250
251            /// Build a config
252            pub fn build(&self) -> typed_store::rocks::DBMapTableConfigMap {
253                typed_store::rocks::DBMapTableConfigMap::new([
254                    #(
255                        (stringify!(#field_names).to_owned(), self.#field_names.clone()),
256                    )*
257                ].into_iter().collect())
258            }
259        }
260
261        impl <
262                #(
263                    #generics_names: #generics_bounds_token,
264                )*
265            > #name #generics {
266
267                pub fn configurator() -> #config_struct_name {
268                    #config_struct_name::init()
269                }
270        }
271
272        // <----------- This section generates the core open logic for opening DBMaps -------------->
273
274        /// Create an intermediate struct used to open the DBMap tables in primary mode
275        /// This is only used internally
276        struct #intermediate_db_map_struct_name #generics {
277                #(
278                    pub #field_names : DBMap #inner_types,
279                )*
280        }
281
282
283        impl <
284                #(
285                    #generics_names: #generics_bounds_token,
286                )*
287            > #intermediate_db_map_struct_name #generics {
288            /// Opens a set of tables in read-write mode
289            /// If as_secondary_with_path is set, the DB is opened in read only mode with the path specified
290            pub fn open_tables_impl(
291                path: std::path::PathBuf,
292                as_secondary_with_path: Option<std::path::PathBuf>,
293                is_transaction: bool,
294                metric_conf: typed_store::rocks::MetricConf,
295                global_db_options_override: Option<typed_store::rocksdb::Options>,
296                tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
297                remove_deprecated_tables: bool,
298            ) -> Self {
299                let path = &path;
300                let default_cf_opt = if let Some(opt) = global_db_options_override.as_ref() {
301                    typed_store::rocks::DBOptions {
302                        options: opt.clone(),
303                        rw_options: typed_store::rocks::default_db_options().rw_options,
304                    }
305                } else {
306                    typed_store::rocks::default_db_options()
307                };
308                let (db, rwopt_cfs) = {
309                    let opt_cfs = match tables_db_options_override {
310                        None => [
311                            #(
312                                (stringify!(#cf_names).to_owned(), #default_options_override_fn_names()),
313                            )*
314                        ],
315                        Some(o) => [
316                            #(
317                                (stringify!(#cf_names).to_owned(), o.to_map().get(stringify!(#cf_names)).unwrap_or(&default_cf_opt).clone()),
318                            )*
319                        ]
320                    };
321                    // Safe to call unwrap because we will have at least one field_name entry in the struct
322                    let rwopt_cfs: std::collections::HashMap<String, typed_store::rocks::ReadWriteOptions> = opt_cfs.iter().map(|q| (q.0.as_str().to_string(), q.1.rw_options.clone())).collect();
323                    let opt_cfs: Vec<_> = opt_cfs.iter().map(|q| (q.0.as_str(), q.1.options.clone())).collect();
324                    let db = match (as_secondary_with_path.clone(), is_transaction) {
325                        (Some(p), _) => typed_store::rocks::open_cf_opts_secondary(path, Some(&p), global_db_options_override, metric_conf, &opt_cfs),
326                        (_, true) => typed_store::rocks::open_cf_opts_transactional(path, global_db_options_override, metric_conf, &opt_cfs),
327                        _ => typed_store::rocks::open_cf_opts(path, global_db_options_override, metric_conf, &opt_cfs)
328                    };
329                    db.map(|d| (d, rwopt_cfs))
330                }.expect(&format!("Cannot open DB at {:?}", path));
331                let deprecated_tables = vec![#(stringify!(#deprecated_cfs),)*];
332                let (
333                        #(
334                            #field_names
335                        ),*
336                ) = (#(
337                        DBMap::#inner_types::reopen(&db, Some(stringify!(#cf_names)), rwopt_cfs.get(stringify!(#cf_names)).unwrap_or(&typed_store::rocks::ReadWriteOptions::default()), remove_deprecated_tables && deprecated_tables.contains(&stringify!(#cf_names))).expect(&format!("Cannot open {} CF.", stringify!(#cf_names))[..])
338                    ),*);
339
340                if as_secondary_with_path.is_none() && remove_deprecated_tables {
341                    #(
342                        db.drop_cf(stringify!(#deprecated_cfs)).expect("failed to drop a deprecated cf");
343                    )*
344                }
345                Self {
346                    #(
347                        #field_names,
348                    )*
349                }
350            }
351        }
352
353
354        // <----------- This section generates the read-write open logic and other common utils -------------->
355
356        impl <
357                #(
358                    #generics_names: #generics_bounds_token,
359                )*
360            > #name #generics {
361            /// Opens a set of tables in read-write mode
362            /// Only one process is allowed to do this at a time
363            /// `global_db_options_override` apply to the whole DB
364            /// `tables_db_options_override` apply to each table. If `None`, the attributes from `default_options_override_fn` are used if any
365            #[expect(unused_parens)]
366            pub fn open_tables_read_write(
367                path: std::path::PathBuf,
368                metric_conf: typed_store::rocks::MetricConf,
369                global_db_options_override: Option<typed_store::rocksdb::Options>,
370                tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>
371            ) -> Self {
372                let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, false, metric_conf, global_db_options_override, tables_db_options_override, false);
373                Self {
374                    #(
375                        #field_names: inner.#field_names,
376                    )*
377                }
378            }
379
380            #[expect(unused_parens)]
381            pub fn open_tables_read_write_with_deprecation_option(
382                path: std::path::PathBuf,
383                metric_conf: typed_store::rocks::MetricConf,
384                global_db_options_override: Option<typed_store::rocksdb::Options>,
385                tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
386                remove_deprecated_tables: bool,
387            ) -> Self {
388                let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, false, metric_conf, global_db_options_override, tables_db_options_override, remove_deprecated_tables);
389                Self {
390                    #(
391                        #field_names: inner.#field_names,
392                    )*
393                }
394            }
395
396            /// Opens a set of tables in transactional read-write mode
397            /// Only one process is allowed to do this at a time
398            /// `global_db_options_override` apply to the whole DB
399            /// `tables_db_options_override` apply to each table. If `None`, the attributes from `default_options_override_fn` are used if any
400            #[expect(unused_parens)]
401            pub fn open_tables_transactional(
402                path: std::path::PathBuf,
403                metric_conf: typed_store::rocks::MetricConf,
404                global_db_options_override: Option<typed_store::rocksdb::Options>,
405                tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>
406            ) -> Self {
407                let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, true, metric_conf, global_db_options_override, tables_db_options_override, false);
408                Self {
409                    #(
410                        #field_names: inner.#field_names,
411                    )*
412                }
413            }
414
415            /// Returns a list of the tables name and type pairs
416            pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
417                vec![#(
418                    (stringify!(#field_names).to_owned(), (stringify!(#key_names).to_owned(), stringify!(#value_names).to_owned())),
419                )*].into_iter().collect()
420            }
421
422            /// This opens the DB in read only mode and returns a struct which exposes debug features
423            pub fn get_read_only_handle (
424                primary_path: std::path::PathBuf,
425                with_secondary_path: Option<std::path::PathBuf>,
426                global_db_options_override: Option<typed_store::rocksdb::Options>,
427                metric_conf: typed_store::rocks::MetricConf,
428                ) -> #secondary_db_map_struct_name #generics {
429                #secondary_db_map_struct_name::open_tables_read_only(primary_path, with_secondary_path, metric_conf, global_db_options_override)
430            }
431        }
432
433
434        // <----------- This section generates the features that use read-only open logic -------------->
435        /// Create an intermediate struct used to open the DBMap tables in secondary mode
436        /// This is only used internally
437        pub struct #secondary_db_map_struct_name #generics {
438            #(
439                pub #field_names : DBMap #inner_types,
440            )*
441        }
442
443        impl <
444                #(
445                    #generics_names: #generics_bounds_token,
446                )*
447            > #secondary_db_map_struct_name #generics {
448            /// Open in read only mode. No limitation on number of processes to do this
449            pub fn open_tables_read_only(
450                primary_path: std::path::PathBuf,
451                with_secondary_path: Option<std::path::PathBuf>,
452                metric_conf: typed_store::rocks::MetricConf,
453                global_db_options_override: Option<typed_store::rocksdb::Options>,
454            ) -> Self {
455                let inner = match with_secondary_path {
456                    Some(q) => #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(q), false, metric_conf, global_db_options_override, None, false),
457                    None => {
458                        let p: std::path::PathBuf = tempfile::tempdir()
459                        .expect("Failed to open temporary directory")
460                        .keep();
461                        #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(p), false, metric_conf, global_db_options_override, None, false)
462                    }
463                };
464                Self {
465                    #(
466                        #field_names: inner.#field_names,
467                    )*
468                }
469            }
470
471            fn cf_name_to_table_name(cf_name: &str) -> eyre::Result<&'static str> {
472                Ok(match cf_name {
473                    #(
474                        stringify!(#cf_names) => stringify!(#field_names),
475                    )*
476                    _ => eyre::bail!("No such cf name: {}", cf_name),
477                })
478            }
479
480            /// Dump all key-value pairs in the page at the given table name
481            /// Tables must be opened in read only mode using `open_tables_read_only`
482            pub fn dump(&self, cf_name: &str, page_size: u16, page_number: usize) -> eyre::Result<std::collections::BTreeMap<String, String>> {
483                let table_name = Self::cf_name_to_table_name(cf_name)?;
484
485                Ok(match table_name {
486                    #(
487                        stringify!(#field_names) => {
488                            typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
489                            typed_store::traits::Map::unbounded_iter(&self.#field_names)
490                                .skip((page_number * (page_size) as usize))
491                                .take(page_size as usize)
492                                .map(|(k, v)| (format!("{:?}", k), format!("{:?}", v)))
493                                .collect::<std::collections::BTreeMap<_, _>>()
494                        }
495                    )*
496
497                    _ => eyre::bail!("No such table name: {}", table_name),
498                })
499            }
500
501            /// Get key value sizes from the db
502            /// Tables must be opened in read only mode using `open_tables_read_only`
503            pub fn table_summary(&self, table_name: &str) -> eyre::Result<typed_store::traits::TableSummary> {
504                let mut count = 0;
505                let mut key_bytes = 0;
506                let mut value_bytes = 0;
507                match table_name {
508                    #(
509                        stringify!(#field_names) => {
510                            typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
511                            self.#field_names.table_summary()
512                        }
513                    )*
514
515                    _ => eyre::bail!("No such table name: {}", table_name),
516                }
517            }
518
519            /// Count the keys in this table
520            /// Tables must be opened in read only mode using `open_tables_read_only`
521            pub fn count_keys(&self, table_name: &str) -> eyre::Result<usize> {
522                Ok(match table_name {
523                    #(
524                        stringify!(#field_names) => {
525                            typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
526                            typed_store::traits::Map::unbounded_iter(&self.#field_names).count()
527                        }
528                    )*
529
530                    _ => eyre::bail!("No such table name: {}", table_name),
531                })
532            }
533
534            pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
535                vec![#(
536                    (stringify!(#field_names).to_owned(), (stringify!(#key_names).to_owned(), stringify!(#value_names).to_owned())),
537                )*].into_iter().collect()
538            }
539
540            /// Try catch up with primary for all tables. This can be a slow operation
541            /// Tables must be opened in read only mode using `open_tables_read_only`
542            pub fn try_catch_up_with_primary_all(&self) -> eyre::Result<()> {
543                #(
544                    typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
545                )*
546                Ok(())
547            }
548        }
549
550        impl <
551                #(
552                    #generics_names: #generics_bounds_token,
553                )*
554            > TypedStoreDebug for #secondary_db_map_struct_name #generics {
555                fn dump_table(
556                    &self,
557                    table_name: String,
558                    page_size: u16,
559                    page_number: usize,
560                ) -> eyre::Result<std::collections::BTreeMap<String, String>> {
561                    self.dump(table_name.as_str(), page_size, page_number)
562                }
563
564                fn primary_db_name(&self) -> String {
565                    stringify!(#name).to_owned()
566                }
567
568                fn describe_all_tables(&self) -> std::collections::BTreeMap<String, (String, String)> {
569                    Self::describe_tables()
570                }
571
572                fn count_table_keys(&self, table_name: String) -> eyre::Result<usize> {
573                    self.count_keys(table_name.as_str())
574                }
575
576                fn table_summary(&self, table_name: String) -> eyre::Result<TableSummary> {
577                    self.table_summary(table_name.as_str())
578                }
579        }
580    })
581}