typed_store_derive/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::collections::BTreeMap;
6
7use proc_macro::TokenStream;
8use proc_macro2::Ident;
9use quote::quote;
10use syn::{
11    AngleBracketedGenericArguments, Attribute, Generics, ItemStruct, Lit, Meta, PathArguments,
12    Type::{self},
13    parse_macro_input,
14};
15
16// This is used as default when none is specified
17const DEFAULT_DB_OPTIONS_CUSTOM_FN: &str = "typed_store::rocks::default_db_options";
18// Custom function which returns the option and overrides the defaults for this
19// table
20const DB_OPTIONS_CUSTOM_FUNCTION: &str = "default_options_override_fn";
21// Use a different name for the column than the identifier
22const DB_OPTIONS_RENAME: &str = "rename";
23// Deprecate a column family with optional migration support
24// Usage: `#[deprecated_db_map]` or `#[deprecated_db_map(migration =
25// "migration_fn_path")]`
26// Hint: we can't use `#[deprecated]` because it doesn't allow us to specify a
27// migration parameter
28const DB_OPTIONS_DEPRECATED_TABLE: &str = "deprecated_db_map";
29
30/// Options can either be simplified form or
31enum GeneralTableOptions {
32    OverrideFunction(String),
33}
34
35impl Default for GeneralTableOptions {
36    fn default() -> Self {
37        Self::OverrideFunction(DEFAULT_DB_OPTIONS_CUSTOM_FN.to_owned())
38    }
39}
40
41/// Parse the migration function path from `#[deprecated_db_map(migration =
42/// "fn_path")]`
43fn parse_deprecated_db_map_migration(attr: &Attribute) -> Option<syn::Path> {
44    match attr.parse_meta() {
45        Ok(Meta::Path(_)) => None, // #[deprecated_db_map] with no args
46        Ok(Meta::List(ml)) => {
47            for nested in &ml.nested {
48                if let syn::NestedMeta::Meta(Meta::NameValue(nv)) = nested {
49                    if nv.path.is_ident("migration") {
50                        if let Lit::Str(s) = &nv.lit {
51                            let fn_path: syn::Path =
52                                s.parse().expect("migration value must be a valid path");
53                            return Some(fn_path);
54                        }
55                    }
56                }
57            }
58            None
59        }
60        _ => None,
61    }
62}
63
64// Extracts the field names, field types, inner types (K,V in {map_type_name}<K,
65// V>), and the options attrs
66fn extract_struct_info(input: ItemStruct) -> ExtractedStructInfo {
67    let mut active_table_fields = TableFields::default();
68    let mut deprecated_table_fields = TableFields::default();
69    let mut deprecated_cfs_with_migration_opts: Vec<(Ident, Option<syn::Path>)> = Vec::new();
70
71    for f in &input.fields {
72        let attrs: BTreeMap<_, _> = f
73            .attrs
74            .iter()
75            .filter(|a| {
76                a.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION)
77                    || a.path.is_ident(DB_OPTIONS_RENAME)
78                    || a.path.is_ident(DB_OPTIONS_DEPRECATED_TABLE)
79            })
80            .map(|a| (a.path.get_ident().unwrap().to_string(), a))
81            .collect();
82
83        let options = if let Some(options) = attrs.get(DB_OPTIONS_CUSTOM_FUNCTION) {
84            GeneralTableOptions::OverrideFunction(get_options_override_function(options).unwrap())
85        } else {
86            GeneralTableOptions::default()
87        };
88
89        let Type::Path(p) = &f.ty else {
90            panic!("All struct members must be of type DBMap or Option<DBMap>");
91        };
92
93        let is_deprecated = attrs.contains_key(DB_OPTIONS_DEPRECATED_TABLE);
94
95        let type_info = &p.path.segments.first().unwrap();
96
97        // For deprecated fields, unwrap Option<DBMap<K,V>> to extract DBMap<K,V>.
98        // Active fields must be DBMap<K,V> directly.
99        let (db_map_ident_str, inner_type) = if is_deprecated {
100            if type_info.ident == "Option" {
101                // Extract DBMap<K,V> from Option<DBMap<K,V>>
102                let option_args = match &type_info.arguments {
103                    PathArguments::AngleBracketed(ab) => ab,
104                    _ => panic!(
105                        "Expected Option<DBMap<K, V>> for deprecated field `{}`",
106                        f.ident.as_ref().unwrap()
107                    ),
108                };
109                let inner_path = match option_args.args.first() {
110                    Some(syn::GenericArgument::Type(Type::Path(p))) => p,
111                    _ => panic!(
112                        "Expected Option<DBMap<K, V>> for deprecated field `{}`",
113                        f.ident.as_ref().unwrap()
114                    ),
115                };
116                let inner_info = inner_path.path.segments.first().unwrap();
117                let inner_type = match &inner_info.arguments {
118                    PathArguments::AngleBracketed(ab) => ab.clone(),
119                    _ => panic!(
120                        "Expected DBMap<K, V> inside Option for deprecated field `{}`",
121                        f.ident.as_ref().unwrap()
122                    ),
123                };
124                (inner_info.ident.to_string(), inner_type)
125            } else {
126                panic!(
127                    "Deprecated field `{}` must use Option<DBMap<K, V>> instead of DBMap<K, V>",
128                    f.ident.as_ref().unwrap()
129                );
130            }
131        } else {
132            let inner_type =
133                if let PathArguments::AngleBracketed(angle_bracket_type) = &type_info.arguments {
134                    angle_bracket_type.clone()
135                } else {
136                    panic!("All struct members must be of type DBMap");
137                };
138            (type_info.ident.to_string(), inner_type)
139        };
140
141        assert!(
142            db_map_ident_str == "DBMap",
143            "All struct members must be of type DBMap (or Option<DBMap> for deprecated fields)"
144        );
145
146        let field_name = f.ident.as_ref().unwrap().clone();
147        let cf_name = if let Some(rename) = attrs.get(DB_OPTIONS_RENAME) {
148            match rename.parse_meta().expect("Cannot parse meta of attribute") {
149                Meta::NameValue(val) => {
150                    if let Lit::Str(s) = val.lit {
151                        // convert to ident
152                        s.parse().expect("Rename value must be identifier")
153                    } else {
154                        panic!("Expected string value for rename")
155                    }
156                }
157                _ => panic!("Expected string value for rename"),
158            }
159        } else {
160            field_name.clone()
161        };
162
163        // None: active cf
164        // Some(None): deprecated cf with no migration
165        // Some(Some(fn_path)): deprecated cf with migration
166        let deprecated_cf_with_migration_opt = attrs
167            .get(DB_OPTIONS_DEPRECATED_TABLE)
168            .map(|attr| parse_deprecated_db_map_migration(attr));
169
170        let target_table_fields = if let Some(migration_opt) = deprecated_cf_with_migration_opt {
171            deprecated_cfs_with_migration_opts.push((cf_name.clone(), migration_opt));
172            &mut deprecated_table_fields
173        } else {
174            &mut active_table_fields
175        };
176
177        target_table_fields.field_names.push(field_name);
178        target_table_fields.cf_names.push(cf_name);
179        target_table_fields.inner_types.push(inner_type);
180        target_table_fields.derived_table_options.push(options);
181    }
182
183    ExtractedStructInfo {
184        active_table_fields,
185        deprecated_table_fields,
186        deprecated_cfs_with_migration_opts,
187    }
188}
189
190/// Extracts the table options override function
191/// The function must take no args and return Options
192fn get_options_override_function(attr: &Attribute) -> syn::Result<String> {
193    let meta = attr.parse_meta()?;
194
195    let val = match meta.clone() {
196        Meta::NameValue(val) => val,
197        _ => {
198            return Err(syn::Error::new_spanned(
199                meta,
200                format!(
201                    "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
202                ),
203            ));
204        }
205    };
206
207    if !val.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION) {
208        return Err(syn::Error::new_spanned(
209            meta,
210            format!(
211                "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
212            ),
213        ));
214    }
215
216    let fn_name = match val.lit {
217        Lit::Str(fn_name) => fn_name,
218        _ => {
219            return Err(syn::Error::new_spanned(
220                meta,
221                format!(
222                    "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
223                ),
224            ));
225        }
226    };
227    Ok(fn_name.value())
228}
229
230fn extract_generics_names(generics: &Generics) -> Vec<Ident> {
231    generics
232        .params
233        .iter()
234        .map(|g| match g {
235            syn::GenericParam::Type(t) => t.ident.clone(),
236            _ => panic!("Unsupported generic type"),
237        })
238        .collect()
239}
240
241/// Parallel vecs describing a set of DB column families.
242#[derive(Default)]
243struct TableFields {
244    field_names: Vec<Ident>,
245    cf_names: Vec<Ident>,
246    inner_types: Vec<AngleBracketedGenericArguments>,
247    derived_table_options: Vec<GeneralTableOptions>,
248}
249
250struct ExtractedStructInfo {
251    /// Active (non-deprecated) tables
252    active_table_fields: TableFields,
253    /// Deprecated tables
254    deprecated_table_fields: TableFields,
255    /// CF names paired with their optional migration function paths
256    deprecated_cfs_with_migration_opts: Vec<(Ident, Option<syn::Path>)>,
257}
258
259#[proc_macro_derive(
260    DBMapUtils,
261    attributes(default_options_override_fn, rename, deprecated_db_map)
262)]
263pub fn derive_dbmap_utils_general(input: TokenStream) -> TokenStream {
264    let input = parse_macro_input!(input as ItemStruct);
265    let name = &input.ident;
266    let generics = &input.generics;
267    let generics_names = extract_generics_names(generics);
268
269    // TODO: use `parse_quote` over `parse()`
270    let ExtractedStructInfo {
271        active_table_fields,
272        deprecated_table_fields,
273        deprecated_cfs_with_migration_opts,
274    } = extract_struct_info(input.clone());
275
276    // Bind fields for use in quote! macro
277    let active_field_names = &active_table_fields.field_names;
278    let active_cf_names = &active_table_fields.cf_names;
279    let active_inner_types = &active_table_fields.inner_types;
280    let deprecated_field_names = &deprecated_table_fields.field_names;
281    let deprecated_cf_names = &deprecated_table_fields.cf_names;
282    let deprecated_inner_types = &deprecated_table_fields.inner_types;
283
284    // Combined field names for struct definitions (need all fields, active +
285    // deprecated)
286    let all_field_names: Vec<_> = active_field_names
287        .iter()
288        .chain(deprecated_field_names.iter())
289        .collect();
290
291    let active_default_options_override_fn_names: Vec<proc_macro2::TokenStream> =
292        active_table_fields
293            .derived_table_options
294            .iter()
295            .map(|q| {
296                let GeneralTableOptions::OverrideFunction(fn_name) = q;
297                fn_name.parse().unwrap()
298            })
299            .collect();
300
301    let active_key_names: Vec<_> = active_inner_types
302        .iter()
303        .map(|q| q.args.first().unwrap())
304        .collect();
305    let active_value_names: Vec<_> = active_inner_types
306        .iter()
307        .map(|q| q.args.last().unwrap())
308        .collect();
309
310    let generics_bounds =
311        "std::fmt::Debug + serde::Serialize + for<'de> serde::de::Deserialize<'de>";
312    let generics_bounds_token: proc_macro2::TokenStream = generics_bounds.parse().unwrap();
313
314    let intermediate_db_map_struct_name_str = format!("{name}IntermediateDBMapStructPrimary");
315    let intermediate_db_map_struct_name: proc_macro2::TokenStream =
316        intermediate_db_map_struct_name_str.parse().unwrap();
317
318    let secondary_db_map_struct_name_str = format!("{name}ReadOnly");
319    let secondary_db_map_struct_name: proc_macro2::TokenStream =
320        secondary_db_map_struct_name_str.parse().unwrap();
321
322    // Generate deprecation cleanup code: for each deprecated CF, optionally run
323    // migration then drop the CF. Uses cf_handle check for idempotency.
324    // Note: Migration functions are only called when the CF exists on disk.
325    // They receive `&Arc<RocksDB>` and may assume the deprecated CF is present.
326    let deprecation_cleanup: Vec<proc_macro2::TokenStream> = deprecated_cfs_with_migration_opts
327        .iter()
328        .map(|(cf_name, migration)| {
329            let migration_call = if let Some(fn_path) = migration {
330                quote! { #fn_path(&db).expect("deprecated table migration failed"); }
331            } else {
332                quote! {}
333            };
334            quote! {
335                if db.cf_handle(stringify!(#cf_name)).is_some() {
336                    #migration_call
337                    db.drop_cf(stringify!(#cf_name)).expect("failed to drop a deprecated cf");
338                }
339            }
340        })
341        .collect();
342
343    TokenStream::from(quote! {
344
345        // <----------- This section generates the core open logic for opening DBMaps -------------->
346
347        /// Create an intermediate struct used to open the DBMap tables in primary mode
348        /// This is only used internally
349        struct #intermediate_db_map_struct_name #generics {
350                #(
351                    pub #active_field_names : DBMap #active_inner_types,
352                )*
353                #(
354                    pub #deprecated_field_names : Option<DBMap #deprecated_inner_types>,
355                )*
356        }
357
358        impl <
359                #(
360                    #generics_names: #generics_bounds_token,
361                )*
362            > #intermediate_db_map_struct_name #generics {
363            /// Opens a set of tables in read-write mode
364            /// If as_secondary_with_path is set, the DB is opened in read only mode with the path specified
365            pub fn open_tables_impl(
366                path: std::path::PathBuf,
367                as_secondary_with_path: Option<std::path::PathBuf>,
368                metric_conf: typed_store::rocks::MetricConf,
369                global_db_options_override: Option<typed_store::rocksdb::Options>,
370                tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
371            ) -> Self {
372                let path = &path;
373                let default_cf_opt = if let Some(opt) = global_db_options_override.as_ref() {
374                    typed_store::rocks::DBOptions {
375                        options: opt.clone(),
376                        rw_options: typed_store::rocks::default_db_options().rw_options,
377                    }
378                } else {
379                    typed_store::rocks::default_db_options()
380                };
381                let (db, rwopt_cfs) = {
382                    let opt_cfs = match tables_db_options_override {
383                        None => [
384                            #(
385                                (stringify!(#active_cf_names).to_owned(), #active_default_options_override_fn_names()),
386                            )*
387                        ],
388                        Some(o) => [
389                            #(
390                                (stringify!(#active_cf_names).to_owned(), o.to_map().get(stringify!(#active_cf_names)).unwrap_or(&default_cf_opt).clone()),
391                            )*
392                        ]
393                    };
394                    // Safe to call unwrap because we will have at least one field_name entry in the struct
395                    let rwopt_cfs: std::collections::HashMap<String, typed_store::rocks::ReadWriteOptions> = opt_cfs.iter().map(|q| (q.0.as_str().to_string(), q.1.rw_options.clone())).collect();
396                    let opt_cfs: Vec<_> = opt_cfs.iter().map(|q| (q.0.as_str(), q.1.options.clone())).collect();
397                    let db = match as_secondary_with_path.clone() {
398                        Some(p) => typed_store::rocks::open_cf_opts_secondary(path, Some(&p), global_db_options_override, metric_conf, &opt_cfs),
399                        _ => typed_store::rocks::open_cf_opts(path, global_db_options_override, metric_conf, &opt_cfs)
400                    };
401                    db.map(|d| (d, rwopt_cfs))
402                }.expect(&format!("Cannot open DB at {:?}", path));
403                let (
404                        #(
405                            #active_field_names
406                        ),*
407                ) = (#(
408                        DBMap::#active_inner_types::reopen(&db, Some(stringify!(#active_cf_names)), rwopt_cfs.get(stringify!(#active_cf_names)).unwrap_or(&typed_store::rocks::ReadWriteOptions::default()), false).expect(&format!("Cannot open {} CF.", stringify!(#active_cf_names))[..])
409                    ),*);
410
411                // Open deprecated CFs only if they exist on disk.
412                // `mut` is needed because primary mode reassigns to `None` after cleanup,
413                // but in secondary mode no reassignment happens — hence `allow(unused_mut)`.
414                #(
415                    #[allow(unused_mut)]
416                    let mut #deprecated_field_names = if db.cf_handle(stringify!(#deprecated_cf_names)).is_some() {
417                        Some(DBMap::#deprecated_inner_types::reopen(
418                            &db,
419                            Some(stringify!(#deprecated_cf_names)),
420                            rwopt_cfs.get(stringify!(#deprecated_cf_names)).unwrap_or(&typed_store::rocks::ReadWriteOptions::default()),
421                            true,
422                        ).expect(&format!("Cannot open deprecated {} CF.", stringify!(#deprecated_cf_names))[..]))
423                    } else {
424                        None
425                    };
426                )*
427
428                if as_secondary_with_path.is_none() {
429                    #(#deprecation_cleanup)*
430                    // After cleanup, CF handles for deprecated tables are stale
431                    #(
432                        #deprecated_field_names = None;
433                    )*
434                }
435                Self {
436                    #(
437                        #all_field_names,
438                    )*
439                }
440            }
441        }
442
443
444        // <----------- This section generates the read-write open logic and other common utils -------------->
445
446        impl <
447                #(
448                    #generics_names: #generics_bounds_token,
449                )*
450            > #name #generics {
451            /// Opens a set of tables in read-write mode
452            /// Only one process is allowed to do this at a time
453            /// `global_db_options_override` apply to the whole DB
454            /// `tables_db_options_override` apply to each table. If `None`, the attributes from `default_options_override_fn` are used if any
455            #[expect(unused_parens)]
456            pub fn open_tables_read_write(
457                path: std::path::PathBuf,
458                metric_conf: typed_store::rocks::MetricConf,
459                global_db_options_override: Option<typed_store::rocksdb::Options>,
460                tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>
461            ) -> Self {
462                let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, metric_conf, global_db_options_override, tables_db_options_override);
463                Self {
464                    #(
465                        #all_field_names: inner.#all_field_names,
466                    )*
467                }
468            }
469
470            /// Returns a list of the tables name and type pairs
471            pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
472                vec![#(
473                    (stringify!(#active_cf_names).to_owned(), (stringify!(#active_key_names).to_owned(), stringify!(#active_value_names).to_owned())),
474                )*].into_iter().collect()
475            }
476
477            /// This opens the DB in read only mode and returns a struct which exposes debug features
478            pub fn get_read_only_handle (
479                primary_path: std::path::PathBuf,
480                with_secondary_path: Option<std::path::PathBuf>,
481                global_db_options_override: Option<typed_store::rocksdb::Options>,
482                metric_conf: typed_store::rocks::MetricConf,
483                ) -> #secondary_db_map_struct_name #generics {
484                #secondary_db_map_struct_name::open_tables_read_only(primary_path, with_secondary_path, metric_conf, global_db_options_override)
485            }
486        }
487
488
489        // <----------- This section generates the features that use read-only open logic -------------->
490        /// Create an intermediate struct used to open the DBMap tables in secondary mode
491        /// This is only used internally
492        pub struct #secondary_db_map_struct_name #generics {
493            #(
494                pub #active_field_names : DBMap #active_inner_types,
495            )*
496            #(
497                pub #deprecated_field_names : Option<DBMap #deprecated_inner_types>,
498            )*
499        }
500
501        impl <
502                #(
503                    #generics_names: #generics_bounds_token,
504                )*
505            > #secondary_db_map_struct_name #generics {
506            /// Open in read only mode. No limitation on number of processes to do this
507            pub fn open_tables_read_only(
508                primary_path: std::path::PathBuf,
509                with_secondary_path: Option<std::path::PathBuf>,
510                metric_conf: typed_store::rocks::MetricConf,
511                global_db_options_override: Option<typed_store::rocksdb::Options>,
512            ) -> Self {
513                let inner = match with_secondary_path {
514                    Some(q) => #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(q), metric_conf, global_db_options_override, None),
515                    None => {
516                        let p: std::path::PathBuf = tempfile::tempdir()
517                        .expect("Failed to open temporary directory")
518                        .keep();
519                        #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(p), metric_conf, global_db_options_override, None)
520                    }
521                };
522                Self {
523                    #(
524                        #all_field_names: inner.#all_field_names,
525                    )*
526                }
527            }
528
529            fn cf_name_to_table_name(cf_name: &str) -> eyre::Result<&'static str> {
530                Ok(match cf_name {
531                    #(
532                        stringify!(#active_cf_names) => stringify!(#active_field_names),
533                    )*
534                    _ => eyre::bail!("No such cf name: {}", cf_name),
535                })
536            }
537
538            /// Dump all key-value pairs in the page at the given table name
539            /// Tables must be opened in read only mode using `open_tables_read_only`
540            pub fn dump(&self, cf_name: &str, page_size: u16, page_number: usize) -> eyre::Result<std::collections::BTreeMap<String, String>> {
541                let table_name = Self::cf_name_to_table_name(cf_name)?;
542
543                Ok(match table_name {
544                    #(
545                        stringify!(#active_field_names) => {
546                            typed_store::traits::Map::try_catch_up_with_primary(&self.#active_field_names)?;
547                            typed_store::traits::Map::safe_iter(&self.#active_field_names)
548                                .skip((page_number * (page_size) as usize))
549                                .take(page_size as usize)
550                                .map(|result| result.map(|(k, v)| (format!("{:?}", k), format!("{:?}", v))))
551                                .collect::<eyre::Result<std::collections::BTreeMap<_, _>, _>>()?
552                        }
553                    )*
554
555                    _ => eyre::bail!("No such table name: {}", table_name),
556                })
557            }
558
559            /// Get key value sizes from the db
560            /// Tables must be opened in read only mode using `open_tables_read_only`
561            pub fn table_summary(&self, table_name: &str) -> eyre::Result<typed_store::traits::TableSummary> {
562                let mut count = 0;
563                let mut key_bytes = 0;
564                let mut value_bytes = 0;
565                match table_name {
566                    #(
567                        stringify!(#active_field_names) => {
568                            typed_store::traits::Map::try_catch_up_with_primary(&self.#active_field_names)?;
569                            self.#active_field_names.table_summary()
570                        }
571                    )*
572
573                    _ => eyre::bail!("No such table name: {}", table_name),
574                }
575            }
576
577            pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
578                vec![#(
579                    (stringify!(#active_cf_names).to_owned(), (stringify!(#active_key_names).to_owned(), stringify!(#active_value_names).to_owned())),
580                )*].into_iter().collect()
581            }
582
583            /// Try catch up with primary for all tables. This can be a slow operation
584            /// Tables must be opened in read only mode using `open_tables_read_only`
585            pub fn try_catch_up_with_primary_all(&self) -> eyre::Result<()> {
586                #(
587                    typed_store::traits::Map::try_catch_up_with_primary(&self.#active_field_names)?;
588                )*
589                Ok(())
590            }
591        }
592    })
593}