iota_core/traffic_controller/
policies.rs

1// Copyright (c) 2021, Facebook, Inc. and its affiliates
2// Copyright (c) Mysten Labs, Inc.
3// Modifications Copyright (c) 2024 IOTA Stiftung
4// SPDX-License-Identifier: Apache-2.0
5
6use std::{
7    cmp::Reverse,
8    collections::{BinaryHeap, HashMap, VecDeque},
9    fmt::Debug,
10    hash::Hash,
11    net::IpAddr,
12    sync::Arc,
13    time::{Duration, Instant, SystemTime},
14};
15
16use count_min_sketch::CountMinSketch32;
17use iota_metrics::spawn_monitored_task;
18use iota_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
19use parking_lot::RwLock;
20use tracing::{info, trace};
21
22const HIGHEST_RATES_CAPACITY: usize = 20;
23
24/// The type of request client.
25#[derive(Hash, Eq, PartialEq, Debug)]
26enum ClientType {
27    Direct,
28    ThroughFullnode,
29}
30
31#[derive(Hash, Eq, PartialEq, Debug)]
32struct SketchKey {
33    salt: u64,
34    ip_addr: IpAddr,
35    client_type: ClientType,
36}
37
38struct HighestRates {
39    direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
40    proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
41    capacity: usize,
42}
43
44pub struct TrafficSketch {
45    /// Circular buffer Count Min Sketches representing a sliding window
46    /// of traffic data. Note that the 32 in CountMinSketch32 represents
47    /// the number of bits used to represent the count in the sketch. Since
48    /// we only count on a sketch for a window of `update_interval`, we only
49    /// need enough precision to represent the max expected unique IP addresses
50    /// we may see in that window. For a 10 second period, we might
51    /// conservatively expect 100,000, which can be represented in 17 bits,
52    /// but not 16. We can potentially lower the memory consumption by using
53    /// CountMinSketch16, which will reliably support up to ~65,000 unique
54    /// IP addresses in the window.
55    sketches: VecDeque<CountMinSketch32<SketchKey>>,
56    window_size: Duration,
57    update_interval: Duration,
58    last_reset_time: Instant,
59    current_sketch_index: usize,
60    /// Used for metrics collection and logging purposes,
61    /// as CountMinSketch does not provide this directly.
62    /// Note that this is an imperfect metric, since we preserve
63    /// the highest N rates (by unique IP) that we have seen,
64    /// but update rates (down or up) as they change so that
65    /// the metric is not monotonic and reflects recent traffic.
66    /// However, this should only lead to inaccuracy edge cases
67    /// with very low traffic.
68    highest_rates: HighestRates,
69}
70
71impl TrafficSketch {
72    pub fn new(
73        window_size: Duration,
74        update_interval: Duration,
75        sketch_capacity: usize,
76        sketch_probability: f64,
77        sketch_tolerance: f64,
78        highest_rates_capacity: usize,
79    ) -> Self {
80        // intentionally round down via integer division. We can't have a partial sketch
81        let num_sketches = window_size.as_secs() / update_interval.as_secs();
82        let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
83        if new_window_size != window_size {
84            info!(
85                "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
86                new_window_size.as_secs(),
87                update_interval.as_secs(),
88            );
89        }
90        let window_size = new_window_size;
91
92        assert!(
93            window_size < Duration::from_secs(600),
94            "window_size too large. Max 600 seconds"
95        );
96        assert!(
97            update_interval < window_size,
98            "Update interval may not be larger than window size"
99        );
100        assert!(
101            update_interval >= Duration::from_secs(1),
102            "Update interval too short, must be at least 1 second"
103        );
104        assert!(
105            num_sketches <= 10,
106            "Given parameters require too many sketches to be stored. Reduce window size or increase update interval."
107        );
108        let mem_estimate = (num_sketches as usize)
109            * CountMinSketch32::<IpAddr>::estimate_memory(
110                sketch_capacity,
111                sketch_probability,
112                sketch_tolerance,
113            )
114            .expect("Failed to estimate memory for CountMinSketch32");
115        assert!(
116            mem_estimate < 128_000_000,
117            "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval."
118        );
119
120        let mut sketches = VecDeque::with_capacity(num_sketches as usize);
121        for _ in 0..num_sketches {
122            sketches.push_back(
123                CountMinSketch32::<SketchKey>::new(
124                    sketch_capacity,
125                    sketch_probability,
126                    sketch_tolerance,
127                )
128                .expect("Failed to create CountMinSketch32"),
129            );
130        }
131        Self {
132            sketches,
133            window_size,
134            update_interval,
135            last_reset_time: Instant::now(),
136            current_sketch_index: 0,
137            highest_rates: HighestRates {
138                direct: BinaryHeap::with_capacity(highest_rates_capacity),
139                proxied: BinaryHeap::with_capacity(highest_rates_capacity),
140                capacity: highest_rates_capacity,
141            },
142        }
143    }
144
145    fn increment_count(&mut self, key: &SketchKey) {
146        // reset all expired intervals
147        let current_time = Instant::now();
148        let mut elapsed = current_time.duration_since(self.last_reset_time);
149        while elapsed >= self.update_interval {
150            self.rotate_window();
151            elapsed -= self.update_interval;
152        }
153        // Increment in the current active sketch
154        self.sketches[self.current_sketch_index].increment(key);
155    }
156
157    fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
158        let count: u32 = self
159            .sketches
160            .iter()
161            .map(|sketch| sketch.estimate(key))
162            .sum();
163        let rate = count as f64 / self.window_size.as_secs() as f64;
164        self.update_highest_rates(key, rate);
165        rate
166    }
167
168    fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
169        match key.client_type {
170            ClientType::Direct => {
171                Self::update_highest_rate(
172                    &mut self.highest_rates.direct,
173                    key.ip_addr,
174                    rate,
175                    self.highest_rates.capacity,
176                );
177            }
178            ClientType::ThroughFullnode => {
179                Self::update_highest_rate(
180                    &mut self.highest_rates.proxied,
181                    key.ip_addr,
182                    rate,
183                    self.highest_rates.capacity,
184                );
185            }
186        }
187    }
188
189    fn update_highest_rate(
190        rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
191        ip_addr: IpAddr,
192        rate: f64,
193        capacity: usize,
194    ) {
195        // Remove previous instance of this IPAddr so that we
196        // can update with new rate
197        rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
198
199        let rate = rate as u64;
200        if rate_heap.len() < capacity {
201            rate_heap.push(Reverse((rate, ip_addr)));
202        } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek() {
203            if rate > smallest_score {
204                rate_heap.pop();
205                rate_heap.push(Reverse((rate, ip_addr)));
206            }
207        }
208    }
209
210    pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
211        self.highest_rates
212            .direct
213            .iter()
214            .map(|Reverse(v)| v)
215            .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
216            .copied()
217    }
218
219    pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
220        self.highest_rates
221            .proxied
222            .iter()
223            .map(|Reverse(v)| v)
224            .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
225            .copied()
226    }
227
228    fn rotate_window(&mut self) {
229        self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
230        self.sketches[self.current_sketch_index].clear();
231        self.last_reset_time = Instant::now();
232    }
233}
234
235#[derive(Clone, Debug)]
236pub struct TrafficTally {
237    pub direct: Option<IpAddr>,
238    pub through_fullnode: Option<IpAddr>,
239    pub error_info: Option<(Weight, String)>,
240    pub spam_weight: Weight,
241    pub timestamp: SystemTime,
242}
243
244impl TrafficTally {
245    pub fn new(
246        direct: Option<IpAddr>,
247        through_fullnode: Option<IpAddr>,
248        error_info: Option<(Weight, String)>,
249        spam_weight: Weight,
250    ) -> Self {
251        Self {
252            direct,
253            through_fullnode,
254            error_info,
255            spam_weight,
256            timestamp: SystemTime::now(),
257        }
258    }
259}
260
261#[derive(Clone, Debug, Default)]
262pub struct PolicyResponse {
263    pub block_client: Option<IpAddr>,
264    pub block_proxied_client: Option<IpAddr>,
265}
266
267pub trait Policy {
268    // returns, e.g. (true, false) if connection_ip should be added to blocklist
269    // and proxy_ip should not
270    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
271    fn policy_config(&self) -> &PolicyConfig;
272}
273
274// Nonserializable representation, also note that inner types are
275// not object safe, so we can't use a trait object instead
276pub enum TrafficControlPolicy {
277    FreqThreshold(FreqThresholdPolicy),
278    NoOp(NoOpPolicy),
279    // Test policies below this point
280    TestNConnIP(TestNConnIPPolicy),
281    TestPanicOnInvocation(TestPanicOnInvocationPolicy),
282}
283
284impl Policy for TrafficControlPolicy {
285    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
286        match self {
287            TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
288            TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
289            TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
290            TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
291        }
292    }
293
294    fn policy_config(&self) -> &PolicyConfig {
295        match self {
296            TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
297            TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
298            TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
299            TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
300        }
301    }
302}
303
304impl TrafficControlPolicy {
305    pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
306        Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
307    }
308    pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
309        Self::from_config(policy_config.clone().error_policy_type, policy_config).await
310    }
311    pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
312        match policy_type {
313            PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
314            PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
315                FreqThresholdPolicy::new(policy_config, freq_threshold_config),
316            ),
317            PolicyType::TestNConnIP(n) => {
318                Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
319            }
320            PolicyType::TestPanicOnInvocation => {
321                Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
322            }
323        }
324    }
325}
326
327////////////// *** Policy definitions *** //////////////
328
329pub struct FreqThresholdPolicy {
330    config: PolicyConfig,
331    sketch: TrafficSketch,
332    client_threshold: u64,
333    proxied_client_threshold: u64,
334    /// Unique salt to be added to all keys in the sketch. This
335    /// ensures that false positives are not correlated across
336    /// all nodes at the same time. For IOTA validators for example,
337    /// this means that false positives should not prevent the network
338    /// from achieving quorum.
339    salt: u64,
340}
341
342impl FreqThresholdPolicy {
343    pub fn new(
344        config: PolicyConfig,
345        FreqThresholdConfig {
346            client_threshold,
347            proxied_client_threshold,
348            window_size_secs,
349            update_interval_secs,
350            sketch_capacity,
351            sketch_probability,
352            sketch_tolerance,
353        }: FreqThresholdConfig,
354    ) -> Self {
355        let sketch = TrafficSketch::new(
356            Duration::from_secs(window_size_secs),
357            Duration::from_secs(update_interval_secs),
358            sketch_capacity,
359            sketch_probability,
360            sketch_tolerance,
361            HIGHEST_RATES_CAPACITY,
362        );
363        Self {
364            config,
365            sketch,
366            client_threshold,
367            proxied_client_threshold,
368            salt: rand::random(),
369        }
370    }
371
372    pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
373        self.sketch.highest_direct_rate()
374    }
375
376    pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
377        self.sketch.highest_proxied_rate()
378    }
379
380    pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
381        let block_client = if let Some(source) = tally.direct {
382            let key = SketchKey {
383                salt: self.salt,
384                ip_addr: source,
385                client_type: ClientType::Direct,
386            };
387            self.sketch.increment_count(&key);
388            let req_rate = self.sketch.get_request_rate(&key);
389            trace!(
390                "FreqThresholdPolicy handling tally -- req_rate: {:?}, client_threshold: {:?}, client: {:?}",
391                req_rate, self.client_threshold, source,
392            );
393            if req_rate >= self.client_threshold as f64 {
394                Some(source)
395            } else {
396                None
397            }
398        } else {
399            None
400        };
401        let block_proxied_client = if let Some(source) = tally.through_fullnode {
402            let key = SketchKey {
403                salt: self.salt,
404                ip_addr: source,
405                client_type: ClientType::ThroughFullnode,
406            };
407            self.sketch.increment_count(&key);
408            if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
409                Some(source)
410            } else {
411                None
412            }
413        } else {
414            None
415        };
416        PolicyResponse {
417            block_client,
418            block_proxied_client,
419        }
420    }
421
422    fn policy_config(&self) -> &PolicyConfig {
423        &self.config
424    }
425}
426
427////////////// *** Test policies below this point *** //////////////
428
429#[derive(Clone)]
430pub struct NoOpPolicy {
431    config: PolicyConfig,
432}
433
434impl NoOpPolicy {
435    pub fn new(config: PolicyConfig) -> Self {
436        Self { config }
437    }
438
439    fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
440        PolicyResponse::default()
441    }
442
443    fn policy_config(&self) -> &PolicyConfig {
444        &self.config
445    }
446}
447
448#[derive(Clone)]
449pub struct TestNConnIPPolicy {
450    config: PolicyConfig,
451    frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
452    threshold: u64,
453}
454
455impl TestNConnIPPolicy {
456    pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
457        let frequencies = Arc::new(RwLock::new(HashMap::new()));
458        let frequencies_clone = frequencies.clone();
459        spawn_monitored_task!(run_clear_frequencies(
460            frequencies_clone,
461            config.connection_blocklist_ttl_sec * 2,
462        ));
463        Self {
464            config,
465            frequencies,
466            threshold,
467        }
468    }
469
470    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
471        let client = if let Some(client) = tally.direct {
472            client
473        } else {
474            return PolicyResponse::default();
475        };
476
477        // increment the count for the IP
478        let mut frequencies = self.frequencies.write();
479        let count = frequencies.entry(client).or_insert(0);
480        *count += 1;
481        PolicyResponse {
482            block_client: if *count >= self.threshold {
483                Some(client)
484            } else {
485                None
486            },
487            block_proxied_client: None,
488        }
489    }
490
491    fn policy_config(&self) -> &PolicyConfig {
492        &self.config
493    }
494}
495
496async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
497    loop {
498        tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
499        frequencies.write().clear();
500    }
501}
502
503#[derive(Clone)]
504pub struct TestPanicOnInvocationPolicy {
505    config: PolicyConfig,
506}
507
508impl TestPanicOnInvocationPolicy {
509    pub fn new(config: PolicyConfig) -> Self {
510        Self { config }
511    }
512
513    fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
514        panic!("Tally for this policy should never be invoked")
515    }
516
517    fn policy_config(&self) -> &PolicyConfig {
518        &self.config
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use std::net::{IpAddr, Ipv4Addr};
525
526    use iota_macros::sim_test;
527    use iota_types::traffic_control::{
528        DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
529    };
530
531    use super::*;
532
533    #[sim_test]
534    async fn test_freq_threshold_policy() {
535        // Create freq policy that will block on average frequency 2 requests per second
536        // for proxied connections and 4 requests per second for direct connections
537        // as observed over a 5 second window.
538        let mut policy = FreqThresholdPolicy::new(
539            PolicyConfig::default(),
540            FreqThresholdConfig {
541                client_threshold: 5,
542                proxied_client_threshold: 2,
543                window_size_secs: 5,
544                update_interval_secs: 1,
545                ..Default::default()
546            },
547        );
548        // alice and bob connection from different IPs through the
549        // same fullnode, thus have the same connection IP on
550        // validator, but different proxy IPs
551        let alice = TrafficTally {
552            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
553            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
554            error_info: None,
555            spam_weight: Weight::one(),
556            timestamp: SystemTime::now(),
557        };
558        let bob = TrafficTally {
559            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
560            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
561            error_info: None,
562            spam_weight: Weight::one(),
563            timestamp: SystemTime::now(),
564        };
565        let charlie = TrafficTally {
566            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
567            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
568            error_info: None,
569            spam_weight: Weight::one(),
570            timestamp: SystemTime::now(),
571        };
572
573        // initial 2 tallies for alice, should not block
574        for _ in 0..2 {
575            let response = policy.handle_tally(alice.clone());
576            assert_eq!(response.block_proxied_client, None);
577            assert_eq!(response.block_client, None);
578        }
579
580        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
581        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
582        assert_eq!(direct_ip_addr, alice.direct.unwrap());
583        assert!(direct_rate < 1);
584        assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
585        assert!(proxied_rate < 1);
586
587        // meanwhile bob spams 10 requests at once and is blocked
588        for _ in 0..9 {
589            let response = policy.handle_tally(bob.clone());
590            assert_eq!(response.block_client, None);
591            assert_eq!(response.block_proxied_client, None);
592        }
593        let response = policy.handle_tally(bob.clone());
594        assert_eq!(response.block_client, None);
595        assert_eq!(response.block_proxied_client, bob.through_fullnode);
596
597        // highest rates should now show bob
598        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
599        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
600        assert_eq!(direct_ip_addr, bob.direct.unwrap());
601        assert_eq!(direct_rate, 2);
602        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
603        assert_eq!(proxied_rate, 2);
604
605        // 2 more tallies, so far we are above 2 tallies
606        // per second, but over the average window of 5 seconds
607        // we are still below the threshold. Should not block
608        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
609        for _ in 0..2 {
610            let response = policy.handle_tally(alice.clone());
611            assert_eq!(response.block_client, None);
612            assert_eq!(response.block_proxied_client, None);
613        }
614        // bob is no longer blocked, as we moved past the bursty traffic
615        // in the sliding window
616        let _ = policy.handle_tally(bob.clone());
617        let response = policy.handle_tally(bob.clone());
618        assert_eq!(response.block_client, None);
619        assert_eq!(response.block_proxied_client, bob.through_fullnode);
620
621        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
622        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
623        // direct rate increased due to alice going through same fullnode
624        assert_eq!(direct_ip_addr, alice.direct.unwrap());
625        assert_eq!(direct_rate, 3);
626        // highest rate should now have been updated given that Bob's rate
627        // recently decreased
628        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
629        assert_eq!(proxied_rate, 2);
630
631        // close to threshold for alice, but still below
632        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
633        for i in 0..5 {
634            let response = policy.handle_tally(alice.clone());
635            assert_eq!(response.block_client, None, "Blocked at i = {i}");
636            assert_eq!(response.block_proxied_client, None);
637        }
638
639        // should block alice now
640        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
641        let response = policy.handle_tally(alice.clone());
642        assert_eq!(response.block_client, None);
643        assert_eq!(response.block_proxied_client, alice.through_fullnode);
644
645        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
646        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
647        assert_eq!(direct_ip_addr, alice.direct.unwrap());
648        assert_eq!(direct_rate, 4);
649        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
650        assert_eq!(proxied_rate, 2);
651
652        // spam through charlie to block connection
653        for i in 0..2 {
654            let response = policy.handle_tally(charlie.clone());
655            assert_eq!(response.block_client, None, "Blocked at i = {i}");
656            assert_eq!(response.block_proxied_client, None);
657        }
658        // Now we block connection IP
659        let response = policy.handle_tally(charlie.clone());
660        assert_eq!(response.block_proxied_client, None);
661        assert_eq!(response.block_client, charlie.direct);
662
663        // Ensure that if we wait another second, we are no longer blocked
664        // as the bursty first second has finally rotated out of the sliding
665        // window
666        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
667        for i in 0..3 {
668            let response = policy.handle_tally(charlie.clone());
669            assert_eq!(response.block_client, None, "Blocked at i = {i}");
670            assert_eq!(response.block_proxied_client, None);
671        }
672
673        // check that we revert back to previous highest rates when rates decrease
674        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
675        // alice and bob rates are now decreased after 5 seconds of break, but charlie's
676        // has not since they have not yet sent a new request
677        let _ = policy.handle_tally(alice.clone());
678        let _ = policy.handle_tally(bob.clone());
679        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
680        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
681        assert_eq!(direct_ip_addr, alice.direct.unwrap());
682        assert_eq!(direct_rate, 0);
683        assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
684        assert_eq!(proxied_rate, 1);
685    }
686
687    #[sim_test]
688    async fn test_traffic_sketch_mem_estimate() {
689        // Test for getting a rough estimate of memory usage for the traffic sketch
690        // given certain parameters. Parameters below are the default.
691        // With default parameters, memory estimate is 113 MB.
692        let window_size = Duration::from_secs(30);
693        let update_interval = Duration::from_secs(5);
694        let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
695            DEFAULT_SKETCH_CAPACITY,
696            DEFAULT_SKETCH_PROBABILITY,
697            DEFAULT_SKETCH_TOLERANCE,
698        )
699        .unwrap()
700            * (window_size.as_secs() / update_interval.as_secs()) as usize;
701        assert!(
702            mem_estimate < 128_000_000,
703            "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
704        );
705    }
706}