1use std::collections::BTreeMap;
6
7use itertools::Itertools;
8use proc_macro::TokenStream;
9use proc_macro2::Ident;
10use quote::quote;
11use syn::{
12 AngleBracketedGenericArguments, Attribute, Generics, ItemStruct, Lit, Meta, PathArguments,
13 Type::{self},
14 parse_macro_input,
15};
16
17const DEFAULT_DB_OPTIONS_CUSTOM_FN: &str = "typed_store::rocks::default_db_options";
19const DB_OPTIONS_CUSTOM_FUNCTION: &str = "default_options_override_fn";
22const DB_OPTIONS_RENAME: &str = "rename";
24const DB_OPTIONS_DEPRECATE: &str = "deprecated";
26
27enum GeneralTableOptions {
29 OverrideFunction(String),
30}
31
32impl Default for GeneralTableOptions {
33 fn default() -> Self {
34 Self::OverrideFunction(DEFAULT_DB_OPTIONS_CUSTOM_FN.to_owned())
35 }
36}
37
38fn extract_struct_info(input: ItemStruct) -> ExtractedStructInfo {
41 let mut deprecated_cfs = vec![];
42
43 let info = input.fields.iter().map(|f| {
44 let attrs: BTreeMap<_, _> = f
45 .attrs
46 .iter()
47 .filter(|a| {
48 a.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION)
49 || a.path.is_ident(DB_OPTIONS_RENAME)
50 || a.path.is_ident(DB_OPTIONS_DEPRECATE)
51 })
52 .map(|a| (a.path.get_ident().unwrap().to_string(), a))
53 .collect();
54
55 let options = if let Some(options) = attrs.get(DB_OPTIONS_CUSTOM_FUNCTION) {
56 GeneralTableOptions::OverrideFunction(get_options_override_function(options).unwrap())
57 } else {
58 GeneralTableOptions::default()
59 };
60
61 let ty = &f.ty;
62 if let Type::Path(p) = ty {
63 let type_info = &p.path.segments.first().unwrap();
64 let inner_type =
65 if let PathArguments::AngleBracketed(angle_bracket_type) = &type_info.arguments {
66 angle_bracket_type.clone()
67 } else {
68 panic!("All struct members must be of type DBMap");
69 };
70
71 let type_str = format!("{}", &type_info.ident);
72 if type_str == "DBMap" {
73 let field_name = f.ident.as_ref().unwrap().clone();
74 let cf_name = if let Some(rename) = attrs.get(DB_OPTIONS_RENAME) {
75 match rename.parse_meta().expect("Cannot parse meta of attribute") {
76 Meta::NameValue(val) => {
77 if let Lit::Str(s) = val.lit {
78 s.parse().expect("Rename value must be identifier")
80 } else {
81 panic!("Expected string value for rename")
82 }
83 }
84 _ => panic!("Expected string value for rename"),
85 }
86 } else {
87 field_name.clone()
88 };
89 if attrs.contains_key(DB_OPTIONS_DEPRECATE) {
90 deprecated_cfs.push(field_name.clone());
91 }
92
93 return ((field_name, cf_name, type_str), (inner_type, options));
94 } else {
95 panic!("All struct members must be of type DBMap");
96 }
97 }
98 panic!("All struct members must be of type DBMap");
99 });
100
101 let (field_info, inner_types_with_opts): (Vec<_>, Vec<_>) = info.unzip();
102 let (field_names, cf_names, simple_field_type_names): (Vec<_>, Vec<_>, Vec<_>) =
103 field_info.into_iter().multiunzip();
104
105 if let Some(first) = simple_field_type_names.first() {
107 simple_field_type_names.iter().for_each(|q| {
108 if q != first {
109 panic!("All struct members must be of same type");
110 }
111 })
112 } else {
113 panic!("Cannot derive on empty struct");
114 };
115
116 let (inner_types, options): (Vec<_>, Vec<_>) = inner_types_with_opts.into_iter().unzip();
117
118 ExtractedStructInfo {
119 field_names,
120 cf_names,
121 inner_types,
122 derived_table_options: options,
123 deprecated_cfs,
124 }
125}
126
127fn get_options_override_function(attr: &Attribute) -> syn::Result<String> {
130 let meta = attr.parse_meta()?;
131
132 let val = match meta.clone() {
133 Meta::NameValue(val) => val,
134 _ => {
135 return Err(syn::Error::new_spanned(
136 meta,
137 format!(
138 "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
139 ),
140 ));
141 }
142 };
143
144 if !val.path.is_ident(DB_OPTIONS_CUSTOM_FUNCTION) {
145 return Err(syn::Error::new_spanned(
146 meta,
147 format!(
148 "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
149 ),
150 ));
151 }
152
153 let fn_name = match val.lit {
154 Lit::Str(fn_name) => fn_name,
155 _ => {
156 return Err(syn::Error::new_spanned(
157 meta,
158 format!(
159 "Expected function name in format `#[{DB_OPTIONS_CUSTOM_FUNCTION} = {{function_name}}]`"
160 ),
161 ));
162 }
163 };
164 Ok(fn_name.value())
165}
166
167fn extract_generics_names(generics: &Generics) -> Vec<Ident> {
168 generics
169 .params
170 .iter()
171 .map(|g| match g {
172 syn::GenericParam::Type(t) => t.ident.clone(),
173 _ => panic!("Unsupported generic type"),
174 })
175 .collect()
176}
177
178struct ExtractedStructInfo {
179 field_names: Vec<Ident>,
180 cf_names: Vec<Ident>,
181 inner_types: Vec<AngleBracketedGenericArguments>,
182 derived_table_options: Vec<GeneralTableOptions>,
183 deprecated_cfs: Vec<Ident>,
184}
185
186#[proc_macro_derive(DBMapUtils, attributes(default_options_override_fn, rename))]
187pub fn derive_dbmap_utils_general(input: TokenStream) -> TokenStream {
188 let input = parse_macro_input!(input as ItemStruct);
189 let name = &input.ident;
190 let generics = &input.generics;
191 let generics_names = extract_generics_names(generics);
192
193 let ExtractedStructInfo {
195 field_names,
196 cf_names,
197 inner_types,
198 derived_table_options,
199 deprecated_cfs,
200 } = extract_struct_info(input.clone());
201
202 let (key_names, value_names): (Vec<_>, Vec<_>) = inner_types
203 .iter()
204 .map(|q| (q.args.first().unwrap(), q.args.last().unwrap()))
205 .unzip();
206
207 let default_options_override_fn_names: Vec<proc_macro2::TokenStream> = derived_table_options
208 .iter()
209 .map(|q| {
210 let GeneralTableOptions::OverrideFunction(fn_name) = q;
211 fn_name.parse().unwrap()
212 })
213 .collect();
214
215 let generics_bounds =
216 "std::fmt::Debug + serde::Serialize + for<'de> serde::de::Deserialize<'de>";
217 let generics_bounds_token: proc_macro2::TokenStream = generics_bounds.parse().unwrap();
218
219 let config_struct_name_str = format!("{name}Configurator");
220 let config_struct_name: proc_macro2::TokenStream = config_struct_name_str.parse().unwrap();
221
222 let intermediate_db_map_struct_name_str = format!("{name}IntermediateDBMapStructPrimary");
223 let intermediate_db_map_struct_name: proc_macro2::TokenStream =
224 intermediate_db_map_struct_name_str.parse().unwrap();
225
226 let secondary_db_map_struct_name_str = format!("{name}ReadOnly");
227 let secondary_db_map_struct_name: proc_macro2::TokenStream =
228 secondary_db_map_struct_name_str.parse().unwrap();
229
230 TokenStream::from(quote! {
231
232 pub struct #config_struct_name {
236 #(
237 pub #field_names : typed_store::rocks::DBOptions,
238 )*
239 }
240
241 impl #config_struct_name {
242 pub fn init() -> Self {
244 Self {
245 #(
246 #field_names : typed_store::rocks::default_db_options(),
247 )*
248 }
249 }
250
251 pub fn build(&self) -> typed_store::rocks::DBMapTableConfigMap {
253 typed_store::rocks::DBMapTableConfigMap::new([
254 #(
255 (stringify!(#field_names).to_owned(), self.#field_names.clone()),
256 )*
257 ].into_iter().collect())
258 }
259 }
260
261 impl <
262 #(
263 #generics_names: #generics_bounds_token,
264 )*
265 > #name #generics {
266
267 pub fn configurator() -> #config_struct_name {
268 #config_struct_name::init()
269 }
270 }
271
272 struct #intermediate_db_map_struct_name #generics {
277 #(
278 pub #field_names : DBMap #inner_types,
279 )*
280 }
281
282
283 impl <
284 #(
285 #generics_names: #generics_bounds_token,
286 )*
287 > #intermediate_db_map_struct_name #generics {
288 pub fn open_tables_impl(
291 path: std::path::PathBuf,
292 as_secondary_with_path: Option<std::path::PathBuf>,
293 is_transaction: bool,
294 metric_conf: typed_store::rocks::MetricConf,
295 global_db_options_override: Option<typed_store::rocksdb::Options>,
296 tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
297 remove_deprecated_tables: bool,
298 ) -> Self {
299 let path = &path;
300 let default_cf_opt = if let Some(opt) = global_db_options_override.as_ref() {
301 typed_store::rocks::DBOptions {
302 options: opt.clone(),
303 rw_options: typed_store::rocks::default_db_options().rw_options,
304 }
305 } else {
306 typed_store::rocks::default_db_options()
307 };
308 let (db, rwopt_cfs) = {
309 let opt_cfs = match tables_db_options_override {
310 None => [
311 #(
312 (stringify!(#cf_names).to_owned(), #default_options_override_fn_names()),
313 )*
314 ],
315 Some(o) => [
316 #(
317 (stringify!(#cf_names).to_owned(), o.to_map().get(stringify!(#cf_names)).unwrap_or(&default_cf_opt).clone()),
318 )*
319 ]
320 };
321 let rwopt_cfs: std::collections::HashMap<String, typed_store::rocks::ReadWriteOptions> = opt_cfs.iter().map(|q| (q.0.as_str().to_string(), q.1.rw_options.clone())).collect();
323 let opt_cfs: Vec<_> = opt_cfs.iter().map(|q| (q.0.as_str(), q.1.options.clone())).collect();
324 let db = match (as_secondary_with_path.clone(), is_transaction) {
325 (Some(p), _) => typed_store::rocks::open_cf_opts_secondary(path, Some(&p), global_db_options_override, metric_conf, &opt_cfs),
326 (_, true) => typed_store::rocks::open_cf_opts_transactional(path, global_db_options_override, metric_conf, &opt_cfs),
327 _ => typed_store::rocks::open_cf_opts(path, global_db_options_override, metric_conf, &opt_cfs)
328 };
329 db.map(|d| (d, rwopt_cfs))
330 }.expect(&format!("Cannot open DB at {:?}", path));
331 let deprecated_tables = vec![#(stringify!(#deprecated_cfs),)*];
332 let (
333 #(
334 #field_names
335 ),*
336 ) = (#(
337 DBMap::#inner_types::reopen(&db, Some(stringify!(#cf_names)), rwopt_cfs.get(stringify!(#cf_names)).unwrap_or(&typed_store::rocks::ReadWriteOptions::default()), remove_deprecated_tables && deprecated_tables.contains(&stringify!(#cf_names))).expect(&format!("Cannot open {} CF.", stringify!(#cf_names))[..])
338 ),*);
339
340 if as_secondary_with_path.is_none() && remove_deprecated_tables {
341 #(
342 db.drop_cf(stringify!(#deprecated_cfs)).expect("failed to drop a deprecated cf");
343 )*
344 }
345 Self {
346 #(
347 #field_names,
348 )*
349 }
350 }
351 }
352
353
354 impl <
357 #(
358 #generics_names: #generics_bounds_token,
359 )*
360 > #name #generics {
361 #[expect(unused_parens)]
366 pub fn open_tables_read_write(
367 path: std::path::PathBuf,
368 metric_conf: typed_store::rocks::MetricConf,
369 global_db_options_override: Option<typed_store::rocksdb::Options>,
370 tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>
371 ) -> Self {
372 let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, false, metric_conf, global_db_options_override, tables_db_options_override, false);
373 Self {
374 #(
375 #field_names: inner.#field_names,
376 )*
377 }
378 }
379
380 #[expect(unused_parens)]
381 pub fn open_tables_read_write_with_deprecation_option(
382 path: std::path::PathBuf,
383 metric_conf: typed_store::rocks::MetricConf,
384 global_db_options_override: Option<typed_store::rocksdb::Options>,
385 tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>,
386 remove_deprecated_tables: bool,
387 ) -> Self {
388 let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, false, metric_conf, global_db_options_override, tables_db_options_override, remove_deprecated_tables);
389 Self {
390 #(
391 #field_names: inner.#field_names,
392 )*
393 }
394 }
395
396 #[expect(unused_parens)]
401 pub fn open_tables_transactional(
402 path: std::path::PathBuf,
403 metric_conf: typed_store::rocks::MetricConf,
404 global_db_options_override: Option<typed_store::rocksdb::Options>,
405 tables_db_options_override: Option<typed_store::rocks::DBMapTableConfigMap>
406 ) -> Self {
407 let inner = #intermediate_db_map_struct_name::open_tables_impl(path, None, true, metric_conf, global_db_options_override, tables_db_options_override, false);
408 Self {
409 #(
410 #field_names: inner.#field_names,
411 )*
412 }
413 }
414
415 pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
417 vec![#(
418 (stringify!(#field_names).to_owned(), (stringify!(#key_names).to_owned(), stringify!(#value_names).to_owned())),
419 )*].into_iter().collect()
420 }
421
422 pub fn get_read_only_handle (
424 primary_path: std::path::PathBuf,
425 with_secondary_path: Option<std::path::PathBuf>,
426 global_db_options_override: Option<typed_store::rocksdb::Options>,
427 metric_conf: typed_store::rocks::MetricConf,
428 ) -> #secondary_db_map_struct_name #generics {
429 #secondary_db_map_struct_name::open_tables_read_only(primary_path, with_secondary_path, metric_conf, global_db_options_override)
430 }
431 }
432
433
434 pub struct #secondary_db_map_struct_name #generics {
438 #(
439 pub #field_names : DBMap #inner_types,
440 )*
441 }
442
443 impl <
444 #(
445 #generics_names: #generics_bounds_token,
446 )*
447 > #secondary_db_map_struct_name #generics {
448 pub fn open_tables_read_only(
450 primary_path: std::path::PathBuf,
451 with_secondary_path: Option<std::path::PathBuf>,
452 metric_conf: typed_store::rocks::MetricConf,
453 global_db_options_override: Option<typed_store::rocksdb::Options>,
454 ) -> Self {
455 let inner = match with_secondary_path {
456 Some(q) => #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(q), false, metric_conf, global_db_options_override, None, false),
457 None => {
458 let p: std::path::PathBuf = tempfile::tempdir()
459 .expect("Failed to open temporary directory")
460 .keep();
461 #intermediate_db_map_struct_name::open_tables_impl(primary_path, Some(p), false, metric_conf, global_db_options_override, None, false)
462 }
463 };
464 Self {
465 #(
466 #field_names: inner.#field_names,
467 )*
468 }
469 }
470
471 fn cf_name_to_table_name(cf_name: &str) -> eyre::Result<&'static str> {
472 Ok(match cf_name {
473 #(
474 stringify!(#cf_names) => stringify!(#field_names),
475 )*
476 _ => eyre::bail!("No such cf name: {}", cf_name),
477 })
478 }
479
480 pub fn dump(&self, cf_name: &str, page_size: u16, page_number: usize) -> eyre::Result<std::collections::BTreeMap<String, String>> {
483 let table_name = Self::cf_name_to_table_name(cf_name)?;
484
485 Ok(match table_name {
486 #(
487 stringify!(#field_names) => {
488 typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
489 typed_store::traits::Map::unbounded_iter(&self.#field_names)
490 .skip((page_number * (page_size) as usize))
491 .take(page_size as usize)
492 .map(|(k, v)| (format!("{:?}", k), format!("{:?}", v)))
493 .collect::<std::collections::BTreeMap<_, _>>()
494 }
495 )*
496
497 _ => eyre::bail!("No such table name: {}", table_name),
498 })
499 }
500
501 pub fn table_summary(&self, table_name: &str) -> eyre::Result<typed_store::traits::TableSummary> {
504 let mut count = 0;
505 let mut key_bytes = 0;
506 let mut value_bytes = 0;
507 match table_name {
508 #(
509 stringify!(#field_names) => {
510 typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
511 self.#field_names.table_summary()
512 }
513 )*
514
515 _ => eyre::bail!("No such table name: {}", table_name),
516 }
517 }
518
519 pub fn count_keys(&self, table_name: &str) -> eyre::Result<usize> {
522 Ok(match table_name {
523 #(
524 stringify!(#field_names) => {
525 typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
526 typed_store::traits::Map::unbounded_iter(&self.#field_names).count()
527 }
528 )*
529
530 _ => eyre::bail!("No such table name: {}", table_name),
531 })
532 }
533
534 pub fn describe_tables() -> std::collections::BTreeMap<String, (String, String)> {
535 vec![#(
536 (stringify!(#field_names).to_owned(), (stringify!(#key_names).to_owned(), stringify!(#value_names).to_owned())),
537 )*].into_iter().collect()
538 }
539
540 pub fn try_catch_up_with_primary_all(&self) -> eyre::Result<()> {
543 #(
544 typed_store::traits::Map::try_catch_up_with_primary(&self.#field_names)?;
545 )*
546 Ok(())
547 }
548 }
549
550 impl <
551 #(
552 #generics_names: #generics_bounds_token,
553 )*
554 > TypedStoreDebug for #secondary_db_map_struct_name #generics {
555 fn dump_table(
556 &self,
557 table_name: String,
558 page_size: u16,
559 page_number: usize,
560 ) -> eyre::Result<std::collections::BTreeMap<String, String>> {
561 self.dump(table_name.as_str(), page_size, page_number)
562 }
563
564 fn primary_db_name(&self) -> String {
565 stringify!(#name).to_owned()
566 }
567
568 fn describe_all_tables(&self) -> std::collections::BTreeMap<String, (String, String)> {
569 Self::describe_tables()
570 }
571
572 fn count_table_keys(&self, table_name: String) -> eyre::Result<usize> {
573 self.count_keys(table_name.as_str())
574 }
575
576 fn table_summary(&self, table_name: String) -> eyre::Result<TableSummary> {
577 self.table_summary(table_name.as_str())
578 }
579 }
580 })
581}