iota_http/
listener.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2025 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::time::Duration;
6
7/// Types that can listen for connections.
8pub trait Listener: Send + 'static {
9    /// The listener's IO type.
10    type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static;
11
12    /// The listener's address type.
13    // all these bounds are necessary to add this information in a request extension
14    type Addr: Clone + Send + Sync + 'static;
15
16    /// Accept a new incoming connection to this listener.
17    ///
18    /// If the underlying accept call can return an error, this function must
19    /// take care of logging and retrying.
20    fn accept(&mut self) -> impl std::future::Future<Output = (Self::Io, Self::Addr)> + Send;
21
22    /// Returns the local address that this listener is bound to.
23    fn local_addr(&self) -> std::io::Result<Self::Addr>;
24}
25
26/// Extensions to [`Listener`].
27pub trait ListenerExt: Listener + Sized {
28    /// Run a mutable closure on every accepted `Io`.
29    ///
30    /// # Example
31    ///
32    /// ```
33    /// use iota_http::ListenerExt;
34    /// use tracing::trace;
35    ///
36    /// # async {
37    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
38    ///     .await
39    ///     .unwrap()
40    ///     .tap_io(|tcp_stream| {
41    ///         if let Err(err) = tcp_stream.set_nodelay(true) {
42    ///             trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
43    ///         }
44    ///     });
45    /// # };
46    /// ```
47    fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
48    where
49        F: FnMut(&mut Self::Io) + Send + 'static,
50    {
51        TapIo {
52            listener: self,
53            tap_fn,
54        }
55    }
56}
57
58impl<L: Listener> ListenerExt for L {}
59
60impl Listener for tokio::net::TcpListener {
61    type Io = tokio::net::TcpStream;
62    type Addr = std::net::SocketAddr;
63
64    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
65        loop {
66            match Self::accept(self).await {
67                Ok(tup) => return tup,
68                Err(e) => handle_accept_error(e).await,
69            }
70        }
71    }
72
73    #[inline]
74    fn local_addr(&self) -> std::io::Result<Self::Addr> {
75        Self::local_addr(self)
76    }
77}
78
79#[derive(Debug)]
80pub struct TcpListenerWithOptions {
81    inner: tokio::net::TcpListener,
82    nodelay: bool,
83    keepalive: Option<Duration>,
84}
85
86impl TcpListenerWithOptions {
87    pub fn new<A: std::net::ToSocketAddrs>(
88        addr: A,
89        nodelay: bool,
90        keepalive: Option<Duration>,
91    ) -> Result<Self, crate::BoxError> {
92        let std_listener = std::net::TcpListener::bind(addr)?;
93        std_listener.set_nonblocking(true)?;
94        let listener = tokio::net::TcpListener::from_std(std_listener)?;
95
96        Ok(Self::from_listener(listener, nodelay, keepalive))
97    }
98
99    /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
100    pub fn from_listener(
101        listener: tokio::net::TcpListener,
102        nodelay: bool,
103        keepalive: Option<Duration>,
104    ) -> Self {
105        Self {
106            inner: listener,
107            nodelay,
108            keepalive,
109        }
110    }
111
112    // Consistent with hyper-0.14, this function does not return an error.
113    fn set_accepted_socket_options(&self, stream: &tokio::net::TcpStream) {
114        if self.nodelay {
115            if let Err(e) = stream.set_nodelay(true) {
116                tracing::warn!("error trying to set TCP nodelay: {}", e);
117            }
118        }
119
120        if let Some(timeout) = self.keepalive {
121            let sock_ref = socket2::SockRef::from(&stream);
122            let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
123
124            if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
125                tracing::warn!("error trying to set TCP keepalive: {}", e);
126            }
127        }
128    }
129}
130
131impl Listener for TcpListenerWithOptions {
132    type Io = tokio::net::TcpStream;
133    type Addr = std::net::SocketAddr;
134
135    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
136        let (io, addr) = Listener::accept(&mut self.inner).await;
137        self.set_accepted_socket_options(&io);
138        (io, addr)
139    }
140
141    #[inline]
142    fn local_addr(&self) -> std::io::Result<Self::Addr> {
143        Listener::local_addr(&self.inner)
144    }
145}
146
147// Uncomment once we update tokio to >=1.41.0
148// #[cfg(unix)]
149// impl Listener for tokio::net::UnixListener {
150//     type Io = tokio::net::UnixStream;
151//     type Addr = std::os::unix::net::SocketAddr;
152
153//     async fn accept(&mut self) -> (Self::Io, Self::Addr) {
154//         loop {
155//             match Self::accept(self).await {
156//                 Ok((io, addr)) => return (io, addr.into()),
157//                 Err(e) => handle_accept_error(e).await,
158//             }
159//         }
160//     }
161
162//     #[inline]
163//     fn local_addr(&self) -> std::io::Result<Self::Addr> {
164//         Self::local_addr(self).map(Into::into)
165//     }
166// }
167
168/// Return type of [`ListenerExt::tap_io`].
169///
170/// See that method for details.
171pub struct TapIo<L, F> {
172    listener: L,
173    tap_fn: F,
174}
175
176impl<L, F> std::fmt::Debug for TapIo<L, F>
177where
178    L: Listener + std::fmt::Debug,
179{
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_struct("TapIo")
182            .field("listener", &self.listener)
183            .finish_non_exhaustive()
184    }
185}
186
187impl<L, F> Listener for TapIo<L, F>
188where
189    L: Listener,
190    F: FnMut(&mut L::Io) + Send + 'static,
191{
192    type Io = L::Io;
193    type Addr = L::Addr;
194
195    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
196        let (mut io, addr) = self.listener.accept().await;
197        (self.tap_fn)(&mut io);
198        (io, addr)
199    }
200
201    fn local_addr(&self) -> std::io::Result<Self::Addr> {
202        self.listener.local_addr()
203    }
204}
205
206async fn handle_accept_error(e: std::io::Error) {
207    if is_connection_error(&e) {
208        return;
209    }
210
211    // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
212    //
213    // > A possible scenario is that the process has hit the max open files
214    // > allowed, and so trying to accept a new connection will fail with
215    // > `EMFILE`. In some cases, it's preferable to just wait for some time, if
216    // > the application will likely close some files (or connections), and try
217    // > to accept the connection again. If this option is `true`, the error
218    // > will be logged at the `error` level, since it is still a big deal,
219    // > and then the listener will sleep for 1 second.
220    //
221    // hyper allowed customizing this but axum does not.
222    tracing::error!("accept error: {e}");
223    tokio::time::sleep(Duration::from_secs(1)).await;
224}
225
226fn is_connection_error(e: &std::io::Error) -> bool {
227    use std::io::ErrorKind;
228
229    matches!(
230        e.kind(),
231        ErrorKind::ConnectionRefused
232            | ErrorKind::ConnectionAborted
233            | ErrorKind::ConnectionReset
234            | ErrorKind::BrokenPipe
235            | ErrorKind::Interrupted
236            | ErrorKind::WouldBlock
237            | ErrorKind::TimedOut
238    )
239}