1use std::{
6 collections::HashMap,
7 fmt,
8 future::Future,
9 io,
10 net::{SocketAddr, ToSocketAddrs},
11 pin::Pin,
12 sync::{Arc, Mutex},
13 task::{self, Poll},
14 time::Instant,
15 vec,
16};
17
18use eyre::{Context, Result, eyre};
19use hyper_util::client::legacy::connect::{HttpConnector, dns::Name};
20use once_cell::sync::OnceCell;
21use tokio::task::JoinHandle;
22use tokio_rustls::rustls::ClientConfig;
23use tonic::transport::{Channel, Endpoint, Uri};
24use tower::Service;
25use tracing::{info, trace};
26
27use crate::{
28 config::Config,
29 multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
30};
31
32pub async fn connect(address: &Multiaddr, tls_config: Option<ClientConfig>) -> Result<Channel> {
33 let channel = endpoint_from_multiaddr(address, tls_config)?
34 .connect()
35 .await?;
36 Ok(channel)
37}
38
39pub fn connect_lazy(address: &Multiaddr, tls_config: Option<ClientConfig>) -> Result<Channel> {
40 let channel = endpoint_from_multiaddr(address, tls_config)?.connect_lazy();
41 Ok(channel)
42}
43
44pub(crate) async fn connect_with_config(
45 address: &Multiaddr,
46 tls_config: Option<ClientConfig>,
47 config: &Config,
48) -> Result<Channel> {
49 let channel = endpoint_from_multiaddr(address, tls_config)?
50 .apply_config(config)
51 .connect()
52 .await?;
53 Ok(channel)
54}
55
56pub(crate) fn connect_lazy_with_config(
57 address: &Multiaddr,
58 tls_config: Option<ClientConfig>,
59 config: &Config,
60) -> Result<Channel> {
61 let channel = endpoint_from_multiaddr(address, tls_config)?
62 .apply_config(config)
63 .connect_lazy();
64 Ok(channel)
65}
66
67fn endpoint_from_multiaddr(
68 addr: &Multiaddr,
69 tls_config: Option<ClientConfig>,
70) -> Result<MyEndpoint> {
71 let mut iter = addr.iter();
72
73 let channel = match iter.next().ok_or_else(|| eyre!("address is empty"))? {
74 Protocol::Dns(_) => {
75 let (dns_name, tcp_port, http_or_https) = parse_dns(addr)?;
76 let uri = format!("{http_or_https}://{dns_name}:{tcp_port}");
77 MyEndpoint::try_from_uri(uri, tls_config)?
78 }
79 Protocol::Ip4(_) => {
80 let (socket_addr, http_or_https) = parse_ip4(addr)?;
81 let uri = format!("{http_or_https}://{socket_addr}");
82 MyEndpoint::try_from_uri(uri, tls_config)?
83 }
84 Protocol::Ip6(_) => {
85 let (socket_addr, http_or_https) = parse_ip6(addr)?;
86 let uri = format!("{http_or_https}://{socket_addr}");
87 MyEndpoint::try_from_uri(uri, tls_config)?
88 }
89 unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
90 };
91
92 Ok(channel)
93}
94
95struct MyEndpoint {
96 endpoint: Endpoint,
97 tls_config: Option<ClientConfig>,
98}
99
100static DISABLE_CACHING_RESOLVER: OnceCell<bool> = OnceCell::new();
101
102impl MyEndpoint {
103 fn new(endpoint: Endpoint, tls_config: Option<ClientConfig>) -> Self {
104 Self {
105 endpoint,
106 tls_config,
107 }
108 }
109
110 fn try_from_uri(uri: String, tls_config: Option<ClientConfig>) -> Result<Self> {
111 let uri: Uri = uri
112 .parse()
113 .with_context(|| format!("unable to create Uri from '{uri}'"))?;
114 let endpoint = Endpoint::from(uri);
115 Ok(Self::new(endpoint, tls_config))
116 }
117
118 fn apply_config(mut self, config: &Config) -> Self {
119 self.endpoint = apply_config_to_endpoint(config, self.endpoint);
120 self
121 }
122
123 fn connect_lazy(self) -> Channel {
124 let disable_caching_resolver = *DISABLE_CACHING_RESOLVER.get_or_init(|| {
125 let disable_caching_resolver = std::env::var("DISABLE_CACHING_RESOLVER").is_ok();
126 info!("DISABLE_CACHING_RESOLVER: {disable_caching_resolver}");
127 disable_caching_resolver
128 });
129
130 if disable_caching_resolver {
131 let mut http = HttpConnector::new();
132 http.enforce_http(false);
133 http.set_nodelay(true);
134 http.set_keepalive(None);
135 http.set_connect_timeout(None);
136
137 if let Some(tls_config) = self.tls_config {
138 Channel::new(
139 hyper_rustls::HttpsConnectorBuilder::new()
140 .with_tls_config(tls_config)
141 .https_only()
142 .enable_http2()
143 .wrap_connector(http),
144 self.endpoint,
145 )
146 } else {
147 self.endpoint.connect_with_connector_lazy(http)
148 }
149 } else {
150 let mut http = HttpConnector::new_with_resolver(CachingResolver::new());
151 http.enforce_http(false);
152 http.set_nodelay(true);
153 http.set_keepalive(None);
154 http.set_connect_timeout(None);
155
156 if let Some(tls_config) = self.tls_config {
157 let https = hyper_rustls::HttpsConnectorBuilder::new()
158 .with_tls_config(tls_config)
159 .https_only()
160 .enable_http2()
161 .wrap_connector(http);
162 Channel::new(https, self.endpoint)
163 } else {
164 self.endpoint.connect_with_connector_lazy(http)
165 }
166 }
167 }
168
169 async fn connect(self) -> Result<Channel> {
170 if let Some(tls_config) = self.tls_config {
171 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
172 .with_tls_config(tls_config)
173 .https_only()
174 .enable_http2()
175 .build();
176 Channel::connect(https_connector, self.endpoint)
177 .await
178 .map_err(Into::into)
179 } else {
180 self.endpoint.connect().await.map_err(Into::into)
181 }
182 }
183}
184
185fn apply_config_to_endpoint(config: &Config, mut endpoint: Endpoint) -> Endpoint {
186 if let Some(limit) = config.concurrency_limit_per_connection {
187 endpoint = endpoint.concurrency_limit(limit);
188 }
189
190 if let Some(timeout) = config.request_timeout {
191 endpoint = endpoint.timeout(timeout);
192 }
193
194 if let Some(timeout) = config.connect_timeout {
195 endpoint = endpoint.connect_timeout(timeout);
196 }
197
198 if let Some(tcp_nodelay) = config.tcp_nodelay {
199 endpoint = endpoint.tcp_nodelay(tcp_nodelay);
200 }
201
202 if let Some(http2_keepalive_interval) = config.http2_keepalive_interval {
203 endpoint = endpoint.http2_keep_alive_interval(http2_keepalive_interval);
204 }
205
206 if let Some(http2_keepalive_timeout) = config.http2_keepalive_timeout {
207 endpoint = endpoint.keep_alive_timeout(http2_keepalive_timeout);
208 }
209
210 if let Some((limit, duration)) = config.rate_limit {
211 endpoint = endpoint.rate_limit(limit, duration);
212 }
213
214 endpoint
215 .initial_stream_window_size(config.http2_initial_stream_window_size)
216 .initial_connection_window_size(config.http2_initial_connection_window_size)
217 .tcp_keepalive(config.tcp_keepalive)
218}
219
220type CacheEntry = (Instant, Vec<SocketAddr>);
221
222#[derive(Clone)]
224pub struct CachingResolver {
225 cache: Arc<Mutex<HashMap<Name, CacheEntry>>>,
226}
227
228type SocketAddrs = vec::IntoIter<SocketAddr>;
229
230pub struct CachingFuture {
231 inner: JoinHandle<Result<SocketAddrs, io::Error>>,
232}
233
234impl CachingResolver {
235 pub fn new() -> Self {
236 CachingResolver {
237 cache: Arc::new(Mutex::new(HashMap::new())),
238 }
239 }
240}
241
242impl Default for CachingResolver {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248impl Service<Name> for CachingResolver {
249 type Response = SocketAddrs;
250 type Error = io::Error;
251 type Future = CachingFuture;
252
253 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
254 Poll::Ready(Ok(()))
255 }
256
257 fn call(&mut self, name: Name) -> Self::Future {
258 let blocking = {
259 let cache = self.cache.clone();
260 tokio::task::spawn_blocking(move || {
261 let entry = cache.lock().unwrap().get(&name).cloned();
262
263 if let Some((when, addrs)) = entry {
264 trace!("cached host={:?}", name.as_str());
265
266 if when.elapsed().as_secs() > 60 {
267 trace!("refreshing cache for host={:?}", name.as_str());
268 tokio::task::spawn_blocking(move || {
270 if let Ok(addrs) = (name.as_str(), 0).to_socket_addrs() {
271 let addrs: Vec<_> = addrs.collect();
272 trace!("updating cached host={:?}", name.as_str());
273 cache.lock().unwrap().insert(name, (Instant::now(), addrs));
274 }
275 });
276 }
277
278 Ok(addrs.into_iter())
279 } else {
280 trace!("resolving host={:?}", name.as_str());
281 match (name.as_str(), 0).to_socket_addrs() {
282 Ok(addrs) => {
283 let addrs: Vec<_> = addrs.collect();
284 cache
285 .lock()
286 .unwrap()
287 .insert(name, (Instant::now(), addrs.clone()));
288 Ok(addrs.into_iter())
289 }
290 res => res,
291 }
292 }
293 })
294 };
295
296 CachingFuture { inner: blocking }
297 }
298}
299
300impl fmt::Debug for CachingResolver {
301 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302 f.pad("CachingResolver")
303 }
304}
305
306impl Future for CachingFuture {
307 type Output = Result<SocketAddrs, io::Error>;
308
309 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
310 Pin::new(&mut self.inner).poll(cx).map(|res| match res {
311 Ok(Ok(addrs)) => Ok(addrs),
312 Ok(Err(err)) => Err(err),
313 Err(join_err) => {
314 if join_err.is_cancelled() {
315 Err(io::Error::new(io::ErrorKind::Interrupted, join_err))
316 } else {
317 panic!("background task failed: {join_err:?}")
318 }
319 }
320 })
321 }
322}
323
324impl fmt::Debug for CachingFuture {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 f.pad("CachingFuture")
327 }
328}
329
330impl Drop for CachingFuture {
331 fn drop(&mut self) {
332 self.inner.abort();
333 }
334}