iota_network_stack/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    convert::Infallible,
7    net::SocketAddr,
8    task::{Context, Poll},
9};
10
11use eyre::{Result, eyre};
12use futures::FutureExt;
13use tokio::net::{TcpListener, ToSocketAddrs};
14use tokio_stream::wrappers::TcpListenerStream;
15use tonic::{
16    body::Body,
17    codegen::{
18        BoxFuture,
19        http::{HeaderValue, Request, Response},
20    },
21    server::NamedService,
22    transport::server::Router,
23};
24use tower::{
25    Layer, Service, ServiceBuilder,
26    layer::util::{Identity, Stack},
27    limit::GlobalConcurrencyLimitLayer,
28    load_shed::LoadShedLayer,
29    util::Either,
30};
31use tower_http::{
32    classify::{GrpcErrorsAsFailures, SharedClassifier},
33    propagate_header::PropagateHeaderLayer,
34    set_header::SetRequestHeaderLayer,
35    trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer},
36};
37
38use crate::{
39    config::Config,
40    metrics::{
41        DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
42        MetricsHandler,
43    },
44    multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
45};
46
47pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
48    router: Router<WrapperService<M>>,
49    health_reporter: tonic_health::server::HealthReporter,
50}
51
52type AddPathToHeaderFunction = fn(&Request<Body>) -> Option<HeaderValue>;
53
54type WrapperService<M> = Stack<
55    Stack<
56        PropagateHeaderLayer,
57        Stack<
58            TraceLayer<
59                SharedClassifier<GrpcErrorsAsFailures>,
60                DefaultMakeSpan,
61                MetricsHandler<M>,
62                MetricsHandler<M>,
63                DefaultOnBodyChunk,
64                DefaultOnEos,
65                MetricsHandler<M>,
66            >,
67            Stack<
68                SetRequestHeaderLayer<AddPathToHeaderFunction>,
69                Stack<
70                    RequestLifetimeLayer<M>,
71                    Stack<
72                        Either<LoadShedLayer, Identity>,
73                        Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
74                    >,
75                >,
76            >,
77        >,
78    >,
79    Identity,
80>;
81
82impl<M: MetricsCallbackProvider> ServerBuilder<M> {
83    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
84        let mut builder = tonic::transport::server::Server::builder();
85
86        if let Some(limit) = config.concurrency_limit_per_connection {
87            builder = builder.concurrency_limit_per_connection(limit);
88        }
89
90        if let Some(timeout) = config.request_timeout {
91            builder = builder.timeout(timeout);
92        }
93
94        if let Some(tcp_nodelay) = config.tcp_nodelay {
95            builder = builder.tcp_nodelay(tcp_nodelay);
96        }
97
98        let load_shed = config
99            .load_shed
100            .unwrap_or_default()
101            .then_some(tower::load_shed::LoadShedLayer::new());
102
103        let metrics = MetricsHandler::new(metrics_provider.clone());
104
105        let request_metrics = TraceLayer::new_for_grpc()
106            .on_request(metrics.clone())
107            .on_response(metrics.clone())
108            .on_failure(metrics);
109
110        let global_concurrency_limit = config
111            .global_concurrency_limit
112            .map(tower::limit::GlobalConcurrencyLimitLayer::new);
113
114        fn add_path_to_request_header(request: &Request<Body>) -> Option<HeaderValue> {
115            let path = request.uri().path();
116            Some(HeaderValue::from_str(path).unwrap())
117        }
118
119        let layer = ServiceBuilder::new()
120            .option_layer(global_concurrency_limit)
121            .option_layer(load_shed)
122            .layer(RequestLifetimeLayer { metrics_provider })
123            .layer(SetRequestHeaderLayer::overriding(
124                GRPC_ENDPOINT_PATH_HEADER.clone(),
125                add_path_to_request_header as AddPathToHeaderFunction,
126            ))
127            .layer(request_metrics)
128            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
129            .into_inner();
130
131        let (health_reporter, health_service) = tonic_health::server::health_reporter();
132        let router = builder
133            .initial_stream_window_size(config.http2_initial_stream_window_size)
134            .initial_connection_window_size(config.http2_initial_connection_window_size)
135            .http2_keepalive_interval(config.http2_keepalive_interval)
136            .http2_keepalive_timeout(config.http2_keepalive_timeout)
137            .max_concurrent_streams(config.http2_max_concurrent_streams)
138            .tcp_keepalive(config.tcp_keepalive)
139            .layer(layer)
140            .add_service(health_service);
141
142        Self {
143            router,
144            health_reporter,
145        }
146    }
147
148    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
149        self.health_reporter.clone()
150    }
151
152    /// Add a new service to this Server.
153    pub fn add_service<S>(mut self, svc: S) -> Self
154    where
155        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
156            + NamedService
157            + Clone
158            + Send
159            + Sync
160            + 'static,
161        S::Future: Send + 'static,
162    {
163        self.router = self.router.add_service(svc);
164        self
165    }
166
167    pub async fn bind(self, addr: &Multiaddr) -> Result<Server> {
168        let mut iter = addr.iter();
169
170        let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
171        let rx_cancellation = rx_cancellation.map(|_| ());
172        let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) =
173            match iter.next().ok_or_else(|| eyre!("malformed addr"))? {
174                Protocol::Dns(_) => {
175                    let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
176                    let (local_addr, incoming) =
177                        tcp_listener_and_update_multiaddr(addr, (dns_name.as_ref(), tcp_port))
178                            .await?;
179                    let server = Box::pin(
180                        self.router
181                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
182                    );
183                    (local_addr, server)
184                }
185                Protocol::Ip4(_) => {
186                    let (socket_addr, _http_or_https) = parse_ip4(addr)?;
187                    let (local_addr, incoming) =
188                        tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
189                    let server = Box::pin(
190                        self.router
191                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
192                    );
193                    (local_addr, server)
194                }
195                Protocol::Ip6(_) => {
196                    let (socket_addr, _http_or_https) = parse_ip6(addr)?;
197                    let (local_addr, incoming) =
198                        tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
199                    let server = Box::pin(
200                        self.router
201                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
202                    );
203                    (local_addr, server)
204                }
205                unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
206            };
207
208        Ok(Server {
209            server,
210            cancel_handle: Some(tx_cancellation),
211            local_addr,
212            health_reporter: self.health_reporter,
213        })
214    }
215}
216
217async fn tcp_listener_and_update_multiaddr<T: ToSocketAddrs>(
218    address: &Multiaddr,
219    socket_addr: T,
220) -> Result<(Multiaddr, TcpListenerStream)> {
221    let (local_addr, incoming) = tcp_listener(socket_addr).await?;
222    let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
223    Ok((local_addr, incoming))
224}
225
226async fn tcp_listener<T: ToSocketAddrs>(address: T) -> Result<(SocketAddr, TcpListenerStream)> {
227    let listener = TcpListener::bind(address).await?;
228    let local_addr = listener.local_addr()?;
229    let incoming = TcpListenerStream::new(listener);
230    Ok((local_addr, incoming))
231}
232
233pub struct Server {
234    server: BoxFuture<(), tonic::transport::Error>,
235    cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
236    local_addr: Multiaddr,
237    health_reporter: tonic_health::server::HealthReporter,
238}
239
240impl Server {
241    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
242        self.server.await
243    }
244
245    pub fn local_addr(&self) -> &Multiaddr {
246        &self.local_addr
247    }
248
249    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
250        self.health_reporter.clone()
251    }
252
253    pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
254        self.cancel_handle.take()
255    }
256}
257
258fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
259    addr.replace(1, |protocol| {
260        if let Protocol::Tcp(_) = protocol {
261            Some(Protocol::Tcp(port))
262        } else {
263            panic!("expected tcp protocol at index 1");
264        }
265    })
266    .expect("tcp protocol at index 1")
267}
268
269#[cfg(test)]
270mod test {
271    use std::{
272        ops::Deref,
273        sync::{Arc, Mutex},
274        time::Duration,
275    };
276
277    use tonic::Code;
278    use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
279
280    use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
281
282    #[test]
283    fn document_multiaddr_limitation_for_unix_protocol() {
284        // You can construct a multiaddr by hand (ie binary format) just fine
285        let path = "/tmp/foo";
286        let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
287
288        // But it doesn't round-trip in the human readable format
289        let s = addr.to_string();
290        assert!(s.parse::<Multiaddr>().is_err());
291    }
292
293    #[tokio::test]
294    async fn test_metrics_layer_successful() {
295        #[derive(Clone)]
296        struct Metrics {
297            /// a flag to figure out whether the
298            /// on_request method has been called.
299            metrics_called: Arc<Mutex<bool>>,
300        }
301
302        impl MetricsCallbackProvider for Metrics {
303            fn on_request(&self, path: String) {
304                assert_eq!(path, "/grpc.health.v1.Health/Check");
305            }
306
307            fn on_response(
308                &self,
309                path: String,
310                _latency: Duration,
311                status: u16,
312                grpc_status_code: Code,
313            ) {
314                assert_eq!(path, "/grpc.health.v1.Health/Check");
315                assert_eq!(status, 200);
316                assert_eq!(grpc_status_code, Code::Ok);
317                let mut m = self.metrics_called.lock().unwrap();
318                *m = true
319            }
320        }
321
322        let metrics = Metrics {
323            metrics_called: Arc::new(Mutex::new(false)),
324        };
325
326        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
327        let config = Config::new();
328
329        let mut server = config
330            .server_builder_with_metrics(metrics.clone())
331            .bind(&address)
332            .await
333            .unwrap();
334
335        let address = server.local_addr().to_owned();
336        let cancel_handle = server.take_cancel_handle().unwrap();
337        let server_handle = tokio::spawn(server.serve());
338        let channel = config.connect(&address).await.unwrap();
339        let mut client = HealthClient::new(channel);
340
341        client
342            .check(HealthCheckRequest {
343                service: "".to_owned(),
344            })
345            .await
346            .unwrap();
347
348        cancel_handle.send(()).unwrap();
349        server_handle.await.unwrap().unwrap();
350
351        assert!(metrics.metrics_called.lock().unwrap().deref());
352    }
353
354    #[tokio::test]
355    async fn test_metrics_layer_error() {
356        #[derive(Clone)]
357        struct Metrics {
358            /// a flag to figure out whether the
359            /// on_request method has been called.
360            metrics_called: Arc<Mutex<bool>>,
361        }
362
363        impl MetricsCallbackProvider for Metrics {
364            fn on_request(&self, path: String) {
365                assert_eq!(path, "/grpc.health.v1.Health/Check");
366            }
367
368            fn on_response(
369                &self,
370                path: String,
371                _latency: Duration,
372                status: u16,
373                grpc_status_code: Code,
374            ) {
375                assert_eq!(path, "/grpc.health.v1.Health/Check");
376                assert_eq!(status, 200);
377                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
378                // code 5 is not_found , which is what we expect to get in this case
379                assert_eq!(grpc_status_code, Code::NotFound);
380                let mut m = self.metrics_called.lock().unwrap();
381                *m = true
382            }
383        }
384
385        let metrics = Metrics {
386            metrics_called: Arc::new(Mutex::new(false)),
387        };
388
389        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
390        let config = Config::new();
391
392        let mut server = config
393            .server_builder_with_metrics(metrics.clone())
394            .bind(&address)
395            .await
396            .unwrap();
397
398        let address = server.local_addr().to_owned();
399        let cancel_handle = server.take_cancel_handle().unwrap();
400        let server_handle = tokio::spawn(server.serve());
401        let channel = config.connect(&address).await.unwrap();
402        let mut client = HealthClient::new(channel);
403
404        // Call the healthcheck for a service that doesn't exist
405        // that should give us back an error with code 5 (not_found)
406        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
407        let _ = client
408            .check(HealthCheckRequest {
409                service: "non-existing-service".to_owned(),
410            })
411            .await;
412
413        cancel_handle.send(()).unwrap();
414        server_handle.await.unwrap().unwrap();
415
416        assert!(metrics.metrics_called.lock().unwrap().deref());
417    }
418
419    async fn test_multiaddr(address: Multiaddr) {
420        let config = Config::new();
421        let mut server = config.server_builder().bind(&address).await.unwrap();
422        let address = server.local_addr().to_owned();
423        let cancel_handle = server.take_cancel_handle().unwrap();
424        let server_handle = tokio::spawn(server.serve());
425        let channel = config.connect(&address).await.unwrap();
426        let mut client = HealthClient::new(channel);
427
428        client
429            .check(HealthCheckRequest {
430                service: "".to_owned(),
431            })
432            .await
433            .unwrap();
434
435        cancel_handle.send(()).unwrap();
436        server_handle.await.unwrap().unwrap();
437    }
438
439    #[tokio::test]
440    async fn dns() {
441        let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
442        test_multiaddr(address).await;
443    }
444
445    #[tokio::test]
446    async fn ip4() {
447        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
448        test_multiaddr(address).await;
449    }
450
451    #[tokio::test]
452    async fn ip6() {
453        let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
454        test_multiaddr(address).await;
455    }
456}
457
458#[derive(Clone)]
459struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
460    metrics_provider: M,
461}
462
463impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
464    type Service = RequestLifetime<M, S>;
465
466    fn layer(&self, inner: S) -> Self::Service {
467        RequestLifetime {
468            inner,
469            metrics_provider: self.metrics_provider.clone(),
470            path: None,
471        }
472    }
473}
474
475#[derive(Clone)]
476struct RequestLifetime<M: MetricsCallbackProvider, S> {
477    inner: S,
478    metrics_provider: M,
479    path: Option<String>,
480}
481
482impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
483    for RequestLifetime<M, S>
484where
485    S: Service<Request<RequestBody>>,
486{
487    type Response = S::Response;
488    type Error = S::Error;
489    type Future = S::Future;
490
491    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
492        self.inner.poll_ready(cx)
493    }
494
495    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
496        if self.path.is_none() {
497            let path = request.uri().path().to_string();
498            self.metrics_provider.on_start(&path);
499            self.path = Some(path);
500        }
501        self.inner.call(request)
502    }
503}
504
505impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
506    fn drop(&mut self) {
507        if let Some(path) = &self.path {
508            self.metrics_provider.on_drop(path)
509        }
510    }
511}