1use std::{collections::HashMap, sync::Arc, time::Duration};
6
7use connection_handler::OnConnectionClose;
8pub use http;
9use http::{Request, Response};
10use hyper_util::service::TowerToHyperService;
11use io::ServerIo;
12use tokio::task::JoinSet;
13use tokio_rustls::TlsAcceptor;
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::{body::BoxBody, connection_info::ActiveConnections};
18
19pub mod body;
20mod config;
21mod connection_handler;
22mod connection_info;
23mod fuse;
24mod io;
25mod listener;
26
27pub use config::Config;
28pub use connection_info::{ConnectInfo, ConnectionId, ConnectionInfo, PeerCertificates};
29pub use listener::{Listener, ListenerExt};
30
31pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
32const ALPN_H2: &[u8] = b"h2";
34const ALPN_H1: &[u8] = b"http/1.1";
36
37#[derive(Default)]
38pub struct Builder {
39 config: Config,
40 tls_config: Option<tokio_rustls::rustls::ServerConfig>,
41}
42
43impl Builder {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn config(mut self, config: Config) -> Self {
49 self.config = config;
50 self
51 }
52
53 pub fn tls_config(mut self, tls_config: tokio_rustls::rustls::ServerConfig) -> Self {
54 self.tls_config = Some(tls_config);
55 self
56 }
57
58 pub fn serve<A, S, ResponseBody>(
59 self,
60 addr: A,
61 service: S,
62 ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
63 where
64 A: std::net::ToSocketAddrs,
65 S: Service<
66 Request<BoxBody>,
67 Response = Response<ResponseBody>,
68 Error: Into<BoxError>,
69 Future: Send,
70 > + Clone
71 + Send
72 + 'static,
73 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
74 {
75 let listener = listener::TcpListenerWithOptions::new(
76 addr,
77 self.config.tcp_nodelay,
78 self.config.tcp_keepalive,
79 )?;
80
81 Self::serve_with_listener(self, listener, service)
82 }
83
84 fn serve_with_listener<L, S, ResponseBody>(
85 self,
86 listener: L,
87 service: S,
88 ) -> Result<ServerHandle<L::Addr>, BoxError>
89 where
90 L: Listener,
91 S: Service<
92 Request<BoxBody>,
93 Response = Response<ResponseBody>,
94 Error: Into<BoxError>,
95 Future: Send,
96 > + Clone
97 + Send
98 + 'static,
99 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
100 {
101 let local_addr = listener.local_addr()?;
102 let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
103 let connections = ActiveConnections::default();
104
105 let tls_config = self.tls_config.map(|mut tls| {
106 tls.alpn_protocols.push(ALPN_H2.into());
107 if self.config.accept_http1 {
108 tls.alpn_protocols.push(ALPN_H1.into());
109 }
110 Arc::new(tls)
111 });
112
113 let (watch_sender, watch_receiver) = tokio::sync::watch::channel(());
114 let server = Server {
115 config: self.config,
116 tls_config,
117 listener,
118 local_addr: local_addr.clone(),
119 service: ServiceBuilder::new()
120 .layer(tower::util::BoxCloneService::layer())
121 .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
122 .map_err(Into::into)
123 .service(service),
124 pending_connections: JoinSet::new(),
125 connection_handlers: JoinSet::new(),
126 connections: connections.clone(),
127 graceful_shutdown_token: graceful_shutdown_token.clone(),
128 _watch_receiver: watch_receiver,
129 };
130
131 let handle = ServerHandle(Arc::new(HandleInner {
132 local_addr,
133 connections,
134 graceful_shutdown_token,
135 watch_sender,
136 }));
137
138 tokio::spawn(server.serve());
139
140 Ok(handle)
141 }
142}
143
144#[derive(Debug)]
145pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
146
147#[derive(Debug)]
148struct HandleInner<A = std::net::SocketAddr> {
149 local_addr: A,
151 connections: ActiveConnections<A>,
152 graceful_shutdown_token: tokio_util::sync::CancellationToken,
153 watch_sender: tokio::sync::watch::Sender<()>,
154}
155
156impl<A> ServerHandle<A> {
157 pub fn local_addr(&self) -> &A {
159 &self.0.local_addr
160 }
161
162 pub fn trigger_shutdown(&self) {
165 self.0.graceful_shutdown_token.cancel();
166 }
167
168 pub async fn wait_for_shutdown(&self) {
174 self.0.watch_sender.closed().await
175 }
176
177 pub async fn shutdown(&self) {
180 self.trigger_shutdown();
181 self.wait_for_shutdown().await;
182 }
183
184 pub fn is_shutdown(&self) -> bool {
186 self.0.watch_sender.is_closed()
187 }
188
189 pub fn connections(
190 &self,
191 ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
192 self.0.connections.read().unwrap()
193 }
194
195 pub fn number_of_connections(&self) -> usize {
197 self.connections().len()
198 }
199}
200
201impl<A> Clone for ServerHandle<A> {
202 fn clone(&self) -> Self {
203 Self(self.0.clone())
204 }
205}
206
207type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
208
209struct Server<L: Listener> {
210 config: Config,
211 tls_config: Option<Arc<tokio_rustls::rustls::ServerConfig>>,
212
213 listener: L,
214 local_addr: L::Addr,
215 service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
216
217 pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
218 connection_handlers: JoinSet<()>,
219 connections: ActiveConnections<L::Addr>,
220 graceful_shutdown_token: tokio_util::sync::CancellationToken,
221 _watch_receiver: tokio::sync::watch::Receiver<()>,
223}
224
225impl<L> Server<L>
226where
227 L: Listener,
228{
229 async fn serve(mut self) -> Result<(), BoxError> {
230 loop {
231 tokio::select! {
232 _ = self.graceful_shutdown_token.cancelled() => {
233 trace!("signal received, shutting down");
234 break;
235 },
236 (io, remote_addr) = self.listener.accept() => {
237 self.handle_incoming(io, remote_addr);
238 },
239 Some(maybe_connection) = self.pending_connections.join_next() => {
240 let (io, remote_addr) = match maybe_connection.unwrap() {
242 Ok((io, remote_addr)) => {
243 (io, remote_addr)
244 }
245 Err(e) => {
246 tracing::debug!(error = %e, "error accepting connection");
247 continue;
248 }
249 };
250
251 trace!("connection accepted");
252 self.handle_connection(io, remote_addr);
253 },
254 Some(connection_handler_output) = self.connection_handlers.join_next() => {
255 let _: () = connection_handler_output.unwrap();
257 },
258 }
259 }
260
261 self.shutdown().await;
263
264 Ok(())
265 }
266
267 fn handle_incoming(&mut self, io: L::Io, remote_addr: L::Addr) {
268 if let Some(tls) = self.tls_config.clone() {
269 let tls_acceptor = TlsAcceptor::from(tls);
270 let allow_insecure = self.config.allow_insecure;
271 self.pending_connections.spawn(async move {
272 if allow_insecure {
273 if let Some(tcp) =
276 <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
277 {
278 let mut buf = [0; 1];
280 tcp.peek(&mut buf).await?;
283 if buf != [0x16] {
286 tracing::trace!("accepting insecure connection");
287 return Ok((ServerIo::new_io(io), remote_addr));
288 }
289 } else {
290 tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
291 }
292 }
293
294 tracing::trace!("accepting TLS connection");
295 let io = tls_acceptor.accept(io).await?;
296 Ok((ServerIo::new_tls_io(io), remote_addr))
297 });
298 } else {
299 self.handle_connection(ServerIo::new_io(io), remote_addr);
300 }
301 }
302
303 fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
304 let connection_shutdown_token = self.graceful_shutdown_token.child_token();
305 let connection_info = ConnectionInfo::new(
306 remote_addr,
307 io.peer_certs(),
308 connection_shutdown_token.clone(),
309 );
310 let connection_id = connection_info.id();
311 let connect_info = connection_info::ConnectInfo {
312 local_addr: self.local_addr.clone(),
313 remote_addr: connection_info.remote_address().clone(),
314 };
315 let peer_certificates = connection_info.peer_certificates().cloned();
316 let hyper_io = hyper_util::rt::TokioIo::new(io);
317
318 let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
319 move |mut request: Request<hyper::body::Incoming>| {
320 request.extensions_mut().insert(connect_info.clone());
321 if let Some(peer_certificates) = peer_certificates.clone() {
322 request.extensions_mut().insert(peer_certificates);
323 }
324
325 request.map(body::boxed)
326 },
327 ));
328
329 self.connections
330 .write()
331 .unwrap()
332 .insert(connection_id, connection_info);
333 let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
334
335 self.connection_handlers
336 .spawn(connection_handler::serve_connection(
337 hyper_io,
338 hyper_svc,
339 self.config.connection_builder(),
340 connection_shutdown_token,
341 self.config.max_connection_age,
342 on_connection_close,
343 ));
344 }
345
346 async fn shutdown(mut self) {
347 const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
350
351 self.graceful_shutdown_token.cancel();
353
354 self.pending_connections.shutdown().await;
356
357 trace!(
359 "waiting for {} connections to close",
360 self.connection_handlers.len()
361 );
362
363 let graceful_shutdown =
364 async { while self.connection_handlers.join_next().await.is_some() {} };
365
366 if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
367 .await
368 .is_err()
369 {
370 tracing::warn!(
371 "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
372 CONNECTION_SHUTDOWN_GRACE_PERIOD
373 );
374 self.connection_handlers.shutdown().await;
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use axum::Router;
382
383 use super::*;
384
385 #[tokio::test]
386 async fn simple() {
387 const MESSAGE: &str = "Hello, World!";
388
389 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
390
391 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
392
393 let url = format!("http://{}", handle.local_addr());
394
395 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
396
397 assert_eq!(response, MESSAGE.as_bytes());
398 }
399
400 #[tokio::test]
401 async fn shutdown() {
402 const MESSAGE: &str = "Hello, World!";
403
404 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
405
406 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
407
408 let url = format!("http://{}", handle.local_addr());
409
410 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
411
412 assert_eq!(handle.connections().len(), 1);
414
415 assert_eq!(response, MESSAGE.as_bytes());
416
417 assert!(!handle.is_shutdown());
418
419 handle.shutdown().await;
420
421 assert!(handle.is_shutdown());
422
423 assert_eq!(handle.connections().len(), 0);
425 }
426}