iota_proc_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use proc_macro::TokenStream;
6use quote::{ToTokens, quote, quote_spanned};
7use syn::{
8    Attribute, BinOp, Data, DataEnum, DeriveInput, Expr, ExprBinary, ExprMacro, Item, ItemMacro,
9    Stmt, StmtMacro, Token, UnOp,
10    fold::{Fold, fold_expr, fold_item_macro, fold_stmt},
11    parse::Parser,
12    parse_macro_input, parse2,
13    punctuated::Punctuated,
14    spanned::Spanned,
15};
16
17#[proc_macro_attribute]
18pub fn init_static_initializers(_args: TokenStream, item: TokenStream) -> TokenStream {
19    let mut input = parse_macro_input!(item as syn::ItemFn);
20
21    let body = &input.block;
22    input.block = syn::parse2(quote! {
23        {
24            // We have some lazily-initialized static state in the program. The initializers
25            // alter the thread-local hash container state any time they create a new hash
26            // container. Therefore, we need to ensure that these initializers are run in a
27            // separate thread before the first test thread is launched. Otherwise, they would
28            // run inside of the first test thread, but not subsequent ones.
29            //
30            // Note that none of this has any effect on process-level determinism. Without this
31            // code, we can still get the same test results from two processes started with the
32            // same seed.
33            //
34            // However, when using sim_test(check_determinism) or MSIM_TEST_CHECK_DETERMINISM=1,
35            // we want the same test invocation to be deterministic when run twice
36            // _in the same process_, so we need to take care of this. This will also
37            // be very important for being able to reproduce a failure that occurs in the Nth
38            // iteration of a multi-iteration test run.
39            std::thread::spawn(|| {
40                use iota_protocol_config::ProtocolConfig;
41                ::iota_simulator::telemetry_subscribers::init_for_testing();
42                ::iota_simulator::iota_types::execution::get_denied_certificates();
43                ::iota_simulator::iota_framework::BuiltInFramework::all_package_ids();
44                ::iota_simulator::iota_types::gas::IotaGasStatus::new_unmetered();
45
46                // For reasons I can't understand, LruCache causes divergent behavior the second
47                // time one is constructed and inserted into, so construct one before the first
48                // test run for determinism.
49                let mut cache = ::iota_simulator::lru::LruCache::new(1.try_into().unwrap());
50                cache.put(1, 1);
51
52                {
53                    // Initialize the static initializers here:
54                    // https://github.com/move-language/move/blob/652badf6fd67e1d4cc2aa6dc69d63ad14083b673/language/tools/move-package/src/package_lock.rs#L12
55                    use std::path::PathBuf;
56                    use iota_simulator::iota_move_build::{BuildConfig, IotaPackageHooks};
57                    use iota_simulator::tempfile::TempDir;
58                    use iota_simulator::move_package::package_hooks::register_package_hooks;
59
60                    register_package_hooks(Box::new(IotaPackageHooks {}));
61                    let mut path = PathBuf::from(env!("SIMTEST_STATIC_INIT_MOVE"));
62                    let mut build_config = BuildConfig::default();
63
64                    build_config.config.install_dir = Some(TempDir::new().unwrap().into_path());
65                    let _all_module_bytes = build_config
66                        .build(&path)
67                        .unwrap()
68                        .get_package_bytes(/* with_unpublished_deps */ false);
69                }
70
71
72                use ::iota_simulator::anemo_tower::callback::CallbackLayer;
73                use ::iota_simulator::anemo_tower::trace::DefaultMakeSpan;
74                use ::iota_simulator::anemo_tower::trace::DefaultOnFailure;
75                use ::iota_simulator::anemo_tower::trace::TraceLayer;
76                use ::iota_metrics::metrics_network::{NetworkMetrics, MetricsMakeCallbackHandler};
77
78                use std::sync::Arc;
79                use ::iota_simulator::fastcrypto::traits::KeyPair;
80                use ::iota_simulator::rand_crate::rngs::{StdRng, OsRng};
81                use ::iota_simulator::rand::SeedableRng;
82                use ::iota_simulator::tower::ServiceBuilder;
83
84                // anemo uses x509-parser, which has many lazy static variables. start a network to
85                // initialize all that static state before the first test.
86                let rt = ::iota_simulator::runtime::Runtime::new();
87                rt.block_on(async move {
88                    use ::iota_simulator::anemo::{Network, Request};
89
90                    let make_network = |port: u16| {
91                        let registry = prometheus::Registry::new();
92                        let inbound_network_metrics =
93                            NetworkMetrics::new("iota", "inbound", &registry);
94                        let outbound_network_metrics =
95                            NetworkMetrics::new("iota", "outbound", &registry);
96
97                        let service = ServiceBuilder::new()
98                            .layer(
99                                TraceLayer::new_for_server_errors()
100                                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
101                                    .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
102                            )
103                            .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
104                                Arc::new(inbound_network_metrics),
105                                usize::MAX,
106                            )))
107                            .service(::iota_simulator::anemo::Router::new());
108
109                        let outbound_layer = ServiceBuilder::new()
110                            .layer(
111                                TraceLayer::new_for_client_and_server_errors()
112                                    .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
113                                    .on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
114                            )
115                            .layer(CallbackLayer::new(MetricsMakeCallbackHandler::new(
116                                Arc::new(outbound_network_metrics),
117                                usize::MAX,
118                            )))
119                            .into_inner();
120
121
122                        Network::bind(format!("127.0.0.1:{}", port))
123                            .server_name("static-init-network")
124                            .private_key(
125                                ::iota_simulator::fastcrypto::ed25519::Ed25519KeyPair::generate(&mut StdRng::from_rng(OsRng).unwrap())
126                                    .private()
127                                    .0
128                                    .to_bytes(),
129                            )
130                            .start(service)
131                            .unwrap()
132                    };
133                    let n1 = make_network(80);
134                    let n2 = make_network(81);
135
136                    let _peer = n1.connect(n2.local_addr()).await.unwrap();
137                });
138            }).join().unwrap();
139
140            #body
141        }
142    })
143    .expect("Parsing failure");
144
145    let result = quote! {
146        #input
147    };
148
149    result.into()
150}
151
152/// The iota_test macro will invoke either `#[msim::test]` or `#[tokio::test]`,
153/// depending on whether the simulator config var is enabled.
154///
155/// This should be used for tests that can meaningfully run in either
156/// environment.
157#[proc_macro_attribute]
158pub fn iota_test(args: TokenStream, item: TokenStream) -> TokenStream {
159    let input = parse_macro_input!(item as syn::ItemFn);
160    let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
161    let args = arg_parser.parse(args).unwrap().into_iter();
162
163    let header = if cfg!(msim) {
164        quote! {
165            #[::iota_simulator::sim_test(crate = "iota_simulator", #(#args)* )]
166        }
167    } else {
168        quote! {
169            #[::tokio::test(#(#args)*)]
170        }
171    };
172
173    let result = quote! {
174        #header
175        #[::iota_macros::init_static_initializers]
176        #input
177    };
178
179    result.into()
180}
181
182/// The `sim_test` macro will invoke `#[msim::test]` if the simulator config var
183/// (`msim`) is enabled.
184///
185/// On this premise, this macro can be used in order to pass any
186/// simulator-specific arguments, such as `check_determinism`,
187/// which is not understood by tokio.
188///
189/// If the simulator config var is disabled, tests will run via
190/// `#[tokio::test]`, unless disabled by setting the environment variable
191/// `IOTA_SKIP_SIMTESTS`.
192#[proc_macro_attribute]
193pub fn sim_test(args: TokenStream, item: TokenStream) -> TokenStream {
194    let input = parse_macro_input!(item as syn::ItemFn);
195    let arg_parser = Punctuated::<syn::Meta, Token![,]>::parse_terminated;
196    let args = arg_parser.parse(args).unwrap().into_iter();
197
198    let ignore = input
199        .attrs
200        .iter()
201        .find(|attr| attr.path().is_ident("ignore"))
202        .map_or(quote! {}, |_| quote! { #[ignore] });
203
204    let result = if cfg!(msim) {
205        let sig = &input.sig;
206        let return_type = &sig.output;
207        let body = &input.block;
208        quote! {
209            #[::iota_simulator::sim_test(crate = "iota_simulator", #(#args),*)]
210            #[::iota_macros::init_static_initializers]
211            #ignore
212            #sig {
213                async fn body_fn() #return_type { #body }
214
215                let ret = body_fn().await;
216
217                ::iota_simulator::task::shutdown_all_nodes();
218
219                // all node handles should have been dropped after the above block exits, but task
220                // shutdown is asynchronous, so we need a brief delay before checking for leaks.
221                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
222
223                assert_eq!(
224                    iota_simulator::NodeLeakDetector::get_current_node_count(),
225                    0,
226                    "IotaNode leak detected"
227                );
228
229                ret
230            }
231        }
232    } else {
233        let fn_name = &input.sig.ident;
234        let sig = &input.sig;
235        let body = &input.block;
236        quote! {
237            #[expect(clippy::needless_return)]
238            #[tokio::test]
239            #ignore
240            #sig {
241                if std::env::var("IOTA_SKIP_SIMTESTS").is_ok() {
242                    println!("not running test {} in `cargo test`: IOTA_SKIP_SIMTESTS is set", stringify!(#fn_name));
243
244                    struct Ret;
245
246                    impl From<Ret> for () {
247                        fn from(_ret: Ret) -> Self {
248                        }
249                    }
250
251                    impl<E> From<Ret> for Result<(), E> {
252                        fn from(_ret: Ret) -> Self {
253                            Ok(())
254                        }
255                    }
256
257                    return Ret.into();
258                }
259
260                #body
261            }
262        }
263    };
264
265    result.into()
266}
267
268#[proc_macro]
269pub fn checked_arithmetic(input: TokenStream) -> TokenStream {
270    let input_file = CheckArithmetic.fold_file(parse_macro_input!(input));
271
272    let output_items = input_file.items;
273
274    let output = quote! {
275        #(#output_items)*
276    };
277
278    TokenStream::from(output)
279}
280
281#[proc_macro_attribute]
282pub fn with_checked_arithmetic(_attr: TokenStream, item: TokenStream) -> TokenStream {
283    let input_item = parse_macro_input!(item as Item);
284    match input_item {
285        Item::Fn(input_fn) => {
286            let transformed_fn = CheckArithmetic.fold_item_fn(input_fn);
287            TokenStream::from(quote! { #transformed_fn })
288        }
289        Item::Impl(input_impl) => {
290            let transformed_impl = CheckArithmetic.fold_item_impl(input_impl);
291            TokenStream::from(quote! { #transformed_impl })
292        }
293        item => {
294            let transformed_impl = CheckArithmetic.fold_item(item);
295            TokenStream::from(quote! { #transformed_impl })
296        }
297    }
298}
299
300struct CheckArithmetic;
301
302impl CheckArithmetic {
303    fn maybe_skip_macro(&self, attrs: &mut Vec<Attribute>) -> bool {
304        if let Some(idx) = attrs
305            .iter()
306            .position(|attr| attr.path().is_ident("skip_checked_arithmetic"))
307        {
308            // Skip processing macro because it is annotated with
309            // #[skip_checked_arithmetic]
310            attrs.remove(idx);
311            true
312        } else {
313            false
314        }
315    }
316
317    fn process_macro_contents(
318        &mut self,
319        tokens: proc_macro2::TokenStream,
320    ) -> syn::Result<proc_macro2::TokenStream> {
321        // Parse the macro's contents as a comma-separated list of expressions.
322        let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
323        let Ok(exprs) = parser.parse(tokens.clone().into()) else {
324            return Err(syn::Error::new_spanned(
325                tokens,
326                "could not process macro contents - use #[skip_checked_arithmetic] to skip this macro",
327            ));
328        };
329
330        // Fold each sub expression.
331        let folded_exprs = exprs
332            .into_iter()
333            .map(|expr| self.fold_expr(expr))
334            .collect::<Vec<_>>();
335
336        // Convert the folded expressions back into tokens and reconstruct the macro.
337        let mut folded_tokens = proc_macro2::TokenStream::new();
338        for (i, folded_expr) in folded_exprs.into_iter().enumerate() {
339            if i > 0 {
340                folded_tokens.extend(std::iter::once::<proc_macro2::TokenTree>(
341                    proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone).into(),
342                ));
343            }
344            folded_expr.to_tokens(&mut folded_tokens);
345        }
346
347        Ok(folded_tokens)
348    }
349}
350
351impl Fold for CheckArithmetic {
352    fn fold_stmt(&mut self, stmt: Stmt) -> Stmt {
353        let stmt = fold_stmt(self, stmt);
354        if let Stmt::Macro(stmt_macro) = stmt {
355            let StmtMacro {
356                mut attrs,
357                mut mac,
358                semi_token,
359            } = stmt_macro;
360
361            if self.maybe_skip_macro(&mut attrs) {
362                Stmt::Macro(StmtMacro {
363                    attrs,
364                    mac,
365                    semi_token,
366                })
367            } else {
368                match self.process_macro_contents(mac.tokens.clone()) {
369                    Ok(folded_tokens) => {
370                        mac.tokens = folded_tokens;
371                        Stmt::Macro(StmtMacro {
372                            attrs,
373                            mac,
374                            semi_token,
375                        })
376                    }
377                    Err(error) => parse2(error.to_compile_error()).unwrap(),
378                }
379            }
380        } else {
381            stmt
382        }
383    }
384
385    fn fold_item_macro(&mut self, mut item_macro: ItemMacro) -> ItemMacro {
386        if !self.maybe_skip_macro(&mut item_macro.attrs) {
387            let err = syn::Error::new_spanned(
388                item_macro.to_token_stream(),
389                "cannot process macros - use #[skip_checked_arithmetic] to skip \
390                    processing this macro",
391            );
392
393            return parse2(err.to_compile_error()).unwrap();
394        }
395        fold_item_macro(self, item_macro)
396    }
397
398    fn fold_expr(&mut self, expr: Expr) -> Expr {
399        let span = expr.span();
400        let expr = fold_expr(self, expr);
401        let expr = match expr {
402            Expr::Macro(expr_macro) => {
403                let ExprMacro { mut attrs, mut mac } = expr_macro;
404
405                if self.maybe_skip_macro(&mut attrs) {
406                    return Expr::Macro(ExprMacro { attrs, mac });
407                } else {
408                    match self.process_macro_contents(mac.tokens.clone()) {
409                        Ok(folded_tokens) => {
410                            mac.tokens = folded_tokens;
411                            let expr_macro = Expr::Macro(ExprMacro { attrs, mac });
412                            quote!(#expr_macro)
413                        }
414                        Err(error) => {
415                            return Expr::Verbatim(error.to_compile_error());
416                        }
417                    }
418                }
419            }
420
421            Expr::Binary(expr_binary) => {
422                let ExprBinary {
423                    attrs,
424                    mut left,
425                    op,
426                    mut right,
427                } = expr_binary;
428
429                fn remove_parens(expr: &mut Expr) {
430                    if let Expr::Paren(paren) = expr {
431                        // i don't even think rust allows this, but just in case
432                        assert!(paren.attrs.is_empty(), "TODO: attrs on parenthesized");
433                        *expr = *paren.expr.clone();
434                    }
435                }
436
437                macro_rules! wrap_op {
438                    ($left: expr, $right: expr, $method: ident, $span: expr) => {{
439                        // Remove parens from exprs since both sides get assigned to tmp variables.
440                        // otherwise we get lint errors
441                        remove_parens(&mut $left);
442                        remove_parens(&mut $right);
443
444                        quote_spanned!($span => {
445                            // assign in one stmt in case either #left or #right contains
446                            // references to `left` or `right` symbols.
447                            let (left, right) = (#left, #right);
448                            left.$method(right)
449                                .unwrap_or_else(||
450                                    panic!(
451                                        "Overflow or underflow in {} {} + {}",
452                                        stringify!($method),
453                                        left,
454                                        right,
455                                    )
456                                )
457                        })
458                    }};
459                }
460
461                macro_rules! wrap_op_assign {
462                    ($left: expr, $right: expr, $method: ident, $span: expr) => {{
463                        // Remove parens from exprs since both sides get assigned to tmp variables.
464                        // otherwise we get lint errors
465                        remove_parens(&mut $left);
466                        remove_parens(&mut $right);
467
468                        quote_spanned!($span => {
469                            // assign in one stmt in case either #left or #right contains
470                            // references to `left` or `right` symbols.
471                            let (left, right) = (&mut #left, #right);
472                            *left = (*left).$method(right)
473                                .unwrap_or_else(||
474                                    panic!(
475                                        "Overflow or underflow in {} {} + {}",
476                                        stringify!($method),
477                                        *left,
478                                        right
479                                    )
480                                )
481                        })
482                    }};
483                }
484
485                match op {
486                    BinOp::Add(_) => {
487                        wrap_op!(left, right, checked_add, span)
488                    }
489                    BinOp::Sub(_) => {
490                        wrap_op!(left, right, checked_sub, span)
491                    }
492                    BinOp::Mul(_) => {
493                        wrap_op!(left, right, checked_mul, span)
494                    }
495                    BinOp::Div(_) => {
496                        wrap_op!(left, right, checked_div, span)
497                    }
498                    BinOp::Rem(_) => {
499                        wrap_op!(left, right, checked_rem, span)
500                    }
501                    BinOp::AddAssign(_) => {
502                        wrap_op_assign!(left, right, checked_add, span)
503                    }
504                    BinOp::SubAssign(_) => {
505                        wrap_op_assign!(left, right, checked_sub, span)
506                    }
507                    BinOp::MulAssign(_) => {
508                        wrap_op_assign!(left, right, checked_mul, span)
509                    }
510                    BinOp::DivAssign(_) => {
511                        wrap_op_assign!(left, right, checked_div, span)
512                    }
513                    BinOp::RemAssign(_) => {
514                        wrap_op_assign!(left, right, checked_rem, span)
515                    }
516                    _ => {
517                        let expr_binary = ExprBinary {
518                            attrs,
519                            left,
520                            op,
521                            right,
522                        };
523                        quote_spanned!(span => #expr_binary)
524                    }
525                }
526            }
527            Expr::Unary(expr_unary) => {
528                let op = &expr_unary.op;
529                let operand = &expr_unary.expr;
530                match op {
531                    UnOp::Neg(_) => {
532                        quote_spanned!(span => #operand.checked_neg().expect("Overflow or underflow in negation"))
533                    }
534                    _ => quote_spanned!(span => #expr_unary),
535                }
536            }
537            _ => quote_spanned!(span => #expr),
538        };
539
540        parse2(expr).unwrap()
541    }
542}
543
544/// This proc macro generates a function `order_to_variant_map` which returns a
545/// map of the position of each variant to the name of the variant.
546/// It is intended to catch changes in enum order when backward compat is
547/// required.
548/// ```rust,ignore
549///    /// Example for this enum
550///    #[derive(EnumVariantOrder)]
551///    pub enum MyEnum {
552///         A,
553///         B(u64),
554///         C{x: bool, y: i8},
555///     }
556///     let order_map = MyEnum::order_to_variant_map();
557///     assert!(order_map.get(0).unwrap() == "A");
558///     assert!(order_map.get(1).unwrap() == "B");
559///     assert!(order_map.get(2).unwrap() == "C");
560/// ```
561#[proc_macro_derive(EnumVariantOrder)]
562pub fn enum_variant_order_derive(input: TokenStream) -> TokenStream {
563    let ast = parse_macro_input!(input as DeriveInput);
564    let name = &ast.ident;
565
566    if let Data::Enum(DataEnum { variants, .. }) = ast.data {
567        let variant_entries = variants
568            .iter()
569            .enumerate()
570            .map(|(index, variant)| {
571                let variant_name = variant.ident.to_string();
572                quote! {
573                    map.insert( #index as u64, (#variant_name).to_string());
574                }
575            })
576            .collect::<Vec<_>>();
577
578        let deriv = quote! {
579            impl iota_enum_compat_util::EnumOrderMap for #name {
580                fn order_to_variant_map() -> std::collections::BTreeMap<u64, String > {
581                    let mut map = std::collections::BTreeMap::new();
582                    #(#variant_entries)*
583                    map
584                }
585            }
586        };
587
588        deriv.into()
589    } else {
590        panic!("EnumVariantOrder can only be used with enums.");
591    }
592}