iota_network_stack/
client.rs1use std::{
6 collections::HashMap,
7 fmt,
8 future::Future,
9 io,
10 net::{SocketAddr, ToSocketAddrs},
11 pin::Pin,
12 sync::{Arc, Mutex},
13 task::{self, Poll},
14 time::Instant,
15 vec,
16};
17
18use eyre::{Context, Result, eyre};
19use hyper_util::client::legacy::connect::{HttpConnector, dns::Name};
20use once_cell::sync::OnceCell;
21use tokio::task::JoinHandle;
22use tonic::transport::{Channel, Endpoint, Uri};
23use tower::Service;
24use tracing::{info, trace};
25
26use crate::{
27 config::Config,
28 multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
29};
30
31pub async fn connect(address: &Multiaddr) -> Result<Channel> {
32 let channel = endpoint_from_multiaddr(address)?.connect().await?;
33 Ok(channel)
34}
35
36pub fn connect_lazy(address: &Multiaddr) -> Result<Channel> {
37 let channel = endpoint_from_multiaddr(address)?.connect_lazy();
38 Ok(channel)
39}
40
41pub(crate) async fn connect_with_config(address: &Multiaddr, config: &Config) -> Result<Channel> {
42 let channel = endpoint_from_multiaddr(address)?
43 .apply_config(config)
44 .connect()
45 .await?;
46 Ok(channel)
47}
48
49pub(crate) fn connect_lazy_with_config(address: &Multiaddr, config: &Config) -> Result<Channel> {
50 let channel = endpoint_from_multiaddr(address)?
51 .apply_config(config)
52 .connect_lazy();
53 Ok(channel)
54}
55
56fn endpoint_from_multiaddr(addr: &Multiaddr) -> Result<MyEndpoint> {
57 let mut iter = addr.iter();
58
59 let channel = match iter.next().ok_or_else(|| eyre!("address is empty"))? {
60 Protocol::Dns(_) => {
61 let (dns_name, tcp_port, http_or_https) = parse_dns(addr)?;
62 let uri = format!("{http_or_https}://{dns_name}:{tcp_port}");
63 MyEndpoint::try_from_uri(uri)?
64 }
65 Protocol::Ip4(_) => {
66 let (socket_addr, http_or_https) = parse_ip4(addr)?;
67 let uri = format!("{http_or_https}://{socket_addr}");
68 MyEndpoint::try_from_uri(uri)?
69 }
70 Protocol::Ip6(_) => {
71 let (socket_addr, http_or_https) = parse_ip6(addr)?;
72 let uri = format!("{http_or_https}://{socket_addr}");
73 MyEndpoint::try_from_uri(uri)?
74 }
75 unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
76 };
77
78 Ok(channel)
79}
80
81struct MyEndpoint {
82 endpoint: Endpoint,
83}
84
85static DISABLE_CACHING_RESOLVER: OnceCell<bool> = OnceCell::new();
86
87impl MyEndpoint {
88 fn new(endpoint: Endpoint) -> Self {
89 Self { endpoint }
90 }
91
92 fn try_from_uri(uri: String) -> Result<Self> {
93 let uri: Uri = uri
94 .parse()
95 .with_context(|| format!("unable to create Uri from '{uri}'"))?;
96 let endpoint = Endpoint::from(uri);
97 Ok(Self::new(endpoint))
98 }
99
100 fn apply_config(mut self, config: &Config) -> Self {
101 self.endpoint = apply_config_to_endpoint(config, self.endpoint);
102 self
103 }
104
105 fn connect_lazy(self) -> Channel {
106 let disable_caching_resolver = *DISABLE_CACHING_RESOLVER.get_or_init(|| {
107 let disable_caching_resolver = std::env::var("DISABLE_CACHING_RESOLVER").is_ok();
108 info!("DISABLE_CACHING_RESOLVER: {disable_caching_resolver}");
109 disable_caching_resolver
110 });
111
112 if disable_caching_resolver {
113 self.endpoint.connect_lazy()
114 } else {
115 let mut http = HttpConnector::new_with_resolver(CachingResolver::new());
116 http.enforce_http(false);
117 http.set_nodelay(true);
118 http.set_keepalive(None);
119 http.set_connect_timeout(None);
120
121 self.endpoint.connect_with_connector_lazy(http)
122 }
123 }
124
125 async fn connect(self) -> Result<Channel> {
126 self.endpoint.connect().await.map_err(Into::into)
127 }
128}
129
130fn apply_config_to_endpoint(config: &Config, mut endpoint: Endpoint) -> Endpoint {
131 if let Some(limit) = config.concurrency_limit_per_connection {
132 endpoint = endpoint.concurrency_limit(limit);
133 }
134
135 if let Some(timeout) = config.request_timeout {
136 endpoint = endpoint.timeout(timeout);
137 }
138
139 if let Some(timeout) = config.connect_timeout {
140 endpoint = endpoint.connect_timeout(timeout);
141 }
142
143 if let Some(tcp_nodelay) = config.tcp_nodelay {
144 endpoint = endpoint.tcp_nodelay(tcp_nodelay);
145 }
146
147 if let Some(http2_keepalive_interval) = config.http2_keepalive_interval {
148 endpoint = endpoint.http2_keep_alive_interval(http2_keepalive_interval);
149 }
150
151 if let Some(http2_keepalive_timeout) = config.http2_keepalive_timeout {
152 endpoint = endpoint.keep_alive_timeout(http2_keepalive_timeout);
153 }
154
155 if let Some((limit, duration)) = config.rate_limit {
156 endpoint = endpoint.rate_limit(limit, duration);
157 }
158
159 endpoint
160 .initial_stream_window_size(config.http2_initial_stream_window_size)
161 .initial_connection_window_size(config.http2_initial_connection_window_size)
162 .tcp_keepalive(config.tcp_keepalive)
163}
164
165type CacheEntry = (Instant, Vec<SocketAddr>);
166
167#[derive(Clone)]
169pub struct CachingResolver {
170 cache: Arc<Mutex<HashMap<Name, CacheEntry>>>,
171}
172
173type SocketAddrs = vec::IntoIter<SocketAddr>;
174
175pub struct CachingFuture {
176 inner: JoinHandle<Result<SocketAddrs, io::Error>>,
177}
178
179impl CachingResolver {
180 pub fn new() -> Self {
181 CachingResolver {
182 cache: Arc::new(Mutex::new(HashMap::new())),
183 }
184 }
185}
186
187impl Default for CachingResolver {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193impl Service<Name> for CachingResolver {
194 type Response = SocketAddrs;
195 type Error = io::Error;
196 type Future = CachingFuture;
197
198 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
199 Poll::Ready(Ok(()))
200 }
201
202 fn call(&mut self, name: Name) -> Self::Future {
203 let blocking = {
204 let cache = self.cache.clone();
205 tokio::task::spawn_blocking(move || {
206 let entry = cache.lock().unwrap().get(&name).cloned();
207
208 if let Some((when, addrs)) = entry {
209 trace!("cached host={:?}", name.as_str());
210
211 if when.elapsed().as_secs() > 60 {
212 trace!("refreshing cache for host={:?}", name.as_str());
213 tokio::task::spawn_blocking(move || {
215 if let Ok(addrs) = (name.as_str(), 0).to_socket_addrs() {
216 let addrs: Vec<_> = addrs.collect();
217 trace!("updating cached host={:?}", name.as_str());
218 cache
219 .lock()
220 .unwrap()
221 .insert(name, (Instant::now(), addrs.clone()));
222 }
223 });
224 }
225
226 Ok(addrs.into_iter())
227 } else {
228 trace!("resolving host={:?}", name.as_str());
229 match (name.as_str(), 0).to_socket_addrs() {
230 Ok(addrs) => {
231 let addrs: Vec<_> = addrs.collect();
232 cache
233 .lock()
234 .unwrap()
235 .insert(name, (Instant::now(), addrs.clone()));
236 Ok(addrs.into_iter())
237 }
238 res => res,
239 }
240 }
241 })
242 };
243
244 CachingFuture { inner: blocking }
245 }
246}
247
248impl fmt::Debug for CachingResolver {
249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 f.pad("CachingResolver")
251 }
252}
253
254impl Future for CachingFuture {
255 type Output = Result<SocketAddrs, io::Error>;
256
257 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
258 Pin::new(&mut self.inner).poll(cx).map(|res| match res {
259 Ok(Ok(addrs)) => Ok(addrs),
260 Ok(Err(err)) => Err(err),
261 Err(join_err) => {
262 if join_err.is_cancelled() {
263 Err(io::Error::new(io::ErrorKind::Interrupted, join_err))
264 } else {
265 panic!("background task failed: {:?}", join_err)
266 }
267 }
268 })
269 }
270}
271
272impl fmt::Debug for CachingFuture {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.pad("CachingFuture")
275 }
276}
277
278impl Drop for CachingFuture {
279 fn drop(&mut self) {
280 self.inner.abort();
281 }
282}