1use std::collections::BTreeMap;
6
7use proc_macro::TokenStream;
8use proc_macro2::Ident;
9use quote::quote;
10use syn::{
11 AngleBracketedGenericArguments, Attribute, Generics, ItemStruct, Lit, Meta, PathArguments,
12 Type::{self},
13 parse_macro_input,
14};
15
16const DEFAULT_DB_OPTIONS_CUSTOM_FN: &str = "typed_store::rocks::default_db_options";
18const DB_OPTIONS_CUSTOM_FUNCTION: &str = "default_options_override_fn";
21const DB_OPTIONS_RENAME: &str = "rename";
23const DB_OPTIONS_DEPRECATED_TABLE: &str = "deprecated_db_map";
29
30enum GeneralTableOptions {
32 OverrideFunction(String),
33}
34
35impl Default for GeneralTableOptions {
36 fn default() -> Self {
37 Self::OverrideFunction(DEFAULT_DB_OPTIONS_CUSTOM_FN.to_owned())
38 }
39}
40
41fn parse_deprecated_db_map_migration(attr: &Attribute) -> Option<syn::Path> {
44 match attr.parse_meta() {
45 Ok(Meta::Path(_)) => None, Ok(Meta::List(ml)) => {
47 for nested in &ml.nested {
48 if let syn::NestedMeta::Meta(Meta::NameValue(nv)) = nested {
49 if nv.path.is_ident("migration") {
50 if let Lit::Str(s) = &nv.lit {
51 let fn_path: syn::Path =
52 s.parse().expect("migration value must be a valid path");
53 return Some(fn_path);
54 }
55 }
56 }
57 }
58 None
59 }
60 _ => None,
61 }
62}
63
64fn extract_struct_info(input: ItemStruct) -> ExtractedStructInfo {
67 let mut active_table_fields = TableFields::default();
68 let mut deprecated_table_fields = TableFields::default();
69 let mut deprecated_cfs_with_migration_opts: Vec<(Ident, Option<syn::Path>)> = Vec::new();
70
71 for f in &input.fields {
72 let attrs: BTreeMap<_, _> = f
73 .attrs
74 .iter()
75 .filter(|a| {
76 a.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION)
77 || a.path.is_ident(DB_OPTIONS_RENAME)
78 || a.path.is_ident(DB_OPTIONS_DEPRECATED_TABLE)
79 })
80 .map(|a| (a.path.get_ident().unwrap().to_string(), a))
81 .collect();
82
83 let options = if let Some(options) = attrs.get(DB_OPTIONS_CUSTOM_FUNCTION) {
84 GeneralTableOptions::OverrideFunction(get_options_override_function(options).unwrap())
85 } else {
86 GeneralTableOptions::default()
87 };
88
89 let Type::Path(p) = &f.ty else {
90 panic!("All struct members must be of type DBMap or Option<DBMap>");
91 };
92
93 let is_deprecated = attrs.contains_key(DB_OPTIONS_DEPRECATED_TABLE);
94
95 let type_info = &p.path.segments.first().unwrap();
96
97 let (db_map_ident_str, inner_type) = if is_deprecated {
100 if type_info.ident == "Option" {
101 let option_args = match &type_info.arguments {
103 PathArguments::AngleBracketed(ab) => ab,
104 _ => panic!(
105 "Expected Option<DBMap<K, V>> for deprecated field `{}`",
106 f.ident.as_ref().unwrap()
107 ),
108 };
109 let inner_path = match option_args.args.first() {
110 Some(syn::GenericArgument::Type(Type::Path(p))) => p,
111 _ => panic!(
112 "Expected Option<DBMap<K, V>> for deprecated field `{}`",
113 f.ident.as_ref().unwrap()
114 ),
115 };
116 let inner_info = inner_path.path.segments.first().unwrap();
117 let inner_type = match &inner_info.arguments {
118 PathArguments::AngleBracketed(ab) => ab.clone(),
119 _ => panic!(
120 "Expected DBMap<K, V> inside Option for deprecated field `{}`",
121 f.ident.as_ref().unwrap()
122 ),
123 };
124 (inner_info.ident.to_string(), inner_type)
125 } else {
126 panic!(
127 "Deprecated field `{}` must use Option<DBMap<K, V>> instead of DBMap<K, V>",
128 f.ident.as_ref().unwrap()
129 );
130 }
131 } else {
132 let inner_type =
133 if let PathArguments::AngleBracketed(angle_bracket_type) = &type_info.arguments {
134 angle_bracket_type.clone()
135 } else {
136 panic!("All struct members must be of type DBMap");
137 };
138 (type_info.ident.to_string(), inner_type)
139 };
140
141 assert!(
142 db_map_ident_str == "DBMap",
143 "All struct members must be of type DBMap (or Option<DBMap> for deprecated fields)"
144 );
145
146 let field_name = f.ident.as_ref().unwrap().clone();
147 let cf_name = if let Some(rename) = attrs.get(DB_OPTIONS_RENAME) {
148 match rename.parse_meta().expect("Cannot parse meta of attribute") {
149 Meta::NameValue(val) => {
150 if let Lit::Str(s) = val.lit {
151 s.parse().expect("Rename value must be identifier")
153 } else {
154 panic!("Expected string value for rename")
155 }
156 }
157 _ => panic!("Expected string value for rename"),
158 }
159 } else {
160 field_name.clone()
161 };
162
163 let deprecated_cf_with_migration_opt = attrs
167 .get(DB_OPTIONS_DEPRECATED_TABLE)
168 .map(|attr| parse_deprecated_db_map_migration(attr));
169
170 let target_table_fields = if let Some(migration_opt) = deprecated_cf_with_migration_opt {
171 deprecated_cfs_with_migration_opts.push((cf_name.clone(), migration_opt));
172 &mut deprecated_table_fields
173 } else {
174 &mut active_table_fields
175 };
176
177 target_table_fields.field_names.push(field_name);
178 target_table_fields.cf_names.push(cf_name);
179 target_table_fields.inner_types.push(inner_type);
180 target_table_fields.derived_table_options.push(options);
181 }
182
183 ExtractedStructInfo {
184 active_table_fields,
185 deprecated_table_fields,
186 deprecated_cfs_with_migration_opts,
187 }
188}
189
190fn get_options_override_function(attr: &Attribute) -> syn::Result<String> {
193 let meta = attr.parse_meta()?;
194
195 let val = match meta.clone() {
196 Meta::NameValue(val) => val,
197 _ => {
198 return Err(syn::Error::new_spanned(
199 meta,
200 format!(
201 "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
202 ),
203 ));
204 }
205 };
206
207 if !val.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION) {
208 return Err(syn::Error::new_spanned(
209 meta,
210 format!(
211 "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
212 ),
213 ));
214 }
215
216 let fn_name = match val.lit {
217 Lit::Str(fn_name) => fn_name,
218 _ => {
219 return Err(syn::Error::new_spanned(
220 meta,
221 format!(
222 "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
223 ),
224 ));
225 }
226 };
227 Ok(fn_name.value())
228}
229
230fn extract_generics_names(generics: &Generics) -> Vec<Ident> {
231 generics
232 .params
233 .iter()
234 .map(|g| match g {
235 syn::GenericParam::Type(t) => t.ident.clone(),
236 _ => panic!("Unsupported generic type"),
237 })
238 .collect()
239}
240
241#[derive(Default)]
243struct TableFields {
244 field_names: Vec<Ident>,
245 cf_names: Vec<Ident>,
246 inner_types: Vec<AngleBracketedGenericArguments>,
247 derived_table_options: Vec<GeneralTableOptions>,
248}
249
250struct ExtractedStructInfo {
251 active_table_fields: TableFields,
253 deprecated_table_fields: TableFields,
255 deprecated_cfs_with_migration_opts: Vec<(Ident, Option<syn::Path>)>,
257}
258
259#[proc_macro_derive(
260 DBMapUtils,
261 attributes(default_options_override_fn, rename, deprecated_db_map)
262)]
263pub fn derive_dbmap_utils_general(input: TokenStream) -> TokenStream {
264 let input = parse_macro_input!(input as ItemStruct);
265 let name = &input.ident;
266 let generics = &input.generics;
267 let generics_names = extract_generics_names(generics);
268
269 let ExtractedStructInfo {
271 active_table_fields,
272 deprecated_table_fields,
273 deprecated_cfs_with_migration_opts,
274 } = extract_struct_info(input.clone());
275
276 let active_field_names = &active_table_fields.field_names;
278 let active_cf_names = &active_table_fields.cf_names;
279 let active_inner_types = &active_table_fields.inner_types;
280 let deprecated_field_names = &deprecated_table_fields.field_names;
281 let deprecated_cf_names = &deprecated_table_fields.cf_names;
282 let deprecated_inner_types = &deprecated_table_fields.inner_types;
283
284 let all_field_names: Vec<_> = active_field_names
287 .iter()
288 .chain(deprecated_field_names.iter())
289 .collect();
290
291 let active_default_options_override_fn_names: Vec<proc_macro2::TokenStream> =
292 active_table_fields
293 .derived_table_options
294 .iter()
295 .map(|q| {
296 let GeneralTableOptions::OverrideFunction(fn_name) = q;
297 fn_name.parse().unwrap()
298 })
299 .collect();
300
301 let active_key_names: Vec<_> = active_inner_types
302 .iter()
303 .map(|q| q.args.first().unwrap())
304 .collect();
305 let active_value_names: Vec<_> = active_inner_types
306 .iter()
307 .map(|q| q.args.last().unwrap())
308 .collect();
309
310 let generics_bounds =
311 "std::fmt::Debug + serde::Serialize + for<'de> serde::de::Deserialize<'de>";
312 let generics_bounds_token: proc_macro2::TokenStream = generics_bounds.parse().unwrap();
313
314 let config_struct_name_str = format!("{name}Configurator");
315 let config_struct_name: proc_macro2::TokenStream = config_struct_name_str.parse().unwrap();
316
317 let intermediate_db_map_struct_name_str = format!("{name}IntermediateDBMapStructPrimary");
318 let intermediate_db_map_struct_name: proc_macro2::TokenStream =
319 intermediate_db_map_struct_name_str.parse().unwrap();
320
321 let secondary_db_map_struct_name_str = format!("{name}ReadOnly");
322 let secondary_db_map_struct_name: proc_macro2::TokenStream =
323 secondary_db_map_struct_name_str.parse().unwrap();
324
325 let deprecation_cleanup: Vec<proc_macro2::TokenStream> = deprecated_cfs_with_migration_opts
330 .iter()
331 .map(|(cf_name, migration)| {
332 let migration_call = if let Some(fn_path) = migration {
333 quote! { #fn_path(&db).expect("deprecated table migration failed"); }
334 } else {
335 quote! {}
336 };
337 quote! {
338 if db.cf_handle(stringify!(#cf_name)).is_some() {
339 #migration_call
340 db.drop_cf(stringify!(#cf_name)).expect("failed to drop a deprecated cf");
341 }
342 }
343 })
344 .collect();
345
346 TokenStream::from(quote! {
347
348 pub struct #config_struct_name {
353 #(
354 pub #active_field_names : typed_store::rocks::DBOptions,
355 )*
356 }
357
358 impl #config_struct_name {
359 pub fn init() -> Self {
361 Self {
362 #(
363 #active_field_names : typed_store::rocks::default_db_options(),
364 )*
365 }
366 }
367
368 pub fn build(&self) -> typed_store::rocks::DBMapTableConfigMap {
370 typed_store::rocks::DBMapTableConfigMap::new([
371 #(
372 (stringify!(#active_field_names).to_owned(), self.#active_field_names.clone()),
373 )*
374 ].into_iter().collect())
375 }
376 }
377
378 impl <
379 #(
380 #generics_names: #generics_bounds_token,
381 )*
382 > #name #generics {
383
384 pub fn configurator() -> #config_struct_name {
385 #config_struct_name::init()
386 }
387 }
388
389 struct #intermediate_db_map_struct_name #generics {
394 #(
395 pub #active_field_names : DBMap #active_inner_types,
396 )*
397 #(
398 pub #deprecated_field_names : Option<DBMap #deprecated_inner_types>,
399 )*
400 }
401
402
403 impl <
404 #(
405 #generics_names: #generics_bounds_token,
406 )*
407 > #intermediate_db_map_struct_name #generics {
408 pub fn open_tables_impl(
411 path: std::path::PathBuf,
412 as_secondary_with_path: Option<std::path::PathBuf>,
413 metric_conf: typed_store::rocks::MetricConf,
414 global_db_options_override: Option<typed_store::rocksdb::Options>,
415 tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
416 ) -> Self {
417 let path = &path;
418 let default_cf_opt = if let Some(opt) = global_db_options_override.as_ref() {
419 typed_store::rocks::DBOptions {
420 options: opt.clone(),
421 rw_options: typed_store::rocks::default_db_options().rw_options,
422 }
423 } else {
424 typed_store::rocks::default_db_options()
425 };
426 let (db, rwopt_cfs) = {
427 let opt_cfs = match tables_db_options_override {
428 None => [
429 #(
430 (stringify!(#active_cf_names).to_owned(), #active_default_options_override_fn_names()),
431 )*
432 ],
433 Some(o) => [
434 #(
435 (stringify!(#active_cf_names).to_owned(), o.to_map().get(stringify!(#active_cf_names)).unwrap_or(&default_cf_opt).clone()),
436 )*
437 ]
438 };
439 let rwopt_cfs: std::collections::HashMap<String, typed_store::rocks::ReadWriteOptions> = opt_cfs.iter().map(|q| (q.0.as_str().to_string(), q.1.rw_options.clone())).collect();
441 let opt_cfs: Vec<_> = opt_cfs.iter().map(|q| (q.0.as_str(), q.1.options.clone())).collect();
442 let db = match as_secondary_with_path.clone() {
443 Some(p) => typed_store::rocks::open_cf_opts_secondary(path, Some(&p), global_db_options_override, metric_conf, &opt_cfs),
444 _ => typed_store::rocks::open_cf_opts(path, global_db_options_override, metric_conf, &opt_cfs)
445 };
446 db.map(|d| (d, rwopt_cfs))
447 }.expect(&format!("Cannot open DB at {:?}", path));
448 let (
449 #(
450 #active_field_names
451 ),*
452 ) = (#(
453 DBMap::#active_inner_types::reopen(&db, Some(stringify!(#active_cf_names)), rwopt_cfs.get(stringify!(#active_cf_names)).unwrap_or(&typed_store::rocks::ReadWriteOptions::default()), false).expect(&format!("Cannot open {} CF.", stringify!(#active_cf_names))[..])
454 ),*);
455
456 #(
460 #[allow(unused_mut)]
461 let mut #deprecated_field_names = if db.cf_handle(stringify!(#deprecated_cf_names)).is_some() {
462 Some(DBMap::#deprecated_inner_types::reopen(
463 &db,
464 Some(stringify!(#deprecated_cf_names)),
465 rwopt_cfs.get(stringify!(#deprecated_cf_names)).unwrap_or(&typed_store::rocks::ReadWriteOptions::default()),
466 true,
467 ).expect(&format!("Cannot open deprecated {} CF.", stringify!(#deprecated_cf_names))[..]))
468 } else {
469 None
470 };
471 )*
472
473 if as_secondary_with_path.is_none() {
474 #(#deprecation_cleanup)*
475 #(
477 #deprecated_field_names = None;
478 )*
479 }
480 Self {
481 #(
482 #all_field_names,
483 )*
484 }
485 }
486 }
487
488
489 impl <
492 #(
493 #generics_names: #generics_bounds_token,
494 )*
495 > #name #generics {
496 #[expect(unused_parens)]
501 pub fn open_tables_read_write(
502 path: std::path::PathBuf,
503 metric_conf: typed_store::rocks::MetricConf,
504 global_db_options_override: Option<typed_store::rocksdb::Options>,
505 tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>
506 ) -> Self {
507 let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, metric_conf, global_db_options_override, tables_db_options_override);
508 Self {
509 #(
510 #all_field_names: inner.#all_field_names,
511 )*
512 }
513 }
514
515 pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
517 vec![#(
518 (stringify!(#active_field_names).to_owned(), (stringify!(#active_key_names).to_owned(), stringify!(#active_value_names).to_owned())),
519 )*].into_iter().collect()
520 }
521
522 pub fn get_read_only_handle (
524 primary_path: std::path::PathBuf,
525 with_secondary_path: Option<std::path::PathBuf>,
526 global_db_options_override: Option<typed_store::rocksdb::Options>,
527 metric_conf: typed_store::rocks::MetricConf,
528 ) -> #secondary_db_map_struct_name #generics {
529 #secondary_db_map_struct_name::open_tables_read_only(primary_path, with_secondary_path, metric_conf, global_db_options_override)
530 }
531 }
532
533
534 pub struct #secondary_db_map_struct_name #generics {
538 #(
539 pub #active_field_names : DBMap #active_inner_types,
540 )*
541 #(
542 pub #deprecated_field_names : Option<DBMap #deprecated_inner_types>,
543 )*
544 }
545
546 impl <
547 #(
548 #generics_names: #generics_bounds_token,
549 )*
550 > #secondary_db_map_struct_name #generics {
551 pub fn open_tables_read_only(
553 primary_path: std::path::PathBuf,
554 with_secondary_path: Option<std::path::PathBuf>,
555 metric_conf: typed_store::rocks::MetricConf,
556 global_db_options_override: Option<typed_store::rocksdb::Options>,
557 ) -> Self {
558 let inner = match with_secondary_path {
559 Some(q) => #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(q), metric_conf, global_db_options_override, None),
560 None => {
561 let p: std::path::PathBuf = tempfile::tempdir()
562 .expect("Failed to open temporary directory")
563 .keep();
564 #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(p), metric_conf, global_db_options_override, None)
565 }
566 };
567 Self {
568 #(
569 #all_field_names: inner.#all_field_names,
570 )*
571 }
572 }
573
574 fn cf_name_to_table_name(cf_name: &str) -> eyre::Result<&'static str> {
575 Ok(match cf_name {
576 #(
577 stringify!(#active_cf_names) => stringify!(#active_field_names),
578 )*
579 _ => eyre::bail!("No such cf name: {}", cf_name),
580 })
581 }
582
583 pub fn dump(&self, cf_name: &str, page_size: u16, page_number: usize) -> eyre::Result<std::collections::BTreeMap<String, String>> {
586 let table_name = Self::cf_name_to_table_name(cf_name)?;
587
588 Ok(match table_name {
589 #(
590 stringify!(#active_field_names) => {
591 typed_store::traits::Map::try_catch_up_with_primary(&self.#active_field_names)?;
592 typed_store::traits::Map::safe_iter(&self.#active_field_names)
593 .skip((page_number * (page_size) as usize))
594 .take(page_size as usize)
595 .map(|result| result.map(|(k, v)| (format!("{:?}", k), format!("{:?}", v))))
596 .collect::<eyre::Result<std::collections::BTreeMap<_, _>, _>>()?
597 }
598 )*
599
600 _ => eyre::bail!("No such table name: {}", table_name),
601 })
602 }
603
604 pub fn table_summary(&self, table_name: &str) -> eyre::Result<typed_store::traits::TableSummary> {
607 let mut count = 0;
608 let mut key_bytes = 0;
609 let mut value_bytes = 0;
610 match table_name {
611 #(
612 stringify!(#active_field_names) => {
613 typed_store::traits::Map::try_catch_up_with_primary(&self.#active_field_names)?;
614 self.#active_field_names.table_summary()
615 }
616 )*
617
618 _ => eyre::bail!("No such table name: {}", table_name),
619 }
620 }
621
622 pub fn count_keys(&self, table_name: &str) -> eyre::Result<usize> {
625 Ok(match table_name {
626 #(
627 stringify!(#active_field_names) => {
628 typed_store::traits::Map::try_catch_up_with_primary(&self.#active_field_names)?;
629 typed_store::traits::Map::safe_iter(&self.#active_field_names)
630 .collect::<Result<Vec<_>, _>>()?
631 .len()
632 }
633 )*
634
635 _ => eyre::bail!("No such table name: {}", table_name),
636 })
637 }
638
639 pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
640 vec![#(
641 (stringify!(#active_field_names).to_owned(), (stringify!(#active_key_names).to_owned(), stringify!(#active_value_names).to_owned())),
642 )*].into_iter().collect()
643 }
644
645 pub fn try_catch_up_with_primary_all(&self) -> eyre::Result<()> {
648 #(
649 typed_store::traits::Map::try_catch_up_with_primary(&self.#active_field_names)?;
650 )*
651 Ok(())
652 }
653 }
654
655 impl <
656 #(
657 #generics_names: #generics_bounds_token,
658 )*
659 > TypedStoreDebug for #secondary_db_map_struct_name #generics {
660 fn dump_table(
661 &self,
662 table_name: String,
663 page_size: u16,
664 page_number: usize,
665 ) -> eyre::Result<std::collections::BTreeMap<String, String>> {
666 self.dump(table_name.as_str(), page_size, page_number)
667 }
668
669 fn primary_db_name(&self) -> String {
670 stringify!(#name).to_owned()
671 }
672
673 fn describe_all_tables(&self) -> std::collections::BTreeMap<String, (String, String)> {
674 Self::describe_tables()
675 }
676
677 fn count_table_keys(&self, table_name: String) -> eyre::Result<usize> {
678 self.count_keys(table_name.as_str())
679 }
680
681 fn table_summary(&self, table_name: String) -> eyre::Result<TableSummary> {
682 self.table_summary(table_name.as_str())
683 }
684 }
685 })
686}