1use std::{collections::HashMap, future::Future, sync::Arc};
6
7use futures::future::BoxFuture;
8pub use iota_proc_macros::*;
9
10#[cfg(msim)]
13#[macro_export]
14macro_rules! nondeterministic {
15 ($expr: expr) => {
16 std::thread::scope(move |s| s.spawn(move || $expr).join().unwrap())
17 };
18}
19
20#[cfg(not(msim))]
22#[macro_export]
23macro_rules! nondeterministic {
24 ($expr: expr) => {
25 $expr
26 };
27}
28
29type FpCallback = dyn Fn() -> Box<dyn std::any::Any + Send + 'static> + Send + Sync;
30type FpMap = HashMap<&'static str, Arc<FpCallback>>;
31
32#[cfg(msim)]
33fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
34 thread_local! {
35 static MAP: std::cell::RefCell<FpMap> = Default::default();
36 }
37
38 MAP.with(|val| func(&mut val.borrow_mut()))
39}
40
41#[cfg(not(msim))]
42fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
43 use std::sync::Mutex;
44
45 use once_cell::sync::Lazy;
46
47 static MAP: Lazy<Mutex<FpMap>> = Lazy::new(Default::default);
48 let mut map = MAP.lock().unwrap();
49 func(&mut map)
50}
51
52fn get_callback(identifier: &'static str) -> Option<Arc<FpCallback>> {
53 with_fp_map(|map| map.get(identifier).cloned())
54}
55
56fn get_sync_fp_result(result: Box<dyn std::any::Any + Send + 'static>) {
57 if result.downcast::<()>().is_err() {
58 panic!("sync failpoint must return ()");
59 }
60}
61
62fn get_async_fp_result(result: Box<dyn std::any::Any + Send + 'static>) -> BoxFuture<'static, ()> {
63 match result.downcast::<BoxFuture<'static, ()>>() {
64 Ok(fut) => *fut,
65 Err(err) => panic!(
66 "async failpoint must return BoxFuture<'static, ()> {:?}",
67 err
68 ),
69 }
70}
71
72fn get_fp_if_result(result: Box<dyn std::any::Any + Send + 'static>) -> bool {
73 match result.downcast::<bool>() {
74 Ok(b) => *b,
75 Err(_) => panic!("failpoint-if must return bool"),
76 }
77}
78
79fn get_fp_some_result<T: Send + 'static>(
80 result: Box<dyn std::any::Any + Send + 'static>,
81) -> Option<T> {
82 match result.downcast::<Option<T>>() {
83 Ok(opt) => *opt,
84 Err(_) => panic!("failpoint-arg must return Option<T>"),
85 }
86}
87
88pub fn handle_fail_point(identifier: &'static str) {
89 if let Some(callback) = get_callback(identifier) {
90 get_sync_fp_result(callback());
91 tracing::trace!("hit failpoint {}", identifier);
92 }
93}
94
95pub async fn handle_fail_point_async(identifier: &'static str) {
96 if let Some(callback) = get_callback(identifier) {
97 tracing::trace!("hit async failpoint {}", identifier);
98 let fut = get_async_fp_result(callback());
99 fut.await;
100 }
101}
102
103pub fn handle_fail_point_if(identifier: &'static str) -> bool {
104 if let Some(callback) = get_callback(identifier) {
105 tracing::trace!("hit failpoint_if {}", identifier);
106 get_fp_if_result(callback())
107 } else {
108 false
109 }
110}
111
112pub fn handle_fail_point_arg<T: Send + 'static>(identifier: &'static str) -> Option<T> {
113 if let Some(callback) = get_callback(identifier) {
114 tracing::trace!("hit failpoint_arg {}", identifier);
115 get_fp_some_result(callback())
116 } else {
117 None
118 }
119}
120
121fn register_fail_point_impl(identifier: &'static str, callback: Arc<FpCallback>) {
122 with_fp_map(move |map| {
123 assert!(
124 map.insert(identifier, callback).is_none(),
125 "duplicate fail point registration"
126 );
127 })
128}
129
130fn clear_fail_point_impl(identifier: &'static str) {
131 with_fp_map(move |map| {
132 assert!(
133 map.remove(identifier).is_some(),
134 "fail point {:?} does not exist",
135 identifier
136 );
137 })
138}
139
140pub fn register_fail_point(identifier: &'static str, callback: impl Fn() + Sync + Send + 'static) {
141 register_fail_point_impl(
142 identifier,
143 Arc::new(move || {
144 callback();
145 Box::new(())
146 }),
147 );
148}
149
150pub fn register_fail_point_async<F>(
153 identifier: &'static str,
154 callback: impl Fn() -> F + Sync + Send + 'static,
155) where
156 F: Future<Output = ()> + Send + 'static,
157{
158 register_fail_point_impl(
159 identifier,
160 Arc::new(move || {
161 let result: BoxFuture<'static, ()> = Box::pin(callback());
162 Box::new(result)
163 }),
164 );
165}
166
167pub fn register_fail_point_if(
186 identifier: &'static str,
187 callback: impl Fn() -> bool + Sync + Send + 'static,
188) {
189 register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
190}
191
192pub fn register_fail_point_arg<T: Send + 'static>(
213 identifier: &'static str,
214 callback: impl Fn() -> Option<T> + Sync + Send + 'static,
215) {
216 register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
217}
218
219pub fn register_fail_points(
220 identifiers: &[&'static str],
221 callback: impl Fn() + Sync + Send + 'static,
222) {
223 let cb: Arc<FpCallback> = Arc::new(move || {
224 callback();
225 Box::new(())
226 });
227 for id in identifiers {
228 register_fail_point_impl(id, cb.clone());
229 }
230}
231
232pub fn clear_fail_point(identifier: &'static str) {
233 clear_fail_point_impl(identifier);
234}
235
236#[cfg(any(msim, fail_points))]
239#[macro_export]
240macro_rules! fail_point {
241 ($tag: expr) => {
242 $crate::handle_fail_point($tag)
243 };
244}
245
246#[cfg(any(msim, fail_points))]
249#[macro_export]
250macro_rules! fail_point_async {
251 ($tag: expr) => {
252 $crate::handle_fail_point_async($tag).await
253 };
254}
255
256#[cfg(any(msim, fail_points))]
260#[macro_export]
261macro_rules! fail_point_if {
262 ($tag: expr, $callback: expr) => {
263 if $crate::handle_fail_point_if($tag) {
264 ($callback)();
265 }
266 };
267}
268
269#[cfg(any(msim, fail_points))]
273#[macro_export]
274macro_rules! fail_point_arg {
275 ($tag: expr, $callback: expr) => {
276 if let Some(arg) = $crate::handle_fail_point_arg($tag) {
277 ($callback)(arg);
278 }
279 };
280}
281
282#[cfg(not(any(msim, fail_points)))]
283#[macro_export]
284macro_rules! fail_point {
285 ($tag: expr) => {};
286}
287
288#[cfg(not(any(msim, fail_points)))]
289#[macro_export]
290macro_rules! fail_point_async {
291 ($tag: expr) => {};
292}
293
294#[cfg(not(any(msim, fail_points)))]
295#[macro_export]
296macro_rules! fail_point_if {
297 ($tag: expr, $callback: expr) => {};
298}
299
300#[cfg(not(any(msim, fail_points)))]
301#[macro_export]
302macro_rules! fail_point_arg {
303 ($tag: expr, $callback: expr) => {};
304}
305
306#[macro_export]
313macro_rules! replay_log {
314 ($($arg:tt)+) => {
315 if std::env::var("REPLAY_LOG").is_ok() {
316 tracing::info!($($arg)+);
317 }
318 };
319}
320
321#[cfg(test)]
324mod test {
325 use super::*;
326
327 macro_rules! pass_through {
332 ($($tt:tt)*) => {
333 $($tt)*
334 }
335 }
336
337 #[with_checked_arithmetic]
338 #[test]
339 fn test_skip_checked_arithmetic() {
340 #[skip_checked_arithmetic]
342 pass_through! {
343 fn unchecked_add(a: i32, b: i32) -> i32 {
344 a + b
345 }
346 }
347
348 unchecked_add(1, 2);
352 }
353
354 checked_arithmetic! {
355
356 struct Test {
357 a: i32,
358 b: i32,
359 }
360
361 fn unchecked_add(a: i32, b: i32) -> i32 {
362 a + b
363 }
364
365 #[test]
366 fn test_checked_arithmetic_macro() {
367 unchecked_add(1, 2);
368 }
369
370 #[test]
371 #[should_panic]
372 fn test_checked_arithmetic_macro_panic() {
373 unchecked_add(i32::MAX, 1);
374 }
375
376 fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
377 let inner = |a: i32, b: i32| a + b;
378 inner(a, b)
379 }
380
381 #[test]
382 #[should_panic]
383 fn test_checked_arithmetic_macro_panic_hidden() {
384 unchecked_add_hidden(i32::MAX, 1);
385 }
386
387 fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
388 fn inner(a: i32, b: i32) -> i32 {
389 a + b
390 }
391 inner(a, b)
392 }
393
394 #[test]
395 #[should_panic]
396 fn test_checked_arithmetic_macro_panic_hidden_2() {
397 unchecked_add_hidden_2(i32::MAX, 1);
398 }
399
400 impl Test {
401 fn add(&self) -> i32 {
402 self.a + self.b
403 }
404 }
405
406 #[test]
407 #[should_panic]
408 fn test_checked_arithmetic_impl() {
409 let t = Test { a: 1, b: i32::MAX };
410 t.add();
411 }
412
413 #[test]
414 #[should_panic]
415 fn test_macro_overflow() {
416 fn f() {
417 println!("{}", i32::MAX + 1);
418 }
419
420 f()
421 }
422
423 #[test]
425 fn test_non_overflow() {
426 fn f() {
427 assert_eq!(1i32 + 2i32, 3i32);
428 assert_eq!(3i32 - 1i32, 2i32);
429 assert_eq!(4i32 * 3i32, 12i32);
430 assert_eq!(12i32 / 3i32, 4i32);
431 assert_eq!(12i32 % 5i32, 2i32);
432
433 let mut a = 1i32;
434 a += 2i32;
435 assert_eq!(a, 3i32);
436
437 let mut a = 3i32;
438 a -= 1i32;
439 assert_eq!(a, 2i32);
440
441 let mut a = 4i32;
442 a *= 3i32;
443 assert_eq!(a, 12i32);
444
445 let mut a = 12i32;
446 a /= 3i32;
447 assert_eq!(a, 4i32);
448
449 let mut a = 12i32;
450 a %= 5i32;
451 assert_eq!(a, 2i32);
452 }
453
454 f();
455 }
456
457
458 #[test]
459 fn test_exprs_evaluated_once_right() {
460 let mut called = false;
461 let mut f = || {
462 if called {
463 panic!("called twice");
464 }
465 called = true;
466 1i32
467 };
468
469 assert_eq!(2i32 + f(), 3);
470 }
471
472 #[test]
473 fn test_exprs_evaluated_once_left() {
474 let mut called = false;
475 let mut f = || {
476 if called {
477 panic!("called twice");
478 }
479 called = true;
480 1i32
481 };
482
483 assert_eq!(f() + 2i32, 3);
484 }
485
486 #[test]
487 fn test_assign_op_evals_once() {
488 struct Foo {
489 a: i32,
490 called: bool,
491 }
492
493 impl Foo {
494 fn get_a_mut(&mut self) -> &mut i32 {
495 if self.called {
496 panic!("called twice");
497 }
498 let ret = &mut self.a;
499 self.called = true;
500 ret
501 }
502 }
503
504 let mut foo = Foo { a: 1, called: false };
505
506 *foo.get_a_mut() += 2;
507 assert_eq!(foo.a, 3);
508 }
509
510 #[test]
511 fn test_more_macro_syntax() {
512 struct Foo {
513 a: i32,
514 b: i32,
515 }
516
517 impl Foo {
518 const BAR: i32 = 1;
519
520 fn new(a: i32, b: i32) -> Foo {
521 Foo { a, b }
522 }
523 }
524
525 fn new_foo(a: i32) -> Foo {
526 Foo { a, b: 0 }
527 }
528
529 assert_eq!(Foo::BAR + 1, 2);
531 assert_eq!(Foo::new(1, 2).b, 2);
532 assert_eq!(new_foo(1).a, 1);
533
534 let v = [Foo::new(1, 2), Foo::new(3, 2)];
535
536 assert_eq!(v[0].a, 1);
537 assert_eq!(v[1].b, 2);
538 }
539
540 }
541
542 #[with_checked_arithmetic]
543 mod with_checked_arithmetic_tests {
544
545 struct Test {
546 a: i32,
547 b: i32,
548 }
549
550 fn unchecked_add(a: i32, b: i32) -> i32 {
551 a + b
552 }
553
554 #[test]
555 fn test_checked_arithmetic_macro() {
556 unchecked_add(1, 2);
557 }
558
559 #[test]
560 #[should_panic]
561 fn test_checked_arithmetic_macro_panic() {
562 unchecked_add(i32::MAX, 1);
563 }
564
565 fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
566 let inner = |a: i32, b: i32| a + b;
567 inner(a, b)
568 }
569
570 #[test]
571 #[should_panic]
572 fn test_checked_arithmetic_macro_panic_hidden() {
573 unchecked_add_hidden(i32::MAX, 1);
574 }
575
576 fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
577 fn inner(a: i32, b: i32) -> i32 {
578 a + b
579 }
580 inner(a, b)
581 }
582
583 #[test]
584 #[should_panic]
585 fn test_checked_arithmetic_macro_panic_hidden_2() {
586 unchecked_add_hidden_2(i32::MAX, 1);
587 }
588
589 impl Test {
590 fn add(&self) -> i32 {
591 self.a + self.b
592 }
593 }
594
595 #[test]
596 #[should_panic]
597 fn test_checked_arithmetic_impl() {
598 let t = Test { a: 1, b: i32::MAX };
599 t.add();
600 }
601
602 #[test]
603 #[should_panic]
604 fn test_macro_overflow() {
605 fn f() {
606 println!("{}", i32::MAX + 1);
607 }
608
609 f()
610 }
611
612 #[test]
614 fn test_non_overflow() {
615 fn f() {
616 assert_eq!(1i32 + 2i32, 3i32);
617 assert_eq!(3i32 - 1i32, 2i32);
618 assert_eq!(4i32 * 3i32, 12i32);
619 assert_eq!(12i32 / 3i32, 4i32);
620 assert_eq!(12i32 % 5i32, 2i32);
621
622 let mut a = 1i32;
623 a += 2i32;
624 assert_eq!(a, 3i32);
625
626 let mut a = 3i32;
627 a -= 1i32;
628 assert_eq!(a, 2i32);
629
630 let mut a = 4i32;
631 a *= 3i32;
632 assert_eq!(a, 12i32);
633
634 let mut a = 12i32;
635 a /= 3i32;
636 assert_eq!(a, 4i32);
637
638 let mut a = 12i32;
639 a %= 5i32;
640 assert_eq!(a, 2i32);
641 }
642
643 f();
644 }
645
646 #[test]
647 fn test_exprs_evaluated_once_right() {
648 let mut called = false;
649 let mut f = || {
650 if called {
651 panic!("called twice");
652 }
653 called = true;
654 1i32
655 };
656
657 assert_eq!(2i32 + f(), 3);
658 }
659
660 #[test]
661 fn test_exprs_evaluated_once_left() {
662 let mut called = false;
663 let mut f = || {
664 if called {
665 panic!("called twice");
666 }
667 called = true;
668 1i32
669 };
670
671 assert_eq!(f() + 2i32, 3);
672 }
673
674 #[test]
675 fn test_assign_op_evals_once() {
676 struct Foo {
677 a: i32,
678 called: bool,
679 }
680
681 impl Foo {
682 fn get_a_mut(&mut self) -> &mut i32 {
683 if self.called {
684 panic!("called twice");
685 }
686 let ret = &mut self.a;
687 self.called = true;
688 ret
689 }
690 }
691
692 let mut foo = Foo {
693 a: 1,
694 called: false,
695 };
696
697 *foo.get_a_mut() += 2;
698 assert_eq!(foo.a, 3);
699 }
700
701 #[test]
702 fn test_more_macro_syntax() {
703 struct Foo {
704 a: i32,
705 b: i32,
706 }
707
708 impl Foo {
709 const BAR: i32 = 1;
710
711 fn new(a: i32, b: i32) -> Foo {
712 Foo { a, b }
713 }
714 }
715
716 fn new_foo(a: i32) -> Foo {
717 Foo { a, b: 0 }
718 }
719
720 assert_eq!(Foo::BAR + 1, 2);
722 assert_eq!(Foo::new(1, 2).b, 2);
723 assert_eq!(new_foo(1).a, 1);
724
725 let v = [Foo::new(1, 2), Foo::new(3, 2)];
726
727 assert_eq!(v[0].a, 1);
728 assert_eq!(v[1].b, 2);
729 }
730 }
731}