iota_macros/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, future::Future, sync::Arc};
6
7use futures::future::BoxFuture;
8pub use iota_proc_macros::*;
9
10/// Evaluates an expression in a new thread which will not be subject to
11/// interception of getrandom(), clock_gettime(), etc.
12#[cfg(msim)]
13#[macro_export]
14macro_rules! nondeterministic {
15    ($expr: expr) => {
16        std::thread::scope(move |s| s.spawn(move || $expr).join().unwrap())
17    };
18}
19
20/// Simply evaluates expr.
21#[cfg(not(msim))]
22#[macro_export]
23macro_rules! nondeterministic {
24    ($expr: expr) => {
25        $expr
26    };
27}
28
29type FpCallback = dyn Fn() -> Box<dyn std::any::Any + Send + 'static> + Send + Sync;
30type FpMap = HashMap<&'static str, Arc<FpCallback>>;
31
32#[cfg(msim)]
33fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
34    thread_local! {
35        static MAP: std::cell::RefCell<FpMap> = Default::default();
36    }
37
38    MAP.with(|val| func(&mut val.borrow_mut()))
39}
40
41#[cfg(not(msim))]
42fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
43    use std::sync::Mutex;
44
45    use once_cell::sync::Lazy;
46
47    static MAP: Lazy<Mutex<FpMap>> = Lazy::new(Default::default);
48    let mut map = MAP.lock().unwrap();
49    func(&mut map)
50}
51
52fn get_callback(identifier: &'static str) -> Option<Arc<FpCallback>> {
53    with_fp_map(|map| map.get(identifier).cloned())
54}
55
56fn get_sync_fp_result(result: Box<dyn std::any::Any + Send + 'static>) {
57    if result.downcast::<()>().is_err() {
58        panic!("sync failpoint must return ()");
59    }
60}
61
62fn get_async_fp_result(result: Box<dyn std::any::Any + Send + 'static>) -> BoxFuture<'static, ()> {
63    match result.downcast::<BoxFuture<'static, ()>>() {
64        Ok(fut) => *fut,
65        Err(err) => panic!("async failpoint must return BoxFuture<'static, ()> {err:?}"),
66    }
67}
68
69fn get_fp_if_result(result: Box<dyn std::any::Any + Send + 'static>) -> bool {
70    match result.downcast::<bool>() {
71        Ok(b) => *b,
72        Err(_) => panic!("failpoint-if must return bool"),
73    }
74}
75
76fn get_fp_some_result<T: Send + 'static>(
77    result: Box<dyn std::any::Any + Send + 'static>,
78) -> Option<T> {
79    match result.downcast::<Option<T>>() {
80        Ok(opt) => *opt,
81        Err(_) => panic!("failpoint-arg must return Option<T>"),
82    }
83}
84
85pub fn handle_fail_point(identifier: &'static str) {
86    if let Some(callback) = get_callback(identifier) {
87        get_sync_fp_result(callback());
88        tracing::trace!("hit failpoint {}", identifier);
89    }
90}
91
92pub async fn handle_fail_point_async(identifier: &'static str) {
93    if let Some(callback) = get_callback(identifier) {
94        tracing::trace!("hit async failpoint {}", identifier);
95        let fut = get_async_fp_result(callback());
96        fut.await;
97    }
98}
99
100pub fn handle_fail_point_if(identifier: &'static str) -> bool {
101    if let Some(callback) = get_callback(identifier) {
102        tracing::trace!("hit failpoint_if {}", identifier);
103        get_fp_if_result(callback())
104    } else {
105        false
106    }
107}
108
109pub fn handle_fail_point_arg<T: Send + 'static>(identifier: &'static str) -> Option<T> {
110    if let Some(callback) = get_callback(identifier) {
111        tracing::trace!("hit failpoint_arg {}", identifier);
112        get_fp_some_result(callback())
113    } else {
114        None
115    }
116}
117
118fn register_fail_point_impl(identifier: &'static str, callback: Arc<FpCallback>) {
119    with_fp_map(move |map| {
120        assert!(
121            map.insert(identifier, callback).is_none(),
122            "duplicate fail point registration"
123        );
124    })
125}
126
127fn clear_fail_point_impl(identifier: &'static str) {
128    with_fp_map(move |map| {
129        assert!(
130            map.remove(identifier).is_some(),
131            "fail point {identifier:?} does not exist"
132        );
133    })
134}
135
136pub fn register_fail_point(identifier: &'static str, callback: impl Fn() + Sync + Send + 'static) {
137    register_fail_point_impl(
138        identifier,
139        Arc::new(move || {
140            callback();
141            Box::new(())
142        }),
143    );
144}
145
146/// Register an asynchronous fail point. Because it is async it can yield
147/// execution of the calling task, e.g. by sleeping.
148pub fn register_fail_point_async<F>(
149    identifier: &'static str,
150    callback: impl Fn() -> F + Sync + Send + 'static,
151) where
152    F: Future<Output = ()> + Send + 'static,
153{
154    register_fail_point_impl(
155        identifier,
156        Arc::new(move || {
157            let result: BoxFuture<'static, ()> = Box::pin(callback());
158            Box::new(result)
159        }),
160    );
161}
162
163/// Register code to run locally if the fail point is hit. Example:
164///
165/// In the test:
166///
167/// ```ignore
168///     register_fail_point_if("foo", || {
169///         iota_simulator::current_simnode_id() == 2
170///     });
171/// ```
172///
173/// In the code:
174///
175/// ```ignore
176///     let mut was_hit = false;
177///     fail_point_if("foo", || {
178///        was_hit = true;
179///     });
180/// ```
181pub fn register_fail_point_if(
182    identifier: &'static str,
183    callback: impl Fn() -> bool + Sync + Send + 'static,
184) {
185    register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
186}
187
188/// Register code to run locally if the fail point is hit, with a value provided
189/// by the test. If the registered callback returns a Some(v), then the `v` is
190/// passed to the callback in the test.
191///
192/// In the test:
193///
194/// ```ignore
195///     register_fail_point_arg("foo", || {
196///         Some(42)
197///     });
198/// ```
199///
200/// In the code:
201///
202/// ```ignore
203///     let mut value = 0;
204///     fail_point_arg!("foo", |arg| {
205///        value = arg;
206///     });
207/// ```
208pub fn register_fail_point_arg<T: Send + 'static>(
209    identifier: &'static str,
210    callback: impl Fn() -> Option<T> + Sync + Send + 'static,
211) {
212    register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
213}
214
215pub fn register_fail_points(
216    identifiers: &[&'static str],
217    callback: impl Fn() + Sync + Send + 'static,
218) {
219    let cb: Arc<FpCallback> = Arc::new(move || {
220        callback();
221        Box::new(())
222    });
223    for id in identifiers {
224        register_fail_point_impl(id, cb.clone());
225    }
226}
227
228pub fn clear_fail_point(identifier: &'static str) {
229    clear_fail_point_impl(identifier);
230}
231
232/// Trigger a fail point. Tests can trigger various behavior when the fail point
233/// is hit.
234#[cfg(any(msim, fail_points))]
235#[macro_export]
236macro_rules! fail_point {
237    ($tag: expr) => {
238        $crate::handle_fail_point($tag)
239    };
240}
241
242/// Trigger an async fail point. Tests can trigger various async behavior when
243/// the fail point is hit.
244#[cfg(any(msim, fail_points))]
245#[macro_export]
246macro_rules! fail_point_async {
247    ($tag: expr) => {
248        $crate::handle_fail_point_async($tag).await
249    };
250}
251
252/// Trigger a failpoint that runs a callback at the callsite if it is enabled.
253/// (whether it is enabled is controlled by whether the registration callback
254/// returns true/false).
255#[cfg(any(msim, fail_points))]
256#[macro_export]
257macro_rules! fail_point_if {
258    ($tag: expr, $callback: expr) => {
259        if $crate::handle_fail_point_if($tag) {
260            ($callback)();
261        }
262    };
263}
264
265/// Trigger a failpoint that runs a callback at the callsite if it is enabled.
266/// If the registration callback returns Some(v), then the `v` is passed to the
267/// callback in the test. Otherwise the failpoint is skipped
268#[cfg(any(msim, fail_points))]
269#[macro_export]
270macro_rules! fail_point_arg {
271    ($tag: expr, $callback: expr) => {
272        if let Some(arg) = $crate::handle_fail_point_arg($tag) {
273            ($callback)(arg);
274        }
275    };
276}
277
278#[cfg(not(any(msim, fail_points)))]
279#[macro_export]
280macro_rules! fail_point {
281    ($tag: expr) => {};
282}
283
284#[cfg(not(any(msim, fail_points)))]
285#[macro_export]
286macro_rules! fail_point_async {
287    ($tag: expr) => {};
288}
289
290#[cfg(not(any(msim, fail_points)))]
291#[macro_export]
292macro_rules! fail_point_if {
293    ($tag: expr, $callback: expr) => {};
294}
295
296#[cfg(not(any(msim, fail_points)))]
297#[macro_export]
298macro_rules! fail_point_arg {
299    ($tag: expr, $callback: expr) => {};
300}
301
302/// Use to write INFO level logs only when REPLAY_LOG
303/// environment variable is set. Useful for log lines that
304/// are only relevant to test infra which still may need to
305/// run a release build. Also note that since logs of a chain
306/// replay are exceedingly verbose, this will allow one to bubble
307/// up "debug level" info while running with RUST_LOG=info.
308#[macro_export]
309macro_rules! replay_log {
310    ($($arg:tt)+) => {
311        if std::env::var("REPLAY_LOG").is_ok() {
312            tracing::info!($($arg)+);
313        }
314    };
315}
316
317// These tests need to be run in release mode, since debug mode does overflow
318// checks by default!
319#[cfg(test)]
320mod test {
321    use super::*;
322
323    // Uncomment to test error messages
324    // #[with_checked_arithmetic]
325    // struct TestStruct;
326
327    macro_rules! pass_through {
328        ($($tt:tt)*) => {
329            $($tt)*
330        }
331    }
332
333    #[with_checked_arithmetic]
334    #[test]
335    fn test_skip_checked_arithmetic() {
336        // comment out this attr to test the error message
337        #[skip_checked_arithmetic]
338        pass_through! {
339            fn unchecked_add(a: i32, b: i32) -> i32 {
340                a + b
341            }
342        }
343
344        // this will not panic even if we pass in (i32::MAX, 1), because we skipped
345        // processing the item macro, so we also need to make sure it doesn't
346        // panic in debug mode.
347        unchecked_add(1, 2);
348    }
349
350    checked_arithmetic! {
351
352    struct Test {
353        a: i32,
354        b: i32,
355    }
356
357    fn unchecked_add(a: i32, b: i32) -> i32 {
358        a + b
359    }
360
361    #[test]
362    fn test_checked_arithmetic_macro() {
363        unchecked_add(1, 2);
364    }
365
366    #[test]
367    #[should_panic]
368    fn test_checked_arithmetic_macro_panic() {
369        unchecked_add(i32::MAX, 1);
370    }
371
372    fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
373        let inner = |a: i32, b: i32| a + b;
374        inner(a, b)
375    }
376
377    #[test]
378    #[should_panic]
379    fn test_checked_arithmetic_macro_panic_hidden() {
380        unchecked_add_hidden(i32::MAX, 1);
381    }
382
383    fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
384        fn inner(a: i32, b: i32) -> i32 {
385            a + b
386        }
387        inner(a, b)
388    }
389
390    #[test]
391    #[should_panic]
392    fn test_checked_arithmetic_macro_panic_hidden_2() {
393        unchecked_add_hidden_2(i32::MAX, 1);
394    }
395
396    impl Test {
397        fn add(&self) -> i32 {
398            self.a + self.b
399        }
400    }
401
402    #[test]
403    #[should_panic]
404    fn test_checked_arithmetic_impl() {
405        let t = Test { a: 1, b: i32::MAX };
406        t.add();
407    }
408
409    #[test]
410    #[should_panic]
411    fn test_macro_overflow() {
412        fn f() {
413            println!("{}", i32::MAX + 1);
414        }
415
416        f()
417    }
418
419    // Make sure that we still do addition correctly!
420    #[test]
421    fn test_non_overflow() {
422        fn f() {
423            assert_eq!(1i32 + 2i32, 3i32);
424            assert_eq!(3i32 - 1i32, 2i32);
425            assert_eq!(4i32 * 3i32, 12i32);
426            assert_eq!(12i32 / 3i32, 4i32);
427            assert_eq!(12i32 % 5i32, 2i32);
428
429            let mut a = 1i32;
430            a += 2i32;
431            assert_eq!(a, 3i32);
432
433            let mut a = 3i32;
434            a -= 1i32;
435            assert_eq!(a, 2i32);
436
437            let mut a = 4i32;
438            a *= 3i32;
439            assert_eq!(a, 12i32);
440
441            let mut a = 12i32;
442            a /= 3i32;
443            assert_eq!(a, 4i32);
444
445            let mut a = 12i32;
446            a %= 5i32;
447            assert_eq!(a, 2i32);
448        }
449
450        f();
451    }
452
453
454    #[test]
455    fn test_exprs_evaluated_once_right() {
456        let mut called = false;
457        let mut f = || {
458            if called {
459                panic!("called twice");
460            }
461            called = true;
462            1i32
463        };
464
465        assert_eq!(2i32 + f(), 3);
466    }
467
468    #[test]
469    fn test_exprs_evaluated_once_left() {
470        let mut called = false;
471        let mut f = || {
472            if called {
473                panic!("called twice");
474            }
475            called = true;
476            1i32
477        };
478
479        assert_eq!(f() + 2i32, 3);
480    }
481
482    #[test]
483    fn test_assign_op_evals_once() {
484        struct Foo {
485            a: i32,
486            called: bool,
487        }
488
489        impl Foo {
490            fn get_a_mut(&mut self) -> &mut i32 {
491                if self.called {
492                    panic!("called twice");
493                }
494                let ret = &mut self.a;
495                self.called = true;
496                ret
497            }
498        }
499
500        let mut foo = Foo { a: 1, called: false };
501
502        *foo.get_a_mut() += 2;
503        assert_eq!(foo.a, 3);
504    }
505
506    #[test]
507    fn test_more_macro_syntax() {
508        struct Foo {
509            a: i32,
510            b: i32,
511        }
512
513        impl Foo {
514            const BAR: i32 = 1;
515
516            fn new(a: i32, b: i32) -> Foo {
517                Foo { a, b }
518            }
519        }
520
521        fn new_foo(a: i32) -> Foo {
522            Foo { a, b: 0 }
523        }
524
525        // verify that we translate the contents of macros correctly
526        assert_eq!(Foo::BAR + 1, 2);
527        assert_eq!(Foo::new(1, 2).b, 2);
528        assert_eq!(new_foo(1).a, 1);
529
530        let v = [Foo::new(1, 2), Foo::new(3, 2)];
531
532        assert_eq!(v[0].a, 1);
533        assert_eq!(v[1].b, 2);
534    }
535
536    }
537
538    #[with_checked_arithmetic]
539    mod with_checked_arithmetic_tests {
540
541        struct Test {
542            a: i32,
543            b: i32,
544        }
545
546        fn unchecked_add(a: i32, b: i32) -> i32 {
547            a + b
548        }
549
550        #[test]
551        fn test_checked_arithmetic_macro() {
552            unchecked_add(1, 2);
553        }
554
555        #[test]
556        #[should_panic]
557        fn test_checked_arithmetic_macro_panic() {
558            unchecked_add(i32::MAX, 1);
559        }
560
561        fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
562            let inner = |a: i32, b: i32| a + b;
563            inner(a, b)
564        }
565
566        #[test]
567        #[should_panic]
568        fn test_checked_arithmetic_macro_panic_hidden() {
569            unchecked_add_hidden(i32::MAX, 1);
570        }
571
572        fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
573            fn inner(a: i32, b: i32) -> i32 {
574                a + b
575            }
576            inner(a, b)
577        }
578
579        #[test]
580        #[should_panic]
581        fn test_checked_arithmetic_macro_panic_hidden_2() {
582            unchecked_add_hidden_2(i32::MAX, 1);
583        }
584
585        impl Test {
586            fn add(&self) -> i32 {
587                self.a + self.b
588            }
589        }
590
591        #[test]
592        #[should_panic]
593        fn test_checked_arithmetic_impl() {
594            let t = Test { a: 1, b: i32::MAX };
595            t.add();
596        }
597
598        #[test]
599        #[should_panic]
600        fn test_macro_overflow() {
601            fn f() {
602                println!("{}", i32::MAX + 1);
603            }
604
605            f()
606        }
607
608        // Make sure that we still do addition correctly!
609        #[test]
610        fn test_non_overflow() {
611            fn f() {
612                assert_eq!(1i32 + 2i32, 3i32);
613                assert_eq!(3i32 - 1i32, 2i32);
614                assert_eq!(4i32 * 3i32, 12i32);
615                assert_eq!(12i32 / 3i32, 4i32);
616                assert_eq!(12i32 % 5i32, 2i32);
617
618                let mut a = 1i32;
619                a += 2i32;
620                assert_eq!(a, 3i32);
621
622                let mut a = 3i32;
623                a -= 1i32;
624                assert_eq!(a, 2i32);
625
626                let mut a = 4i32;
627                a *= 3i32;
628                assert_eq!(a, 12i32);
629
630                let mut a = 12i32;
631                a /= 3i32;
632                assert_eq!(a, 4i32);
633
634                let mut a = 12i32;
635                a %= 5i32;
636                assert_eq!(a, 2i32);
637            }
638
639            f();
640        }
641
642        #[test]
643        fn test_exprs_evaluated_once_right() {
644            let mut called = false;
645            let mut f = || {
646                if called {
647                    panic!("called twice");
648                }
649                called = true;
650                1i32
651            };
652
653            assert_eq!(2i32 + f(), 3);
654        }
655
656        #[test]
657        fn test_exprs_evaluated_once_left() {
658            let mut called = false;
659            let mut f = || {
660                if called {
661                    panic!("called twice");
662                }
663                called = true;
664                1i32
665            };
666
667            assert_eq!(f() + 2i32, 3);
668        }
669
670        #[test]
671        fn test_assign_op_evals_once() {
672            struct Foo {
673                a: i32,
674                called: bool,
675            }
676
677            impl Foo {
678                fn get_a_mut(&mut self) -> &mut i32 {
679                    if self.called {
680                        panic!("called twice");
681                    }
682                    let ret = &mut self.a;
683                    self.called = true;
684                    ret
685                }
686            }
687
688            let mut foo = Foo {
689                a: 1,
690                called: false,
691            };
692
693            *foo.get_a_mut() += 2;
694            assert_eq!(foo.a, 3);
695        }
696
697        #[test]
698        fn test_more_macro_syntax() {
699            struct Foo {
700                a: i32,
701                b: i32,
702            }
703
704            impl Foo {
705                const BAR: i32 = 1;
706
707                fn new(a: i32, b: i32) -> Foo {
708                    Foo { a, b }
709                }
710            }
711
712            fn new_foo(a: i32) -> Foo {
713                Foo { a, b: 0 }
714            }
715
716            // verify that we translate the contents of macros correctly
717            assert_eq!(Foo::BAR + 1, 2);
718            assert_eq!(Foo::new(1, 2).b, 2);
719            assert_eq!(new_foo(1).a, 1);
720
721            let v = [Foo::new(1, 2), Foo::new(3, 2)];
722
723            assert_eq!(v[0].a, 1);
724            assert_eq!(v[1].b, 2);
725        }
726    }
727}