iota_core/traffic_controller/
nodefw_test_server.rs1use std::{
6 collections::HashMap,
7 net::SocketAddr,
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use axum::{
13 Json, Router,
14 extract::State,
15 http::StatusCode,
16 response::IntoResponse,
17 routing::{get, post},
18};
19use tokio::{
20 sync::{Mutex, Notify},
21 task::JoinHandle,
22};
23
24use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses};
25
26#[derive(Clone)]
27struct AppState {
28 blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
30}
31
32pub struct NodeFwTestServer {
33 server_handle: Option<JoinHandle<()>>,
34 shutdown_signal: Arc<Notify>,
35 state: AppState,
36}
37
38impl NodeFwTestServer {
39 pub fn new() -> Self {
40 Self {
41 server_handle: None,
42 shutdown_signal: Arc::new(Notify::new()),
43 state: AppState {
44 blocklist: Arc::new(Mutex::new(HashMap::new())),
45 },
46 }
47 }
48
49 pub async fn start(&mut self, port: u16) {
50 let app_state = self.state.clone();
51 let app = Router::new()
52 .route("/list_addresses", get(Self::list_addresses))
53 .route("/block_addresses", post(Self::block_addresses))
54 .with_state(app_state.clone());
55
56 let addr = SocketAddr::from(([127, 0, 0, 1], port));
57
58 let handle = tokio::spawn(async move {
59 let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
60 axum::serve(listener, app).await.unwrap();
61 });
62
63 tokio::spawn(Self::periodically_remove_expired_addresses(
64 app_state.blocklist.clone(),
65 ));
66
67 self.server_handle = Some(handle);
68 }
69
70 pub async fn list_addresses_rpc(&self) -> Vec<BlockAddress> {
72 let blocklist = self.state.blocklist.lock().await;
73 blocklist.keys().cloned().collect()
74 }
75
76 async fn list_addresses(State(state): State<AppState>) -> impl IntoResponse {
78 let blocklist = state.blocklist.lock().await;
79 let block_addresses = blocklist.keys().cloned().collect();
80 Json(BlockAddresses {
81 addresses: block_addresses,
82 })
83 }
84
85 async fn periodically_remove_expired_addresses(
86 blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
87 ) {
88 loop {
89 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
90 let mut blocklist = blocklist.lock().await;
91 let now = SystemTime::now();
92 blocklist.retain(|_address, expiry| now < *expiry);
93 }
94 }
95
96 async fn block_addresses(
98 State(state): State<AppState>,
99 Json(addresses): Json<BlockAddresses>,
100 ) -> impl IntoResponse {
101 let mut blocklist = state.blocklist.lock().await;
102 for addr in addresses.addresses.iter() {
103 blocklist.insert(
104 addr.clone(),
105 SystemTime::now() + Duration::from_secs(addr.ttl),
106 );
107 }
108 (StatusCode::CREATED, "created")
109 }
110
111 pub async fn stop(&self) {
112 self.shutdown_signal.notify_one();
113 }
114}
115
116impl Default for NodeFwTestServer {
117 fn default() -> Self {
118 Self::new()
119 }
120}