iota_network_stack/
client.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    fmt,
8    future::Future,
9    io,
10    net::{SocketAddr, ToSocketAddrs},
11    pin::Pin,
12    sync::{Arc, Mutex},
13    task::{self, Poll},
14    time::Instant,
15    vec,
16};
17
18use eyre::{Context, Result, eyre};
19use hyper_util::client::legacy::connect::{HttpConnector, dns::Name};
20use once_cell::sync::OnceCell;
21use tokio::task::JoinHandle;
22use tokio_rustls::rustls::ClientConfig;
23use tonic::transport::{Channel, Endpoint, Uri};
24use tower::Service;
25use tracing::{info, trace};
26
27use crate::{
28    config::Config,
29    multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
30};
31
32pub async fn connect(address: &Multiaddr, tls_config: Option<ClientConfig>) -> Result<Channel> {
33    let channel = endpoint_from_multiaddr(address, tls_config)?
34        .connect()
35        .await?;
36    Ok(channel)
37}
38
39pub fn connect_lazy(address: &Multiaddr, tls_config: Option<ClientConfig>) -> Result<Channel> {
40    let channel = endpoint_from_multiaddr(address, tls_config)?.connect_lazy();
41    Ok(channel)
42}
43
44pub(crate) async fn connect_with_config(
45    address: &Multiaddr,
46    tls_config: Option<ClientConfig>,
47    config: &Config,
48) -> Result<Channel> {
49    let channel = endpoint_from_multiaddr(address, tls_config)?
50        .apply_config(config)
51        .connect()
52        .await?;
53    Ok(channel)
54}
55
56pub(crate) fn connect_lazy_with_config(
57    address: &Multiaddr,
58    tls_config: Option<ClientConfig>,
59    config: &Config,
60) -> Result<Channel> {
61    let channel = endpoint_from_multiaddr(address, tls_config)?
62        .apply_config(config)
63        .connect_lazy();
64    Ok(channel)
65}
66
67fn endpoint_from_multiaddr(
68    addr: &Multiaddr,
69    tls_config: Option<ClientConfig>,
70) -> Result<MyEndpoint> {
71    let mut iter = addr.iter();
72
73    let channel = match iter.next().ok_or_else(|| eyre!("address is empty"))? {
74        Protocol::Dns(_) => {
75            let (dns_name, tcp_port, http_or_https) = parse_dns(addr)?;
76            let uri = format!("{http_or_https}://{dns_name}:{tcp_port}");
77            MyEndpoint::try_from_uri(uri, tls_config)?
78        }
79        Protocol::Ip4(_) => {
80            let (socket_addr, http_or_https) = parse_ip4(addr)?;
81            let uri = format!("{http_or_https}://{socket_addr}");
82            MyEndpoint::try_from_uri(uri, tls_config)?
83        }
84        Protocol::Ip6(_) => {
85            let (socket_addr, http_or_https) = parse_ip6(addr)?;
86            let uri = format!("{http_or_https}://{socket_addr}");
87            MyEndpoint::try_from_uri(uri, tls_config)?
88        }
89        unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
90    };
91
92    Ok(channel)
93}
94
95struct MyEndpoint {
96    endpoint: Endpoint,
97    tls_config: Option<ClientConfig>,
98}
99
100static DISABLE_CACHING_RESOLVER: OnceCell<bool> = OnceCell::new();
101
102impl MyEndpoint {
103    fn new(endpoint: Endpoint, tls_config: Option<ClientConfig>) -> Self {
104        Self {
105            endpoint,
106            tls_config,
107        }
108    }
109
110    fn try_from_uri(uri: String, tls_config: Option<ClientConfig>) -> Result<Self> {
111        let uri: Uri = uri
112            .parse()
113            .with_context(|| format!("unable to create Uri from '{uri}'"))?;
114        let endpoint = Endpoint::from(uri);
115        Ok(Self::new(endpoint, tls_config))
116    }
117
118    fn apply_config(mut self, config: &Config) -> Self {
119        self.endpoint = apply_config_to_endpoint(config, self.endpoint);
120        self
121    }
122
123    fn connect_lazy(self) -> Channel {
124        let disable_caching_resolver = *DISABLE_CACHING_RESOLVER.get_or_init(|| {
125            let disable_caching_resolver = std::env::var("DISABLE_CACHING_RESOLVER").is_ok();
126            info!("DISABLE_CACHING_RESOLVER: {disable_caching_resolver}");
127            disable_caching_resolver
128        });
129
130        if disable_caching_resolver {
131            if let Some(tls_config) = self.tls_config {
132                self.endpoint.connect_with_connector_lazy(
133                    hyper_rustls::HttpsConnectorBuilder::new()
134                        .with_tls_config(tls_config)
135                        .https_only()
136                        .enable_http2()
137                        .build(),
138                )
139            } else {
140                self.endpoint.connect_lazy()
141            }
142        } else {
143            let mut http = HttpConnector::new_with_resolver(CachingResolver::new());
144            http.enforce_http(false);
145            http.set_nodelay(true);
146            http.set_keepalive(None);
147            http.set_connect_timeout(None);
148
149            if let Some(tls_config) = self.tls_config {
150                let https = hyper_rustls::HttpsConnectorBuilder::new()
151                    .with_tls_config(tls_config)
152                    .https_only()
153                    .enable_http1()
154                    .wrap_connector(http);
155                self.endpoint.connect_with_connector_lazy(https)
156            } else {
157                self.endpoint.connect_with_connector_lazy(http)
158            }
159        }
160    }
161
162    async fn connect(self) -> Result<Channel> {
163        if let Some(tls_config) = self.tls_config {
164            let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
165                .with_tls_config(tls_config)
166                .https_only()
167                .enable_http2()
168                .build();
169            self.endpoint
170                .connect_with_connector(https_connector)
171                .await
172                .map_err(Into::into)
173        } else {
174            self.endpoint.connect().await.map_err(Into::into)
175        }
176    }
177}
178
179fn apply_config_to_endpoint(config: &Config, mut endpoint: Endpoint) -> Endpoint {
180    if let Some(limit) = config.concurrency_limit_per_connection {
181        endpoint = endpoint.concurrency_limit(limit);
182    }
183
184    if let Some(timeout) = config.request_timeout {
185        endpoint = endpoint.timeout(timeout);
186    }
187
188    if let Some(timeout) = config.connect_timeout {
189        endpoint = endpoint.connect_timeout(timeout);
190    }
191
192    if let Some(tcp_nodelay) = config.tcp_nodelay {
193        endpoint = endpoint.tcp_nodelay(tcp_nodelay);
194    }
195
196    if let Some(http2_keepalive_interval) = config.http2_keepalive_interval {
197        endpoint = endpoint.http2_keep_alive_interval(http2_keepalive_interval);
198    }
199
200    if let Some(http2_keepalive_timeout) = config.http2_keepalive_timeout {
201        endpoint = endpoint.keep_alive_timeout(http2_keepalive_timeout);
202    }
203
204    if let Some((limit, duration)) = config.rate_limit {
205        endpoint = endpoint.rate_limit(limit, duration);
206    }
207
208    endpoint
209        .initial_stream_window_size(config.http2_initial_stream_window_size)
210        .initial_connection_window_size(config.http2_initial_connection_window_size)
211        .tcp_keepalive(config.tcp_keepalive)
212}
213
214type CacheEntry = (Instant, Vec<SocketAddr>);
215
216/// A caching resolver based on hyper_util GaiResolver
217#[derive(Clone)]
218pub struct CachingResolver {
219    cache: Arc<Mutex<HashMap<Name, CacheEntry>>>,
220}
221
222type SocketAddrs = vec::IntoIter<SocketAddr>;
223
224pub struct CachingFuture {
225    inner: JoinHandle<Result<SocketAddrs, io::Error>>,
226}
227
228impl CachingResolver {
229    pub fn new() -> Self {
230        CachingResolver {
231            cache: Arc::new(Mutex::new(HashMap::new())),
232        }
233    }
234}
235
236impl Default for CachingResolver {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl Service<Name> for CachingResolver {
243    type Response = SocketAddrs;
244    type Error = io::Error;
245    type Future = CachingFuture;
246
247    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
248        Poll::Ready(Ok(()))
249    }
250
251    fn call(&mut self, name: Name) -> Self::Future {
252        let blocking = {
253            let cache = self.cache.clone();
254            tokio::task::spawn_blocking(move || {
255                let entry = cache.lock().unwrap().get(&name).cloned();
256
257                if let Some((when, addrs)) = entry {
258                    trace!("cached host={:?}", name.as_str());
259
260                    if when.elapsed().as_secs() > 60 {
261                        trace!("refreshing cache for host={:?}", name.as_str());
262                        // Start a new task to update the cache later.
263                        tokio::task::spawn_blocking(move || {
264                            if let Ok(addrs) = (name.as_str(), 0).to_socket_addrs() {
265                                let addrs: Vec<_> = addrs.collect();
266                                trace!("updating cached host={:?}", name.as_str());
267                                cache.lock().unwrap().insert(name, (Instant::now(), addrs));
268                            }
269                        });
270                    }
271
272                    Ok(addrs.into_iter())
273                } else {
274                    trace!("resolving host={:?}", name.as_str());
275                    match (name.as_str(), 0).to_socket_addrs() {
276                        Ok(addrs) => {
277                            let addrs: Vec<_> = addrs.collect();
278                            cache
279                                .lock()
280                                .unwrap()
281                                .insert(name, (Instant::now(), addrs.clone()));
282                            Ok(addrs.into_iter())
283                        }
284                        res => res,
285                    }
286                }
287            })
288        };
289
290        CachingFuture { inner: blocking }
291    }
292}
293
294impl fmt::Debug for CachingResolver {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        f.pad("CachingResolver")
297    }
298}
299
300impl Future for CachingFuture {
301    type Output = Result<SocketAddrs, io::Error>;
302
303    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
304        Pin::new(&mut self.inner).poll(cx).map(|res| match res {
305            Ok(Ok(addrs)) => Ok(addrs),
306            Ok(Err(err)) => Err(err),
307            Err(join_err) => {
308                if join_err.is_cancelled() {
309                    Err(io::Error::new(io::ErrorKind::Interrupted, join_err))
310                } else {
311                    panic!("background task failed: {join_err:?}")
312                }
313            }
314        })
315    }
316}
317
318impl fmt::Debug for CachingFuture {
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        f.pad("CachingFuture")
321    }
322}
323
324impl Drop for CachingFuture {
325    fn drop(&mut self) {
326        self.inner.abort();
327    }
328}