1pub mod metrics;
6pub mod nodefw_client;
7pub mod nodefw_test_server;
8pub mod policies;
9
10use std::{
11 fmt::Debug,
12 fs,
13 net::{IpAddr, Ipv4Addr},
14 ops::Add,
15 sync::Arc,
16 time::{Duration, Instant, SystemTime},
17};
18
19use dashmap::DashMap;
20use fs::File;
21use iota_metrics::spawn_monitored_task;
22use iota_types::traffic_control::{PolicyConfig, RemoteFirewallConfig, Weight};
23use prometheus::IntGauge;
24use rand::Rng;
25use tokio::{
26 sync::{mpsc, mpsc::error::TrySendError},
27 time,
28};
29use tracing::{debug, error, info, trace, warn};
30
31use self::metrics::TrafficControllerMetrics;
32use crate::traffic_controller::{
33 nodefw_client::{BlockAddress, BlockAddresses, NodeFWClient},
34 policies::{Policy, PolicyResponse, TrafficControlPolicy, TrafficTally},
35};
36
37pub const METRICS_INTERVAL_SECS: u64 = 2;
38pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 300;
39
40type Blocklist = Arc<DashMap<IpAddr, SystemTime>>;
41
42#[derive(Clone)]
43struct Blocklists {
44 clients: Blocklist,
45 proxied_clients: Blocklist,
46}
47
48#[derive(Clone)]
49pub struct TrafficController {
50 tally_channel: mpsc::Sender<TrafficTally>,
51 blocklists: Blocklists,
52 metrics: Arc<TrafficControllerMetrics>,
53 dry_run_mode: bool,
54}
55
56impl Debug for TrafficController {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("TrafficController")
64 .field(
65 "connection_ip_blocklist_len",
66 &self.metrics.connection_ip_blocklist_len.get(),
67 )
68 .field(
69 "proxy_ip_blocklist_len",
70 &self.metrics.proxy_ip_blocklist_len.get(),
71 )
72 .finish()
73 }
74}
75
76impl TrafficController {
77 pub fn spawn(
78 policy_config: PolicyConfig,
79 metrics: TrafficControllerMetrics,
80 fw_config: Option<RemoteFirewallConfig>,
81 ) -> Self {
82 let metrics = Arc::new(metrics);
83 let (tx, rx) = mpsc::channel(policy_config.channel_capacity);
84 let mem_drainfile_present = fw_config
89 .as_ref()
90 .map(|config| config.drain_path.exists())
91 .unwrap_or(false);
92 metrics
93 .deadmans_switch_enabled
94 .set(mem_drainfile_present as i64);
95
96 let ret = Self {
97 tally_channel: tx,
98 blocklists: Blocklists {
99 clients: Arc::new(DashMap::new()),
100 proxied_clients: Arc::new(DashMap::new()),
101 },
102 metrics: metrics.clone(),
103 dry_run_mode: policy_config.dry_run,
104 };
105 let tally_loop_blocklists = ret.blocklists.clone();
106 let clear_loop_blocklists = ret.blocklists.clone();
107 let tally_loop_metrics = metrics.clone();
108 let clear_loop_metrics = metrics.clone();
109 spawn_monitored_task!(run_tally_loop(
110 rx,
111 policy_config,
112 fw_config,
113 tally_loop_blocklists,
114 tally_loop_metrics,
115 mem_drainfile_present,
116 ));
117 spawn_monitored_task!(run_clear_blocklists_loop(
118 clear_loop_blocklists,
119 clear_loop_metrics,
120 ));
121 ret
122 }
123
124 pub fn spawn_for_test(
125 policy_config: PolicyConfig,
126 fw_config: Option<RemoteFirewallConfig>,
127 ) -> Self {
128 let metrics = TrafficControllerMetrics::new(&prometheus::Registry::new());
129 Self::spawn(policy_config, metrics, fw_config)
130 }
131
132 pub fn tally(&self, tally: TrafficTally) {
133 match self.tally_channel.try_send(tally) {
139 Err(TrySendError::Full(_)) => {
140 warn!("TrafficController tally channel full, dropping tally");
141 self.metrics.tally_channel_overflow.inc();
142 }
146 Err(TrySendError::Closed(_)) => {
147 panic!("TrafficController tally channel closed unexpectedly");
148 }
149 Ok(_) => {}
150 }
151 }
152
153 pub async fn check(&self, client: &Option<IpAddr>, proxied_client: &Option<IpAddr>) -> bool {
155 match (
156 self.check_impl(client, proxied_client).await,
157 self.dry_run_mode(),
158 ) {
159 (true, _) => true,
161 (false, true) => {
163 debug!(
164 "Dry run mode: Blocked request from client {:?}, proxied client: {:?}",
165 client, proxied_client
166 );
167 self.metrics.num_dry_run_blocked_requests.inc();
168 true
169 }
170 (false, false) => false,
172 }
173 }
174
175 pub async fn check_impl(
177 &self,
178 client: &Option<IpAddr>,
179 proxied_client: &Option<IpAddr>,
180 ) -> bool {
181 let client_check = self.check_and_clear_blocklist(
182 client,
183 self.blocklists.clients.clone(),
184 &self.metrics.connection_ip_blocklist_len,
185 );
186 let proxied_client_check = self.check_and_clear_blocklist(
187 proxied_client,
188 self.blocklists.proxied_clients.clone(),
189 &self.metrics.proxy_ip_blocklist_len,
190 );
191 let (client_check, proxied_client_check) =
192 futures::future::join(client_check, proxied_client_check).await;
193 client_check && proxied_client_check
194 }
195
196 pub fn dry_run_mode(&self) -> bool {
197 self.dry_run_mode
198 }
199
200 async fn check_and_clear_blocklist(
201 &self,
202 client: &Option<IpAddr>,
203 blocklist: Blocklist,
204 blocklist_len_gauge: &IntGauge,
205 ) -> bool {
206 let client = match client {
207 Some(client) => client,
208 None => return true,
209 };
210 let now = SystemTime::now();
211 let (should_block, should_remove) = {
214 match blocklist.get(client) {
215 Some(expiration) if now >= *expiration => (false, true),
216 None => (false, false),
217 _ => (true, false),
218 }
219 };
220 if should_remove {
221 blocklist_len_gauge.dec();
222 blocklist.remove(client);
223 }
224 !should_block
225 }
226}
227
228async fn run_clear_blocklists_loop(blocklists: Blocklists, metrics: Arc<TrafficControllerMetrics>) {
235 loop {
236 tokio::time::sleep(Duration::from_secs(3)).await;
237 let now = SystemTime::now();
238 blocklists.clients.retain(|_, expiration| now < *expiration);
239 blocklists
240 .proxied_clients
241 .retain(|_, expiration| now < *expiration);
242 metrics
243 .connection_ip_blocklist_len
244 .set(blocklists.clients.len() as i64);
245 metrics
246 .proxy_ip_blocklist_len
247 .set(blocklists.proxied_clients.len() as i64);
248 }
249}
250
251async fn run_tally_loop(
252 mut receiver: mpsc::Receiver<TrafficTally>,
253 policy_config: PolicyConfig,
254 fw_config: Option<RemoteFirewallConfig>,
255 blocklists: Blocklists,
256 metrics: Arc<TrafficControllerMetrics>,
257 mut mem_drainfile_present: bool,
258) {
259 let mut spam_policy = TrafficControlPolicy::from_spam_config(policy_config.clone()).await;
260 let mut error_policy = TrafficControlPolicy::from_error_config(policy_config.clone()).await;
261 let spam_blocklists = Arc::new(blocklists.clone());
262 let error_blocklists = Arc::new(blocklists);
263 let node_fw_client = fw_config
264 .as_ref()
265 .map(|fw_config| NodeFWClient::new(fw_config.remote_fw_url.clone()));
266
267 let timeout = fw_config
268 .as_ref()
269 .map(|fw_config| fw_config.drain_timeout_secs)
270 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
271 let mut metric_timer = Instant::now();
272
273 loop {
274 tokio::select! {
275 received = receiver.recv() => {
276 metrics.tallies.inc();
277 match received {
278 Some(tally) => {
279 if let Err(err) = handle_spam_tally(
281 &mut spam_policy,
282 &policy_config,
283 &node_fw_client,
284 &fw_config,
285 tally.clone(),
286 spam_blocklists.clone(),
287 metrics.clone(),
288 mem_drainfile_present,
289 )
290 .await {
291 warn!("Error handling spam tally: {}", err);
292 }
293 if let Err(err) = handle_error_tally(
294 &mut error_policy,
295 &policy_config,
296 &node_fw_client,
297 &fw_config,
298 tally,
299 error_blocklists.clone(),
300 metrics.clone(),
301 mem_drainfile_present,
302 )
303 .await {
304 warn!("Error handling error tally: {}", err);
305 }
306 }
307 None => {
308 info!("TrafficController tally channel closed by all senders");
309 return;
310 }
311 }
312 }
313 _ = tokio::time::sleep(tokio::time::Duration::from_secs(timeout)) => {
315 if let Some(fw_config) = &fw_config {
316 error!("No traffic tallies received in {} seconds.", timeout);
317 if mem_drainfile_present {
318 continue;
319 }
320 if !fw_config.drain_path.exists() {
321 mem_drainfile_present = true;
322 warn!("Draining Node firewall.");
323 File::create(&fw_config.drain_path)
324 .expect("Failed to touch nodefw drain file");
325 metrics.deadmans_switch_enabled.set(1);
326 }
327 }
328 }
329 }
330
331 if metric_timer.elapsed() > Duration::from_secs(METRICS_INTERVAL_SECS) {
334 if let TrafficControlPolicy::FreqThreshold(spam_policy) = &spam_policy {
335 if let Some(highest_direct_rate) = spam_policy.highest_direct_rate() {
336 metrics
337 .highest_direct_spam_rate
338 .set(highest_direct_rate.0 as i64);
339 trace!("Recent highest direct spam rate: {:?}", highest_direct_rate);
340 }
341 if let Some(highest_proxied_rate) = spam_policy.highest_proxied_rate() {
342 metrics
343 .highest_proxied_spam_rate
344 .set(highest_proxied_rate.0 as i64);
345 trace!(
346 "Recent highest proxied spam rate: {:?}",
347 highest_proxied_rate
348 );
349 }
350 }
351 if let TrafficControlPolicy::FreqThreshold(error_policy) = &error_policy {
352 if let Some(highest_direct_rate) = error_policy.highest_direct_rate() {
353 metrics
354 .highest_direct_error_rate
355 .set(highest_direct_rate.0 as i64);
356 trace!(
357 "Recent highest direct error rate: {:?}",
358 highest_direct_rate
359 );
360 }
361 if let Some(highest_proxied_rate) = error_policy.highest_proxied_rate() {
362 metrics
363 .highest_proxied_error_rate
364 .set(highest_proxied_rate.0 as i64);
365 trace!(
366 "Recent highest proxied error rate: {:?}",
367 highest_proxied_rate
368 );
369 }
370 }
371 metric_timer = Instant::now();
372 }
373 }
374}
375
376async fn handle_error_tally(
377 policy: &mut TrafficControlPolicy,
378 policy_config: &PolicyConfig,
379 nodefw_client: &Option<NodeFWClient>,
380 fw_config: &Option<RemoteFirewallConfig>,
381 tally: TrafficTally,
382 blocklists: Arc<Blocklists>,
383 metrics: Arc<TrafficControllerMetrics>,
384 mem_drainfile_present: bool,
385) -> Result<(), reqwest::Error> {
386 if !tally.error_weight.is_sampled() {
387 return Ok(());
388 }
389 let resp = policy.handle_tally(tally.clone());
390 metrics.error_tally_handled.inc();
391 if let Some(fw_config) = fw_config {
392 if fw_config.delegate_error_blocking && !mem_drainfile_present {
393 let client = nodefw_client
394 .as_ref()
395 .expect("Expected NodeFWClient for blocklist delegation");
396 return delegate_policy_response(
397 resp,
398 policy_config,
399 client,
400 fw_config.destination_port,
401 metrics.clone(),
402 )
403 .await;
404 }
405 }
406 handle_policy_response(resp, policy_config, blocklists, metrics).await;
407 Ok(())
408}
409
410async fn handle_spam_tally(
411 policy: &mut TrafficControlPolicy,
412 policy_config: &PolicyConfig,
413 nodefw_client: &Option<NodeFWClient>,
414 fw_config: &Option<RemoteFirewallConfig>,
415 tally: TrafficTally,
416 blocklists: Arc<Blocklists>,
417 metrics: Arc<TrafficControllerMetrics>,
418 mem_drainfile_present: bool,
419) -> Result<(), reqwest::Error> {
420 if !(tally.spam_weight.is_sampled() && policy_config.spam_sample_rate.is_sampled()) {
421 return Ok(());
422 }
423 let resp = policy.handle_tally(tally.clone());
424 metrics.tally_handled.inc();
425 if let Some(fw_config) = fw_config {
426 if fw_config.delegate_spam_blocking && !mem_drainfile_present {
427 let client = nodefw_client
428 .as_ref()
429 .expect("Expected NodeFWClient for blocklist delegation");
430 return delegate_policy_response(
431 resp,
432 policy_config,
433 client,
434 fw_config.destination_port,
435 metrics.clone(),
436 )
437 .await;
438 }
439 }
440 handle_policy_response(resp, policy_config, blocklists, metrics).await;
441 Ok(())
442}
443
444async fn handle_policy_response(
445 response: PolicyResponse,
446 policy_config: &PolicyConfig,
447 blocklists: Arc<Blocklists>,
448 metrics: Arc<TrafficControllerMetrics>,
449) {
450 let PolicyResponse {
451 block_client,
452 block_proxied_client,
453 } = response;
454 let PolicyConfig {
455 connection_blocklist_ttl_sec,
456 proxy_blocklist_ttl_sec,
457 ..
458 } = policy_config;
459 if let Some(client) = block_client {
460 if blocklists
461 .clients
462 .insert(
463 client,
464 SystemTime::now() + Duration::from_secs(*connection_blocklist_ttl_sec),
465 )
466 .is_none()
467 {
468 debug!("Blocking client: {:?}", client);
470 metrics.connection_ip_blocklist_len.inc();
471 }
472 }
473 if let Some(client) = block_proxied_client {
474 if blocklists
475 .proxied_clients
476 .insert(
477 client,
478 SystemTime::now() + Duration::from_secs(*proxy_blocklist_ttl_sec),
479 )
480 .is_none()
481 {
482 debug!("Blocking proxied client: {:?}", client);
484 metrics.proxy_ip_blocklist_len.inc();
485 }
486 }
487}
488
489async fn delegate_policy_response(
490 response: PolicyResponse,
491 policy_config: &PolicyConfig,
492 node_fw_client: &NodeFWClient,
493 destination_port: u16,
494 metrics: Arc<TrafficControllerMetrics>,
495) -> Result<(), reqwest::Error> {
496 let PolicyResponse {
497 block_client,
498 block_proxied_client,
499 } = response;
500 let PolicyConfig {
501 connection_blocklist_ttl_sec,
502 proxy_blocklist_ttl_sec,
503 ..
504 } = policy_config;
505 let mut addresses = vec![];
506 if let Some(client_id) = block_client {
507 debug!("Delegating client blocking to firewall");
508 addresses.push(BlockAddress {
509 source_address: client_id.to_string(),
510 destination_port,
511 ttl: *connection_blocklist_ttl_sec,
512 });
513 }
514 if let Some(ip) = block_proxied_client {
515 debug!("Delegating proxied client blocking to firewall");
516 addresses.push(BlockAddress {
517 source_address: ip.to_string(),
518 destination_port,
519 ttl: *proxy_blocklist_ttl_sec,
520 });
521 }
522 if addresses.is_empty() {
523 Ok(())
524 } else {
525 metrics
526 .blocks_delegated_to_firewall
527 .inc_by(addresses.len() as u64);
528 match node_fw_client
529 .block_addresses(BlockAddresses { addresses })
530 .await
531 {
532 Ok(()) => Ok(()),
533 Err(err) => {
534 metrics.firewall_delegation_request_fail.inc();
535 Err(err)
536 }
537 }
538 }
539}
540
541#[derive(Debug, Clone)]
542pub struct TrafficSimMetrics {
543 pub num_requests: u64,
544 pub num_blocked: u64,
545 pub time_to_first_block: Option<Duration>,
546 pub abs_time_to_first_block: Option<Duration>,
547 pub total_time_blocked: Duration,
548 pub num_blocklist_adds: u64,
549}
550
551impl Default for TrafficSimMetrics {
552 fn default() -> Self {
553 Self {
554 num_requests: 0,
555 num_blocked: 0,
556 time_to_first_block: None,
557 abs_time_to_first_block: None,
558 total_time_blocked: Duration::from_micros(0),
559 num_blocklist_adds: 0,
560 }
561 }
562}
563
564impl Add for TrafficSimMetrics {
565 type Output = Self;
566
567 fn add(self, other: Self) -> Self {
568 Self {
569 num_requests: self.num_requests + other.num_requests,
570 num_blocked: self.num_blocked + other.num_blocked,
571 time_to_first_block: match (self.time_to_first_block, other.time_to_first_block) {
572 (Some(a), Some(b)) => Some(a + b),
573 (Some(a), None) => Some(a),
574 (None, Some(b)) => Some(b),
575 (None, None) => None,
576 },
577 abs_time_to_first_block: match (
578 self.abs_time_to_first_block,
579 other.abs_time_to_first_block,
580 ) {
581 (Some(a), Some(b)) => Some(a.min(b)),
582 (Some(a), None) => Some(a),
583 (None, Some(b)) => Some(b),
584 (None, None) => None,
585 },
586 total_time_blocked: self.total_time_blocked + other.total_time_blocked,
587 num_blocklist_adds: self.num_blocklist_adds + other.num_blocklist_adds,
588 }
589 }
590}
591
592pub struct TrafficSim {
593 pub traffic_controller: TrafficController,
594}
595
596impl TrafficSim {
597 pub async fn run(
598 policy: PolicyConfig,
599 num_clients: u8,
600 per_client_tps: usize,
601 duration: Duration,
602 report: bool,
603 ) -> TrafficSimMetrics {
604 assert!(
605 per_client_tps <= 10_000,
606 "per_client_tps must be less than 10,000. For higher values, increase num_clients"
607 );
608 assert!(num_clients < 20, "num_clients must be greater than 0");
609 assert!(num_clients > 0);
610 assert!(per_client_tps > 0);
611 assert!(duration.as_secs() > 0);
612
613 let controller = TrafficController::spawn_for_test(policy.clone(), None);
614 let tasks = (0..num_clients).map(|task_num| {
615 tokio::spawn(Self::run_single_client(
616 controller.clone(),
617 duration,
618 task_num,
619 per_client_tps,
620 ))
621 });
622
623 let status_task = if report {
624 Some(tokio::spawn(async move {
625 println!(
626 "Running naive traffic simulation for {} seconds",
627 duration.as_secs()
628 );
629 println!("Policy: {:#?}", policy);
630 println!("Num clients: {}", num_clients);
631 println!("TPS per client: {}", per_client_tps);
632 println!(
633 "Target total TPS: {}",
634 per_client_tps * num_clients as usize
635 );
636 println!("\n");
637 for _ in 0..duration.as_secs() {
638 print!(".");
639 tokio::time::sleep(Duration::from_secs(1)).await;
640 }
641 println!();
642 }))
643 } else {
644 None
645 };
646
647 let metrics = futures::future::join_all(tasks).await.into_iter().fold(
648 TrafficSimMetrics::default(),
649 |acc, run_client_ret| {
650 if run_client_ret.is_err() {
651 error!(
652 "Error running traffic sim client: {:?}",
653 run_client_ret.err()
654 );
655 acc
656 } else {
657 let metrics = run_client_ret.unwrap();
658 acc + metrics
659 }
660 },
661 );
662
663 if report {
664 status_task.unwrap().await.unwrap();
665 Self::report_metrics(metrics.clone(), duration, per_client_tps, num_clients);
666 }
667 metrics
668 }
669
670 async fn run_single_client(
671 controller: TrafficController,
672 duration: Duration,
673 task_num: u8,
674 per_client_tps: usize,
675 ) -> TrafficSimMetrics {
676 let sleep_time = Duration::from_micros(rand::thread_rng().gen_range(0..100));
680 tokio::time::sleep(sleep_time).await;
681
682 let mut num_requests = 0;
684 let mut num_blocked = 0;
685 let mut time_to_first_block = None;
686 let mut total_time_blocked = Duration::from_micros(0);
687 let mut num_blocklist_adds = 0;
688 let mut currently_blocked = false;
690 let mut time_blocked_start = Instant::now();
691 let start = Instant::now();
692
693 let sleep_time = Duration::from_micros(1_000_000 / per_client_tps as u64);
695 let mut interval_ticker = time::interval(sleep_time);
696
697 while start.elapsed() < duration {
698 let client = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, task_num)));
699 let allowed = controller.check(&client, &None).await;
700 if allowed {
701 if currently_blocked {
702 total_time_blocked += time_blocked_start.elapsed();
703 currently_blocked = false;
704 }
705 controller.tally(TrafficTally::new(
706 client,
707 None,
709 Weight::one(),
711 Weight::one(),
712 ));
713 } else {
714 if !currently_blocked {
715 time_blocked_start = Instant::now();
716 currently_blocked = true;
717 num_blocklist_adds += 1;
718 if time_to_first_block.is_none() {
719 time_to_first_block = Some(start.elapsed());
720 }
721 }
722 num_blocked += 1;
723 }
724 num_requests += 1;
725
726 interval_ticker.tick().await;
727 }
728 TrafficSimMetrics {
729 num_requests,
730 num_blocked,
731 time_to_first_block,
732 abs_time_to_first_block: time_to_first_block,
733 total_time_blocked,
734 num_blocklist_adds,
735 }
736 }
737
738 fn report_metrics(
739 metrics: TrafficSimMetrics,
740 duration: Duration,
741 per_client_tps: usize,
742 num_clients: u8,
743 ) {
744 println!("TrafficSim metrics:");
745 println!("-------------------");
746 println!(
748 "Num expected requests: {}",
749 per_client_tps * (num_clients as usize) * duration.as_secs() as usize
750 );
751 println!("Num actual requests: {}", metrics.num_requests);
752 println!("Num blocked requests: {}", metrics.num_blocked);
756 println!(
760 "Num times added to blocklist: {}",
761 metrics.num_blocklist_adds
762 );
763 let avg_first_block_time = metrics
767 .time_to_first_block
768 .map(|ttf| ttf / num_clients as u32);
769 println!("Average time to first block: {:?}", avg_first_block_time);
770 println!(
774 "Absolute time to first block (across all clients): {:?}",
775 metrics.abs_time_to_first_block
776 );
777 let avg_time_blocked = if metrics.num_blocklist_adds > 0 {
779 metrics.total_time_blocked.as_millis() as u64 / metrics.num_blocklist_adds
780 } else {
781 0
782 };
783 println!(
784 "Average time blocked (ttl): {:?}",
785 Duration::from_millis(avg_time_blocked)
786 );
787 }
788}