Skip to main content

iota_http/
lib.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2025 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, sync::Arc, time::Duration};
6
7use connection_handler::OnConnectionClose;
8pub use http;
9use http::{Request, Response};
10use hyper_util::service::TowerToHyperService;
11use io::ServerIo;
12use tokio::task::JoinSet;
13use tokio_rustls::{TlsAcceptor, rustls};
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::{body::BoxBody, connection_info::ActiveConnections};
18
19pub mod body;
20mod config;
21mod connection_handler;
22mod connection_info;
23mod fuse;
24mod io;
25mod listener;
26
27pub use config::Config;
28pub use connection_info::{ConnectInfo, ConnectionId, ConnectionInfo, PeerCertificates};
29pub use listener::{Listener, ListenerExt};
30
31pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
32/// h2 alpn in plain format for rustls.
33const ALPN_H2: &[u8] = b"h2";
34/// h1 alpn in plain format for rustls.
35const ALPN_H1: &[u8] = b"http/1.1";
36
37#[derive(Default)]
38pub struct Builder {
39    config: Config,
40    tls_config: Option<rustls::ServerConfig>,
41}
42
43impl Builder {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn config(mut self, config: Config) -> Self {
49        self.config = config;
50        self
51    }
52
53    // Convenience method for configuring TLS with a single server cert
54    //
55    // Attempts to load PEM formatted files for the certificate chain and private
56    // key material from the provided file system paths.
57    pub fn tls_single_cert(
58        self,
59        cert_file: impl AsRef<std::path::Path>,
60        private_key_file: impl AsRef<std::path::Path>,
61    ) -> Result<Self, BoxError> {
62        let tls_config =
63            iota_tls::create_rustls_server_config_from_pem(cert_file, private_key_file)?;
64        Ok(self.tls_config(tls_config))
65    }
66
67    pub fn tls_config(mut self, tls_config: rustls::ServerConfig) -> Self {
68        self.tls_config = Some(tls_config);
69        self
70    }
71
72    pub fn serve<A, S, ResponseBody>(
73        self,
74        addr: A,
75        service: S,
76    ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
77    where
78        A: std::net::ToSocketAddrs,
79        S: Service<
80                Request<BoxBody>,
81                Response = Response<ResponseBody>,
82                Error: Into<BoxError>,
83                Future: Send,
84            > + Clone
85            + Send
86            + 'static,
87        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
88    {
89        let listener = listener::TcpListenerWithOptions::new(
90            addr,
91            self.config.tcp_nodelay,
92            self.config.tcp_keepalive,
93        )?;
94
95        Self::serve_with_listener(self, listener, service)
96    }
97
98    fn serve_with_listener<L, S, ResponseBody>(
99        self,
100        listener: L,
101        service: S,
102    ) -> Result<ServerHandle<L::Addr>, BoxError>
103    where
104        L: Listener,
105        S: Service<
106                Request<BoxBody>,
107                Response = Response<ResponseBody>,
108                Error: Into<BoxError>,
109                Future: Send,
110            > + Clone
111            + Send
112            + 'static,
113        ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
114    {
115        let local_addr = listener.local_addr()?;
116        let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
117        let connections = ActiveConnections::default();
118
119        let tls_config = self.tls_config.map(|mut tls| {
120            tls.alpn_protocols.push(ALPN_H2.into());
121            if self.config.accept_http1 {
122                tls.alpn_protocols.push(ALPN_H1.into());
123            }
124            Arc::new(tls)
125        });
126
127        let (watch_sender, watch_receiver) = tokio::sync::watch::channel(());
128        let server = Server {
129            config: self.config,
130            tls_config,
131            listener,
132            local_addr: local_addr.clone(),
133            service: ServiceBuilder::new()
134                .layer(tower::util::BoxCloneService::layer())
135                .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
136                .map_err(Into::into)
137                .service(service),
138            pending_connections: JoinSet::new(),
139            connection_handlers: JoinSet::new(),
140            connections: connections.clone(),
141            graceful_shutdown_token: graceful_shutdown_token.clone(),
142            _watch_receiver: watch_receiver,
143        };
144
145        let handle = ServerHandle(Arc::new(HandleInner {
146            local_addr,
147            connections,
148            graceful_shutdown_token,
149            watch_sender,
150        }));
151
152        tokio::spawn(server.serve());
153
154        Ok(handle)
155    }
156}
157
158#[derive(Debug)]
159pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
160
161#[derive(Debug)]
162struct HandleInner<A = std::net::SocketAddr> {
163    /// The local address of the server.
164    local_addr: A,
165    connections: ActiveConnections<A>,
166    graceful_shutdown_token: tokio_util::sync::CancellationToken,
167    watch_sender: tokio::sync::watch::Sender<()>,
168}
169
170impl<A> ServerHandle<A> {
171    /// Returns the local address of the server
172    pub fn local_addr(&self) -> &A {
173        &self.0.local_addr
174    }
175
176    /// Trigger a graceful shutdown of the server, but don't wait till the
177    /// server has completed shutting down
178    pub fn trigger_shutdown(&self) {
179        self.0.graceful_shutdown_token.cancel();
180    }
181
182    /// Completes once the network has been shutdown.
183    ///
184    /// This explicitly *does not* trigger the network to shutdown, see
185    /// `trigger_shutdown` or `shutdown` if you want to trigger shutting
186    /// down the server.
187    pub async fn wait_for_shutdown(&self) {
188        self.0.watch_sender.closed().await
189    }
190
191    /// Triggers a shutdown of the server and waits for it to complete shutting
192    /// down.
193    pub async fn shutdown(&self) {
194        self.trigger_shutdown();
195        self.wait_for_shutdown().await;
196    }
197
198    /// Checks if the Server has been shutdown.
199    pub fn is_shutdown(&self) -> bool {
200        self.0.watch_sender.is_closed()
201    }
202
203    pub fn connections(
204        &self,
205    ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
206        self.0.connections.read().unwrap()
207    }
208
209    /// Returns the number of active connections the server is handling
210    pub fn number_of_connections(&self) -> usize {
211        self.connections().len()
212    }
213}
214
215impl<A> Clone for ServerHandle<A> {
216    fn clone(&self) -> Self {
217        Self(self.0.clone())
218    }
219}
220
221type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
222
223struct Server<L: Listener> {
224    config: Config,
225    tls_config: Option<Arc<rustls::ServerConfig>>,
226
227    listener: L,
228    local_addr: L::Addr,
229    service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
230
231    pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
232    connection_handlers: JoinSet<()>,
233    connections: ActiveConnections<L::Addr>,
234    graceful_shutdown_token: tokio_util::sync::CancellationToken,
235    // Used to signal to a ServerHandle when the server has completed shutting down
236    _watch_receiver: tokio::sync::watch::Receiver<()>,
237}
238
239impl<L> Server<L>
240where
241    L: Listener,
242{
243    async fn serve(mut self) -> Result<(), BoxError> {
244        loop {
245            tokio::select! {
246                _ = self.graceful_shutdown_token.cancelled() => {
247                    trace!("signal received, shutting down");
248                    break;
249                },
250                (io, remote_addr) = self.listener.accept() => {
251                    self.handle_incoming(io, remote_addr);
252                },
253                Some(maybe_connection) = self.pending_connections.join_next() => {
254                    // If a task panics, just propagate it
255                    let (io, remote_addr) = match maybe_connection.unwrap() {
256                        Ok((io, remote_addr)) => {
257                            (io, remote_addr)
258                        }
259                        Err(e) => {
260                            tracing::debug!(error = %e, "error accepting connection");
261                            continue;
262                        }
263                    };
264
265                    trace!("connection accepted");
266                    self.handle_connection(io, remote_addr);
267                },
268                Some(connection_handler_output) = self.connection_handlers.join_next() => {
269                    // If a task panics, just propagate it
270                    let _: () = connection_handler_output.unwrap();
271                },
272            }
273        }
274
275        // Shutting down, wait for all connection handlers to finish
276        self.shutdown().await;
277
278        Ok(())
279    }
280
281    fn handle_incoming(&mut self, io: L::Io, remote_addr: L::Addr) {
282        if let Some(tls) = self.tls_config.clone() {
283            let tls_acceptor = TlsAcceptor::from(tls);
284            let allow_insecure = self.config.allow_insecure;
285            self.pending_connections.spawn(async move {
286                if allow_insecure {
287                    // XXX: If we want to allow for supporting insecure traffic from other types of
288                    // io, we'll need to implement a generic peekable IO type
289                    if let Some(tcp) =
290                        <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
291                    {
292                        // Determine whether new connection is TLS.
293                        let mut buf = [0; 1];
294                        // `peek` blocks until at least some data is available, so if there is no error then
295                        // it must return the one byte we are requesting.
296                        tcp.peek(&mut buf).await?;
297                        // First byte of a TLS handshake is 0x16, so if it isn't 0x16 then its
298                        // insecure
299                        if buf != [0x16] {
300                            tracing::trace!("accepting insecure connection");
301                            return Ok((ServerIo::new_io(io), remote_addr));
302                        }
303                    } else {
304                        tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
305                    }
306                }
307
308                tracing::trace!("accepting TLS connection");
309                let io = tls_acceptor.accept(io).await?;
310                Ok((ServerIo::new_tls_io(io), remote_addr))
311            });
312        } else {
313            self.handle_connection(ServerIo::new_io(io), remote_addr);
314        }
315    }
316
317    fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
318        let connection_shutdown_token = self.graceful_shutdown_token.child_token();
319        let connection_info = ConnectionInfo::new(
320            remote_addr,
321            io.peer_certs(),
322            connection_shutdown_token.clone(),
323        );
324        let connection_id = connection_info.id();
325        let connect_info = connection_info::ConnectInfo {
326            local_addr: self.local_addr.clone(),
327            remote_addr: connection_info.remote_address().clone(),
328        };
329        let peer_certificates = connection_info.peer_certificates().cloned();
330        let hyper_io = hyper_util::rt::TokioIo::new(io);
331
332        let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
333            move |mut request: Request<hyper::body::Incoming>| {
334                request.extensions_mut().insert(connect_info.clone());
335                if let Some(peer_certificates) = peer_certificates.clone() {
336                    request.extensions_mut().insert(peer_certificates);
337                }
338
339                request.map(body::boxed)
340            },
341        ));
342
343        self.connections
344            .write()
345            .unwrap()
346            .insert(connection_id, connection_info);
347        let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
348
349        self.connection_handlers
350            .spawn(connection_handler::serve_connection(
351                hyper_io,
352                hyper_svc,
353                self.config.connection_builder(),
354                connection_shutdown_token,
355                self.config.max_connection_age,
356                on_connection_close,
357            ));
358    }
359
360    async fn shutdown(mut self) {
361        // The time we are willing to wait for a connection to get gracefully shutdown
362        // before we attempt to forcefully shutdown all active connections
363        const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
364
365        // Just to be careful make sure the token is canceled
366        self.graceful_shutdown_token.cancel();
367
368        // Terminate any in-progress pending connections
369        self.pending_connections.shutdown().await;
370
371        // Wait for all connection handlers to terminate
372        trace!(
373            "waiting for {} connections to close",
374            self.connection_handlers.len()
375        );
376
377        let graceful_shutdown =
378            async { while self.connection_handlers.join_next().await.is_some() {} };
379
380        if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
381            .await
382            .is_err()
383        {
384            tracing::warn!(
385                "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
386                CONNECTION_SHUTDOWN_GRACE_PERIOD
387            );
388            self.connection_handlers.shutdown().await;
389        }
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use axum::Router;
396
397    use super::*;
398
399    #[tokio::test]
400    async fn simple() {
401        const MESSAGE: &str = "Hello, World!";
402
403        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
404
405        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
406
407        let url = format!("http://{}", handle.local_addr());
408
409        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
410
411        assert_eq!(response, MESSAGE.as_bytes());
412    }
413
414    #[tokio::test]
415    async fn shutdown() {
416        const MESSAGE: &str = "Hello, World!";
417
418        let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
419
420        let handle = Builder::new().serve(("localhost", 0), app).unwrap();
421
422        let url = format!("http://{}", handle.local_addr());
423
424        let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
425
426        // a request was just made so we should have 1 active connection
427        assert_eq!(handle.connections().len(), 1);
428
429        assert_eq!(response, MESSAGE.as_bytes());
430
431        assert!(!handle.is_shutdown());
432
433        handle.shutdown().await;
434
435        assert!(handle.is_shutdown());
436
437        // Now that the network has been shutdown there should be zero connections
438        assert_eq!(handle.connections().len(), 0);
439    }
440}