iota_http/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2025 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, sync::Arc, time::Duration};
6
7use connection_handler::OnConnectionClose;
8pub use http;
9use http::{Request, Response};
10use hyper_util::service::TowerToHyperService;
11use io::ServerIo;
12use tokio::task::JoinSet;
13use tokio_rustls::TlsAcceptor;
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::{body::BoxBody, connection_info::ActiveConnections};
18
19pub mod body;
20mod config;
21mod connection_handler;
22mod connection_info;
23mod fuse;
24mod io;
25mod listener;
26
27pub use config::Config;
28pub use connection_info::{ConnectInfo, ConnectionId, ConnectionInfo, PeerCertificates};
29pub use listener::{Listener, ListenerExt};
30
31pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
32/// h2 alpn in plain format for rustls.
33const ALPN_H2: &[u8] = b"h2";
34/// h1 alpn in plain format for rustls.
35const ALPN_H1: &[u8] = b"http/1.1";
36
37#[derive(Default)]
38pub struct Builder {
39    config: Config,
40    tls_config: Option<tokio_rustls::rustls::ServerConfig>,
41}
42
43impl Builder {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn config(mut self, config: Config) -> Self {
49        self.config = config;
50        self
51    }
52
53    pub fn tls_config(mut self, tls_config: tokio_rustls::rustls::ServerConfig) -> Self {
54        self.tls_config = Some(tls_config);
55        self
56    }
57
58    pub fn serve<A, S, ResponseBody>(
59        self,
60        addr: A,
61        service: S,
62    ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
63    where
64        A: std::net::ToSocketAddrs,
65        S: Service<
66                Request<BoxBody>,
67                Response = Response<ResponseBody>,
68                Error: Into<BoxError>,
69                Future: Send,
70            > + Clone
71            + Send
72            + 'static,
73        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
74    {
75        let listener = listener::TcpListenerWithOptions::new(
76            addr,
77            self.config.tcp_nodelay,
78            self.config.tcp_keepalive,
79        )?;
80
81        Self::serve_with_listener(self, listener, service)
82    }
83
84    fn serve_with_listener<L, S, ResponseBody>(
85        self,
86        listener: L,
87        service: S,
88    ) -> Result<ServerHandle<L::Addr>, BoxError>
89    where
90        L: Listener,
91        S: Service<
92                Request<BoxBody>,
93                Response = Response<ResponseBody>,
94                Error: Into<BoxError>,
95                Future: Send,
96            > + Clone
97            + Send
98            + 'static,
99        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
100    {
101        let local_addr = listener.local_addr()?;
102        let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
103        let connections = ActiveConnections::default();
104
105        let tls_config = self.tls_config.map(|mut tls| {
106            tls.alpn_protocols.push(ALPN_H2.into());
107            if self.config.accept_http1 {
108                tls.alpn_protocols.push(ALPN_H1.into());
109            }
110            Arc::new(tls)
111        });
112
113        let (watch_sender, watch_receiver) = tokio::sync::watch::channel(());
114        let server = Server {
115            config: self.config,
116            tls_config,
117            listener,
118            local_addr: local_addr.clone(),
119            service: ServiceBuilder::new()
120                .layer(tower::util::BoxCloneService::layer())
121                .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
122                .map_err(Into::into)
123                .service(service),
124            pending_connections: JoinSet::new(),
125            connection_handlers: JoinSet::new(),
126            connections: connections.clone(),
127            graceful_shutdown_token: graceful_shutdown_token.clone(),
128            _watch_receiver: watch_receiver,
129        };
130
131        let handle = ServerHandle(Arc::new(HandleInner {
132            local_addr,
133            connections,
134            graceful_shutdown_token,
135            watch_sender,
136        }));
137
138        tokio::spawn(server.serve());
139
140        Ok(handle)
141    }
142}
143
144#[derive(Debug)]
145pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
146
147#[derive(Debug)]
148struct HandleInner<A = std::net::SocketAddr> {
149    /// The local address of the server.
150    local_addr: A,
151    connections: ActiveConnections<A>,
152    graceful_shutdown_token: tokio_util::sync::CancellationToken,
153    watch_sender: tokio::sync::watch::Sender<()>,
154}
155
156impl<A> ServerHandle<A> {
157    /// Returns the local address of the server
158    pub fn local_addr(&self) -> &A {
159        &self.0.local_addr
160    }
161
162    /// Trigger a graceful shutdown of the server, but don't wait till the
163    /// server has completed shutting down
164    pub fn trigger_shutdown(&self) {
165        self.0.graceful_shutdown_token.cancel();
166    }
167
168    /// Completes once the network has been shutdown.
169    ///
170    /// This explicitly *does not* trigger the network to shutdown, see
171    /// `trigger_shutdown` or `shutdown` if you want to trigger shutting
172    /// down the server.
173    pub async fn wait_for_shutdown(&self) {
174        self.0.watch_sender.closed().await
175    }
176
177    /// Triggers a shutdown of the server and waits for it to complete shutting
178    /// down.
179    pub async fn shutdown(&self) {
180        self.trigger_shutdown();
181        self.wait_for_shutdown().await;
182    }
183
184    /// Checks if the Server has been shutdown.
185    pub fn is_shutdown(&self) -> bool {
186        self.0.watch_sender.is_closed()
187    }
188
189    pub fn connections(
190        &self,
191    ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
192        self.0.connections.read().unwrap()
193    }
194
195    /// Returns the number of active connections the server is handling
196    pub fn number_of_connections(&self) -> usize {
197        self.connections().len()
198    }
199}
200
201impl<A> Clone for ServerHandle<A> {
202    fn clone(&self) -> Self {
203        Self(self.0.clone())
204    }
205}
206
207type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
208
209struct Server<L: Listener> {
210    config: Config,
211    tls_config: Option<Arc<tokio_rustls::rustls::ServerConfig>>,
212
213    listener: L,
214    local_addr: L::Addr,
215    service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
216
217    pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
218    connection_handlers: JoinSet<()>,
219    connections: ActiveConnections<L::Addr>,
220    graceful_shutdown_token: tokio_util::sync::CancellationToken,
221    // Used to signal to a ServerHandle when the server has completed shutting down
222    _watch_receiver: tokio::sync::watch::Receiver<()>,
223}
224
225impl<L> Server<L>
226where
227    L: Listener,
228{
229    async fn serve(mut self) -> Result<(), BoxError> {
230        loop {
231            tokio::select! {
232                _ = self.graceful_shutdown_token.cancelled() => {
233                    trace!("signal received, shutting down");
234                    break;
235                },
236                (io, remote_addr) = self.listener.accept() => {
237                    self.handle_incoming(io, remote_addr);
238                },
239                Some(maybe_connection) = self.pending_connections.join_next() => {
240                    // If a task panics, just propagate it
241                    let (io, remote_addr) = match maybe_connection.unwrap() {
242                        Ok((io, remote_addr)) => {
243                            (io, remote_addr)
244                        }
245                        Err(e) => {
246                            tracing::debug!(error = %e, "error accepting connection");
247                            continue;
248                        }
249                    };
250
251                    trace!("connection accepted");
252                    self.handle_connection(io, remote_addr);
253                },
254                Some(connection_handler_output) = self.connection_handlers.join_next() => {
255                    // If a task panics, just propagate it
256                    let _: () = connection_handler_output.unwrap();
257                },
258            }
259        }
260
261        // Shutting down, wait for all connection handlers to finish
262        self.shutdown().await;
263
264        Ok(())
265    }
266
267    fn handle_incoming(&mut self, io: L::Io, remote_addr: L::Addr) {
268        if let Some(tls) = self.tls_config.clone() {
269            let tls_acceptor = TlsAcceptor::from(tls);
270            let allow_insecure = self.config.allow_insecure;
271            self.pending_connections.spawn(async move {
272                if allow_insecure {
273                    // XXX: If we want to allow for supporting insecure traffic from other types of
274                    // io, we'll need to implement a generic peekable IO type
275                    if let Some(tcp) =
276                        <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
277                    {
278                        // Determine whether new connection is TLS.
279                        let mut buf = [0; 1];
280                        // `peek` blocks until at least some data is available, so if there is no error then
281                        // it must return the one byte we are requesting.
282                        tcp.peek(&mut buf).await?;
283                        // First byte of a TLS handshake is 0x16, so if it isn't 0x16 then its
284                        // insecure
285                        if buf != [0x16] {
286                            tracing::trace!("accepting insecure connection");
287                            return Ok((ServerIo::new_io(io), remote_addr));
288                        }
289                    } else {
290                        tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
291                    }
292                }
293
294                tracing::trace!("accepting TLS connection");
295                let io = tls_acceptor.accept(io).await?;
296                Ok((ServerIo::new_tls_io(io), remote_addr))
297            });
298        } else {
299            self.handle_connection(ServerIo::new_io(io), remote_addr);
300        }
301    }
302
303    fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
304        let connection_shutdown_token = self.graceful_shutdown_token.child_token();
305        let connection_info = ConnectionInfo::new(
306            remote_addr,
307            io.peer_certs(),
308            connection_shutdown_token.clone(),
309        );
310        let connection_id = connection_info.id();
311        let connect_info = connection_info::ConnectInfo {
312            local_addr: self.local_addr.clone(),
313            remote_addr: connection_info.remote_address().clone(),
314        };
315        let peer_certificates = connection_info.peer_certificates().cloned();
316        let hyper_io = hyper_util::rt::TokioIo::new(io);
317
318        let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
319            move |mut request: Request<hyper::body::Incoming>| {
320                request.extensions_mut().insert(connect_info.clone());
321                if let Some(peer_certificates) = peer_certificates.clone() {
322                    request.extensions_mut().insert(peer_certificates);
323                }
324
325                request.map(body::boxed)
326            },
327        ));
328
329        self.connections
330            .write()
331            .unwrap()
332            .insert(connection_id, connection_info);
333        let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
334
335        self.connection_handlers
336            .spawn(connection_handler::serve_connection(
337                hyper_io,
338                hyper_svc,
339                self.config.connection_builder(),
340                connection_shutdown_token,
341                self.config.max_connection_age,
342                on_connection_close,
343            ));
344    }
345
346    async fn shutdown(mut self) {
347        // The time we are willing to wait for a connection to get gracefully shutdown
348        // before we attempt to forcefully shutdown all active connections
349        const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
350
351        // Just to be careful make sure the token is canceled
352        self.graceful_shutdown_token.cancel();
353
354        // Terminate any in-progress pending connections
355        self.pending_connections.shutdown().await;
356
357        // Wait for all connection handlers to terminate
358        trace!(
359            "waiting for {} connections to close",
360            self.connection_handlers.len()
361        );
362
363        let graceful_shutdown =
364            async { while self.connection_handlers.join_next().await.is_some() {} };
365
366        if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
367            .await
368            .is_err()
369        {
370            tracing::warn!(
371                "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
372                CONNECTION_SHUTDOWN_GRACE_PERIOD
373            );
374            self.connection_handlers.shutdown().await;
375        }
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use axum::Router;
382
383    use super::*;
384
385    #[tokio::test]
386    async fn simple() {
387        const MESSAGE: &str = "Hello, World!";
388
389        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
390
391        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
392
393        let url = format!("http://{}", handle.local_addr());
394
395        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
396
397        assert_eq!(response, MESSAGE.as_bytes());
398    }
399
400    #[tokio::test]
401    async fn shutdown() {
402        const MESSAGE: &str = "Hello, World!";
403
404        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
405
406        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
407
408        let url = format!("http://{}", handle.local_addr());
409
410        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
411
412        // a request was just made so we should have 1 active connection
413        assert_eq!(handle.connections().len(), 1);
414
415        assert_eq!(response, MESSAGE.as_bytes());
416
417        assert!(!handle.is_shutdown());
418
419        handle.shutdown().await;
420
421        assert!(handle.is_shutdown());
422
423        // Now that the network has been shutdown there should be zero connections
424        assert_eq!(handle.connections().len(), 0);
425    }
426}