1use std::{collections::HashMap, sync::Arc, time::Duration};
6
7use connection_handler::OnConnectionClose;
8pub use http;
9use http::{Request, Response};
10use hyper_util::service::TowerToHyperService;
11use io::ServerIo;
12use tokio::task::JoinSet;
13use tokio_rustls::{TlsAcceptor, rustls};
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::{body::BoxBody, connection_info::ActiveConnections};
18
19pub mod body;
20mod config;
21mod connection_handler;
22mod connection_info;
23mod fuse;
24mod io;
25mod listener;
26
27pub use config::Config;
28pub use connection_info::{ConnectInfo, ConnectionId, ConnectionInfo, PeerCertificates};
29pub use listener::{Listener, ListenerExt};
30
31pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
32const ALPN_H2: &[u8] = b"h2";
34const ALPN_H1: &[u8] = b"http/1.1";
36
37#[derive(Default)]
38pub struct Builder {
39 config: Config,
40 tls_config: Option<rustls::ServerConfig>,
41}
42
43impl Builder {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn config(mut self, config: Config) -> Self {
49 self.config = config;
50 self
51 }
52
53 pub fn tls_single_cert(
58 self,
59 cert_file: impl AsRef<std::path::Path>,
60 private_key_file: impl AsRef<std::path::Path>,
61 ) -> Result<Self, BoxError> {
62 let tls_config =
63 iota_tls::create_rustls_server_config_from_pem(cert_file, private_key_file)?;
64 Ok(self.tls_config(tls_config))
65 }
66
67 pub fn tls_config(mut self, tls_config: rustls::ServerConfig) -> Self {
68 self.tls_config = Some(tls_config);
69 self
70 }
71
72 pub fn serve<A, S, ResponseBody>(
73 self,
74 addr: A,
75 service: S,
76 ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
77 where
78 A: std::net::ToSocketAddrs,
79 S: Service<
80 Request<BoxBody>,
81 Response = Response<ResponseBody>,
82 Error: Into<BoxError>,
83 Future: Send,
84 > + Clone
85 + Send
86 + 'static,
87 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
88 {
89 let listener = listener::TcpListenerWithOptions::new(
90 addr,
91 self.config.tcp_nodelay,
92 self.config.tcp_keepalive,
93 )?;
94
95 Self::serve_with_listener(self, listener, service)
96 }
97
98 fn serve_with_listener<L, S, ResponseBody>(
99 self,
100 listener: L,
101 service: S,
102 ) -> Result<ServerHandle<L::Addr>, BoxError>
103 where
104 L: Listener,
105 S: Service<
106 Request<BoxBody>,
107 Response = Response<ResponseBody>,
108 Error: Into<BoxError>,
109 Future: Send,
110 > + Clone
111 + Send
112 + 'static,
113 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
114 {
115 let local_addr = listener.local_addr()?;
116 let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
117 let connections = ActiveConnections::default();
118
119 let tls_config = self.tls_config.map(|mut tls| {
120 tls.alpn_protocols.push(ALPN_H2.into());
121 if self.config.accept_http1 {
122 tls.alpn_protocols.push(ALPN_H1.into());
123 }
124 Arc::new(tls)
125 });
126
127 let (watch_sender, watch_receiver) = tokio::sync::watch::channel(());
128 let server = Server {
129 config: self.config,
130 tls_config,
131 listener,
132 local_addr: local_addr.clone(),
133 service: ServiceBuilder::new()
134 .layer(tower::util::BoxCloneService::layer())
135 .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
136 .map_err(Into::into)
137 .service(service),
138 pending_connections: JoinSet::new(),
139 connection_handlers: JoinSet::new(),
140 connections: connections.clone(),
141 graceful_shutdown_token: graceful_shutdown_token.clone(),
142 _watch_receiver: watch_receiver,
143 };
144
145 let handle = ServerHandle(Arc::new(HandleInner {
146 local_addr,
147 connections,
148 graceful_shutdown_token,
149 watch_sender,
150 }));
151
152 tokio::spawn(server.serve());
153
154 Ok(handle)
155 }
156}
157
158#[derive(Debug)]
159pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
160
161#[derive(Debug)]
162struct HandleInner<A = std::net::SocketAddr> {
163 local_addr: A,
165 connections: ActiveConnections<A>,
166 graceful_shutdown_token: tokio_util::sync::CancellationToken,
167 watch_sender: tokio::sync::watch::Sender<()>,
168}
169
170impl<A> ServerHandle<A> {
171 pub fn local_addr(&self) -> &A {
173 &self.0.local_addr
174 }
175
176 pub fn trigger_shutdown(&self) {
179 self.0.graceful_shutdown_token.cancel();
180 }
181
182 pub async fn wait_for_shutdown(&self) {
188 self.0.watch_sender.closed().await
189 }
190
191 pub async fn shutdown(&self) {
194 self.trigger_shutdown();
195 self.wait_for_shutdown().await;
196 }
197
198 pub fn is_shutdown(&self) -> bool {
200 self.0.watch_sender.is_closed()
201 }
202
203 pub fn connections(
204 &self,
205 ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
206 self.0.connections.read().unwrap()
207 }
208
209 pub fn number_of_connections(&self) -> usize {
211 self.connections().len()
212 }
213}
214
215impl<A> Clone for ServerHandle<A> {
216 fn clone(&self) -> Self {
217 Self(self.0.clone())
218 }
219}
220
221type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
222
223struct Server<L: Listener> {
224 config: Config,
225 tls_config: Option<Arc<rustls::ServerConfig>>,
226
227 listener: L,
228 local_addr: L::Addr,
229 service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
230
231 pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
232 connection_handlers: JoinSet<()>,
233 connections: ActiveConnections<L::Addr>,
234 graceful_shutdown_token: tokio_util::sync::CancellationToken,
235 _watch_receiver: tokio::sync::watch::Receiver<()>,
237}
238
239impl<L> Server<L>
240where
241 L: Listener,
242{
243 async fn serve(mut self) -> Result<(), BoxError> {
244 loop {
245 tokio::select! {
246 _ = self.graceful_shutdown_token.cancelled() => {
247 trace!("signal received, shutting down");
248 break;
249 },
250 (io, remote_addr) = self.listener.accept() => {
251 self.handle_incoming(io, remote_addr);
252 },
253 Some(maybe_connection) = self.pending_connections.join_next() => {
254 let (io, remote_addr) = match maybe_connection.unwrap() {
256 Ok((io, remote_addr)) => {
257 (io, remote_addr)
258 }
259 Err(e) => {
260 tracing::debug!(error = %e, "error accepting connection");
261 continue;
262 }
263 };
264
265 trace!("connection accepted");
266 self.handle_connection(io, remote_addr);
267 },
268 Some(connection_handler_output) = self.connection_handlers.join_next() => {
269 let _: () = connection_handler_output.unwrap();
271 },
272 }
273 }
274
275 self.shutdown().await;
277
278 Ok(())
279 }
280
281 fn handle_incoming(&mut self, io: L::Io, remote_addr: L::Addr) {
282 if let Some(tls) = self.tls_config.clone() {
283 let tls_acceptor = TlsAcceptor::from(tls);
284 let allow_insecure = self.config.allow_insecure;
285 self.pending_connections.spawn(async move {
286 if allow_insecure {
287 if let Some(tcp) =
290 <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
291 {
292 let mut buf = [0; 1];
294 tcp.peek(&mut buf).await?;
297 if buf != [0x16] {
300 tracing::trace!("accepting insecure connection");
301 return Ok((ServerIo::new_io(io), remote_addr));
302 }
303 } else {
304 tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
305 }
306 }
307
308 tracing::trace!("accepting TLS connection");
309 let io = tls_acceptor.accept(io).await?;
310 Ok((ServerIo::new_tls_io(io), remote_addr))
311 });
312 } else {
313 self.handle_connection(ServerIo::new_io(io), remote_addr);
314 }
315 }
316
317 fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
318 let connection_shutdown_token = self.graceful_shutdown_token.child_token();
319 let connection_info = ConnectionInfo::new(
320 remote_addr,
321 io.peer_certs(),
322 connection_shutdown_token.clone(),
323 );
324 let connection_id = connection_info.id();
325 let connect_info = connection_info::ConnectInfo {
326 local_addr: self.local_addr.clone(),
327 remote_addr: connection_info.remote_address().clone(),
328 };
329 let peer_certificates = connection_info.peer_certificates().cloned();
330 let hyper_io = hyper_util::rt::TokioIo::new(io);
331
332 let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
333 move |mut request: Request<hyper::body::Incoming>| {
334 request.extensions_mut().insert(connect_info.clone());
335 if let Some(peer_certificates) = peer_certificates.clone() {
336 request.extensions_mut().insert(peer_certificates);
337 }
338
339 request.map(body::boxed)
340 },
341 ));
342
343 self.connections
344 .write()
345 .unwrap()
346 .insert(connection_id, connection_info);
347 let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
348
349 self.connection_handlers
350 .spawn(connection_handler::serve_connection(
351 hyper_io,
352 hyper_svc,
353 self.config.connection_builder(),
354 connection_shutdown_token,
355 self.config.max_connection_age,
356 on_connection_close,
357 ));
358 }
359
360 async fn shutdown(mut self) {
361 const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
364
365 self.graceful_shutdown_token.cancel();
367
368 self.pending_connections.shutdown().await;
370
371 trace!(
373 "waiting for {} connections to close",
374 self.connection_handlers.len()
375 );
376
377 let graceful_shutdown =
378 async { while self.connection_handlers.join_next().await.is_some() {} };
379
380 if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
381 .await
382 .is_err()
383 {
384 tracing::warn!(
385 "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
386 CONNECTION_SHUTDOWN_GRACE_PERIOD
387 );
388 self.connection_handlers.shutdown().await;
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use axum::Router;
396
397 use super::*;
398
399 #[tokio::test]
400 async fn simple() {
401 const MESSAGE: &str = "Hello, World!";
402
403 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
404
405 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
406
407 let url = format!("http://{}", handle.local_addr());
408
409 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
410
411 assert_eq!(response, MESSAGE.as_bytes());
412 }
413
414 #[tokio::test]
415 async fn shutdown() {
416 const MESSAGE: &str = "Hello, World!";
417
418 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
419
420 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
421
422 let url = format!("http://{}", handle.local_addr());
423
424 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
425
426 assert_eq!(handle.connections().len(), 1);
428
429 assert_eq!(response, MESSAGE.as_bytes());
430
431 assert!(!handle.is_shutdown());
432
433 handle.shutdown().await;
434
435 assert!(handle.is_shutdown());
436
437 assert_eq!(handle.connections().len(), 0);
439 }
440}