1use std::{
7 cmp::Reverse,
8 collections::{BinaryHeap, HashMap, VecDeque},
9 fmt::Debug,
10 hash::Hash,
11 net::IpAddr,
12 sync::Arc,
13 time::{Duration, Instant, SystemTime},
14};
15
16use count_min_sketch::CountMinSketch32;
17use iota_metrics::spawn_monitored_task;
18use iota_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
19use parking_lot::RwLock;
20use tracing::{info, trace};
21
22const HIGHEST_RATES_CAPACITY: usize = 20;
23
24#[derive(Hash, Eq, PartialEq, Debug)]
26enum ClientType {
27 Direct,
28 ThroughFullnode,
29}
30
31#[derive(Hash, Eq, PartialEq, Debug)]
32struct SketchKey {
33 salt: u64,
34 ip_addr: IpAddr,
35 client_type: ClientType,
36}
37
38struct HighestRates {
39 direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
40 proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
41 capacity: usize,
42}
43
44pub struct TrafficSketch {
45 sketches: VecDeque<CountMinSketch32<SketchKey>>,
56 window_size: Duration,
57 update_interval: Duration,
58 last_reset_time: Instant,
59 current_sketch_index: usize,
60 highest_rates: HighestRates,
69}
70
71impl TrafficSketch {
72 pub fn new(
73 window_size: Duration,
74 update_interval: Duration,
75 sketch_capacity: usize,
76 sketch_probability: f64,
77 sketch_tolerance: f64,
78 highest_rates_capacity: usize,
79 ) -> Self {
80 let num_sketches = window_size.as_secs() / update_interval.as_secs();
82 let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
83 if new_window_size != window_size {
84 info!(
85 "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
86 new_window_size.as_secs(),
87 update_interval.as_secs(),
88 );
89 }
90 let window_size = new_window_size;
91
92 assert!(
93 window_size < Duration::from_secs(600),
94 "window_size too large. Max 600 seconds"
95 );
96 assert!(
97 update_interval < window_size,
98 "Update interval may not be larger than window size"
99 );
100 assert!(
101 update_interval >= Duration::from_secs(1),
102 "Update interval too short, must be at least 1 second"
103 );
104 assert!(
105 num_sketches <= 10,
106 "Given parameters require too many sketches to be stored. Reduce window size or increase update interval."
107 );
108 let mem_estimate = (num_sketches as usize)
109 * CountMinSketch32::<IpAddr>::estimate_memory(
110 sketch_capacity,
111 sketch_probability,
112 sketch_tolerance,
113 )
114 .expect("Failed to estimate memory for CountMinSketch32");
115 assert!(
116 mem_estimate < 128_000_000,
117 "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval."
118 );
119
120 let mut sketches = VecDeque::with_capacity(num_sketches as usize);
121 for _ in 0..num_sketches {
122 sketches.push_back(
123 CountMinSketch32::<SketchKey>::new(
124 sketch_capacity,
125 sketch_probability,
126 sketch_tolerance,
127 )
128 .expect("Failed to create CountMinSketch32"),
129 );
130 }
131 Self {
132 sketches,
133 window_size,
134 update_interval,
135 last_reset_time: Instant::now(),
136 current_sketch_index: 0,
137 highest_rates: HighestRates {
138 direct: BinaryHeap::with_capacity(highest_rates_capacity),
139 proxied: BinaryHeap::with_capacity(highest_rates_capacity),
140 capacity: highest_rates_capacity,
141 },
142 }
143 }
144
145 fn increment_count(&mut self, key: &SketchKey) {
146 let current_time = Instant::now();
148 let mut elapsed = current_time.duration_since(self.last_reset_time);
149 while elapsed >= self.update_interval {
150 self.rotate_window();
151 elapsed -= self.update_interval;
152 }
153 self.sketches[self.current_sketch_index].increment(key);
155 }
156
157 fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
158 let count: u32 = self
159 .sketches
160 .iter()
161 .map(|sketch| sketch.estimate(key))
162 .sum();
163 let rate = count as f64 / self.window_size.as_secs() as f64;
164 self.update_highest_rates(key, rate);
165 rate
166 }
167
168 fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
169 match key.client_type {
170 ClientType::Direct => {
171 Self::update_highest_rate(
172 &mut self.highest_rates.direct,
173 key.ip_addr,
174 rate,
175 self.highest_rates.capacity,
176 );
177 }
178 ClientType::ThroughFullnode => {
179 Self::update_highest_rate(
180 &mut self.highest_rates.proxied,
181 key.ip_addr,
182 rate,
183 self.highest_rates.capacity,
184 );
185 }
186 }
187 }
188
189 fn update_highest_rate(
190 rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
191 ip_addr: IpAddr,
192 rate: f64,
193 capacity: usize,
194 ) {
195 rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
198
199 let rate = rate as u64;
200 if rate_heap.len() < capacity {
201 rate_heap.push(Reverse((rate, ip_addr)));
202 } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek() {
203 if rate > smallest_score {
204 rate_heap.pop();
205 rate_heap.push(Reverse((rate, ip_addr)));
206 }
207 }
208 }
209
210 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
211 self.highest_rates
212 .direct
213 .iter()
214 .map(|Reverse(v)| v)
215 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
216 .copied()
217 }
218
219 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
220 self.highest_rates
221 .proxied
222 .iter()
223 .map(|Reverse(v)| v)
224 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
225 .copied()
226 }
227
228 fn rotate_window(&mut self) {
229 self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
230 self.sketches[self.current_sketch_index].clear();
231 self.last_reset_time = Instant::now();
232 }
233}
234
235#[derive(Clone, Debug)]
236pub struct TrafficTally {
237 pub direct: Option<IpAddr>,
238 pub through_fullnode: Option<IpAddr>,
239 pub error_info: Option<(Weight, String)>,
240 pub spam_weight: Weight,
241 pub timestamp: SystemTime,
242}
243
244impl TrafficTally {
245 pub fn new(
246 direct: Option<IpAddr>,
247 through_fullnode: Option<IpAddr>,
248 error_info: Option<(Weight, String)>,
249 spam_weight: Weight,
250 ) -> Self {
251 Self {
252 direct,
253 through_fullnode,
254 error_info,
255 spam_weight,
256 timestamp: SystemTime::now(),
257 }
258 }
259}
260
261#[derive(Clone, Debug, Default)]
262pub struct PolicyResponse {
263 pub block_client: Option<IpAddr>,
264 pub block_proxied_client: Option<IpAddr>,
265}
266
267pub trait Policy {
268 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
271 fn policy_config(&self) -> &PolicyConfig;
272}
273
274pub enum TrafficControlPolicy {
277 FreqThreshold(FreqThresholdPolicy),
278 NoOp(NoOpPolicy),
279 TestNConnIP(TestNConnIPPolicy),
281 TestPanicOnInvocation(TestPanicOnInvocationPolicy),
282}
283
284impl Policy for TrafficControlPolicy {
285 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
286 match self {
287 TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
288 TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
289 TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
290 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
291 }
292 }
293
294 fn policy_config(&self) -> &PolicyConfig {
295 match self {
296 TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
297 TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
298 TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
299 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
300 }
301 }
302}
303
304impl TrafficControlPolicy {
305 pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
306 Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
307 }
308 pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
309 Self::from_config(policy_config.clone().error_policy_type, policy_config).await
310 }
311 pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
312 match policy_type {
313 PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
314 PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
315 FreqThresholdPolicy::new(policy_config, freq_threshold_config),
316 ),
317 PolicyType::TestNConnIP(n) => {
318 Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
319 }
320 PolicyType::TestPanicOnInvocation => {
321 Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
322 }
323 }
324 }
325}
326
327pub struct FreqThresholdPolicy {
330 config: PolicyConfig,
331 sketch: TrafficSketch,
332 client_threshold: u64,
333 proxied_client_threshold: u64,
334 salt: u64,
340}
341
342impl FreqThresholdPolicy {
343 pub fn new(
344 config: PolicyConfig,
345 FreqThresholdConfig {
346 client_threshold,
347 proxied_client_threshold,
348 window_size_secs,
349 update_interval_secs,
350 sketch_capacity,
351 sketch_probability,
352 sketch_tolerance,
353 }: FreqThresholdConfig,
354 ) -> Self {
355 let sketch = TrafficSketch::new(
356 Duration::from_secs(window_size_secs),
357 Duration::from_secs(update_interval_secs),
358 sketch_capacity,
359 sketch_probability,
360 sketch_tolerance,
361 HIGHEST_RATES_CAPACITY,
362 );
363 Self {
364 config,
365 sketch,
366 client_threshold,
367 proxied_client_threshold,
368 salt: rand::random(),
369 }
370 }
371
372 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
373 self.sketch.highest_direct_rate()
374 }
375
376 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
377 self.sketch.highest_proxied_rate()
378 }
379
380 pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
381 let block_client = if let Some(source) = tally.direct {
382 let key = SketchKey {
383 salt: self.salt,
384 ip_addr: source,
385 client_type: ClientType::Direct,
386 };
387 self.sketch.increment_count(&key);
388 let req_rate = self.sketch.get_request_rate(&key);
389 trace!(
390 "FreqThresholdPolicy handling tally -- req_rate: {:?}, client_threshold: {:?}, client: {:?}",
391 req_rate, self.client_threshold, source,
392 );
393 if req_rate >= self.client_threshold as f64 {
394 Some(source)
395 } else {
396 None
397 }
398 } else {
399 None
400 };
401 let block_proxied_client = if let Some(source) = tally.through_fullnode {
402 let key = SketchKey {
403 salt: self.salt,
404 ip_addr: source,
405 client_type: ClientType::ThroughFullnode,
406 };
407 self.sketch.increment_count(&key);
408 if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
409 Some(source)
410 } else {
411 None
412 }
413 } else {
414 None
415 };
416 PolicyResponse {
417 block_client,
418 block_proxied_client,
419 }
420 }
421
422 fn policy_config(&self) -> &PolicyConfig {
423 &self.config
424 }
425}
426
427#[derive(Clone)]
430pub struct NoOpPolicy {
431 config: PolicyConfig,
432}
433
434impl NoOpPolicy {
435 pub fn new(config: PolicyConfig) -> Self {
436 Self { config }
437 }
438
439 fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
440 PolicyResponse::default()
441 }
442
443 fn policy_config(&self) -> &PolicyConfig {
444 &self.config
445 }
446}
447
448#[derive(Clone)]
449pub struct TestNConnIPPolicy {
450 config: PolicyConfig,
451 frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
452 threshold: u64,
453}
454
455impl TestNConnIPPolicy {
456 pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
457 let frequencies = Arc::new(RwLock::new(HashMap::new()));
458 let frequencies_clone = frequencies.clone();
459 spawn_monitored_task!(run_clear_frequencies(
460 frequencies_clone,
461 config.connection_blocklist_ttl_sec * 2,
462 ));
463 Self {
464 config,
465 frequencies,
466 threshold,
467 }
468 }
469
470 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
471 let client = if let Some(client) = tally.direct {
472 client
473 } else {
474 return PolicyResponse::default();
475 };
476
477 let mut frequencies = self.frequencies.write();
479 let count = frequencies.entry(client).or_insert(0);
480 *count += 1;
481 PolicyResponse {
482 block_client: if *count >= self.threshold {
483 Some(client)
484 } else {
485 None
486 },
487 block_proxied_client: None,
488 }
489 }
490
491 fn policy_config(&self) -> &PolicyConfig {
492 &self.config
493 }
494}
495
496async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
497 loop {
498 tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
499 frequencies.write().clear();
500 }
501}
502
503#[derive(Clone)]
504pub struct TestPanicOnInvocationPolicy {
505 config: PolicyConfig,
506}
507
508impl TestPanicOnInvocationPolicy {
509 pub fn new(config: PolicyConfig) -> Self {
510 Self { config }
511 }
512
513 fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
514 panic!("Tally for this policy should never be invoked")
515 }
516
517 fn policy_config(&self) -> &PolicyConfig {
518 &self.config
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use std::net::{IpAddr, Ipv4Addr};
525
526 use iota_macros::sim_test;
527 use iota_types::traffic_control::{
528 DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
529 };
530
531 use super::*;
532
533 #[sim_test]
534 async fn test_freq_threshold_policy() {
535 let mut policy = FreqThresholdPolicy::new(
539 PolicyConfig::default(),
540 FreqThresholdConfig {
541 client_threshold: 5,
542 proxied_client_threshold: 2,
543 window_size_secs: 5,
544 update_interval_secs: 1,
545 ..Default::default()
546 },
547 );
548 let alice = TrafficTally {
552 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
553 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
554 error_info: None,
555 spam_weight: Weight::one(),
556 timestamp: SystemTime::now(),
557 };
558 let bob = TrafficTally {
559 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
560 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
561 error_info: None,
562 spam_weight: Weight::one(),
563 timestamp: SystemTime::now(),
564 };
565 let charlie = TrafficTally {
566 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
567 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
568 error_info: None,
569 spam_weight: Weight::one(),
570 timestamp: SystemTime::now(),
571 };
572
573 for _ in 0..2 {
575 let response = policy.handle_tally(alice.clone());
576 assert_eq!(response.block_proxied_client, None);
577 assert_eq!(response.block_client, None);
578 }
579
580 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
581 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
582 assert_eq!(direct_ip_addr, alice.direct.unwrap());
583 assert!(direct_rate < 1);
584 assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
585 assert!(proxied_rate < 1);
586
587 for _ in 0..9 {
589 let response = policy.handle_tally(bob.clone());
590 assert_eq!(response.block_client, None);
591 assert_eq!(response.block_proxied_client, None);
592 }
593 let response = policy.handle_tally(bob.clone());
594 assert_eq!(response.block_client, None);
595 assert_eq!(response.block_proxied_client, bob.through_fullnode);
596
597 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
599 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
600 assert_eq!(direct_ip_addr, bob.direct.unwrap());
601 assert_eq!(direct_rate, 2);
602 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
603 assert_eq!(proxied_rate, 2);
604
605 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
609 for _ in 0..2 {
610 let response = policy.handle_tally(alice.clone());
611 assert_eq!(response.block_client, None);
612 assert_eq!(response.block_proxied_client, None);
613 }
614 let _ = policy.handle_tally(bob.clone());
617 let response = policy.handle_tally(bob.clone());
618 assert_eq!(response.block_client, None);
619 assert_eq!(response.block_proxied_client, bob.through_fullnode);
620
621 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
622 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
623 assert_eq!(direct_ip_addr, alice.direct.unwrap());
625 assert_eq!(direct_rate, 3);
626 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
629 assert_eq!(proxied_rate, 2);
630
631 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
633 for i in 0..5 {
634 let response = policy.handle_tally(alice.clone());
635 assert_eq!(response.block_client, None, "Blocked at i = {i}");
636 assert_eq!(response.block_proxied_client, None);
637 }
638
639 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
641 let response = policy.handle_tally(alice.clone());
642 assert_eq!(response.block_client, None);
643 assert_eq!(response.block_proxied_client, alice.through_fullnode);
644
645 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
646 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
647 assert_eq!(direct_ip_addr, alice.direct.unwrap());
648 assert_eq!(direct_rate, 4);
649 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
650 assert_eq!(proxied_rate, 2);
651
652 for i in 0..2 {
654 let response = policy.handle_tally(charlie.clone());
655 assert_eq!(response.block_client, None, "Blocked at i = {i}");
656 assert_eq!(response.block_proxied_client, None);
657 }
658 let response = policy.handle_tally(charlie.clone());
660 assert_eq!(response.block_proxied_client, None);
661 assert_eq!(response.block_client, charlie.direct);
662
663 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
667 for i in 0..3 {
668 let response = policy.handle_tally(charlie.clone());
669 assert_eq!(response.block_client, None, "Blocked at i = {i}");
670 assert_eq!(response.block_proxied_client, None);
671 }
672
673 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
675 let _ = policy.handle_tally(alice.clone());
678 let _ = policy.handle_tally(bob.clone());
679 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
680 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
681 assert_eq!(direct_ip_addr, alice.direct.unwrap());
682 assert_eq!(direct_rate, 0);
683 assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
684 assert_eq!(proxied_rate, 1);
685 }
686
687 #[sim_test]
688 async fn test_traffic_sketch_mem_estimate() {
689 let window_size = Duration::from_secs(30);
693 let update_interval = Duration::from_secs(5);
694 let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
695 DEFAULT_SKETCH_CAPACITY,
696 DEFAULT_SKETCH_PROBABILITY,
697 DEFAULT_SKETCH_TOLERANCE,
698 )
699 .unwrap()
700 * (window_size.as_secs() / update_interval.as_secs()) as usize;
701 assert!(
702 mem_estimate < 128_000_000,
703 "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
704 );
705 }
706}