1use std::{
6 io,
7 io::IoSlice,
8 pin::Pin,
9 task::{Context, Poll},
10};
11
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio_rustls::server::TlsStream;
14
15pub(crate) enum ServerIo<IO> {
16 Io(IO),
17 TlsIo(Box<TlsStream<IO>>),
18}
19
20impl<IO> ServerIo<IO> {
21 pub(crate) fn new_io(io: IO) -> Self {
22 Self::Io(io)
23 }
24
25 pub(crate) fn new_tls_io(io: TlsStream<IO>) -> Self {
26 Self::TlsIo(Box::new(io))
27 }
28
29 pub(crate) fn peer_certs(
30 &self,
31 ) -> Option<std::sync::Arc<Vec<tokio_rustls::rustls::pki_types::CertificateDer<'static>>>> {
32 match self {
33 Self::Io(_) => None,
34 Self::TlsIo(io) => {
35 let (_inner, session) = io.get_ref();
36
37 session
38 .peer_certificates()
39 .map(|certs| certs.to_owned().into())
40 }
41 }
42 }
43}
44
45impl<IO> AsyncRead for ServerIo<IO>
46where
47 IO: AsyncWrite + AsyncRead + Unpin,
48{
49 fn poll_read(
50 mut self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 buf: &mut ReadBuf<'_>,
53 ) -> Poll<io::Result<()>> {
54 match &mut *self {
55 Self::Io(io) => Pin::new(io).poll_read(cx, buf),
56 Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf),
57 }
58 }
59}
60
61impl<IO> AsyncWrite for ServerIo<IO>
62where
63 IO: AsyncWrite + AsyncRead + Unpin,
64{
65 fn poll_write(
66 mut self: Pin<&mut Self>,
67 cx: &mut Context<'_>,
68 buf: &[u8],
69 ) -> Poll<io::Result<usize>> {
70 match &mut *self {
71 Self::Io(io) => Pin::new(io).poll_write(cx, buf),
72 Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf),
73 }
74 }
75
76 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
77 match &mut *self {
78 Self::Io(io) => Pin::new(io).poll_flush(cx),
79 Self::TlsIo(io) => Pin::new(io).poll_flush(cx),
80 }
81 }
82
83 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
84 match &mut *self {
85 Self::Io(io) => Pin::new(io).poll_shutdown(cx),
86 Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx),
87 }
88 }
89
90 fn poll_write_vectored(
91 mut self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 bufs: &[IoSlice<'_>],
94 ) -> Poll<Result<usize, io::Error>> {
95 match &mut *self {
96 Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs),
97 Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs),
98 }
99 }
100
101 fn is_write_vectored(&self) -> bool {
102 match self {
103 Self::Io(io) => io.is_write_vectored(),
104 Self::TlsIo(io) => io.is_write_vectored(),
105 }
106 }
107}