1use std::{
7 cmp::Reverse,
8 collections::{BinaryHeap, HashMap, VecDeque},
9 fmt::Debug,
10 hash::Hash,
11 net::IpAddr,
12 sync::Arc,
13 time::{Duration, Instant, SystemTime},
14};
15
16use count_min_sketch::CountMinSketch32;
17use iota_metrics::spawn_monitored_task;
18use iota_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
19use parking_lot::RwLock;
20use tracing::info;
21
22const HIGHEST_RATES_CAPACITY: usize = 20;
23
24#[derive(Hash, Eq, PartialEq, Debug)]
26enum ClientType {
27 Direct,
28 ThroughFullnode,
29}
30
31#[derive(Hash, Eq, PartialEq, Debug)]
32struct SketchKey(IpAddr, ClientType);
33
34struct HighestRates {
35 direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
36 proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
37 capacity: usize,
38}
39
40pub struct TrafficSketch {
41 sketches: VecDeque<CountMinSketch32<SketchKey>>,
52 window_size: Duration,
53 update_interval: Duration,
54 last_reset_time: Instant,
55 current_sketch_index: usize,
56 highest_rates: HighestRates,
65}
66
67impl TrafficSketch {
68 pub fn new(
69 window_size: Duration,
70 update_interval: Duration,
71 sketch_capacity: usize,
72 sketch_probability: f64,
73 sketch_tolerance: f64,
74 highest_rates_capacity: usize,
75 ) -> Self {
76 let num_sketches = window_size.as_secs() / update_interval.as_secs();
78 let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
79 if new_window_size != window_size {
80 info!(
81 "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
82 new_window_size.as_secs(),
83 update_interval.as_secs(),
84 );
85 }
86 let window_size = new_window_size;
87
88 assert!(
89 window_size < Duration::from_secs(600),
90 "window_size too large. Max 600 seconds"
91 );
92 assert!(
93 update_interval < window_size,
94 "Update interval may not be larger than window size"
95 );
96 assert!(
97 update_interval >= Duration::from_secs(1),
98 "Update interval too short, must be at least 1 second"
99 );
100 assert!(
101 num_sketches <= 10,
102 "Given parameters require too many sketches to be stored. Reduce window size or increase update interval."
103 );
104 let mem_estimate = (num_sketches as usize)
105 * CountMinSketch32::<IpAddr>::estimate_memory(
106 sketch_capacity,
107 sketch_probability,
108 sketch_tolerance,
109 )
110 .expect("Failed to estimate memory for CountMinSketch32");
111 assert!(
112 mem_estimate < 128_000_000,
113 "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval."
114 );
115
116 let mut sketches = VecDeque::with_capacity(num_sketches as usize);
117 for _ in 0..num_sketches {
118 sketches.push_back(
119 CountMinSketch32::<SketchKey>::new(
120 sketch_capacity,
121 sketch_probability,
122 sketch_tolerance,
123 )
124 .expect("Failed to create CountMinSketch32"),
125 );
126 }
127 Self {
128 sketches,
129 window_size,
130 update_interval,
131 last_reset_time: Instant::now(),
132 current_sketch_index: 0,
133 highest_rates: HighestRates {
134 direct: BinaryHeap::with_capacity(highest_rates_capacity),
135 proxied: BinaryHeap::with_capacity(highest_rates_capacity),
136 capacity: highest_rates_capacity,
137 },
138 }
139 }
140
141 fn increment_count(&mut self, key: &SketchKey) {
142 let current_time = Instant::now();
144 let mut elapsed = current_time.duration_since(self.last_reset_time);
145 while elapsed >= self.update_interval {
146 self.rotate_window();
147 elapsed -= self.update_interval;
148 }
149 self.sketches[self.current_sketch_index].increment(key);
151 }
152
153 fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
154 let count: u32 = self
155 .sketches
156 .iter()
157 .map(|sketch| sketch.estimate(key))
158 .sum();
159 let rate = count as f64 / self.window_size.as_secs() as f64;
160 self.update_highest_rates(key, rate);
161 rate
162 }
163
164 fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
165 match key.1 {
166 ClientType::Direct => {
167 Self::update_highest_rate(
168 &mut self.highest_rates.direct,
169 key.0,
170 rate,
171 self.highest_rates.capacity,
172 );
173 }
174 ClientType::ThroughFullnode => {
175 Self::update_highest_rate(
176 &mut self.highest_rates.proxied,
177 key.0,
178 rate,
179 self.highest_rates.capacity,
180 );
181 }
182 }
183 }
184
185 fn update_highest_rate(
186 rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
187 ip_addr: IpAddr,
188 rate: f64,
189 capacity: usize,
190 ) {
191 rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
194
195 let rate = rate as u64;
196 if rate_heap.len() < capacity {
197 rate_heap.push(Reverse((rate, ip_addr)));
198 } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek() {
199 if rate > smallest_score {
200 rate_heap.pop();
201 rate_heap.push(Reverse((rate, ip_addr)));
202 }
203 }
204 }
205
206 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
207 self.highest_rates
208 .direct
209 .iter()
210 .map(|Reverse(v)| v)
211 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
212 .copied()
213 }
214
215 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
216 self.highest_rates
217 .proxied
218 .iter()
219 .map(|Reverse(v)| v)
220 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
221 .copied()
222 }
223
224 fn rotate_window(&mut self) {
225 self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
226 self.sketches[self.current_sketch_index].clear();
227 self.last_reset_time = Instant::now();
228 }
229}
230
231#[derive(Clone, Debug)]
232pub struct TrafficTally {
233 pub direct: Option<IpAddr>,
234 pub through_fullnode: Option<IpAddr>,
235 pub error_weight: Weight,
236 pub spam_weight: Weight,
237 pub timestamp: SystemTime,
238}
239
240impl TrafficTally {
241 pub fn new(
242 direct: Option<IpAddr>,
243 through_fullnode: Option<IpAddr>,
244 error_weight: Weight,
245 spam_weight: Weight,
246 ) -> Self {
247 Self {
248 direct,
249 through_fullnode,
250 error_weight,
251 spam_weight,
252 timestamp: SystemTime::now(),
253 }
254 }
255}
256
257#[derive(Clone, Debug, Default)]
258pub struct PolicyResponse {
259 pub block_client: Option<IpAddr>,
260 pub block_proxied_client: Option<IpAddr>,
261}
262
263pub trait Policy {
264 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
267 fn policy_config(&self) -> &PolicyConfig;
268}
269
270pub enum TrafficControlPolicy {
273 FreqThreshold(FreqThresholdPolicy),
274 NoOp(NoOpPolicy),
275 TestNConnIP(TestNConnIPPolicy),
277 TestPanicOnInvocation(TestPanicOnInvocationPolicy),
278}
279
280impl Policy for TrafficControlPolicy {
281 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
282 match self {
283 TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
284 TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
285 TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
286 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
287 }
288 }
289
290 fn policy_config(&self) -> &PolicyConfig {
291 match self {
292 TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
293 TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
294 TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
295 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
296 }
297 }
298}
299
300impl TrafficControlPolicy {
301 pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
302 Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
303 }
304 pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
305 Self::from_config(policy_config.clone().error_policy_type, policy_config).await
306 }
307 pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
308 match policy_type {
309 PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
310 PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
311 FreqThresholdPolicy::new(policy_config, freq_threshold_config),
312 ),
313 PolicyType::TestNConnIP(n) => {
314 Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
315 }
316 PolicyType::TestPanicOnInvocation => {
317 Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
318 }
319 }
320 }
321}
322
323pub struct FreqThresholdPolicy {
326 config: PolicyConfig,
327 sketch: TrafficSketch,
328 client_threshold: u64,
329 proxied_client_threshold: u64,
330}
331
332impl FreqThresholdPolicy {
333 pub fn new(
334 config: PolicyConfig,
335 FreqThresholdConfig {
336 client_threshold,
337 proxied_client_threshold,
338 window_size_secs,
339 update_interval_secs,
340 sketch_capacity,
341 sketch_probability,
342 sketch_tolerance,
343 }: FreqThresholdConfig,
344 ) -> Self {
345 let sketch = TrafficSketch::new(
346 Duration::from_secs(window_size_secs),
347 Duration::from_secs(update_interval_secs),
348 sketch_capacity,
349 sketch_probability,
350 sketch_tolerance,
351 HIGHEST_RATES_CAPACITY,
352 );
353 Self {
354 config,
355 sketch,
356 client_threshold,
357 proxied_client_threshold,
358 }
359 }
360
361 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
362 self.sketch.highest_direct_rate()
363 }
364
365 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
366 self.sketch.highest_proxied_rate()
367 }
368
369 pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
370 let block_client = if let Some(source) = tally.direct {
371 let key = SketchKey(source, ClientType::Direct);
372 self.sketch.increment_count(&key);
373 if self.sketch.get_request_rate(&key) >= self.client_threshold as f64 {
374 Some(source)
375 } else {
376 None
377 }
378 } else {
379 None
380 };
381 let block_proxied_client = if let Some(source) = tally.through_fullnode {
382 let key = SketchKey(source, ClientType::ThroughFullnode);
383 self.sketch.increment_count(&key);
384 if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
385 Some(source)
386 } else {
387 None
388 }
389 } else {
390 None
391 };
392 PolicyResponse {
393 block_client,
394 block_proxied_client,
395 }
396 }
397
398 fn policy_config(&self) -> &PolicyConfig {
399 &self.config
400 }
401}
402
403#[derive(Clone)]
406pub struct NoOpPolicy {
407 config: PolicyConfig,
408}
409
410impl NoOpPolicy {
411 pub fn new(config: PolicyConfig) -> Self {
412 Self { config }
413 }
414
415 fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
416 PolicyResponse::default()
417 }
418
419 fn policy_config(&self) -> &PolicyConfig {
420 &self.config
421 }
422}
423
424#[derive(Clone)]
425pub struct TestNConnIPPolicy {
426 config: PolicyConfig,
427 frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
428 threshold: u64,
429}
430
431impl TestNConnIPPolicy {
432 pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
433 let frequencies = Arc::new(RwLock::new(HashMap::new()));
434 let frequencies_clone = frequencies.clone();
435 spawn_monitored_task!(run_clear_frequencies(
436 frequencies_clone,
437 config.connection_blocklist_ttl_sec * 2,
438 ));
439 Self {
440 config,
441 frequencies,
442 threshold,
443 }
444 }
445
446 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
447 let client = if let Some(client) = tally.direct {
448 client
449 } else {
450 return PolicyResponse::default();
451 };
452
453 let mut frequencies = self.frequencies.write();
455 let count = frequencies.entry(client).or_insert(0);
456 *count += 1;
457 PolicyResponse {
458 block_client: if *count >= self.threshold {
459 Some(client)
460 } else {
461 None
462 },
463 block_proxied_client: None,
464 }
465 }
466
467 fn policy_config(&self) -> &PolicyConfig {
468 &self.config
469 }
470}
471
472async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
473 loop {
474 tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
475 frequencies.write().clear();
476 }
477}
478
479#[derive(Clone)]
480pub struct TestPanicOnInvocationPolicy {
481 config: PolicyConfig,
482}
483
484impl TestPanicOnInvocationPolicy {
485 pub fn new(config: PolicyConfig) -> Self {
486 Self { config }
487 }
488
489 fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
490 panic!("Tally for this policy should never be invoked")
491 }
492
493 fn policy_config(&self) -> &PolicyConfig {
494 &self.config
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use std::net::{IpAddr, Ipv4Addr};
501
502 use iota_macros::sim_test;
503 use iota_types::traffic_control::{
504 DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
505 };
506
507 use super::*;
508
509 #[sim_test]
510 async fn test_freq_threshold_policy() {
511 let mut policy = FreqThresholdPolicy::new(
515 PolicyConfig::default(),
516 FreqThresholdConfig {
517 client_threshold: 5,
518 proxied_client_threshold: 2,
519 window_size_secs: 5,
520 update_interval_secs: 1,
521 ..Default::default()
522 },
523 );
524 let alice = TrafficTally {
528 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
529 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
530 error_weight: Weight::zero(),
531 spam_weight: Weight::one(),
532 timestamp: SystemTime::now(),
533 };
534 let bob = TrafficTally {
535 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
536 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
537 error_weight: Weight::zero(),
538 spam_weight: Weight::one(),
539 timestamp: SystemTime::now(),
540 };
541 let charlie = TrafficTally {
542 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
543 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
544 error_weight: Weight::zero(),
545 spam_weight: Weight::one(),
546 timestamp: SystemTime::now(),
547 };
548
549 for _ in 0..2 {
551 let response = policy.handle_tally(alice.clone());
552 assert_eq!(response.block_proxied_client, None);
553 assert_eq!(response.block_client, None);
554 }
555
556 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
557 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
558 assert_eq!(direct_ip_addr, alice.direct.unwrap());
559 assert!(direct_rate < 1);
560 assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
561 assert!(proxied_rate < 1);
562
563 for _ in 0..9 {
565 let response = policy.handle_tally(bob.clone());
566 assert_eq!(response.block_client, None);
567 assert_eq!(response.block_proxied_client, None);
568 }
569 let response = policy.handle_tally(bob.clone());
570 assert_eq!(response.block_client, None);
571 assert_eq!(response.block_proxied_client, bob.through_fullnode);
572
573 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
575 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
576 assert_eq!(direct_ip_addr, bob.direct.unwrap());
577 assert_eq!(direct_rate, 2);
578 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
579 assert_eq!(proxied_rate, 2);
580
581 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
585 for _ in 0..2 {
586 let response = policy.handle_tally(alice.clone());
587 assert_eq!(response.block_client, None);
588 assert_eq!(response.block_proxied_client, None);
589 }
590 let _ = policy.handle_tally(bob.clone());
593 let response = policy.handle_tally(bob.clone());
594 assert_eq!(response.block_client, None);
595 assert_eq!(response.block_proxied_client, bob.through_fullnode);
596
597 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
598 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
599 assert_eq!(direct_ip_addr, alice.direct.unwrap());
601 assert_eq!(direct_rate, 3);
602 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
605 assert_eq!(proxied_rate, 2);
606
607 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
609 for i in 0..5 {
610 let response = policy.handle_tally(alice.clone());
611 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
612 assert_eq!(response.block_proxied_client, None);
613 }
614
615 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
617 let response = policy.handle_tally(alice.clone());
618 assert_eq!(response.block_client, None);
619 assert_eq!(response.block_proxied_client, alice.through_fullnode);
620
621 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
622 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
623 assert_eq!(direct_ip_addr, alice.direct.unwrap());
624 assert_eq!(direct_rate, 4);
625 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
626 assert_eq!(proxied_rate, 2);
627
628 for i in 0..2 {
630 let response = policy.handle_tally(charlie.clone());
631 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
632 assert_eq!(response.block_proxied_client, None);
633 }
634 let response = policy.handle_tally(charlie.clone());
636 assert_eq!(response.block_proxied_client, None);
637 assert_eq!(response.block_client, charlie.direct);
638
639 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
643 for i in 0..3 {
644 let response = policy.handle_tally(charlie.clone());
645 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
646 assert_eq!(response.block_proxied_client, None);
647 }
648
649 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
651 let _ = policy.handle_tally(alice.clone());
654 let _ = policy.handle_tally(bob.clone());
655 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
656 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
657 assert_eq!(direct_ip_addr, alice.direct.unwrap());
658 assert_eq!(direct_rate, 0);
659 assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
660 assert_eq!(proxied_rate, 1);
661 }
662
663 #[sim_test]
664 async fn test_traffic_sketch_mem_estimate() {
665 let window_size = Duration::from_secs(30);
669 let update_interval = Duration::from_secs(5);
670 let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
671 DEFAULT_SKETCH_CAPACITY,
672 DEFAULT_SKETCH_PROBABILITY,
673 DEFAULT_SKETCH_TOLERANCE,
674 )
675 .unwrap()
676 * (window_size.as_secs() / update_interval.as_secs()) as usize;
677 assert!(
678 mem_estimate < 128_000_000,
679 "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
680 );
681 }
682}