iota_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, future::Future, sync::Arc};
6
7use futures::future::BoxFuture;
8pub use iota_proc_macros::*;
9
10/// Evaluates an expression in a new thread which will not be subject to
11/// interception of getrandom(), clock_gettime(), etc.
12#[cfg(msim)]
13#[macro_export]
14macro_rules! nondeterministic {
15    ($expr: expr) => {
16        std::thread::scope(move |s| s.spawn(move || $expr).join().unwrap())
17    };
18}
19
20/// Simply evaluates expr.
21#[cfg(not(msim))]
22#[macro_export]
23macro_rules! nondeterministic {
24    ($expr: expr) => {
25        $expr
26    };
27}
28
29type FpCallback = dyn Fn() -> Box<dyn std::any::Any + Send + 'static> + Send + Sync;
30type FpMap = HashMap<&'static str, Arc<FpCallback>>;
31
32#[cfg(msim)]
33fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
34    thread_local! {
35        static MAP: std::cell::RefCell<FpMap> = Default::default();
36    }
37
38    MAP.with(|val| func(&mut val.borrow_mut()))
39}
40
41#[cfg(not(msim))]
42fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
43    use std::sync::Mutex;
44
45    use once_cell::sync::Lazy;
46
47    static MAP: Lazy<Mutex<FpMap>> = Lazy::new(Default::default);
48    let mut map = MAP.lock().unwrap();
49    func(&mut map)
50}
51
52fn get_callback(identifier: &'static str) -> Option<Arc<FpCallback>> {
53    with_fp_map(|map| map.get(identifier).cloned())
54}
55
56fn get_sync_fp_result(result: Box<dyn std::any::Any + Send + 'static>) {
57    if result.downcast::<()>().is_err() {
58        panic!("sync failpoint must return ()");
59    }
60}
61
62fn get_async_fp_result(result: Box<dyn std::any::Any + Send + 'static>) -> BoxFuture<'static, ()> {
63    match result.downcast::<BoxFuture<'static, ()>>() {
64        Ok(fut) => *fut,
65        Err(err) => panic!(
66            "async failpoint must return BoxFuture<'static, ()> {:?}",
67            err
68        ),
69    }
70}
71
72fn get_fp_if_result(result: Box<dyn std::any::Any + Send + 'static>) -> bool {
73    match result.downcast::<bool>() {
74        Ok(b) => *b,
75        Err(_) => panic!("failpoint-if must return bool"),
76    }
77}
78
79fn get_fp_some_result<T: Send + 'static>(
80    result: Box<dyn std::any::Any + Send + 'static>,
81) -> Option<T> {
82    match result.downcast::<Option<T>>() {
83        Ok(opt) => *opt,
84        Err(_) => panic!("failpoint-arg must return Option<T>"),
85    }
86}
87
88pub fn handle_fail_point(identifier: &'static str) {
89    if let Some(callback) = get_callback(identifier) {
90        get_sync_fp_result(callback());
91        tracing::trace!("hit failpoint {}", identifier);
92    }
93}
94
95pub async fn handle_fail_point_async(identifier: &'static str) {
96    if let Some(callback) = get_callback(identifier) {
97        tracing::trace!("hit async failpoint {}", identifier);
98        let fut = get_async_fp_result(callback());
99        fut.await;
100    }
101}
102
103pub fn handle_fail_point_if(identifier: &'static str) -> bool {
104    if let Some(callback) = get_callback(identifier) {
105        tracing::trace!("hit failpoint_if {}", identifier);
106        get_fp_if_result(callback())
107    } else {
108        false
109    }
110}
111
112pub fn handle_fail_point_arg<T: Send + 'static>(identifier: &'static str) -> Option<T> {
113    if let Some(callback) = get_callback(identifier) {
114        tracing::trace!("hit failpoint_arg {}", identifier);
115        get_fp_some_result(callback())
116    } else {
117        None
118    }
119}
120
121fn register_fail_point_impl(identifier: &'static str, callback: Arc<FpCallback>) {
122    with_fp_map(move |map| {
123        assert!(
124            map.insert(identifier, callback).is_none(),
125            "duplicate fail point registration"
126        );
127    })
128}
129
130fn clear_fail_point_impl(identifier: &'static str) {
131    with_fp_map(move |map| {
132        assert!(
133            map.remove(identifier).is_some(),
134            "fail point {:?} does not exist",
135            identifier
136        );
137    })
138}
139
140pub fn register_fail_point(identifier: &'static str, callback: impl Fn() + Sync + Send + 'static) {
141    register_fail_point_impl(
142        identifier,
143        Arc::new(move || {
144            callback();
145            Box::new(())
146        }),
147    );
148}
149
150/// Register an asynchronous fail point. Because it is async it can yield
151/// execution of the calling task, e.g. by sleeping.
152pub fn register_fail_point_async<F>(
153    identifier: &'static str,
154    callback: impl Fn() -> F + Sync + Send + 'static,
155) where
156    F: Future<Output = ()> + Send + 'static,
157{
158    register_fail_point_impl(
159        identifier,
160        Arc::new(move || {
161            let result: BoxFuture<'static, ()> = Box::pin(callback());
162            Box::new(result)
163        }),
164    );
165}
166
167/// Register code to run locally if the fail point is hit. Example:
168///
169/// In the test:
170///
171/// ```ignore
172///     register_fail_point_if("foo", || {
173///         iota_simulator::current_simnode_id() == 2
174///     });
175/// ```
176///
177/// In the code:
178///
179/// ```ignore
180///     let mut was_hit = false;
181///     fail_point_if("foo", || {
182///        was_hit = true;
183///     });
184/// ```
185pub fn register_fail_point_if(
186    identifier: &'static str,
187    callback: impl Fn() -> bool + Sync + Send + 'static,
188) {
189    register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
190}
191
192/// Register code to run locally if the fail point is hit, with a value provided
193/// by the test. If the registered callback returns a Some(v), then the `v` is
194/// passed to the callback in the test.
195///
196/// In the test:
197///
198/// ```ignore
199///     register_fail_point_arg("foo", || {
200///         Some(42)
201///     });
202/// ```
203///
204/// In the code:
205///
206/// ```ignore
207///     let mut value = 0;
208///     fail_point_arg!("foo", |arg| {
209///        value = arg;
210///     });
211/// ```
212pub fn register_fail_point_arg<T: Send + 'static>(
213    identifier: &'static str,
214    callback: impl Fn() -> Option<T> + Sync + Send + 'static,
215) {
216    register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
217}
218
219pub fn register_fail_points(
220    identifiers: &[&'static str],
221    callback: impl Fn() + Sync + Send + 'static,
222) {
223    let cb: Arc<FpCallback> = Arc::new(move || {
224        callback();
225        Box::new(())
226    });
227    for id in identifiers {
228        register_fail_point_impl(id, cb.clone());
229    }
230}
231
232pub fn clear_fail_point(identifier: &'static str) {
233    clear_fail_point_impl(identifier);
234}
235
236/// Trigger a fail point. Tests can trigger various behavior when the fail point
237/// is hit.
238#[cfg(any(msim, fail_points))]
239#[macro_export]
240macro_rules! fail_point {
241    ($tag: expr) => {
242        $crate::handle_fail_point($tag)
243    };
244}
245
246/// Trigger an async fail point. Tests can trigger various async behavior when
247/// the fail point is hit.
248#[cfg(any(msim, fail_points))]
249#[macro_export]
250macro_rules! fail_point_async {
251    ($tag: expr) => {
252        $crate::handle_fail_point_async($tag).await
253    };
254}
255
256/// Trigger a failpoint that runs a callback at the callsite if it is enabled.
257/// (whether it is enabled is controlled by whether the registration callback
258/// returns true/false).
259#[cfg(any(msim, fail_points))]
260#[macro_export]
261macro_rules! fail_point_if {
262    ($tag: expr, $callback: expr) => {
263        if $crate::handle_fail_point_if($tag) {
264            ($callback)();
265        }
266    };
267}
268
269/// Trigger a failpoint that runs a callback at the callsite if it is enabled.
270/// If the registration callback returns Some(v), then the `v` is passed to the
271/// callback in the test. Otherwise the failpoint is skipped
272#[cfg(any(msim, fail_points))]
273#[macro_export]
274macro_rules! fail_point_arg {
275    ($tag: expr, $callback: expr) => {
276        if let Some(arg) = $crate::handle_fail_point_arg($tag) {
277            ($callback)(arg);
278        }
279    };
280}
281
282#[cfg(not(any(msim, fail_points)))]
283#[macro_export]
284macro_rules! fail_point {
285    ($tag: expr) => {};
286}
287
288#[cfg(not(any(msim, fail_points)))]
289#[macro_export]
290macro_rules! fail_point_async {
291    ($tag: expr) => {};
292}
293
294#[cfg(not(any(msim, fail_points)))]
295#[macro_export]
296macro_rules! fail_point_if {
297    ($tag: expr, $callback: expr) => {};
298}
299
300#[cfg(not(any(msim, fail_points)))]
301#[macro_export]
302macro_rules! fail_point_arg {
303    ($tag: expr, $callback: expr) => {};
304}
305
306/// Use to write INFO level logs only when REPLAY_LOG
307/// environment variable is set. Useful for log lines that
308/// are only relevant to test infra which still may need to
309/// run a release build. Also note that since logs of a chain
310/// replay are exceedingly verbose, this will allow one to bubble
311/// up "debug level" info while running with RUST_LOG=info.
312#[macro_export]
313macro_rules! replay_log {
314    ($($arg:tt)+) => {
315        if std::env::var("REPLAY_LOG").is_ok() {
316            tracing::info!($($arg)+);
317        }
318    };
319}
320
321// These tests need to be run in release mode, since debug mode does overflow
322// checks by default!
323#[cfg(test)]
324mod test {
325    use super::*;
326
327    // Uncomment to test error messages
328    // #[with_checked_arithmetic]
329    // struct TestStruct;
330
331    macro_rules! pass_through {
332        ($($tt:tt)*) => {
333            $($tt)*
334        }
335    }
336
337    #[with_checked_arithmetic]
338    #[test]
339    fn test_skip_checked_arithmetic() {
340        // comment out this attr to test the error message
341        #[skip_checked_arithmetic]
342        pass_through! {
343            fn unchecked_add(a: i32, b: i32) -> i32 {
344                a + b
345            }
346        }
347
348        // this will not panic even if we pass in (i32::MAX, 1), because we skipped
349        // processing the item macro, so we also need to make sure it doesn't
350        // panic in debug mode.
351        unchecked_add(1, 2);
352    }
353
354    checked_arithmetic! {
355
356    struct Test {
357        a: i32,
358        b: i32,
359    }
360
361    fn unchecked_add(a: i32, b: i32) -> i32 {
362        a + b
363    }
364
365    #[test]
366    fn test_checked_arithmetic_macro() {
367        unchecked_add(1, 2);
368    }
369
370    #[test]
371    #[should_panic]
372    fn test_checked_arithmetic_macro_panic() {
373        unchecked_add(i32::MAX, 1);
374    }
375
376    fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
377        let inner = |a: i32, b: i32| a + b;
378        inner(a, b)
379    }
380
381    #[test]
382    #[should_panic]
383    fn test_checked_arithmetic_macro_panic_hidden() {
384        unchecked_add_hidden(i32::MAX, 1);
385    }
386
387    fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
388        fn inner(a: i32, b: i32) -> i32 {
389            a + b
390        }
391        inner(a, b)
392    }
393
394    #[test]
395    #[should_panic]
396    fn test_checked_arithmetic_macro_panic_hidden_2() {
397        unchecked_add_hidden_2(i32::MAX, 1);
398    }
399
400    impl Test {
401        fn add(&self) -> i32 {
402            self.a + self.b
403        }
404    }
405
406    #[test]
407    #[should_panic]
408    fn test_checked_arithmetic_impl() {
409        let t = Test { a: 1, b: i32::MAX };
410        t.add();
411    }
412
413    #[test]
414    #[should_panic]
415    fn test_macro_overflow() {
416        fn f() {
417            println!("{}", i32::MAX + 1);
418        }
419
420        f()
421    }
422
423    // Make sure that we still do addition correctly!
424    #[test]
425    fn test_non_overflow() {
426        fn f() {
427            assert_eq!(1i32 + 2i32, 3i32);
428            assert_eq!(3i32 - 1i32, 2i32);
429            assert_eq!(4i32 * 3i32, 12i32);
430            assert_eq!(12i32 / 3i32, 4i32);
431            assert_eq!(12i32 % 5i32, 2i32);
432
433            let mut a = 1i32;
434            a += 2i32;
435            assert_eq!(a, 3i32);
436
437            let mut a = 3i32;
438            a -= 1i32;
439            assert_eq!(a, 2i32);
440
441            let mut a = 4i32;
442            a *= 3i32;
443            assert_eq!(a, 12i32);
444
445            let mut a = 12i32;
446            a /= 3i32;
447            assert_eq!(a, 4i32);
448
449            let mut a = 12i32;
450            a %= 5i32;
451            assert_eq!(a, 2i32);
452        }
453
454        f();
455    }
456
457
458    #[test]
459    fn test_exprs_evaluated_once_right() {
460        let mut called = false;
461        let mut f = || {
462            if called {
463                panic!("called twice");
464            }
465            called = true;
466            1i32
467        };
468
469        assert_eq!(2i32 + f(), 3);
470    }
471
472    #[test]
473    fn test_exprs_evaluated_once_left() {
474        let mut called = false;
475        let mut f = || {
476            if called {
477                panic!("called twice");
478            }
479            called = true;
480            1i32
481        };
482
483        assert_eq!(f() + 2i32, 3);
484    }
485
486    #[test]
487    fn test_assign_op_evals_once() {
488        struct Foo {
489            a: i32,
490            called: bool,
491        }
492
493        impl Foo {
494            fn get_a_mut(&mut self) -> &mut i32 {
495                if self.called {
496                    panic!("called twice");
497                }
498                let ret = &mut self.a;
499                self.called = true;
500                ret
501            }
502        }
503
504        let mut foo = Foo { a: 1, called: false };
505
506        *foo.get_a_mut() += 2;
507        assert_eq!(foo.a, 3);
508    }
509
510    #[test]
511    fn test_more_macro_syntax() {
512        struct Foo {
513            a: i32,
514            b: i32,
515        }
516
517        impl Foo {
518            const BAR: i32 = 1;
519
520            fn new(a: i32, b: i32) -> Foo {
521                Foo { a, b }
522            }
523        }
524
525        fn new_foo(a: i32) -> Foo {
526            Foo { a, b: 0 }
527        }
528
529        // verify that we translate the contents of macros correctly
530        assert_eq!(Foo::BAR + 1, 2);
531        assert_eq!(Foo::new(1, 2).b, 2);
532        assert_eq!(new_foo(1).a, 1);
533
534        let v = [Foo::new(1, 2), Foo::new(3, 2)];
535
536        assert_eq!(v[0].a, 1);
537        assert_eq!(v[1].b, 2);
538    }
539
540    }
541
542    #[with_checked_arithmetic]
543    mod with_checked_arithmetic_tests {
544
545        struct Test {
546            a: i32,
547            b: i32,
548        }
549
550        fn unchecked_add(a: i32, b: i32) -> i32 {
551            a + b
552        }
553
554        #[test]
555        fn test_checked_arithmetic_macro() {
556            unchecked_add(1, 2);
557        }
558
559        #[test]
560        #[should_panic]
561        fn test_checked_arithmetic_macro_panic() {
562            unchecked_add(i32::MAX, 1);
563        }
564
565        fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
566            let inner = |a: i32, b: i32| a + b;
567            inner(a, b)
568        }
569
570        #[test]
571        #[should_panic]
572        fn test_checked_arithmetic_macro_panic_hidden() {
573            unchecked_add_hidden(i32::MAX, 1);
574        }
575
576        fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
577            fn inner(a: i32, b: i32) -> i32 {
578                a + b
579            }
580            inner(a, b)
581        }
582
583        #[test]
584        #[should_panic]
585        fn test_checked_arithmetic_macro_panic_hidden_2() {
586            unchecked_add_hidden_2(i32::MAX, 1);
587        }
588
589        impl Test {
590            fn add(&self) -> i32 {
591                self.a + self.b
592            }
593        }
594
595        #[test]
596        #[should_panic]
597        fn test_checked_arithmetic_impl() {
598            let t = Test { a: 1, b: i32::MAX };
599            t.add();
600        }
601
602        #[test]
603        #[should_panic]
604        fn test_macro_overflow() {
605            fn f() {
606                println!("{}", i32::MAX + 1);
607            }
608
609            f()
610        }
611
612        // Make sure that we still do addition correctly!
613        #[test]
614        fn test_non_overflow() {
615            fn f() {
616                assert_eq!(1i32 + 2i32, 3i32);
617                assert_eq!(3i32 - 1i32, 2i32);
618                assert_eq!(4i32 * 3i32, 12i32);
619                assert_eq!(12i32 / 3i32, 4i32);
620                assert_eq!(12i32 % 5i32, 2i32);
621
622                let mut a = 1i32;
623                a += 2i32;
624                assert_eq!(a, 3i32);
625
626                let mut a = 3i32;
627                a -= 1i32;
628                assert_eq!(a, 2i32);
629
630                let mut a = 4i32;
631                a *= 3i32;
632                assert_eq!(a, 12i32);
633
634                let mut a = 12i32;
635                a /= 3i32;
636                assert_eq!(a, 4i32);
637
638                let mut a = 12i32;
639                a %= 5i32;
640                assert_eq!(a, 2i32);
641            }
642
643            f();
644        }
645
646        #[test]
647        fn test_exprs_evaluated_once_right() {
648            let mut called = false;
649            let mut f = || {
650                if called {
651                    panic!("called twice");
652                }
653                called = true;
654                1i32
655            };
656
657            assert_eq!(2i32 + f(), 3);
658        }
659
660        #[test]
661        fn test_exprs_evaluated_once_left() {
662            let mut called = false;
663            let mut f = || {
664                if called {
665                    panic!("called twice");
666                }
667                called = true;
668                1i32
669            };
670
671            assert_eq!(f() + 2i32, 3);
672        }
673
674        #[test]
675        fn test_assign_op_evals_once() {
676            struct Foo {
677                a: i32,
678                called: bool,
679            }
680
681            impl Foo {
682                fn get_a_mut(&mut self) -> &mut i32 {
683                    if self.called {
684                        panic!("called twice");
685                    }
686                    let ret = &mut self.a;
687                    self.called = true;
688                    ret
689                }
690            }
691
692            let mut foo = Foo {
693                a: 1,
694                called: false,
695            };
696
697            *foo.get_a_mut() += 2;
698            assert_eq!(foo.a, 3);
699        }
700
701        #[test]
702        fn test_more_macro_syntax() {
703            struct Foo {
704                a: i32,
705                b: i32,
706            }
707
708            impl Foo {
709                const BAR: i32 = 1;
710
711                fn new(a: i32, b: i32) -> Foo {
712                    Foo { a, b }
713                }
714            }
715
716            fn new_foo(a: i32) -> Foo {
717                Foo { a, b: 0 }
718            }
719
720            // verify that we translate the contents of macros correctly
721            assert_eq!(Foo::BAR + 1, 2);
722            assert_eq!(Foo::new(1, 2).b, 2);
723            assert_eq!(new_foo(1).a, 1);
724
725            let v = [Foo::new(1, 2), Foo::new(3, 2)];
726
727            assert_eq!(v[0].a, 1);
728            assert_eq!(v[1].b, 2);
729        }
730    }
731}