1use std::{collections::HashMap, future::Future, sync::Arc};
6
7use futures::future::BoxFuture;
8pub use iota_proc_macros::*;
9
10#[cfg(msim)]
13#[macro_export]
14macro_rules! nondeterministic {
15 ($expr: expr) => {
16 std::thread::scope(move |s| s.spawn(move || $expr).join().unwrap())
17 };
18}
19
20#[cfg(not(msim))]
22#[macro_export]
23macro_rules! nondeterministic {
24 ($expr: expr) => {
25 $expr
26 };
27}
28
29type FpCallback = dyn Fn() -> Box<dyn std::any::Any + Send + 'static> + Send + Sync;
30type FpMap = HashMap<&'static str, Arc<FpCallback>>;
31
32#[cfg(msim)]
33fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
34 thread_local! {
35 static MAP: std::cell::RefCell<FpMap> = Default::default();
36 }
37
38 MAP.with(|val| func(&mut val.borrow_mut()))
39}
40
41#[cfg(not(msim))]
42fn with_fp_map<T>(func: impl FnOnce(&mut FpMap) -> T) -> T {
43 use std::sync::Mutex;
44
45 use once_cell::sync::Lazy;
46
47 static MAP: Lazy<Mutex<FpMap>> = Lazy::new(Default::default);
48 let mut map = MAP.lock().unwrap();
49 func(&mut map)
50}
51
52fn get_callback(identifier: &'static str) -> Option<Arc<FpCallback>> {
53 with_fp_map(|map| map.get(identifier).cloned())
54}
55
56fn get_sync_fp_result(result: Box<dyn std::any::Any + Send + 'static>) {
57 if result.downcast::<()>().is_err() {
58 panic!("sync failpoint must return ()");
59 }
60}
61
62fn get_async_fp_result(result: Box<dyn std::any::Any + Send + 'static>) -> BoxFuture<'static, ()> {
63 match result.downcast::<BoxFuture<'static, ()>>() {
64 Ok(fut) => *fut,
65 Err(err) => panic!("async failpoint must return BoxFuture<'static, ()> {err:?}"),
66 }
67}
68
69fn get_fp_if_result(result: Box<dyn std::any::Any + Send + 'static>) -> bool {
70 match result.downcast::<bool>() {
71 Ok(b) => *b,
72 Err(_) => panic!("failpoint-if must return bool"),
73 }
74}
75
76fn get_fp_some_result<T: Send + 'static>(
77 result: Box<dyn std::any::Any + Send + 'static>,
78) -> Option<T> {
79 match result.downcast::<Option<T>>() {
80 Ok(opt) => *opt,
81 Err(_) => panic!("failpoint-arg must return Option<T>"),
82 }
83}
84
85pub fn handle_fail_point(identifier: &'static str) {
86 if let Some(callback) = get_callback(identifier) {
87 get_sync_fp_result(callback());
88 tracing::trace!("hit failpoint {}", identifier);
89 }
90}
91
92pub async fn handle_fail_point_async(identifier: &'static str) {
93 if let Some(callback) = get_callback(identifier) {
94 tracing::trace!("hit async failpoint {}", identifier);
95 let fut = get_async_fp_result(callback());
96 fut.await;
97 }
98}
99
100pub fn handle_fail_point_if(identifier: &'static str) -> bool {
101 if let Some(callback) = get_callback(identifier) {
102 tracing::trace!("hit failpoint_if {}", identifier);
103 get_fp_if_result(callback())
104 } else {
105 false
106 }
107}
108
109pub fn handle_fail_point_arg<T: Send + 'static>(identifier: &'static str) -> Option<T> {
110 if let Some(callback) = get_callback(identifier) {
111 tracing::trace!("hit failpoint_arg {}", identifier);
112 get_fp_some_result(callback())
113 } else {
114 None
115 }
116}
117
118fn register_fail_point_impl(identifier: &'static str, callback: Arc<FpCallback>) {
119 with_fp_map(move |map| {
120 assert!(
121 map.insert(identifier, callback).is_none(),
122 "duplicate fail point registration"
123 );
124 })
125}
126
127fn clear_fail_point_impl(identifier: &'static str) {
128 with_fp_map(move |map| {
129 assert!(
130 map.remove(identifier).is_some(),
131 "fail point {identifier:?} does not exist"
132 );
133 })
134}
135
136pub fn register_fail_point(identifier: &'static str, callback: impl Fn() + Sync + Send + 'static) {
137 register_fail_point_impl(
138 identifier,
139 Arc::new(move || {
140 callback();
141 Box::new(())
142 }),
143 );
144}
145
146pub fn register_fail_point_async<F>(
149 identifier: &'static str,
150 callback: impl Fn() -> F + Sync + Send + 'static,
151) where
152 F: Future<Output = ()> + Send + 'static,
153{
154 register_fail_point_impl(
155 identifier,
156 Arc::new(move || {
157 let result: BoxFuture<'static, ()> = Box::pin(callback());
158 Box::new(result)
159 }),
160 );
161}
162
163pub fn register_fail_point_if(
182 identifier: &'static str,
183 callback: impl Fn() -> bool + Sync + Send + 'static,
184) {
185 register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
186}
187
188pub fn register_fail_point_arg<T: Send + 'static>(
209 identifier: &'static str,
210 callback: impl Fn() -> Option<T> + Sync + Send + 'static,
211) {
212 register_fail_point_impl(identifier, Arc::new(move || Box::new(callback())));
213}
214
215pub fn register_fail_points(
216 identifiers: &[&'static str],
217 callback: impl Fn() + Sync + Send + 'static,
218) {
219 let cb: Arc<FpCallback> = Arc::new(move || {
220 callback();
221 Box::new(())
222 });
223 for id in identifiers {
224 register_fail_point_impl(id, cb.clone());
225 }
226}
227
228pub fn clear_fail_point(identifier: &'static str) {
229 clear_fail_point_impl(identifier);
230}
231
232#[cfg(any(msim, fail_points))]
235#[macro_export]
236macro_rules! fail_point {
237 ($tag: expr) => {
238 $crate::handle_fail_point($tag)
239 };
240}
241
242#[cfg(any(msim, fail_points))]
245#[macro_export]
246macro_rules! fail_point_async {
247 ($tag: expr) => {
248 $crate::handle_fail_point_async($tag).await
249 };
250}
251
252#[cfg(any(msim, fail_points))]
256#[macro_export]
257macro_rules! fail_point_if {
258 ($tag: expr, $callback: expr) => {
259 if $crate::handle_fail_point_if($tag) {
260 ($callback)();
261 }
262 };
263}
264
265#[cfg(any(msim, fail_points))]
269#[macro_export]
270macro_rules! fail_point_arg {
271 ($tag: expr, $callback: expr) => {
272 if let Some(arg) = $crate::handle_fail_point_arg($tag) {
273 ($callback)(arg);
274 }
275 };
276}
277
278#[cfg(not(any(msim, fail_points)))]
279#[macro_export]
280macro_rules! fail_point {
281 ($tag: expr) => {};
282}
283
284#[cfg(not(any(msim, fail_points)))]
285#[macro_export]
286macro_rules! fail_point_async {
287 ($tag: expr) => {};
288}
289
290#[cfg(not(any(msim, fail_points)))]
291#[macro_export]
292macro_rules! fail_point_if {
293 ($tag: expr, $callback: expr) => {};
294}
295
296#[cfg(not(any(msim, fail_points)))]
297#[macro_export]
298macro_rules! fail_point_arg {
299 ($tag: expr, $callback: expr) => {};
300}
301
302#[macro_export]
309macro_rules! replay_log {
310 ($($arg:tt)+) => {
311 if std::env::var("REPLAY_LOG").is_ok() {
312 tracing::info!($($arg)+);
313 }
314 };
315}
316
317#[cfg(test)]
320mod test {
321 use super::*;
322
323 macro_rules! pass_through {
328 ($($tt:tt)*) => {
329 $($tt)*
330 }
331 }
332
333 #[with_checked_arithmetic]
334 #[test]
335 fn test_skip_checked_arithmetic() {
336 #[skip_checked_arithmetic]
338 pass_through! {
339 fn unchecked_add(a: i32, b: i32) -> i32 {
340 a + b
341 }
342 }
343
344 unchecked_add(1, 2);
348 }
349
350 checked_arithmetic! {
351
352 struct Test {
353 a: i32,
354 b: i32,
355 }
356
357 fn unchecked_add(a: i32, b: i32) -> i32 {
358 a + b
359 }
360
361 #[test]
362 fn test_checked_arithmetic_macro() {
363 unchecked_add(1, 2);
364 }
365
366 #[test]
367 #[should_panic]
368 fn test_checked_arithmetic_macro_panic() {
369 unchecked_add(i32::MAX, 1);
370 }
371
372 fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
373 let inner = |a: i32, b: i32| a + b;
374 inner(a, b)
375 }
376
377 #[test]
378 #[should_panic]
379 fn test_checked_arithmetic_macro_panic_hidden() {
380 unchecked_add_hidden(i32::MAX, 1);
381 }
382
383 fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
384 fn inner(a: i32, b: i32) -> i32 {
385 a + b
386 }
387 inner(a, b)
388 }
389
390 #[test]
391 #[should_panic]
392 fn test_checked_arithmetic_macro_panic_hidden_2() {
393 unchecked_add_hidden_2(i32::MAX, 1);
394 }
395
396 impl Test {
397 fn add(&self) -> i32 {
398 self.a + self.b
399 }
400 }
401
402 #[test]
403 #[should_panic]
404 fn test_checked_arithmetic_impl() {
405 let t = Test { a: 1, b: i32::MAX };
406 t.add();
407 }
408
409 #[test]
410 #[should_panic]
411 fn test_macro_overflow() {
412 fn f() {
413 println!("{}", i32::MAX + 1);
414 }
415
416 f()
417 }
418
419 #[test]
421 fn test_non_overflow() {
422 fn f() {
423 assert_eq!(1i32 + 2i32, 3i32);
424 assert_eq!(3i32 - 1i32, 2i32);
425 assert_eq!(4i32 * 3i32, 12i32);
426 assert_eq!(12i32 / 3i32, 4i32);
427 assert_eq!(12i32 % 5i32, 2i32);
428
429 let mut a = 1i32;
430 a += 2i32;
431 assert_eq!(a, 3i32);
432
433 let mut a = 3i32;
434 a -= 1i32;
435 assert_eq!(a, 2i32);
436
437 let mut a = 4i32;
438 a *= 3i32;
439 assert_eq!(a, 12i32);
440
441 let mut a = 12i32;
442 a /= 3i32;
443 assert_eq!(a, 4i32);
444
445 let mut a = 12i32;
446 a %= 5i32;
447 assert_eq!(a, 2i32);
448 }
449
450 f();
451 }
452
453
454 #[test]
455 fn test_exprs_evaluated_once_right() {
456 let mut called = false;
457 let mut f = || {
458 if called {
459 panic!("called twice");
460 }
461 called = true;
462 1i32
463 };
464
465 assert_eq!(2i32 + f(), 3);
466 }
467
468 #[test]
469 fn test_exprs_evaluated_once_left() {
470 let mut called = false;
471 let mut f = || {
472 if called {
473 panic!("called twice");
474 }
475 called = true;
476 1i32
477 };
478
479 assert_eq!(f() + 2i32, 3);
480 }
481
482 #[test]
483 fn test_assign_op_evals_once() {
484 struct Foo {
485 a: i32,
486 called: bool,
487 }
488
489 impl Foo {
490 fn get_a_mut(&mut self) -> &mut i32 {
491 if self.called {
492 panic!("called twice");
493 }
494 let ret = &mut self.a;
495 self.called = true;
496 ret
497 }
498 }
499
500 let mut foo = Foo { a: 1, called: false };
501
502 *foo.get_a_mut() += 2;
503 assert_eq!(foo.a, 3);
504 }
505
506 #[test]
507 fn test_more_macro_syntax() {
508 struct Foo {
509 a: i32,
510 b: i32,
511 }
512
513 impl Foo {
514 const BAR: i32 = 1;
515
516 fn new(a: i32, b: i32) -> Foo {
517 Foo { a, b }
518 }
519 }
520
521 fn new_foo(a: i32) -> Foo {
522 Foo { a, b: 0 }
523 }
524
525 assert_eq!(Foo::BAR + 1, 2);
527 assert_eq!(Foo::new(1, 2).b, 2);
528 assert_eq!(new_foo(1).a, 1);
529
530 let v = [Foo::new(1, 2), Foo::new(3, 2)];
531
532 assert_eq!(v[0].a, 1);
533 assert_eq!(v[1].b, 2);
534 }
535
536 }
537
538 #[with_checked_arithmetic]
539 mod with_checked_arithmetic_tests {
540
541 struct Test {
542 a: i32,
543 b: i32,
544 }
545
546 fn unchecked_add(a: i32, b: i32) -> i32 {
547 a + b
548 }
549
550 #[test]
551 fn test_checked_arithmetic_macro() {
552 unchecked_add(1, 2);
553 }
554
555 #[test]
556 #[should_panic]
557 fn test_checked_arithmetic_macro_panic() {
558 unchecked_add(i32::MAX, 1);
559 }
560
561 fn unchecked_add_hidden(a: i32, b: i32) -> i32 {
562 let inner = |a: i32, b: i32| a + b;
563 inner(a, b)
564 }
565
566 #[test]
567 #[should_panic]
568 fn test_checked_arithmetic_macro_panic_hidden() {
569 unchecked_add_hidden(i32::MAX, 1);
570 }
571
572 fn unchecked_add_hidden_2(a: i32, b: i32) -> i32 {
573 fn inner(a: i32, b: i32) -> i32 {
574 a + b
575 }
576 inner(a, b)
577 }
578
579 #[test]
580 #[should_panic]
581 fn test_checked_arithmetic_macro_panic_hidden_2() {
582 unchecked_add_hidden_2(i32::MAX, 1);
583 }
584
585 impl Test {
586 fn add(&self) -> i32 {
587 self.a + self.b
588 }
589 }
590
591 #[test]
592 #[should_panic]
593 fn test_checked_arithmetic_impl() {
594 let t = Test { a: 1, b: i32::MAX };
595 t.add();
596 }
597
598 #[test]
599 #[should_panic]
600 fn test_macro_overflow() {
601 fn f() {
602 println!("{}", i32::MAX + 1);
603 }
604
605 f()
606 }
607
608 #[test]
610 fn test_non_overflow() {
611 fn f() {
612 assert_eq!(1i32 + 2i32, 3i32);
613 assert_eq!(3i32 - 1i32, 2i32);
614 assert_eq!(4i32 * 3i32, 12i32);
615 assert_eq!(12i32 / 3i32, 4i32);
616 assert_eq!(12i32 % 5i32, 2i32);
617
618 let mut a = 1i32;
619 a += 2i32;
620 assert_eq!(a, 3i32);
621
622 let mut a = 3i32;
623 a -= 1i32;
624 assert_eq!(a, 2i32);
625
626 let mut a = 4i32;
627 a *= 3i32;
628 assert_eq!(a, 12i32);
629
630 let mut a = 12i32;
631 a /= 3i32;
632 assert_eq!(a, 4i32);
633
634 let mut a = 12i32;
635 a %= 5i32;
636 assert_eq!(a, 2i32);
637 }
638
639 f();
640 }
641
642 #[test]
643 fn test_exprs_evaluated_once_right() {
644 let mut called = false;
645 let mut f = || {
646 if called {
647 panic!("called twice");
648 }
649 called = true;
650 1i32
651 };
652
653 assert_eq!(2i32 + f(), 3);
654 }
655
656 #[test]
657 fn test_exprs_evaluated_once_left() {
658 let mut called = false;
659 let mut f = || {
660 if called {
661 panic!("called twice");
662 }
663 called = true;
664 1i32
665 };
666
667 assert_eq!(f() + 2i32, 3);
668 }
669
670 #[test]
671 fn test_assign_op_evals_once() {
672 struct Foo {
673 a: i32,
674 called: bool,
675 }
676
677 impl Foo {
678 fn get_a_mut(&mut self) -> &mut i32 {
679 if self.called {
680 panic!("called twice");
681 }
682 let ret = &mut self.a;
683 self.called = true;
684 ret
685 }
686 }
687
688 let mut foo = Foo {
689 a: 1,
690 called: false,
691 };
692
693 *foo.get_a_mut() += 2;
694 assert_eq!(foo.a, 3);
695 }
696
697 #[test]
698 fn test_more_macro_syntax() {
699 struct Foo {
700 a: i32,
701 b: i32,
702 }
703
704 impl Foo {
705 const BAR: i32 = 1;
706
707 fn new(a: i32, b: i32) -> Foo {
708 Foo { a, b }
709 }
710 }
711
712 fn new_foo(a: i32) -> Foo {
713 Foo { a, b: 0 }
714 }
715
716 assert_eq!(Foo::BAR + 1, 2);
718 assert_eq!(Foo::new(1, 2).b, 2);
719 assert_eq!(new_foo(1).a, 1);
720
721 let v = [Foo::new(1, 2), Foo::new(3, 2)];
722
723 assert_eq!(v[0].a, 1);
724 assert_eq!(v[1].b, 2);
725 }
726 }
727}