iota_http/
io.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2025 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    io,
7    io::IoSlice,
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio_rustls::server::TlsStream;
14
15pub(crate) enum ServerIo<IO> {
16    Io(IO),
17    TlsIo(Box<TlsStream<IO>>),
18}
19
20impl<IO> ServerIo<IO> {
21    pub(crate) fn new_io(io: IO) -> Self {
22        Self::Io(io)
23    }
24
25    pub(crate) fn new_tls_io(io: TlsStream<IO>) -> Self {
26        Self::TlsIo(Box::new(io))
27    }
28
29    pub(crate) fn peer_certs(
30        &self,
31    ) -> Option<std::sync::Arc<Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>>> {
32        match self {
33            Self::Io(_) => None,
34            Self::TlsIo(io) => {
35                let (_inner, session) = io.get_ref();
36
37                session
38                    .peer_certificates()
39                    .map(|certs| certs.to_owned().into())
40            }
41        }
42    }
43}
44
45impl<IO> AsyncRead for ServerIo<IO>
46where
47    IO: AsyncWrite + AsyncRead + Unpin,
48{
49    fn poll_read(
50        mut self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52        buf: &mut ReadBuf<'_>,
53    ) -> Poll<io::Result<()>> {
54        match &mut *self {
55            Self::Io(io) => Pin::new(io).poll_read(cx, buf),
56            Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf),
57        }
58    }
59}
60
61impl<IO> AsyncWrite for ServerIo<IO>
62where
63    IO: AsyncWrite + AsyncRead + Unpin,
64{
65    fn poll_write(
66        mut self: Pin<&mut Self>,
67        cx: &mut Context<'_>,
68        buf: &[u8],
69    ) -> Poll<io::Result<usize>> {
70        match &mut *self {
71            Self::Io(io) => Pin::new(io).poll_write(cx, buf),
72            Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf),
73        }
74    }
75
76    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
77        match &mut *self {
78            Self::Io(io) => Pin::new(io).poll_flush(cx),
79            Self::TlsIo(io) => Pin::new(io).poll_flush(cx),
80        }
81    }
82
83    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
84        match &mut *self {
85            Self::Io(io) => Pin::new(io).poll_shutdown(cx),
86            Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx),
87        }
88    }
89
90    fn poll_write_vectored(
91        mut self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        bufs: &[IoSlice<'_>],
94    ) -> Poll<Result<usize, io::Error>> {
95        match &mut *self {
96            Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs),
97            Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs),
98        }
99    }
100
101    fn is_write_vectored(&self) -> bool {
102        match self {
103            Self::Io(io) => io.is_write_vectored(),
104            Self::TlsIo(io) => io.is_write_vectored(),
105        }
106    }
107}