1use std::time::Duration;
6
7pub trait Listener: Send + 'static {
9 type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static;
11
12 type Addr: Clone + Send + Sync + 'static;
15
16 fn accept(&mut self) -> impl std::future::Future<Output = (Self::Io, Self::Addr)> + Send;
21
22 fn local_addr(&self) -> std::io::Result<Self::Addr>;
24}
25
26pub trait ListenerExt: Listener + Sized {
28 fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
48 where
49 F: FnMut(&mut Self::Io) + Send + 'static,
50 {
51 TapIo {
52 listener: self,
53 tap_fn,
54 }
55 }
56}
57
58impl<L: Listener> ListenerExt for L {}
59
60impl Listener for tokio::net::TcpListener {
61 type Io = tokio::net::TcpStream;
62 type Addr = std::net::SocketAddr;
63
64 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
65 loop {
66 match Self::accept(self).await {
67 Ok(tup) => return tup,
68 Err(e) => handle_accept_error(e).await,
69 }
70 }
71 }
72
73 #[inline]
74 fn local_addr(&self) -> std::io::Result<Self::Addr> {
75 Self::local_addr(self)
76 }
77}
78
79#[derive(Debug)]
80pub struct TcpListenerWithOptions {
81 inner: tokio::net::TcpListener,
82 nodelay: bool,
83 keepalive: Option<Duration>,
84}
85
86impl TcpListenerWithOptions {
87 pub fn new<A: std::net::ToSocketAddrs>(
88 addr: A,
89 nodelay: bool,
90 keepalive: Option<Duration>,
91 ) -> Result<Self, crate::BoxError> {
92 let std_listener = std::net::TcpListener::bind(addr)?;
93 std_listener.set_nonblocking(true)?;
94 let listener = tokio::net::TcpListener::from_std(std_listener)?;
95
96 Ok(Self::from_listener(listener, nodelay, keepalive))
97 }
98
99 pub fn from_listener(
101 listener: tokio::net::TcpListener,
102 nodelay: bool,
103 keepalive: Option<Duration>,
104 ) -> Self {
105 Self {
106 inner: listener,
107 nodelay,
108 keepalive,
109 }
110 }
111
112 fn set_accepted_socket_options(&self, stream: &tokio::net::TcpStream) {
114 if self.nodelay {
115 if let Err(e) = stream.set_nodelay(true) {
116 tracing::warn!("error trying to set TCP nodelay: {}", e);
117 }
118 }
119
120 if let Some(timeout) = self.keepalive {
121 let sock_ref = socket2::SockRef::from(&stream);
122 let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
123
124 if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
125 tracing::warn!("error trying to set TCP keepalive: {}", e);
126 }
127 }
128 }
129}
130
131impl Listener for TcpListenerWithOptions {
132 type Io = tokio::net::TcpStream;
133 type Addr = std::net::SocketAddr;
134
135 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
136 let (io, addr) = Listener::accept(&mut self.inner).await;
137 self.set_accepted_socket_options(&io);
138 (io, addr)
139 }
140
141 #[inline]
142 fn local_addr(&self) -> std::io::Result<Self::Addr> {
143 Listener::local_addr(&self.inner)
144 }
145}
146
147pub struct TapIo<L, F> {
172 listener: L,
173 tap_fn: F,
174}
175
176impl<L, F> std::fmt::Debug for TapIo<L, F>
177where
178 L: Listener + std::fmt::Debug,
179{
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("TapIo")
182 .field("listener", &self.listener)
183 .finish_non_exhaustive()
184 }
185}
186
187impl<L, F> Listener for TapIo<L, F>
188where
189 L: Listener,
190 F: FnMut(&mut L::Io) + Send + 'static,
191{
192 type Io = L::Io;
193 type Addr = L::Addr;
194
195 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
196 let (mut io, addr) = self.listener.accept().await;
197 (self.tap_fn)(&mut io);
198 (io, addr)
199 }
200
201 fn local_addr(&self) -> std::io::Result<Self::Addr> {
202 self.listener.local_addr()
203 }
204}
205
206async fn handle_accept_error(e: std::io::Error) {
207 if is_connection_error(&e) {
208 return;
209 }
210
211 tracing::error!("accept error: {e}");
223 tokio::time::sleep(Duration::from_secs(1)).await;
224}
225
226fn is_connection_error(e: &std::io::Error) -> bool {
227 use std::io::ErrorKind;
228
229 matches!(
230 e.kind(),
231 ErrorKind::ConnectionRefused
232 | ErrorKind::ConnectionAborted
233 | ErrorKind::ConnectionReset
234 | ErrorKind::BrokenPipe
235 | ErrorKind::Interrupted
236 | ErrorKind::WouldBlock
237 | ErrorKind::TimedOut
238 )
239}