Skip to main content

iota_types/
traffic_control.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::path::PathBuf;
6
7use serde::{Deserialize, Serialize, de::Deserializer};
8use serde_with::serde_as;
9
10// These values set to loosely attempt to limit
11// memory usage for a single sketch to ~20MB
12// For reference, see
13// https://github.com/jedisct1/rust-count-min-sketch/blob/master/src/lib.rs
14pub const DEFAULT_SKETCH_CAPACITY: usize = 50_000;
15pub const DEFAULT_SKETCH_PROBABILITY: f64 = 0.999;
16pub const DEFAULT_SKETCH_TOLERANCE: f64 = 0.2;
17use rand::distributions::Distribution;
18
19const TRAFFIC_SINK_TIMEOUT_SEC: u64 = 300;
20
21/// The source that should be used to identify the client's
22/// IP address. To be used to configure cases where a node has
23/// infra running in front of the node that is separate from the
24/// protocol, such as a load balancer. Note that this is not the
25/// same as the client type (e.g a direct client vs a proxy client,
26/// as in the case of a fullnode driving requests from many clients).
27///
28/// For x-forwarded-for, the usize parameter is the number of forwarding
29/// hops between the client and the node for requests going your infra
30/// or infra provider. Example:
31///
32/// ```ignore
33///     (client) -> { (global proxy) -> (regional proxy) -> (node) }
34/// ```
35///
36/// where
37///
38/// ```ignore
39///     { <server>, ... }
40/// ```
41///
42/// are controlled by the Node operator / their cloud provider.
43/// In this case, we set:
44///
45/// ```ignore
46/// policy-config:
47///    client-id-source:
48///      x-forwarded-for: 2
49///    ...
50/// ```
51///
52/// NOTE: x-forwarded-for: 0 is a special case value that can be used by Node
53/// operators to discover the number of hops that should be configured. To use:
54///
55/// 1. Set `x-forwarded-for: 0` for the `client-id-source` in the config.
56/// 2. Run the node and query any endpoint (AuthorityServer for validator, or
57///    json rpc for rpc node) from a known IP address.
58/// 3. Search for lines containing `x-forwarded-for` in the logs. The log lines
59///    should contain the contents of the `x-forwarded-for` header, if present,
60///    or a corresponding error if not.
61/// 4. The value for number of hops is derived from any such log line that
62///    contains your known IP address, and is defined as 1 + the number of IP
63///    addresses in the `x-forwarded-for` that occur **after** the known client
64///    IP address. Example:
65///
66/// ```ignore
67///     [<known client IP>] <--- number of hops is 1
68///     ["1.2.3.4", <known client IP>, "5.6.7.8", "9.10.11.12"] <--- number of hops is 3
69/// ```
70#[derive(Clone, Debug, Deserialize, Serialize, Default)]
71#[serde(rename_all = "kebab-case")]
72pub enum ClientIdSource {
73    #[default]
74    SocketAddr,
75    XForwardedFor(usize),
76}
77
78#[derive(Clone, Debug, Deserialize, Serialize)]
79pub struct TrafficControlReconfigParams {
80    pub error_threshold: Option<u64>,
81    pub spam_threshold: Option<u64>,
82    pub dry_run: Option<bool>,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86pub struct Weight(f32);
87
88impl Weight {
89    pub fn new(value: f32) -> Result<Self, &'static str> {
90        if (0.0..=1.0).contains(&value) {
91            Ok(Self(value))
92        } else {
93            Err("Weight must be between 0.0 and 1.0")
94        }
95    }
96
97    pub fn one() -> Self {
98        Self(1.0)
99    }
100
101    pub fn zero() -> Self {
102        Self(0.0)
103    }
104
105    pub fn value(&self) -> f32 {
106        self.0
107    }
108
109    pub fn is_sampled(&self) -> bool {
110        let mut rng = rand::thread_rng();
111        let sample = rand::distributions::Uniform::new(0.0, 1.0).sample(&mut rng);
112        sample <= self.value()
113    }
114}
115
116fn validate_sample_rate<'de, D>(deserializer: D) -> Result<Weight, D::Error>
117where
118    D: Deserializer<'de>,
119{
120    let value = f32::deserialize(deserializer)?;
121    Weight::new(value)
122        .map_err(|_| serde::de::Error::custom("spam-sample-rate must be between 0.0 and 1.0"))
123}
124
125impl PartialEq for Weight {
126    fn eq(&self, other: &Self) -> bool {
127        self.value() == other.value()
128    }
129}
130
131#[serde_as]
132#[derive(Clone, Debug, Deserialize, Serialize)]
133#[serde(rename_all = "kebab-case")]
134pub struct RemoteFirewallConfig {
135    pub remote_fw_url: String,
136    pub destination_port: u16,
137    #[serde(default)]
138    pub delegate_spam_blocking: bool,
139    #[serde(default)]
140    pub delegate_error_blocking: bool,
141    #[serde(default = "default_drain_path")]
142    pub drain_path: PathBuf,
143    /// Time in secs, after which no registered ingress traffic
144    /// will trigger dead mans switch to drain any firewalls
145    #[serde(default = "default_drain_timeout")]
146    pub drain_timeout_secs: u64,
147}
148
149fn default_drain_path() -> PathBuf {
150    PathBuf::from("/tmp/drain")
151}
152
153fn default_drain_timeout() -> u64 {
154    TRAFFIC_SINK_TIMEOUT_SEC
155}
156
157#[serde_as]
158#[derive(Clone, Debug, Deserialize, Serialize)]
159#[serde(rename_all = "kebab-case")]
160pub struct FreqThresholdConfig {
161    #[serde(default = "default_client_threshold")]
162    pub client_threshold: u64,
163    #[serde(default = "default_proxied_client_threshold")]
164    pub proxied_client_threshold: u64,
165    #[serde(default = "default_window_size_secs")]
166    pub window_size_secs: u64,
167    #[serde(default = "default_update_interval_secs")]
168    pub update_interval_secs: u64,
169    #[serde(default = "default_sketch_capacity")]
170    pub sketch_capacity: usize,
171    #[serde(default = "default_sketch_probability")]
172    pub sketch_probability: f64,
173    #[serde(default = "default_sketch_tolerance")]
174    pub sketch_tolerance: f64,
175}
176
177impl Default for FreqThresholdConfig {
178    fn default() -> Self {
179        Self {
180            client_threshold: default_client_threshold(),
181            proxied_client_threshold: default_proxied_client_threshold(),
182            window_size_secs: default_window_size_secs(),
183            update_interval_secs: default_update_interval_secs(),
184            sketch_capacity: default_sketch_capacity(),
185            sketch_probability: default_sketch_probability(),
186            sketch_tolerance: default_sketch_tolerance(),
187        }
188    }
189}
190
191fn default_client_threshold() -> u64 {
192    // by default only block client with unreasonably
193    // high qps, as a client could be a single fullnode proxying
194    // the majority of traffic from many behaving clients in normal
195    // operations. If used as a spam policy, all requests would
196    // count against this threshold within the window time. In
197    // practice this should always be set
198    1_000_000
199}
200
201fn default_proxied_client_threshold() -> u64 {
202    10
203}
204
205fn default_window_size_secs() -> u64 {
206    30
207}
208
209fn default_update_interval_secs() -> u64 {
210    5
211}
212
213fn default_sketch_capacity() -> usize {
214    DEFAULT_SKETCH_CAPACITY
215}
216
217fn default_sketch_probability() -> f64 {
218    DEFAULT_SKETCH_PROBABILITY
219}
220
221fn default_sketch_tolerance() -> f64 {
222    DEFAULT_SKETCH_TOLERANCE
223}
224
225// Serializable representation of policy types, used in config
226// in order to easily change in tests or to killswitch
227#[derive(Clone, Serialize, Deserialize, Debug, Default)]
228pub enum PolicyType {
229    /// Does nothing
230    #[default]
231    NoOp,
232
233    /// Blocks connection_ip after reaching a tally frequency (tallies per
234    /// second) of `threshold`, as calculated over an average window of
235    /// `window_size_secs` with granularity of `update_interval_secs`
236    #[serde(rename = "freq-threshold", alias = "FreqThreshold")]
237    FreqThreshold(FreqThresholdConfig),
238
239    // Below this point are test policies, and thus should not be used in production
240    /// Simple policy that adds connection_ip to blocklist when the same
241    /// connection_ip is encountered in tally N times. If used in an error
242    /// policy, this would trigger after N errors
243    TestNConnIP(u64),
244    /// Test policy that panics when invoked. To be used as an error policy in
245    /// tests that do not expect request errors in order to verify that the
246    /// error policy is not invoked
247    TestPanicOnInvocation,
248}
249
250#[serde_as]
251#[derive(Clone, Debug, Deserialize, Serialize)]
252#[serde(rename_all = "kebab-case")]
253pub struct PolicyConfig {
254    #[serde(default = "default_client_id_source")]
255    pub client_id_source: ClientIdSource,
256    #[serde(default = "default_connection_blocklist_ttl_sec")]
257    pub connection_blocklist_ttl_sec: u64,
258    #[serde(default)]
259    pub proxy_blocklist_ttl_sec: u64,
260    #[serde(default)]
261    pub spam_policy_type: PolicyType,
262    #[serde(default)]
263    pub error_policy_type: PolicyType,
264    #[serde(default = "default_channel_capacity")]
265    pub channel_capacity: usize,
266    #[serde(
267        default = "default_spam_sample_rate",
268        deserialize_with = "validate_sample_rate"
269    )]
270    /// Note that this sample policy is applied on top of the
271    /// endpoint-specific sample policy (not configurable) which
272    /// weighs endpoints by the relative effort required to serve
273    /// them. Therefore a sample rate of N will yield an actual
274    /// sample rate <= N.
275    pub spam_sample_rate: Weight,
276    #[serde(default = "default_dry_run")]
277    pub dry_run: bool,
278    /// List of String which should all parse to type IPAddr.
279    /// If set, only requests from provided IPs will be allowed,
280    /// and any blocklist related configuration will be ignored.
281    #[serde(default)]
282    pub allow_list: Option<Vec<String>>,
283}
284
285impl Default for PolicyConfig {
286    fn default() -> Self {
287        Self {
288            client_id_source: default_client_id_source(),
289            connection_blocklist_ttl_sec: 0,
290            proxy_blocklist_ttl_sec: 0,
291            spam_policy_type: PolicyType::NoOp,
292            error_policy_type: PolicyType::NoOp,
293            channel_capacity: 100,
294            spam_sample_rate: default_spam_sample_rate(),
295            dry_run: default_dry_run(),
296            allow_list: None,
297        }
298    }
299}
300
301impl PolicyConfig {
302    pub fn default_dos_protection_policy() -> Self {
303        Self {
304            client_id_source: ClientIdSource::SocketAddr,
305            spam_policy_type: PolicyType::FreqThreshold(FreqThresholdConfig {
306                client_threshold: 1000,
307                window_size_secs: 5,
308                update_interval_secs: 1,
309                ..FreqThresholdConfig::default()
310            }),
311            error_policy_type: PolicyType::FreqThreshold(FreqThresholdConfig {
312                client_threshold: 50,
313                window_size_secs: 5,
314                update_interval_secs: 1,
315                ..FreqThresholdConfig::default()
316            }),
317            channel_capacity: 6000,
318            spam_sample_rate: Weight::new(1.0).unwrap(),
319            dry_run: true,
320            ..PolicyConfig::default()
321        }
322    }
323}
324
325pub fn default_client_id_source() -> ClientIdSource {
326    ClientIdSource::SocketAddr
327}
328
329pub fn default_connection_blocklist_ttl_sec() -> u64 {
330    60
331}
332pub fn default_channel_capacity() -> usize {
333    100
334}
335
336pub fn default_dry_run() -> bool {
337    true
338}
339
340pub fn default_spam_sample_rate() -> Weight {
341    Weight::new(0.2).unwrap()
342}