1use std::path::PathBuf;
6
7use serde::{Deserialize, Serialize, de::Deserializer};
8use serde_with::serde_as;
9
10pub const DEFAULT_SKETCH_CAPACITY: usize = 50_000;
15pub const DEFAULT_SKETCH_PROBABILITY: f64 = 0.999;
16pub const DEFAULT_SKETCH_TOLERANCE: f64 = 0.2;
17use rand::distributions::Distribution;
18
19const TRAFFIC_SINK_TIMEOUT_SEC: u64 = 300;
20
21#[derive(Clone, Debug, Deserialize, Serialize, Default)]
71#[serde(rename_all = "kebab-case")]
72pub enum ClientIdSource {
73 #[default]
74 SocketAddr,
75 XForwardedFor(usize),
76}
77
78#[derive(Clone, Debug, Deserialize, Serialize)]
79pub struct TrafficControlReconfigParams {
80 pub error_threshold: Option<u64>,
81 pub spam_threshold: Option<u64>,
82 pub dry_run: Option<bool>,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86pub struct Weight(f32);
87
88impl Weight {
89 pub fn new(value: f32) -> Result<Self, &'static str> {
90 if (0.0..=1.0).contains(&value) {
91 Ok(Self(value))
92 } else {
93 Err("Weight must be between 0.0 and 1.0")
94 }
95 }
96
97 pub fn one() -> Self {
98 Self(1.0)
99 }
100
101 pub fn zero() -> Self {
102 Self(0.0)
103 }
104
105 pub fn value(&self) -> f32 {
106 self.0
107 }
108
109 pub fn is_sampled(&self) -> bool {
110 let mut rng = rand::thread_rng();
111 let sample = rand::distributions::Uniform::new(0.0, 1.0).sample(&mut rng);
112 sample <= self.value()
113 }
114}
115
116fn validate_sample_rate<'de, D>(deserializer: D) -> Result<Weight, D::Error>
117where
118 D: Deserializer<'de>,
119{
120 let value = f32::deserialize(deserializer)?;
121 Weight::new(value)
122 .map_err(|_| serde::de::Error::custom("spam-sample-rate must be between 0.0 and 1.0"))
123}
124
125impl PartialEq for Weight {
126 fn eq(&self, other: &Self) -> bool {
127 self.value() == other.value()
128 }
129}
130
131#[serde_as]
132#[derive(Clone, Debug, Deserialize, Serialize)]
133#[serde(rename_all = "kebab-case")]
134pub struct RemoteFirewallConfig {
135 pub remote_fw_url: String,
136 pub destination_port: u16,
137 #[serde(default)]
138 pub delegate_spam_blocking: bool,
139 #[serde(default)]
140 pub delegate_error_blocking: bool,
141 #[serde(default = "default_drain_path")]
142 pub drain_path: PathBuf,
143 #[serde(default = "default_drain_timeout")]
146 pub drain_timeout_secs: u64,
147}
148
149fn default_drain_path() -> PathBuf {
150 PathBuf::from("/tmp/drain")
151}
152
153fn default_drain_timeout() -> u64 {
154 TRAFFIC_SINK_TIMEOUT_SEC
155}
156
157#[serde_as]
158#[derive(Clone, Debug, Deserialize, Serialize)]
159#[serde(rename_all = "kebab-case")]
160pub struct FreqThresholdConfig {
161 #[serde(default = "default_client_threshold")]
162 pub client_threshold: u64,
163 #[serde(default = "default_proxied_client_threshold")]
164 pub proxied_client_threshold: u64,
165 #[serde(default = "default_window_size_secs")]
166 pub window_size_secs: u64,
167 #[serde(default = "default_update_interval_secs")]
168 pub update_interval_secs: u64,
169 #[serde(default = "default_sketch_capacity")]
170 pub sketch_capacity: usize,
171 #[serde(default = "default_sketch_probability")]
172 pub sketch_probability: f64,
173 #[serde(default = "default_sketch_tolerance")]
174 pub sketch_tolerance: f64,
175}
176
177impl Default for FreqThresholdConfig {
178 fn default() -> Self {
179 Self {
180 client_threshold: default_client_threshold(),
181 proxied_client_threshold: default_proxied_client_threshold(),
182 window_size_secs: default_window_size_secs(),
183 update_interval_secs: default_update_interval_secs(),
184 sketch_capacity: default_sketch_capacity(),
185 sketch_probability: default_sketch_probability(),
186 sketch_tolerance: default_sketch_tolerance(),
187 }
188 }
189}
190
191fn default_client_threshold() -> u64 {
192 1_000_000
199}
200
201fn default_proxied_client_threshold() -> u64 {
202 10
203}
204
205fn default_window_size_secs() -> u64 {
206 30
207}
208
209fn default_update_interval_secs() -> u64 {
210 5
211}
212
213fn default_sketch_capacity() -> usize {
214 DEFAULT_SKETCH_CAPACITY
215}
216
217fn default_sketch_probability() -> f64 {
218 DEFAULT_SKETCH_PROBABILITY
219}
220
221fn default_sketch_tolerance() -> f64 {
222 DEFAULT_SKETCH_TOLERANCE
223}
224
225#[derive(Clone, Serialize, Deserialize, Debug, Default)]
228pub enum PolicyType {
229 #[default]
231 NoOp,
232
233 #[serde(rename = "freq-threshold", alias = "FreqThreshold")]
237 FreqThreshold(FreqThresholdConfig),
238
239 TestNConnIP(u64),
244 TestPanicOnInvocation,
248}
249
250#[serde_as]
251#[derive(Clone, Debug, Deserialize, Serialize)]
252#[serde(rename_all = "kebab-case")]
253pub struct PolicyConfig {
254 #[serde(default = "default_client_id_source")]
255 pub client_id_source: ClientIdSource,
256 #[serde(default = "default_connection_blocklist_ttl_sec")]
257 pub connection_blocklist_ttl_sec: u64,
258 #[serde(default)]
259 pub proxy_blocklist_ttl_sec: u64,
260 #[serde(default)]
261 pub spam_policy_type: PolicyType,
262 #[serde(default)]
263 pub error_policy_type: PolicyType,
264 #[serde(default = "default_channel_capacity")]
265 pub channel_capacity: usize,
266 #[serde(
267 default = "default_spam_sample_rate",
268 deserialize_with = "validate_sample_rate"
269 )]
270 pub spam_sample_rate: Weight,
276 #[serde(default = "default_dry_run")]
277 pub dry_run: bool,
278 #[serde(default)]
282 pub allow_list: Option<Vec<String>>,
283}
284
285impl Default for PolicyConfig {
286 fn default() -> Self {
287 Self {
288 client_id_source: default_client_id_source(),
289 connection_blocklist_ttl_sec: 0,
290 proxy_blocklist_ttl_sec: 0,
291 spam_policy_type: PolicyType::NoOp,
292 error_policy_type: PolicyType::NoOp,
293 channel_capacity: 100,
294 spam_sample_rate: default_spam_sample_rate(),
295 dry_run: default_dry_run(),
296 allow_list: None,
297 }
298 }
299}
300
301impl PolicyConfig {
302 pub fn default_dos_protection_policy() -> Self {
303 Self {
304 client_id_source: ClientIdSource::SocketAddr,
305 spam_policy_type: PolicyType::FreqThreshold(FreqThresholdConfig {
306 client_threshold: 1000,
307 window_size_secs: 5,
308 update_interval_secs: 1,
309 ..FreqThresholdConfig::default()
310 }),
311 error_policy_type: PolicyType::FreqThreshold(FreqThresholdConfig {
312 client_threshold: 50,
313 window_size_secs: 5,
314 update_interval_secs: 1,
315 ..FreqThresholdConfig::default()
316 }),
317 channel_capacity: 6000,
318 spam_sample_rate: Weight::new(1.0).unwrap(),
319 dry_run: true,
320 ..PolicyConfig::default()
321 }
322 }
323}
324
325pub fn default_client_id_source() -> ClientIdSource {
326 ClientIdSource::SocketAddr
327}
328
329pub fn default_connection_blocklist_ttl_sec() -> u64 {
330 60
331}
332pub fn default_channel_capacity() -> usize {
333 100
334}
335
336pub fn default_dry_run() -> bool {
337 true
338}
339
340pub fn default_spam_sample_rate() -> Weight {
341 Weight::new(0.2).unwrap()
342}