iota_protocol_config_macros/
lib.rs1extern crate proc_macro;
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
10
11#[proc_macro_derive(ProtocolConfigAccessors)]
45pub fn accessors_macro(input: TokenStream) -> TokenStream {
46 let ast = parse_macro_input!(input as DeriveInput);
47
48 let struct_name = &ast.ident;
49 let data = &ast.data;
50 let mut inner_types = vec![];
51
52 let tokens = match data {
53 Data::Struct(data_struct) => match &data_struct.fields {
54 Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
56 let field_name = field.ident.as_ref().expect("Field must be named");
58 let field_type = &field.ty;
59 match field_type {
61 Type::Path(type_path)
62 if type_path
63 .path
64 .segments
65 .last()
66 .is_some_and(|segment| segment.ident == "Option") =>
67 {
68 let inner_type = if let syn::PathArguments::AngleBracketed(
70 angle_bracketed_generic_arguments,
71 ) = &type_path.path.segments.last().unwrap().arguments
72 {
73 if let Some(syn::GenericArgument::Type(ty)) =
74 angle_bracketed_generic_arguments.args.first()
75 {
76 ty.clone()
77 } else {
78 panic!("Expected a type argument.");
79 }
80 } else {
81 panic!("Expected angle bracketed arguments.");
82 };
83
84 let as_option_name = format!("{field_name}_as_option");
85 let as_option_name: proc_macro2::TokenStream =
86 as_option_name.parse().unwrap();
87 let test_setter_name: proc_macro2::TokenStream =
88 format!("set_{field_name}_for_testing").parse().unwrap();
89 let test_un_setter_name: proc_macro2::TokenStream =
90 format!("disable_{field_name}_for_testing").parse().unwrap();
91 let test_setter_from_str_name: proc_macro2::TokenStream =
92 format!("set_{field_name}_from_str_for_testing").parse().unwrap();
93
94 let getter = quote! {
95 pub fn #field_name(&self) -> #inner_type {
97 self.#field_name.expect(Self::CONSTANT_ERR_MSG)
98 }
99
100 pub fn #as_option_name(&self) -> #field_type {
101 self.#field_name
102 }
103 };
104
105 let test_setter = quote! {
106 pub fn #test_setter_name(&mut self, val: #inner_type) {
108 self.#field_name = Some(val);
109 }
110
111 pub fn #test_setter_from_str_name(&mut self, val: String) {
113 use std::str::FromStr;
114 self.#test_setter_name(#inner_type::from_str(&val).unwrap());
115 }
116
117 pub fn #test_un_setter_name(&mut self) {
119 self.#field_name = None;
120 }
121 };
122
123 let value_setter = quote! {
124 stringify!(#field_name) => self.#test_setter_from_str_name(val),
125 };
126
127
128 let value_lookup = quote! {
129 stringify!(#field_name) => self.#field_name.map(|v| ProtocolConfigValue::#inner_type(v)),
130 };
131
132 let field_name_str = quote! {
133 stringify!(#field_name)
134 };
135
136 if inner_types.contains(&inner_type) {
138 None
139 } else {
140 inner_types.push(inner_type.clone());
141 Some(quote! {
142 #inner_type
143 })
144 };
145
146 Some(((getter, (test_setter, value_setter)), (value_lookup, field_name_str)))
147 }
148 _ => None,
149 }
150 }),
151 _ => panic!("Only named fields are supported."),
152 },
153 _ => panic!("Only structs supported."),
154 };
155
156 #[expect(clippy::type_complexity)]
157 let ((getters, (test_setters, value_setters)), (value_lookup, field_names_str)): (
158 (Vec<_>, (Vec<_>, Vec<_>)),
159 (Vec<_>, Vec<_>),
160 ) = tokens.unzip();
161 let output = quote! {
162 impl #struct_name {
164 const CONSTANT_ERR_MSG: &'static str = "protocol constant not present in current protocol version";
165 #(#getters)*
166
167 pub fn lookup_attr(&self, value: String) -> Option<ProtocolConfigValue> {
169 match value.as_str() {
170 #(#value_lookup)*
171 _ => None,
172 }
173 }
174
175 pub fn attr_map(&self) -> std::collections::BTreeMap<String, Option<ProtocolConfigValue>> {
177 vec![
178 #(((#field_names_str).to_owned(), self.lookup_attr((#field_names_str).to_owned())),)*
179 ].into_iter().collect()
180 }
181
182 pub fn lookup_feature(&self, value: String) -> Option<bool> {
184 self.feature_flags.lookup_attr(value)
185 }
186
187 pub fn feature_map(&self) -> std::collections::BTreeMap<String, bool> {
188 self.feature_flags.attr_map()
189 }
190 }
191
192 impl #struct_name {
194 #(#test_setters)*
195
196 pub fn set_attr_for_testing(&mut self, attr: String, val: String) {
197 match attr.as_str() {
198 #(#value_setters)*
199 _ => panic!("Attempting to set unknown attribute: {}", attr),
200 }
201 }
202 }
203
204 #[expect(non_camel_case_types)]
205 #[derive(Clone, Serialize, Debug, PartialEq, Deserialize, schemars::JsonSchema)]
206 pub enum ProtocolConfigValue {
207 #(#inner_types(#inner_types),)*
208 }
209
210 impl std::fmt::Display for ProtocolConfigValue {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 use std::fmt::Write;
213 let mut writer = String::new();
214 match self {
215 #(
216 ProtocolConfigValue::#inner_types(x) => {
217 write!(writer, "{}", x)?;
218 }
219 )*
220 }
221 write!(f, "{}", writer)
222 }
223 }
224 };
225
226 TokenStream::from(output)
227}
228
229#[proc_macro_derive(ProtocolConfigOverride)]
230pub fn protocol_config_override_macro(input: TokenStream) -> TokenStream {
231 let ast = parse_macro_input!(input as DeriveInput);
232
233 let struct_name = &ast.ident;
235 let optional_struct_name =
236 syn::Ident::new(&format!("{}Optional", struct_name), struct_name.span());
237
238 let fields = match &ast.data {
240 Data::Struct(data_struct) => match &data_struct.fields {
241 Fields::Named(fields_named) => &fields_named.named,
242 _ => panic!("ProtocolConfig must have named fields"),
243 },
244 _ => panic!("ProtocolConfig must be a struct"),
245 };
246
247 let optional_fields = fields.iter().map(|field| {
249 let field_name = &field.ident;
250 let field_type = &field.ty;
251 quote! {
252 #field_name: Option<#field_type>
253 }
254 });
255
256 let update_fields = fields.iter().map(|field| {
258 let field_name = &field.ident;
259 quote! {
260 if let Some(value) = self.#field_name {
261 tracing::warn!(
262 "ProtocolConfig field \"{}\" has been overridden with the value: {value:?}",
263 stringify!(#field_name),
264 );
265 config.#field_name = value;
266 }
267 }
268 });
269
270 let output = quote! {
272 #[derive(serde::Deserialize, Debug)]
273 pub struct #optional_struct_name {
274 #(#optional_fields,)*
275 }
276
277 impl #optional_struct_name {
278 pub fn apply_to(self, config: &mut #struct_name) {
279 #(#update_fields)*
280 }
281 }
282 };
283
284 TokenStream::from(output)
285}
286
287#[proc_macro_derive(ProtocolConfigFeatureFlagsGetters)]
288pub fn feature_flag_getters_macro(input: TokenStream) -> TokenStream {
289 let ast = parse_macro_input!(input as DeriveInput);
290
291 let struct_name = &ast.ident;
292 let data = &ast.data;
293
294 let getters = match data {
295 Data::Struct(data_struct) => match &data_struct.fields {
296 Fields::Named(fields_named) => fields_named.named.iter().filter_map(|field| {
298 let field_name = field.ident.as_ref().expect("Field must be named");
300 let field_type = &field.ty;
301 match field_type {
303 Type::Path(type_path)
304 if type_path
305 .path
306 .segments
307 .last()
308 .is_some_and(|segment| segment.ident == "bool") =>
309 {
310 Some((
311 quote! {
312 pub fn #field_name(&self) -> #field_type {
314 self.#field_name
315 }
316 },
317 (
318 quote! {
319 stringify!(#field_name) => Some(self.#field_name),
320 },
321 quote! {
322 stringify!(#field_name)
323 },
324 ),
325 ))
326 }
327 _ => None,
328 }
329 }),
330 _ => panic!("Only named fields are supported."),
331 },
332 _ => panic!("Only structs supported."),
333 };
334
335 let (by_fn_getters, (string_name_getters, field_names)): (Vec<_>, (Vec<_>, Vec<_>)) =
336 getters.unzip();
337
338 let output = quote! {
339 impl #struct_name {
341 #(#by_fn_getters)*
342
343 pub fn lookup_attr(&self, value: String) -> Option<bool> {
345 match value.as_str() {
346 #(#string_name_getters)*
347 _ => None,
348 }
349 }
350
351 pub fn attr_map(&self) -> std::collections::BTreeMap<String, bool> {
353 vec![
354 #(((#field_names).to_owned(), self.lookup_attr((#field_names).to_owned()).unwrap()),)*
356 ].into_iter().collect()
357 }
358 }
359 };
360
361 TokenStream::from(output)
362}