1use std::{io, sync::Arc};
6
7use axum::{Extension, middleware::AddExtension};
8use axum_server::{
9 accept::Accept,
10 tls_rustls::{RustlsAcceptor, RustlsConfig},
11};
12use fastcrypto::ed25519::Ed25519PublicKey;
13use rustls::pki_types::CertificateDer;
14use tokio::io::{AsyncRead, AsyncWrite};
15use tokio_rustls::server::TlsStream;
16use tower_layer::Layer;
17
18#[derive(Debug, Clone)]
19pub struct TlsConnectionInfo {
20 sni_hostname: Option<Arc<str>>,
21 peer_certificates: Option<Arc<[CertificateDer<'static>]>>,
22 public_key: Option<Ed25519PublicKey>,
23}
24
25impl TlsConnectionInfo {
26 pub fn sni_hostname(&self) -> Option<&str> {
27 self.sni_hostname.as_deref()
28 }
29
30 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'static>]> {
31 self.peer_certificates.as_deref()
32 }
33
34 pub fn public_key(&self) -> Option<&Ed25519PublicKey> {
35 self.public_key.as_ref()
36 }
37}
38
39#[derive(Debug, Clone)]
42pub struct TlsAcceptor {
43 inner: RustlsAcceptor,
44}
45
46impl TlsAcceptor {
47 pub fn new(config: rustls::ServerConfig) -> Self {
48 Self {
49 inner: RustlsAcceptor::new(RustlsConfig::from_config(Arc::new(config))),
50 }
51 }
52}
53
54type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
55
56impl<I, S> Accept<I, S> for TlsAcceptor
57where
58 I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
59 S: Send + 'static,
60{
61 type Stream = TlsStream<I>;
62 type Service = AddExtension<S, TlsConnectionInfo>;
63 type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
64
65 fn accept(&self, stream: I, service: S) -> Self::Future {
66 let acceptor = self.inner.clone();
67
68 Box::pin(async move {
69 let (stream, service) = acceptor.accept(stream, service).await?;
70 let server_conn = stream.get_ref().1;
71
72 let public_key = if let Some([peer_certificate, ..]) = server_conn.peer_certificates() {
73 crate::certgen::public_key_from_certificate(peer_certificate).ok()
74 } else {
75 None
76 };
77
78 let tls_connect_info = TlsConnectionInfo {
79 peer_certificates: server_conn.peer_certificates().map(From::from),
80 sni_hostname: server_conn.server_name().map(From::from),
81 public_key,
82 };
83 let service = Extension(tls_connect_info).layer(service);
84
85 Ok((stream, service))
86 })
87 }
88}