iota_protocol_config_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5extern crate proc_macro;
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
10
11/// This proc macro generates getters, attribute lookup, etc for protocol config
12/// fields of type `Option<T>` and for the feature flags
13/// Example for a field: `new_constant: Option<u64>`, and for feature flags
14/// `feature: bool`, we derive
15/// ```rust,ignore
16///     /// Returns the value of the field if exists at the given version, otherwise panic
17///     pub fn new_constant(&self) -> u64 {
18///         self.new_constant.expect(Self::CONSTANT_ERR_MSG)
19///     }
20///     /// Returns the value of the field if exists at the given version, otherwise None.
21///     pub fn new_constant_as_option(&self) -> Option<u64> {
22///         self.new_constant
23///     }
24///     // We auto derive an enum such that the variants are all the types of the fields
25///     pub enum ProtocolConfigValue {
26///        u32(u32),
27///        u64(u64),
28///        ..............
29///     }
30///     // This enum is used to return field values so that the type is also encoded in the response
31///
32///     /// Returns the value of the field if exists at the given version, otherwise None
33///     pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue>;
34///
35///     /// Returns a map of all configs to values
36///     pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>>;
37///
38///     /// Returns a feature by the string name or None if it doesn't exist
39///     pub fn lookup_feature(&self, value: String) -> Option<bool>;
40///
41///     /// Returns a map of all features to values
42///     pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool>;
43/// ```
44#[proc_macro_derive(ProtocolConfigAccessors)]
45pub fn accessors_macro(input: TokenStream) -> TokenStream {
46    let ast = parse_macro_input!(input as DeriveInput);
47
48    let struct_name = &ast.ident;
49    let data = &ast.data;
50    let mut inner_types = vec![];
51
52    let tokens = match data {
53        Data::Struct(data_struct) => match &data_struct.fields {
54            // Operate on each field of the ProtocolConfig struct
55            Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
56                // Extract field name and type
57                let field_name = field.ident.as_ref().expect("Field must be named");
58                let field_type = &field.ty;
59                // Check if field is of type Option<T>
60                match field_type {
61                    Type::Path(type_path)
62                        if type_path
63                            .path
64                            .segments
65                            .last()
66                            .is_some_and(|segment| segment.ident == "Option") =>
67                    {
68                        // Extract inner type T from Option<T>
69                        let inner_type = if let syn::PathArguments::AngleBracketed(
70                            angle_bracketed_generic_arguments,
71                        ) = &type_path.path.segments.last().unwrap().arguments
72                        {
73                            if let Some(syn::GenericArgument::Type(ty)) =
74                                angle_bracketed_generic_arguments.args.first()
75                            {
76                                ty.clone()
77                            } else {
78                                panic!("Expected a type argument.");
79                            }
80                        } else {
81                            panic!("Expected angle bracketed arguments.");
82                        };
83
84                        let as_option_name = format!("{field_name}_as_option");
85                        let as_option_name: proc_macro2::TokenStream =
86                        as_option_name.parse().unwrap();
87                        let test_setter_name: proc_macro2::TokenStream =
88                            format!("set_{field_name}_for_testing").parse().unwrap();
89                        let test_un_setter_name: proc_macro2::TokenStream =
90                            format!("disable_{field_name}_for_testing").parse().unwrap();
91                        let test_setter_from_str_name: proc_macro2::TokenStream =
92                            format!("set_{field_name}_from_str_for_testing").parse().unwrap();
93
94                        let getter = quote! {
95                            // Derive the getter
96                            pub fn #field_name(&self) -> #inner_type {
97                                self.#field_name.expect(Self::CONSTANT_ERR_MSG)
98                            }
99
100                            pub fn #as_option_name(&self) -> #field_type {
101                                self.#field_name
102                            }
103                        };
104
105                        let test_setter = quote! {
106                            // Derive the setter
107                            pub fn #test_setter_name(&mut self, val: #inner_type) {
108                                self.#field_name = Some(val);
109                            }
110
111                            // Derive the setter from String
112                            pub fn #test_setter_from_str_name(&mut self, val: String) {
113                                use std::str::FromStr;
114                                self.#test_setter_name(#inner_type::from_str(&val).unwrap());
115                            }
116
117                            // Derive the un-setter
118                            pub fn #test_un_setter_name(&mut self) {
119                                self.#field_name = None;
120                            }
121                        };
122
123                        let value_setter = quote! {
124                            stringify!(#field_name) => self.#test_setter_from_str_name(val),
125                        };
126
127
128                        let value_lookup = quote! {
129                            stringify!(#field_name) => self.#field_name.map(|v| ProtocolConfigValue::#inner_type(v)),
130                        };
131
132                        let field_name_str = quote! {
133                            stringify!(#field_name)
134                        };
135
136                        // Track all the types seen
137                        if inner_types.contains(&inner_type) {
138                            None
139                        } else {
140                            inner_types.push(inner_type.clone());
141                            Some(quote! {
142                               #inner_type
143                            })
144                        };
145
146                        Some(((getter, (test_setter, value_setter)), (value_lookup, field_name_str)))
147                    }
148                    _ => None,
149                }
150            }),
151            _ => panic!("Only named fields are supported."),
152        },
153        _ => panic!("Only structs supported."),
154    };
155
156    #[expect(clippy::type_complexity)]
157    let ((getters, (test_setters, value_setters)), (value_lookup, field_names_str)): (
158        (Vec<_>, (Vec<_>, Vec<_>)),
159        (Vec<_>, Vec<_>),
160    ) = tokens.unzip();
161    let output = quote! {
162        // For each getter, expand it out into a function in the impl block
163        impl #struct_name {
164            const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
165            #(#getters)*
166
167            /// Lookup a config attribute by its string representation
168            pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
169                match value.as_str() {
170                    #(#value_lookup)*
171                    _ => None,
172                }
173            }
174
175            /// Get a map of all config attribute from string representations
176            pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
177                vec![
178                    #(((#field_names_str).to_owned(), self.lookup_attr((#field_names_str).to_owned())),)*
179                    ].into_iter().collect()
180            }
181
182            /// Get the feature flags
183            pub fn lookup_feature(&self, value: String) -> Option<bool> {
184                self.feature_flags.lookup_attr(value)
185            }
186
187            pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
188                self.feature_flags.attr_map()
189            }
190        }
191
192        // For each attr, derive a setter from the raw value and from string repr
193        impl #struct_name {
194            #(#test_setters)*
195
196            pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
197                match attr.as_str() {
198                    #(#value_setters)*
199                    _ => panic!("Attempting to set unknown attribute: {}", attr),
200                }
201            }
202        }
203
204        #[expect(non_camel_case_types)]
205        #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
206        pub enum ProtocolConfigValue {
207            #(#inner_types(#inner_types),)*
208        }
209
210        impl std::fmt::Display for ProtocolConfigValue {
211            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212                use std::fmt::Write;
213                let mut writer = String::new();
214                match self {
215                    #(
216                        ProtocolConfigValue::#inner_types(x) => {
217                            write!(writer, "{}", x)?;
218                        }
219                    )*
220                }
221                write!(f, "{}", writer)
222            }
223        }
224    };
225
226    TokenStream::from(output)
227}
228
229#[proc_macro_derive(ProtocolConfigOverride)]
230pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
231    let ast = parse_macro_input!(input as DeriveInput);
232
233    // Create a new struct name by appending "Optional".
234    let struct_name = &ast.ident;
235    let optional_struct_name =
236        syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
237
238    // Extract the fields from the struct
239    let fields = match &ast.data {
240        Data::Struct(data_struct) => match &data_struct.fields {
241            Fields::Named(fields_named) => &fields_named.named,
242            _ => panic!("ProtocolConfig must have named fields"),
243        },
244        _ => panic!("ProtocolConfig must be a struct"),
245    };
246
247    // Create new fields with types wrapped in Option.
248    let optional_fields = fields.iter().map(|field| {
249        let field_name = &field.ident;
250        let field_type = &field.ty;
251        quote! {
252            #field_name: Option<#field_type>
253        }
254    });
255
256    // Generate the function to update the original struct.
257    let update_fields = fields.iter().map(|field| {
258        let field_name = &field.ident;
259        quote! {
260            if let Some(value) = self.#field_name {
261                tracing::warn!(
262                    "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
263                    stringify!(#field_name),
264                );
265                config.#field_name = value;
266            }
267        }
268    });
269
270    // Generate the new struct definition.
271    let output = quote! {
272        #[derive(serde::Deserialize, Debug)]
273        pub struct #optional_struct_name {
274            #(#optional_fields,)*
275        }
276
277        impl #optional_struct_name {
278            pub fn apply_to(self, config: &mut #struct_name) {
279                #(#update_fields)*
280            }
281        }
282    };
283
284    TokenStream::from(output)
285}
286
287#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters)]
288pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
289    let ast = parse_macro_input!(input as DeriveInput);
290
291    let struct_name = &ast.ident;
292    let data = &ast.data;
293
294    let getters = match data {
295        Data::Struct(data_struct) => match &data_struct.fields {
296            // Operate on each field of the ProtocolConfig struct
297            Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
298                // Extract field name and type
299                let field_name = field.ident.as_ref().expect("Field must be named");
300                let field_type = &field.ty;
301                // Check if field is of type bool
302                match field_type {
303                    Type::Path(type_path)
304                        if type_path
305                            .path
306                            .segments
307                            .last()
308                            .is_some_and(|segment| segment.ident == "bool") =>
309                    {
310                        Some((
311                            quote! {
312                                // Derive the getter
313                                pub fn #field_name(&self) -> #field_type {
314                                    self.#field_name
315                                }
316                            },
317                            (
318                                quote! {
319                                    stringify!(#field_name) => Some(self.#field_name),
320                                },
321                                quote! {
322                                    stringify!(#field_name)
323                                },
324                            ),
325                        ))
326                    }
327                    _ => None,
328                }
329            }),
330            _ => panic!("Only named fields are supported."),
331        },
332        _ => panic!("Only structs supported."),
333    };
334
335    let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
336        getters.unzip();
337
338    let output = quote! {
339        // For each getter, expand it out into a function in the impl block
340        impl #struct_name {
341            #(#by_fn_getters)*
342
343            /// Lookup a feature flag by its string representation
344            pub fn lookup_attr(&self, value: String) -> Option<bool> {
345                match value.as_str() {
346                    #(#string_name_getters)*
347                    _ => None,
348                }
349            }
350
351            /// Get a map of all feature flags from string representations
352            pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
353                vec![
354                    // Okay to unwrap since we added all above
355                    #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
356                    ].into_iter().collect()
357            }
358        }
359    };
360
361    TokenStream::from(output)
362}