1use std::{collections::HashMap, sync::Arc, time::Duration};
6
7use connection_handler::OnConnectionClose;
8pub use http;
9use http::{Request, Response};
10use hyper_util::service::TowerToHyperService;
11use io::ServerIo;
12use tokio::task::JoinSet;
13use tokio_rustls::TlsAcceptor;
14use tower::{Service, ServiceBuilder, ServiceExt};
15use tracing::trace;
16
17use self::{body::BoxBody, connection_info::ActiveConnections};
18
19pub mod body;
20mod config;
21mod connection_handler;
22mod connection_info;
23mod fuse;
24mod io;
25mod listener;
26
27pub use config::Config;
28pub use connection_info::{ConnectInfo, ConnectionId, ConnectionInfo, PeerCertificates};
29pub use listener::{Listener, ListenerExt};
30
31pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;
32const ALPN_H2: &[u8] = b"h2";
34const ALPN_H1: &[u8] = b"http/1.1";
36
37#[derive(Default)]
38pub struct Builder {
39 config: Config,
40 tls_config: Option<tokio_rustls::rustls::ServerConfig>,
41}
42
43impl Builder {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn config(mut self, config: Config) -> Self {
49 self.config = config;
50 self
51 }
52
53 pub fn tls_config(mut self, tls_config: tokio_rustls::rustls::ServerConfig) -> Self {
54 self.tls_config = Some(tls_config);
55 self
56 }
57
58 pub fn serve<A, S, ResponseBody>(
59 self,
60 addr: A,
61 service: S,
62 ) -> Result<ServerHandle<std::net::SocketAddr>, BoxError>
63 where
64 A: std::net::ToSocketAddrs,
65 S: Service<
66 Request<BoxBody>,
67 Response = Response<ResponseBody>,
68 Error: Into<BoxError>,
69 Future: Send,
70 > + Clone
71 + Send
72 + 'static,
73 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
74 {
75 let listener = listener::TcpListenerWithOptions::new(
76 addr,
77 self.config.tcp_nodelay,
78 self.config.tcp_keepalive,
79 )?;
80
81 Self::serve_with_listener(self, listener, service)
82 }
83
84 fn serve_with_listener<L, S, ResponseBody>(
85 self,
86 listener: L,
87 service: S,
88 ) -> Result<ServerHandle<L::Addr>, BoxError>
89 where
90 L: Listener,
91 S: Service<
92 Request<BoxBody>,
93 Response = Response<ResponseBody>,
94 Error: Into<BoxError>,
95 Future: Send,
96 > + Clone
97 + Send
98 + 'static,
99 ResponseBody: http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>> + Send + 'static,
100 {
101 let local_addr = listener.local_addr()?;
102 let graceful_shutdown_token = tokio_util::sync::CancellationToken::new();
103 let connections = ActiveConnections::default();
104
105 let tls_config = self.tls_config.map(|mut tls| {
106 tls.alpn_protocols.push(ALPN_H2.into());
107 if self.config.accept_http1 {
108 tls.alpn_protocols.push(ALPN_H1.into());
109 }
110 Arc::new(tls)
111 });
112
113 let (watch_sender, watch_receiver) = tokio::sync::watch::channel(());
114 let server = Server {
115 config: self.config,
116 tls_config,
117 listener,
118 local_addr: local_addr.clone(),
119 service: ServiceBuilder::new()
120 .layer(tower::util::BoxCloneService::layer())
121 .map_response(|response: Response<ResponseBody>| response.map(body::boxed))
122 .map_err(Into::into)
123 .service(service),
124 pending_connections: JoinSet::new(),
125 connection_handlers: JoinSet::new(),
126 connections: connections.clone(),
127 graceful_shutdown_token: graceful_shutdown_token.clone(),
128 _watch_receiver: watch_receiver,
129 };
130
131 let handle = ServerHandle(Arc::new(HandleInner {
132 local_addr,
133 connections,
134 graceful_shutdown_token,
135 watch_sender,
136 }));
137
138 tokio::spawn(server.serve());
139
140 Ok(handle)
141 }
142}
143
144#[derive(Debug)]
145pub struct ServerHandle<A = std::net::SocketAddr>(Arc<HandleInner<A>>);
146
147#[derive(Debug)]
148struct HandleInner<A = std::net::SocketAddr> {
149 local_addr: A,
151 connections: ActiveConnections<A>,
152 graceful_shutdown_token: tokio_util::sync::CancellationToken,
153 watch_sender: tokio::sync::watch::Sender<()>,
154}
155
156impl<A> ServerHandle<A> {
157 pub fn local_addr(&self) -> &A {
159 &self.0.local_addr
160 }
161
162 pub fn trigger_shutdown(&self) {
165 self.0.graceful_shutdown_token.cancel();
166 }
167
168 pub async fn wait_for_shutdown(&self) {
174 self.0.watch_sender.closed().await
175 }
176
177 pub async fn shutdown(&self) {
180 self.trigger_shutdown();
181 self.wait_for_shutdown().await;
182 }
183
184 pub fn is_shutdown(&self) -> bool {
186 self.0.watch_sender.is_closed()
187 }
188
189 pub fn connections(
190 &self,
191 ) -> std::sync::RwLockReadGuard<'_, HashMap<ConnectionId, ConnectionInfo<A>>> {
192 self.0.connections.read().unwrap()
193 }
194
195 pub fn number_of_connections(&self) -> usize {
197 self.connections().len()
198 }
199}
200
201type ConnectingOutput<Io, Addr> = Result<(ServerIo<Io>, Addr), crate::BoxError>;
202
203struct Server<L: Listener> {
204 config: Config,
205 tls_config: Option<Arc<tokio_rustls::rustls::ServerConfig>>,
206
207 listener: L,
208 local_addr: L::Addr,
209 service: tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::BoxError>,
210
211 pending_connections: JoinSet<ConnectingOutput<L::Io, L::Addr>>,
212 connection_handlers: JoinSet<()>,
213 connections: ActiveConnections<L::Addr>,
214 graceful_shutdown_token: tokio_util::sync::CancellationToken,
215 _watch_receiver: tokio::sync::watch::Receiver<()>,
217}
218
219impl<L> Server<L>
220where
221 L: Listener,
222{
223 async fn serve(mut self) -> Result<(), BoxError> {
224 loop {
225 tokio::select! {
226 _ = self.graceful_shutdown_token.cancelled() => {
227 trace!("signal received, shutting down");
228 break;
229 },
230 (io, remote_addr) = self.listener.accept() => {
231 self.handle_incoming(io, remote_addr);
232 },
233 Some(maybe_connection) = self.pending_connections.join_next() => {
234 let (io, remote_addr) = match maybe_connection.unwrap() {
236 Ok((io, remote_addr)) => {
237 (io, remote_addr)
238 }
239 Err(e) => {
240 tracing::debug!(error = %e, "error accepting connection");
241 continue;
242 }
243 };
244
245 trace!("connection accepted");
246 self.handle_connection(io, remote_addr);
247 },
248 Some(connection_handler_output) = self.connection_handlers.join_next() => {
249 let _: () = connection_handler_output.unwrap();
251 },
252 }
253 }
254
255 self.shutdown().await;
257
258 Ok(())
259 }
260
261 fn handle_incoming(&mut self, io: L::Io, remote_addr: L::Addr) {
262 if let Some(tls) = self.tls_config.clone() {
263 let tls_acceptor = TlsAcceptor::from(tls);
264 let allow_insecure = self.config.allow_insecure;
265 self.pending_connections.spawn(async move {
266 if allow_insecure {
267 if let Some(tcp) =
270 <dyn std::any::Any>::downcast_ref::<tokio::net::TcpStream>(&io)
271 {
272 let mut buf = [0; 1];
274 tcp.peek(&mut buf).await?;
277 if buf != [0x16] {
280 tracing::trace!("accepting insecure connection");
281 return Ok((ServerIo::new_io(io), remote_addr));
282 }
283 } else {
284 tracing::warn!("'allow_insecure' is configured but io type is not 'tokio::net::TcpStream'");
285 }
286 }
287
288 tracing::trace!("accepting TLS connection");
289 let io = tls_acceptor.accept(io).await?;
290 Ok((ServerIo::new_tls_io(io), remote_addr))
291 });
292 } else {
293 self.handle_connection(ServerIo::new_io(io), remote_addr);
294 }
295 }
296
297 fn handle_connection(&mut self, io: ServerIo<L::Io>, remote_addr: L::Addr) {
298 let connection_shutdown_token = self.graceful_shutdown_token.child_token();
299 let connection_info = ConnectionInfo::new(
300 remote_addr,
301 io.peer_certs(),
302 connection_shutdown_token.clone(),
303 );
304 let connection_id = connection_info.id();
305 let connect_info = connection_info::ConnectInfo {
306 local_addr: self.local_addr.clone(),
307 remote_addr: connection_info.remote_address().clone(),
308 };
309 let peer_certificates = connection_info.peer_certificates().cloned();
310 let hyper_io = hyper_util::rt::TokioIo::new(io);
311
312 let hyper_svc = TowerToHyperService::new(self.service.clone().map_request(
313 move |mut request: Request<hyper::body::Incoming>| {
314 request.extensions_mut().insert(connect_info.clone());
315 if let Some(peer_certificates) = peer_certificates.clone() {
316 request.extensions_mut().insert(peer_certificates);
317 }
318
319 request.map(body::boxed)
320 },
321 ));
322
323 self.connections
324 .write()
325 .unwrap()
326 .insert(connection_id, connection_info);
327 let on_connection_close = OnConnectionClose::new(connection_id, self.connections.clone());
328
329 self.connection_handlers
330 .spawn(connection_handler::serve_connection(
331 hyper_io,
332 hyper_svc,
333 self.config.connection_builder(),
334 connection_shutdown_token,
335 self.config.max_connection_age,
336 on_connection_close,
337 ));
338 }
339
340 async fn shutdown(mut self) {
341 const CONNECTION_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(1);
344
345 self.graceful_shutdown_token.cancel();
347
348 self.pending_connections.shutdown().await;
350
351 trace!(
353 "waiting for {} connections to close",
354 self.connection_handlers.len()
355 );
356
357 let graceful_shutdown =
358 async { while self.connection_handlers.join_next().await.is_some() {} };
359
360 if tokio::time::timeout(CONNECTION_SHUTDOWN_GRACE_PERIOD, graceful_shutdown)
361 .await
362 .is_err()
363 {
364 tracing::warn!(
365 "Failed to stop all connection handlers in {:?}. Forcing shutdown.",
366 CONNECTION_SHUTDOWN_GRACE_PERIOD
367 );
368 self.connection_handlers.shutdown().await;
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use axum::Router;
376
377 use super::*;
378
379 #[tokio::test]
380 async fn simple() {
381 const MESSAGE: &str = "Hello, World!";
382
383 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
384
385 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
386
387 let url = format!("http://{}", handle.local_addr());
388
389 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
390
391 assert_eq!(response, MESSAGE.as_bytes());
392 }
393
394 #[tokio::test]
395 async fn shutdown() {
396 const MESSAGE: &str = "Hello, World!";
397
398 let app = Router::new().route("/", axum::routing::get(|| async { MESSAGE }));
399
400 let handle = Builder::new().serve(("localhost", 0), app).unwrap();
401
402 let url = format!("http://{}", handle.local_addr());
403
404 let response = reqwest::get(url).await.unwrap().bytes().await.unwrap();
405
406 assert_eq!(handle.connections().len(), 1);
408
409 assert_eq!(response, MESSAGE.as_bytes());
410
411 assert!(!handle.is_shutdown());
412
413 handle.shutdown().await;
414
415 assert!(handle.is_shutdown());
416
417 assert_eq!(handle.connections().len(), 0);
419 }
420}