iota_network_stack/
client.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    fmt,
8    future::Future,
9    io,
10    net::{SocketAddr, ToSocketAddrs},
11    pin::Pin,
12    sync::{Arc, Mutex},
13    task::{self, Poll},
14    time::Instant,
15    vec,
16};
17
18use eyre::{Context, Result, eyre};
19use hyper_util::client::legacy::connect::{HttpConnector, dns::Name};
20use once_cell::sync::OnceCell;
21use tokio::task::JoinHandle;
22use tonic::transport::{Channel, Endpoint, Uri};
23use tower::Service;
24use tracing::{info, trace};
25
26use crate::{
27    config::Config,
28    multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
29};
30
31pub async fn connect(address: &Multiaddr) -> Result<Channel> {
32    let channel = endpoint_from_multiaddr(address)?.connect().await?;
33    Ok(channel)
34}
35
36pub fn connect_lazy(address: &Multiaddr) -> Result<Channel> {
37    let channel = endpoint_from_multiaddr(address)?.connect_lazy();
38    Ok(channel)
39}
40
41pub(crate) async fn connect_with_config(address: &Multiaddr, config: &Config) -> Result<Channel> {
42    let channel = endpoint_from_multiaddr(address)?
43        .apply_config(config)
44        .connect()
45        .await?;
46    Ok(channel)
47}
48
49pub(crate) fn connect_lazy_with_config(address: &Multiaddr, config: &Config) -> Result<Channel> {
50    let channel = endpoint_from_multiaddr(address)?
51        .apply_config(config)
52        .connect_lazy();
53    Ok(channel)
54}
55
56fn endpoint_from_multiaddr(addr: &Multiaddr) -> Result<MyEndpoint> {
57    let mut iter = addr.iter();
58
59    let channel = match iter.next().ok_or_else(|| eyre!("address is empty"))? {
60        Protocol::Dns(_) => {
61            let (dns_name, tcp_port, http_or_https) = parse_dns(addr)?;
62            let uri = format!("{http_or_https}://{dns_name}:{tcp_port}");
63            MyEndpoint::try_from_uri(uri)?
64        }
65        Protocol::Ip4(_) => {
66            let (socket_addr, http_or_https) = parse_ip4(addr)?;
67            let uri = format!("{http_or_https}://{socket_addr}");
68            MyEndpoint::try_from_uri(uri)?
69        }
70        Protocol::Ip6(_) => {
71            let (socket_addr, http_or_https) = parse_ip6(addr)?;
72            let uri = format!("{http_or_https}://{socket_addr}");
73            MyEndpoint::try_from_uri(uri)?
74        }
75        unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
76    };
77
78    Ok(channel)
79}
80
81struct MyEndpoint {
82    endpoint: Endpoint,
83}
84
85static DISABLE_CACHING_RESOLVER: OnceCell<bool> = OnceCell::new();
86
87impl MyEndpoint {
88    fn new(endpoint: Endpoint) -> Self {
89        Self { endpoint }
90    }
91
92    fn try_from_uri(uri: String) -> Result<Self> {
93        let uri: Uri = uri
94            .parse()
95            .with_context(|| format!("unable to create Uri from '{uri}'"))?;
96        let endpoint = Endpoint::from(uri);
97        Ok(Self::new(endpoint))
98    }
99
100    fn apply_config(mut self, config: &Config) -> Self {
101        self.endpoint = apply_config_to_endpoint(config, self.endpoint);
102        self
103    }
104
105    fn connect_lazy(self) -> Channel {
106        let disable_caching_resolver = *DISABLE_CACHING_RESOLVER.get_or_init(|| {
107            let disable_caching_resolver = std::env::var("DISABLE_CACHING_RESOLVER").is_ok();
108            info!("DISABLE_CACHING_RESOLVER: {disable_caching_resolver}");
109            disable_caching_resolver
110        });
111
112        if disable_caching_resolver {
113            self.endpoint.connect_lazy()
114        } else {
115            let mut http = HttpConnector::new_with_resolver(CachingResolver::new());
116            http.enforce_http(false);
117            http.set_nodelay(true);
118            http.set_keepalive(None);
119            http.set_connect_timeout(None);
120
121            self.endpoint.connect_with_connector_lazy(http)
122        }
123    }
124
125    async fn connect(self) -> Result<Channel> {
126        self.endpoint.connect().await.map_err(Into::into)
127    }
128}
129
130fn apply_config_to_endpoint(config: &Config, mut endpoint: Endpoint) -> Endpoint {
131    if let Some(limit) = config.concurrency_limit_per_connection {
132        endpoint = endpoint.concurrency_limit(limit);
133    }
134
135    if let Some(timeout) = config.request_timeout {
136        endpoint = endpoint.timeout(timeout);
137    }
138
139    if let Some(timeout) = config.connect_timeout {
140        endpoint = endpoint.connect_timeout(timeout);
141    }
142
143    if let Some(tcp_nodelay) = config.tcp_nodelay {
144        endpoint = endpoint.tcp_nodelay(tcp_nodelay);
145    }
146
147    if let Some(http2_keepalive_interval) = config.http2_keepalive_interval {
148        endpoint = endpoint.http2_keep_alive_interval(http2_keepalive_interval);
149    }
150
151    if let Some(http2_keepalive_timeout) = config.http2_keepalive_timeout {
152        endpoint = endpoint.keep_alive_timeout(http2_keepalive_timeout);
153    }
154
155    if let Some((limit, duration)) = config.rate_limit {
156        endpoint = endpoint.rate_limit(limit, duration);
157    }
158
159    endpoint
160        .initial_stream_window_size(config.http2_initial_stream_window_size)
161        .initial_connection_window_size(config.http2_initial_connection_window_size)
162        .tcp_keepalive(config.tcp_keepalive)
163}
164
165type CacheEntry = (Instant, Vec<SocketAddr>);
166
167/// A caching resolver based on hyper_util GaiResolver
168#[derive(Clone)]
169pub struct CachingResolver {
170    cache: Arc<Mutex<HashMap<Name, CacheEntry>>>,
171}
172
173type SocketAddrs = vec::IntoIter<SocketAddr>;
174
175pub struct CachingFuture {
176    inner: JoinHandle<Result<SocketAddrs, io::Error>>,
177}
178
179impl CachingResolver {
180    pub fn new() -> Self {
181        CachingResolver {
182            cache: Arc::new(Mutex::new(HashMap::new())),
183        }
184    }
185}
186
187impl Default for CachingResolver {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193impl Service<Name> for CachingResolver {
194    type Response = SocketAddrs;
195    type Error = io::Error;
196    type Future = CachingFuture;
197
198    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
199        Poll::Ready(Ok(()))
200    }
201
202    fn call(&mut self, name: Name) -> Self::Future {
203        let blocking = {
204            let cache = self.cache.clone();
205            tokio::task::spawn_blocking(move || {
206                let entry = cache.lock().unwrap().get(&name).cloned();
207
208                if let Some((when, addrs)) = entry {
209                    trace!("cached host={:?}", name.as_str());
210
211                    if when.elapsed().as_secs() > 60 {
212                        trace!("refreshing cache for host={:?}", name.as_str());
213                        // Start a new task to update the cache later.
214                        tokio::task::spawn_blocking(move || {
215                            if let Ok(addrs) = (name.as_str(), 0).to_socket_addrs() {
216                                let addrs: Vec<_> = addrs.collect();
217                                trace!("updating cached host={:?}", name.as_str());
218                                cache
219                                    .lock()
220                                    .unwrap()
221                                    .insert(name, (Instant::now(), addrs.clone()));
222                            }
223                        });
224                    }
225
226                    Ok(addrs.into_iter())
227                } else {
228                    trace!("resolving host={:?}", name.as_str());
229                    match (name.as_str(), 0).to_socket_addrs() {
230                        Ok(addrs) => {
231                            let addrs: Vec<_> = addrs.collect();
232                            cache
233                                .lock()
234                                .unwrap()
235                                .insert(name, (Instant::now(), addrs.clone()));
236                            Ok(addrs.into_iter())
237                        }
238                        res => res,
239                    }
240                }
241            })
242        };
243
244        CachingFuture { inner: blocking }
245    }
246}
247
248impl fmt::Debug for CachingResolver {
249    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250        f.pad("CachingResolver")
251    }
252}
253
254impl Future for CachingFuture {
255    type Output = Result<SocketAddrs, io::Error>;
256
257    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
258        Pin::new(&mut self.inner).poll(cx).map(|res| match res {
259            Ok(Ok(addrs)) => Ok(addrs),
260            Ok(Err(err)) => Err(err),
261            Err(join_err) => {
262                if join_err.is_cancelled() {
263                    Err(io::Error::new(io::ErrorKind::Interrupted, join_err))
264                } else {
265                    panic!("background task failed: {:?}", join_err)
266                }
267            }
268        })
269    }
270}
271
272impl fmt::Debug for CachingFuture {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        f.pad("CachingFuture")
275    }
276}
277
278impl Drop for CachingFuture {
279    fn drop(&mut self) {
280        self.inner.abort();
281    }
282}