iota_core/traffic_controller/
nodefw_test_server.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    net::SocketAddr,
8    sync::Arc,
9    time::{Duration, SystemTime},
10};
11
12use axum::{
13    Json, Router,
14    extract::State,
15    http::StatusCode,
16    response::IntoResponse,
17    routing::{get, post},
18};
19use tokio::{
20    sync::{Mutex, Notify},
21    task::JoinHandle,
22};
23
24use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses};
25
26#[derive(Clone)]
27struct AppState {
28    /// BlockAddress -> expiry time
29    blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
30}
31
32pub struct NodeFwTestServer {
33    server_handle: Option<JoinHandle<()>>,
34    shutdown_signal: Arc<Notify>,
35    state: AppState,
36}
37
38impl NodeFwTestServer {
39    pub fn new() -> Self {
40        Self {
41            server_handle: None,
42            shutdown_signal: Arc::new(Notify::new()),
43            state: AppState {
44                blocklist: Arc::new(Mutex::new(HashMap::new())),
45            },
46        }
47    }
48
49    pub async fn start(&mut self, port: u16) {
50        let app_state = self.state.clone();
51        let app = Router::new()
52            .route("/list_addresses", get(Self::list_addresses))
53            .route("/block_addresses", post(Self::block_addresses))
54            .with_state(app_state.clone());
55
56        let addr = SocketAddr::from(([127, 0, 0, 1], port));
57
58        let handle = tokio::spawn(async move {
59            let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
60            axum::serve(listener, app).await.unwrap();
61        });
62
63        tokio::spawn(Self::periodically_remove_expired_addresses(
64            app_state.blocklist.clone(),
65        ));
66
67        self.server_handle = Some(handle);
68    }
69
70    /// Direct access api for test verification
71    pub async fn list_addresses_rpc(&self) -> Vec<BlockAddress> {
72        let blocklist = self.state.blocklist.lock().await;
73        blocklist.keys().cloned().collect()
74    }
75
76    /// Endpoint handler to list addresses
77    async fn list_addresses(State(state): State<AppState>) -> impl IntoResponse {
78        let blocklist = state.blocklist.lock().await;
79        let block_addresses = blocklist.keys().cloned().collect();
80        Json(BlockAddresses {
81            addresses: block_addresses,
82        })
83    }
84
85    async fn periodically_remove_expired_addresses(
86        blocklist: Arc<Mutex<HashMap<BlockAddress, SystemTime>>>,
87    ) {
88        loop {
89            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
90            let mut blocklist = blocklist.lock().await;
91            let now = SystemTime::now();
92            blocklist.retain(|_address, expiry| now < *expiry);
93        }
94    }
95
96    /// Endpoint handler to block addresses
97    async fn block_addresses(
98        State(state): State<AppState>,
99        Json(addresses): Json<BlockAddresses>,
100    ) -> impl IntoResponse {
101        let mut blocklist = state.blocklist.lock().await;
102        for addr in addresses.addresses.iter() {
103            blocklist.insert(
104                addr.clone(),
105                SystemTime::now() + Duration::from_secs(addr.ttl),
106            );
107        }
108        (StatusCode::CREATED, "created")
109    }
110
111    pub async fn stop(&self) {
112        self.shutdown_signal.notify_one();
113    }
114}
115
116impl Default for NodeFwTestServer {
117    fn default() -> Self {
118        Self::new()
119    }
120}