iota_tls/
acceptor.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{io, sync::Arc};
6
7use axum::{Extension, middleware::AddExtension};
8use axum_server::{
9    accept::Accept,
10    tls_rustls::{RustlsAcceptor, RustlsConfig},
11};
12use fastcrypto::ed25519::Ed25519PublicKey;
13use rustls::pki_types::CertificateDer;
14use tokio::io::{AsyncRead, AsyncWrite};
15use tokio_rustls::server::TlsStream;
16use tower_layer::Layer;
17
18#[derive(Debug, Clone)]
19pub struct TlsConnectionInfo {
20    sni_hostname: Option<Arc<str>>,
21    peer_certificates: Option<Arc<[CertificateDer<'static>]>>,
22    public_key: Option<Ed25519PublicKey>,
23}
24
25impl TlsConnectionInfo {
26    pub fn sni_hostname(&self) -> Option<&str> {
27        self.sni_hostname.as_deref()
28    }
29
30    pub fn peer_certificates(&self) -> Option<&[CertificateDer<'static>]> {
31        self.peer_certificates.as_deref()
32    }
33
34    pub fn public_key(&self) -> Option<&Ed25519PublicKey> {
35        self.public_key.as_ref()
36    }
37}
38
39/// An `Acceptor` that will provide `TlsConnectionInfo` as an axum `Extension`
40/// for use in handlers.
41#[derive(Debug, Clone)]
42pub struct TlsAcceptor {
43    inner: RustlsAcceptor,
44}
45
46impl TlsAcceptor {
47    pub fn new(config: rustls::ServerConfig) -> Self {
48        Self {
49            inner: RustlsAcceptor::new(RustlsConfig::from_config(Arc::new(config))),
50        }
51    }
52}
53
54type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
55
56impl<I, S> Accept<I, S> for TlsAcceptor
57where
58    I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
59    S: Send + 'static,
60{
61    type Stream = TlsStream<I>;
62    type Service = AddExtension<S, TlsConnectionInfo>;
63    type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
64
65    fn accept(&self, stream: I, service: S) -> Self::Future {
66        let acceptor = self.inner.clone();
67
68        Box::pin(async move {
69            let (stream, service) = acceptor.accept(stream, service).await?;
70            let server_conn = stream.get_ref().1;
71
72            let public_key = if let Some([peer_certificate, ..]) = server_conn.peer_certificates() {
73                crate::certgen::public_key_from_certificate(peer_certificate).ok()
74            } else {
75                None
76            };
77
78            let tls_connect_info = TlsConnectionInfo {
79                peer_certificates: server_conn.peer_certificates().map(From::from),
80                sni_hostname: server_conn.server_name().map(From::from),
81                public_key,
82            };
83            let service = Extension(tls_connect_info).layer(service);
84
85            Ok((stream, service))
86        })
87    }
88}