iota_network_stack/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    convert::Infallible,
7    net::SocketAddr,
8    task::{Context, Poll},
9};
10
11use eyre::{Result, eyre};
12use futures::FutureExt;
13use tokio::net::{TcpListener, ToSocketAddrs};
14use tokio_stream::wrappers::TcpListenerStream;
15use tonic::{
16    body::BoxBody,
17    codegen::{
18        BoxFuture,
19        http::{HeaderValue, Request, Response},
20    },
21    server::NamedService,
22    transport::server::Router,
23};
24use tower::{
25    Layer, Service, ServiceBuilder,
26    layer::util::{Identity, Stack},
27    limit::GlobalConcurrencyLimitLayer,
28    load_shed::LoadShedLayer,
29    util::Either,
30};
31use tower_http::{
32    classify::{GrpcErrorsAsFailures, SharedClassifier},
33    propagate_header::PropagateHeaderLayer,
34    set_header::SetRequestHeaderLayer,
35    trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer},
36};
37
38use crate::{
39    config::Config,
40    metrics::{
41        DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
42        MetricsHandler,
43    },
44    multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
45};
46
47pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
48    router: Router<WrapperService<M>>,
49    health_reporter: tonic_health::server::HealthReporter,
50}
51
52type AddPathToHeaderFunction = fn(&Request<BoxBody>) -> Option<HeaderValue>;
53
54type WrapperService<M> = Stack<
55    Stack<
56        PropagateHeaderLayer,
57        Stack<
58            TraceLayer<
59                SharedClassifier<GrpcErrorsAsFailures>,
60                DefaultMakeSpan,
61                MetricsHandler<M>,
62                MetricsHandler<M>,
63                DefaultOnBodyChunk,
64                DefaultOnEos,
65                MetricsHandler<M>,
66            >,
67            Stack<
68                SetRequestHeaderLayer<AddPathToHeaderFunction>,
69                Stack<
70                    RequestLifetimeLayer<M>,
71                    Stack<
72                        Either<LoadShedLayer, Identity>,
73                        Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
74                    >,
75                >,
76            >,
77        >,
78    >,
79    Identity,
80>;
81
82impl<M: MetricsCallbackProvider> ServerBuilder<M> {
83    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
84        let mut builder = tonic::transport::server::Server::builder();
85
86        if let Some(limit) = config.concurrency_limit_per_connection {
87            builder = builder.concurrency_limit_per_connection(limit);
88        }
89
90        if let Some(timeout) = config.request_timeout {
91            builder = builder.timeout(timeout);
92        }
93
94        if let Some(tcp_nodelay) = config.tcp_nodelay {
95            builder = builder.tcp_nodelay(tcp_nodelay);
96        }
97
98        let load_shed = config
99            .load_shed
100            .unwrap_or_default()
101            .then_some(tower::load_shed::LoadShedLayer::new());
102
103        let metrics = MetricsHandler::new(metrics_provider.clone());
104
105        let request_metrics = TraceLayer::new_for_grpc()
106            .on_request(metrics.clone())
107            .on_response(metrics.clone())
108            .on_failure(metrics);
109
110        let global_concurrency_limit = config
111            .global_concurrency_limit
112            .map(tower::limit::GlobalConcurrencyLimitLayer::new);
113
114        fn add_path_to_request_header(request: &Request<BoxBody>) -> Option<HeaderValue> {
115            let path = request.uri().path();
116            Some(HeaderValue::from_str(path).unwrap())
117        }
118
119        let layer = ServiceBuilder::new()
120            .option_layer(global_concurrency_limit)
121            .option_layer(load_shed)
122            .layer(RequestLifetimeLayer { metrics_provider })
123            .layer(SetRequestHeaderLayer::overriding(
124                GRPC_ENDPOINT_PATH_HEADER.clone(),
125                add_path_to_request_header as AddPathToHeaderFunction,
126            ))
127            .layer(request_metrics)
128            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
129            .into_inner();
130
131        let (health_reporter, health_service) = tonic_health::server::health_reporter();
132        let router = builder
133            .initial_stream_window_size(config.http2_initial_stream_window_size)
134            .initial_connection_window_size(config.http2_initial_connection_window_size)
135            .http2_keepalive_interval(config.http2_keepalive_interval)
136            .http2_keepalive_timeout(config.http2_keepalive_timeout)
137            .max_concurrent_streams(config.http2_max_concurrent_streams)
138            .tcp_keepalive(config.tcp_keepalive)
139            .layer(layer)
140            .add_service(health_service);
141
142        Self {
143            router,
144            health_reporter,
145        }
146    }
147
148    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
149        self.health_reporter.clone()
150    }
151
152    /// Add a new service to this Server.
153    pub fn add_service<S>(mut self, svc: S) -> Self
154    where
155        S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
156            + NamedService
157            + Clone
158            + Send
159            + 'static,
160        S::Future: Send + 'static,
161    {
162        self.router = self.router.add_service(svc);
163        self
164    }
165
166    pub async fn bind(self, addr: &Multiaddr) -> Result<Server> {
167        let mut iter = addr.iter();
168
169        let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
170        let rx_cancellation = rx_cancellation.map(|_| ());
171        let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) =
172            match iter.next().ok_or_else(|| eyre!("malformed addr"))? {
173                Protocol::Dns(_) => {
174                    let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
175                    let (local_addr, incoming) =
176                        tcp_listener_and_update_multiaddr(addr, (dns_name.as_ref(), tcp_port))
177                            .await?;
178                    let server = Box::pin(
179                        self.router
180                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
181                    );
182                    (local_addr, server)
183                }
184                Protocol::Ip4(_) => {
185                    let (socket_addr, _http_or_https) = parse_ip4(addr)?;
186                    let (local_addr, incoming) =
187                        tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
188                    let server = Box::pin(
189                        self.router
190                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
191                    );
192                    (local_addr, server)
193                }
194                Protocol::Ip6(_) => {
195                    let (socket_addr, _http_or_https) = parse_ip6(addr)?;
196                    let (local_addr, incoming) =
197                        tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
198                    let server = Box::pin(
199                        self.router
200                            .serve_with_incoming_shutdown(incoming, rx_cancellation),
201                    );
202                    (local_addr, server)
203                }
204                unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
205            };
206
207        Ok(Server {
208            server,
209            cancel_handle: Some(tx_cancellation),
210            local_addr,
211            health_reporter: self.health_reporter,
212        })
213    }
214}
215
216async fn tcp_listener_and_update_multiaddr<T: ToSocketAddrs>(
217    address: &Multiaddr,
218    socket_addr: T,
219) -> Result<(Multiaddr, TcpListenerStream)> {
220    let (local_addr, incoming) = tcp_listener(socket_addr).await?;
221    let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
222    Ok((local_addr, incoming))
223}
224
225async fn tcp_listener<T: ToSocketAddrs>(address: T) -> Result<(SocketAddr, TcpListenerStream)> {
226    let listener = TcpListener::bind(address).await?;
227    let local_addr = listener.local_addr()?;
228    let incoming = TcpListenerStream::new(listener);
229    Ok((local_addr, incoming))
230}
231
232pub struct Server {
233    server: BoxFuture<(), tonic::transport::Error>,
234    cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
235    local_addr: Multiaddr,
236    health_reporter: tonic_health::server::HealthReporter,
237}
238
239impl Server {
240    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
241        self.server.await
242    }
243
244    pub fn local_addr(&self) -> &Multiaddr {
245        &self.local_addr
246    }
247
248    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
249        self.health_reporter.clone()
250    }
251
252    pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
253        self.cancel_handle.take()
254    }
255}
256
257fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
258    addr.replace(1, |protocol| {
259        if let Protocol::Tcp(_) = protocol {
260            Some(Protocol::Tcp(port))
261        } else {
262            panic!("expected tcp protocol at index 1");
263        }
264    })
265    .expect("tcp protocol at index 1")
266}
267
268#[cfg(test)]
269mod test {
270    use std::{
271        ops::Deref,
272        sync::{Arc, Mutex},
273        time::Duration,
274    };
275
276    use tonic::Code;
277    use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
278
279    use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
280
281    #[test]
282    fn document_multiaddr_limitation_for_unix_protocol() {
283        // You can construct a multiaddr by hand (ie binary format) just fine
284        let path = "/tmp/foo";
285        let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
286
287        // But it doesn't round-trip in the human readable format
288        let s = addr.to_string();
289        assert!(s.parse::<Multiaddr>().is_err());
290    }
291
292    #[tokio::test]
293    async fn test_metrics_layer_successful() {
294        #[derive(Clone)]
295        struct Metrics {
296            /// a flag to figure out whether the
297            /// on_request method has been called.
298            metrics_called: Arc<Mutex<bool>>,
299        }
300
301        impl MetricsCallbackProvider for Metrics {
302            fn on_request(&self, path: String) {
303                assert_eq!(path, "/grpc.health.v1.Health/Check");
304            }
305
306            fn on_response(
307                &self,
308                path: String,
309                _latency: Duration,
310                status: u16,
311                grpc_status_code: Code,
312            ) {
313                assert_eq!(path, "/grpc.health.v1.Health/Check");
314                assert_eq!(status, 200);
315                assert_eq!(grpc_status_code, Code::Ok);
316                let mut m = self.metrics_called.lock().unwrap();
317                *m = true
318            }
319        }
320
321        let metrics = Metrics {
322            metrics_called: Arc::new(Mutex::new(false)),
323        };
324
325        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
326        let config = Config::new();
327
328        let mut server = config
329            .server_builder_with_metrics(metrics.clone())
330            .bind(&address)
331            .await
332            .unwrap();
333
334        let address = server.local_addr().to_owned();
335        let cancel_handle = server.take_cancel_handle().unwrap();
336        let server_handle = tokio::spawn(server.serve());
337        let channel = config.connect(&address).await.unwrap();
338        let mut client = HealthClient::new(channel);
339
340        client
341            .check(HealthCheckRequest {
342                service: "".to_owned(),
343            })
344            .await
345            .unwrap();
346
347        cancel_handle.send(()).unwrap();
348        server_handle.await.unwrap().unwrap();
349
350        assert!(metrics.metrics_called.lock().unwrap().deref());
351    }
352
353    #[tokio::test]
354    async fn test_metrics_layer_error() {
355        #[derive(Clone)]
356        struct Metrics {
357            /// a flag to figure out whether the
358            /// on_request method has been called.
359            metrics_called: Arc<Mutex<bool>>,
360        }
361
362        impl MetricsCallbackProvider for Metrics {
363            fn on_request(&self, path: String) {
364                assert_eq!(path, "/grpc.health.v1.Health/Check");
365            }
366
367            fn on_response(
368                &self,
369                path: String,
370                _latency: Duration,
371                status: u16,
372                grpc_status_code: Code,
373            ) {
374                assert_eq!(path, "/grpc.health.v1.Health/Check");
375                assert_eq!(status, 200);
376                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
377                // code 5 is not_found , which is what we expect to get in this case
378                assert_eq!(grpc_status_code, Code::NotFound);
379                let mut m = self.metrics_called.lock().unwrap();
380                *m = true
381            }
382        }
383
384        let metrics = Metrics {
385            metrics_called: Arc::new(Mutex::new(false)),
386        };
387
388        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
389        let config = Config::new();
390
391        let mut server = config
392            .server_builder_with_metrics(metrics.clone())
393            .bind(&address)
394            .await
395            .unwrap();
396
397        let address = server.local_addr().to_owned();
398        let cancel_handle = server.take_cancel_handle().unwrap();
399        let server_handle = tokio::spawn(server.serve());
400        let channel = config.connect(&address).await.unwrap();
401        let mut client = HealthClient::new(channel);
402
403        // Call the healthcheck for a service that doesn't exist
404        // that should give us back an error with code 5 (not_found)
405        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
406        let _ = client
407            .check(HealthCheckRequest {
408                service: "non-existing-service".to_owned(),
409            })
410            .await;
411
412        cancel_handle.send(()).unwrap();
413        server_handle.await.unwrap().unwrap();
414
415        assert!(metrics.metrics_called.lock().unwrap().deref());
416    }
417
418    async fn test_multiaddr(address: Multiaddr) {
419        let config = Config::new();
420        let mut server = config.server_builder().bind(&address).await.unwrap();
421        let address = server.local_addr().to_owned();
422        let cancel_handle = server.take_cancel_handle().unwrap();
423        let server_handle = tokio::spawn(server.serve());
424        let channel = config.connect(&address).await.unwrap();
425        let mut client = HealthClient::new(channel);
426
427        client
428            .check(HealthCheckRequest {
429                service: "".to_owned(),
430            })
431            .await
432            .unwrap();
433
434        cancel_handle.send(()).unwrap();
435        server_handle.await.unwrap().unwrap();
436    }
437
438    #[tokio::test]
439    async fn dns() {
440        let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
441        test_multiaddr(address).await;
442    }
443
444    #[tokio::test]
445    async fn ip4() {
446        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
447        test_multiaddr(address).await;
448    }
449
450    #[tokio::test]
451    async fn ip6() {
452        let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
453        test_multiaddr(address).await;
454    }
455}
456
457#[derive(Clone)]
458struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
459    metrics_provider: M,
460}
461
462impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
463    type Service = RequestLifetime<M, S>;
464
465    fn layer(&self, inner: S) -> Self::Service {
466        RequestLifetime {
467            inner,
468            metrics_provider: self.metrics_provider.clone(),
469            path: None,
470        }
471    }
472}
473
474#[derive(Clone)]
475struct RequestLifetime<M: MetricsCallbackProvider, S> {
476    inner: S,
477    metrics_provider: M,
478    path: Option<String>,
479}
480
481impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
482    for RequestLifetime<M, S>
483where
484    S: Service<Request<RequestBody>>,
485{
486    type Response = S::Response;
487    type Error = S::Error;
488    type Future = S::Future;
489
490    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
491        self.inner.poll_ready(cx)
492    }
493
494    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
495        if self.path.is_none() {
496            let path = request.uri().path().to_string();
497            self.metrics_provider.on_start(&path);
498            self.path = Some(path);
499        }
500        self.inner.call(request)
501    }
502}
503
504impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
505    fn drop(&mut self) {
506        if let Some(path) = &self.path {
507            self.metrics_provider.on_drop(path)
508        }
509    }
510}