iota_http/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2025 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, sync::Arc, time::Duration};
6
7use connection_handler::OnConnectionClose;
8pub use http;
9use http::{Request, Response};
10use hyper_util::service::TowerToHyperService;
11use io::ServerIo;
12use tokio::task::JoinSet;
13use tokio_rustls::TlsAcceptor;
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::{body::BoxBody, connection_info::ActiveConnections};
18
19pub mod body;
20mod config;
21mod connection_handler;
22mod connection_info;
23mod fuse;
24mod io;
25mod listener;
26
27pub use config::Config;
28pub use connection_info::{ConnectInfo, ConnectionId, ConnectionInfo, PeerCertificates};
29pub use listener::{Listener, ListenerExt};
30
31pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
32/// h2 alpn in plain format for rustls.
33const ALPN_H2: &[u8] = b"h2";
34/// h1 alpn in plain format for rustls.
35const ALPN_H1: &[u8] = b"http/1.1";
36
37#[derive(Default)]
38pub struct Builder {
39    config: Config,
40    tls_config: Option<tokio_rustls::rustls::ServerConfig>,
41}
42
43impl Builder {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn config(mut self, config: Config) -> Self {
49        self.config = config;
50        self
51    }
52
53    pub fn tls_config(mut self, tls_config: tokio_rustls::rustls::ServerConfig) -> Self {
54        self.tls_config = Some(tls_config);
55        self
56    }
57
58    pub fn serve<A, S, ResponseBody>(
59        self,
60        addr: A,
61        service: S,
62    ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
63    where
64        A: std::net::ToSocketAddrs,
65        S: Service<
66                Request<BoxBody>,
67                Response = Response<ResponseBody>,
68                Error: Into<BoxError>,
69                Future: Send,
70            > + Clone
71            + Send
72            + 'static,
73        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
74    {
75        let listener = listener::TcpListenerWithOptions::new(
76            addr,
77            self.config.tcp_nodelay,
78            self.config.tcp_keepalive,
79        )?;
80
81        Self::serve_with_listener(self, listener, service)
82    }
83
84    fn serve_with_listener<L, S, ResponseBody>(
85        self,
86        listener: L,
87        service: S,
88    ) -> Result<ServerHandle<L::Addr>, BoxError>
89    where
90        L: Listener,
91        S: Service<
92                Request<BoxBody>,
93                Response = Response<ResponseBody>,
94                Error: Into<BoxError>,
95                Future: Send,
96            > + Clone
97            + Send
98            + 'static,
99        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
100    {
101        let local_addr = listener.local_addr()?;
102        let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
103        let connections = ActiveConnections::default();
104
105        let tls_config = self.tls_config.map(|mut tls| {
106            tls.alpn_protocols.push(ALPN_H2.into());
107            if self.config.accept_http1 {
108                tls.alpn_protocols.push(ALPN_H1.into());
109            }
110            Arc::new(tls)
111        });
112
113        let (watch_sender, watch_receiver) = tokio::sync::watch::channel(());
114        let server = Server {
115            config: self.config,
116            tls_config,
117            listener,
118            local_addr: local_addr.clone(),
119            service: ServiceBuilder::new()
120                .layer(tower::util::BoxCloneService::layer())
121                .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
122                .map_err(Into::into)
123                .service(service),
124            pending_connections: JoinSet::new(),
125            connection_handlers: JoinSet::new(),
126            connections: connections.clone(),
127            graceful_shutdown_token: graceful_shutdown_token.clone(),
128            _watch_receiver: watch_receiver,
129        };
130
131        let handle = ServerHandle(Arc::new(HandleInner {
132            local_addr,
133            connections,
134            graceful_shutdown_token,
135            watch_sender,
136        }));
137
138        tokio::spawn(server.serve());
139
140        Ok(handle)
141    }
142}
143
144#[derive(Debug)]
145pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
146
147#[derive(Debug)]
148struct HandleInner<A = std::net::SocketAddr> {
149    /// The local address of the server.
150    local_addr: A,
151    connections: ActiveConnections<A>,
152    graceful_shutdown_token: tokio_util::sync::CancellationToken,
153    watch_sender: tokio::sync::watch::Sender<()>,
154}
155
156impl<A> ServerHandle<A> {
157    /// Returns the local address of the server
158    pub fn local_addr(&self) -> &A {
159        &self.0.local_addr
160    }
161
162    /// Trigger a graceful shutdown of the server, but don't wait till the
163    /// server has completed shutting down
164    pub fn trigger_shutdown(&self) {
165        self.0.graceful_shutdown_token.cancel();
166    }
167
168    /// Completes once the network has been shutdown.
169    ///
170    /// This explicitly *does not* trigger the network to shutdown, see
171    /// `trigger_shutdown` or `shutdown` if you want to trigger shutting
172    /// down the server.
173    pub async fn wait_for_shutdown(&self) {
174        self.0.watch_sender.closed().await
175    }
176
177    /// Triggers a shutdown of the server and waits for it to complete shutting
178    /// down.
179    pub async fn shutdown(&self) {
180        self.trigger_shutdown();
181        self.wait_for_shutdown().await;
182    }
183
184    /// Checks if the Server has been shutdown.
185    pub fn is_shutdown(&self) -> bool {
186        self.0.watch_sender.is_closed()
187    }
188
189    pub fn connections(
190        &self,
191    ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
192        self.0.connections.read().unwrap()
193    }
194
195    /// Returns the number of active connections the server is handling
196    pub fn number_of_connections(&self) -> usize {
197        self.connections().len()
198    }
199}
200
201type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
202
203struct Server<L: Listener> {
204    config: Config,
205    tls_config: Option<Arc<tokio_rustls::rustls::ServerConfig>>,
206
207    listener: L,
208    local_addr: L::Addr,
209    service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
210
211    pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
212    connection_handlers: JoinSet<()>,
213    connections: ActiveConnections<L::Addr>,
214    graceful_shutdown_token: tokio_util::sync::CancellationToken,
215    // Used to signal to a ServerHandle when the server has completed shutting down
216    _watch_receiver: tokio::sync::watch::Receiver<()>,
217}
218
219impl<L> Server<L>
220where
221    L: Listener,
222{
223    async fn serve(mut self) -> Result<(), BoxError> {
224        loop {
225            tokio::select! {
226                _ = self.graceful_shutdown_token.cancelled() => {
227                    trace!("signal received, shutting down");
228                    break;
229                },
230                (io, remote_addr) = self.listener.accept() => {
231                    self.handle_incoming(io, remote_addr);
232                },
233                Some(maybe_connection) = self.pending_connections.join_next() => {
234                    // If a task panics, just propagate it
235                    let (io, remote_addr) = match maybe_connection.unwrap() {
236                        Ok((io, remote_addr)) => {
237                            (io, remote_addr)
238                        }
239                        Err(e) => {
240                            tracing::debug!(error = %e, "error accepting connection");
241                            continue;
242                        }
243                    };
244
245                    trace!("connection accepted");
246                    self.handle_connection(io, remote_addr);
247                },
248                Some(connection_handler_output) = self.connection_handlers.join_next() => {
249                    // If a task panics, just propagate it
250                    let _: () = connection_handler_output.unwrap();
251                },
252            }
253        }
254
255        // Shutting down, wait for all connection handlers to finish
256        self.shutdown().await;
257
258        Ok(())
259    }
260
261    fn handle_incoming(&mut self, io: L::Io, remote_addr: L::Addr) {
262        if let Some(tls) = self.tls_config.clone() {
263            let tls_acceptor = TlsAcceptor::from(tls);
264            let allow_insecure = self.config.allow_insecure;
265            self.pending_connections.spawn(async move {
266                if allow_insecure {
267                    // XXX: If we want to allow for supporting insecure traffic from other types of
268                    // io, we'll need to implement a generic peekable IO type
269                    if let Some(tcp) =
270                        <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
271                    {
272                        // Determine whether new connection is TLS.
273                        let mut buf = [0; 1];
274                        // `peek` blocks until at least some data is available, so if there is no error then
275                        // it must return the one byte we are requesting.
276                        tcp.peek(&mut buf).await?;
277                        // First byte of a TLS handshake is 0x16, so if it isn't 0x16 then its
278                        // insecure
279                        if buf != [0x16] {
280                            tracing::trace!("accepting insecure connection");
281                            return Ok((ServerIo::new_io(io), remote_addr));
282                        }
283                    } else {
284                        tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
285                    }
286                }
287
288                tracing::trace!("accepting TLS connection");
289                let io = tls_acceptor.accept(io).await?;
290                Ok((ServerIo::new_tls_io(io), remote_addr))
291            });
292        } else {
293            self.handle_connection(ServerIo::new_io(io), remote_addr);
294        }
295    }
296
297    fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
298        let connection_shutdown_token = self.graceful_shutdown_token.child_token();
299        let connection_info = ConnectionInfo::new(
300            remote_addr,
301            io.peer_certs(),
302            connection_shutdown_token.clone(),
303        );
304        let connection_id = connection_info.id();
305        let connect_info = connection_info::ConnectInfo {
306            local_addr: self.local_addr.clone(),
307            remote_addr: connection_info.remote_address().clone(),
308        };
309        let peer_certificates = connection_info.peer_certificates().cloned();
310        let hyper_io = hyper_util::rt::TokioIo::new(io);
311
312        let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
313            move |mut request: Request<hyper::body::Incoming>| {
314                request.extensions_mut().insert(connect_info.clone());
315                if let Some(peer_certificates) = peer_certificates.clone() {
316                    request.extensions_mut().insert(peer_certificates);
317                }
318
319                request.map(body::boxed)
320            },
321        ));
322
323        self.connections
324            .write()
325            .unwrap()
326            .insert(connection_id, connection_info);
327        let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
328
329        self.connection_handlers
330            .spawn(connection_handler::serve_connection(
331                hyper_io,
332                hyper_svc,
333                self.config.connection_builder(),
334                connection_shutdown_token,
335                self.config.max_connection_age,
336                on_connection_close,
337            ));
338    }
339
340    async fn shutdown(mut self) {
341        // The time we are willing to wait for a connection to get gracefully shutdown
342        // before we attempt to forcefully shutdown all active connections
343        const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
344
345        // Just to be careful make sure the token is canceled
346        self.graceful_shutdown_token.cancel();
347
348        // Terminate any in-progress pending connections
349        self.pending_connections.shutdown().await;
350
351        // Wait for all connection handlers to terminate
352        trace!(
353            "waiting for {} connections to close",
354            self.connection_handlers.len()
355        );
356
357        let graceful_shutdown =
358            async { while self.connection_handlers.join_next().await.is_some() {} };
359
360        if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
361            .await
362            .is_err()
363        {
364            tracing::warn!(
365                "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
366                CONNECTION_SHUTDOWN_GRACE_PERIOD
367            );
368            self.connection_handlers.shutdown().await;
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use axum::Router;
376
377    use super::*;
378
379    #[tokio::test]
380    async fn simple() {
381        const MESSAGE: &str = "Hello, World!";
382
383        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
384
385        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
386
387        let url = format!("http://{}", handle.local_addr());
388
389        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
390
391        assert_eq!(response, MESSAGE.as_bytes());
392    }
393
394    #[tokio::test]
395    async fn shutdown() {
396        const MESSAGE: &str = "Hello, World!";
397
398        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
399
400        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
401
402        let url = format!("http://{}", handle.local_addr());
403
404        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
405
406        // a request was just made so we should have 1 active connection
407        assert_eq!(handle.connections().len(), 1);
408
409        assert_eq!(response, MESSAGE.as_bytes());
410
411        assert!(!handle.is_shutdown());
412
413        handle.shutdown().await;
414
415        assert!(handle.is_shutdown());
416
417        // Now that the network has been shutdown there should be zero connections
418        assert_eq!(handle.connections().len(), 0);
419    }
420}