iota_network_stack/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    convert::Infallible,
7    task::{Context, Poll},
8};
9
10use eyre::{Result, eyre};
11use tokio_rustls::rustls::ServerConfig;
12use tonic::{
13    body::Body,
14    codegen::http::{HeaderValue, Request, Response},
15    server::NamedService,
16};
17use tower::{Layer, Service, ServiceBuilder};
18use tower_http::{
19    propagate_header::PropagateHeaderLayer, set_header::SetRequestHeaderLayer, trace::TraceLayer,
20};
21
22use crate::{
23    config::Config,
24    metrics::{
25        DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
26        MetricsHandler,
27    },
28    multiaddr::{Multiaddr, Protocol},
29};
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32    config: Config,
33    metrics_provider: M,
34    router: tonic::service::Routes,
35    health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40        let (health_reporter, health_service) = tonic_health::server::health_reporter();
41        let router = tonic::service::Routes::new(health_service);
42
43        Self {
44            config: config.to_owned(),
45            metrics_provider,
46            router,
47            health_reporter,
48        }
49    }
50
51    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52        self.health_reporter.clone()
53    }
54
55    /// Add a new service to this Server.
56    pub fn add_service<S>(mut self, svc: S) -> Self
57    where
58        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59            + NamedService
60            + Clone
61            + Send
62            + Sync
63            + 'static,
64        S::Future: Send + 'static,
65    {
66        self.router = self.router.add_service(svc);
67        self
68    }
69
70    pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71        let http_config = self
72            .config
73            .http_config()
74            // Temporarily continue allowing clients to connection without TLS even when the server
75            // is configured with a tls_config
76            .allow_insecure(true);
77
78        let request_timeout = self.config.request_timeout;
79        let metrics_provider = self.metrics_provider;
80        let metrics = MetricsHandler::new(metrics_provider.clone());
81        let request_metrics = TraceLayer::new_for_grpc()
82            .on_request(metrics.clone())
83            .on_response(metrics.clone())
84            .on_failure(metrics);
85
86        fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
87            let path = request.uri().path();
88            Some(HeaderValue::from_str(path).unwrap())
89        }
90
91        let limiting_layers = ServiceBuilder::new()
92            .option_layer(
93                self.config
94                    .load_shed
95                    .unwrap_or_default()
96                    .then_some(tower::load_shed::LoadShedLayer::new()),
97            )
98            .option_layer(
99                self.config
100                    .global_concurrency_limit
101                    .map(tower::limit::GlobalConcurrencyLimitLayer::new),
102            );
103
104        let route_layers = ServiceBuilder::new()
105            .map_request(|mut request: http::Request<_>| {
106                if let Some(connect_info) = request.extensions().get::<iota_http::ConnectInfo>() {
107                    let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
108                        local_addr: Some(connect_info.local_addr),
109                        remote_addr: Some(connect_info.remote_addr),
110                    };
111                    request.extensions_mut().insert(tonic_connect_info);
112                }
113                request
114            })
115            .layer(RequestLifetimeLayer { metrics_provider })
116            .layer(SetRequestHeaderLayer::overriding(
117                GRPC_ENDPOINT_PATH_HEADER.clone(),
118                add_path_to_request_header,
119            ))
120            .layer(request_metrics)
121            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
122            .layer_fn(move |service| {
123                crate::grpc_timeout::GrpcTimeout::new(service, request_timeout)
124            });
125
126        let mut builder = iota_http::Builder::new().config(http_config);
127
128        if let Some(tls_config) = tls_config {
129            builder = builder.tls_config(tls_config);
130        }
131
132        let server_handle = builder
133            .serve(
134                addr,
135                limiting_layers.service(self.router.into_axum_router().layer(route_layers)),
136            )
137            .map_err(|e| eyre!(e))?;
138
139        let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
140        Ok(Server {
141            server_handle,
142            local_addr,
143            health_reporter: self.health_reporter,
144        })
145    }
146}
147
148/// TLS server name to use for the public IOTA validator interface.
149pub const IOTA_TLS_SERVER_NAME: &str = "iota";
150
151pub struct Server {
152    server_handle: iota_http::ServerHandle,
153    local_addr: Multiaddr,
154    health_reporter: tonic_health::server::HealthReporter,
155}
156
157impl Server {
158    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
159        self.server_handle.wait_for_shutdown().await;
160        Ok(())
161    }
162
163    pub fn trigger_shutdown(&self) {
164        self.server_handle.trigger_shutdown();
165    }
166
167    pub fn local_addr(&self) -> &Multiaddr {
168        &self.local_addr
169    }
170
171    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
172        self.health_reporter.clone()
173    }
174
175    pub fn handle(&self) -> &iota_http::ServerHandle {
176        &self.server_handle
177    }
178}
179
180fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
181    addr.replace(1, |protocol| {
182        if let Protocol::Tcp(_) = protocol {
183            Some(Protocol::Tcp(port))
184        } else {
185            panic!("expected tcp protocol at index 1");
186        }
187    })
188    .expect("tcp protocol at index 1")
189}
190
191#[cfg(test)]
192mod test {
193    use std::{
194        ops::Deref,
195        sync::{Arc, Mutex},
196        time::Duration,
197    };
198
199    use tonic::Code;
200    use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
201
202    use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
203
204    #[test]
205    fn document_multiaddr_limitation_for_unix_protocol() {
206        // You can construct a multiaddr by hand (ie binary format) just fine
207        let path = "/tmp/foo";
208        let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
209
210        // But it doesn't round-trip in the human readable format
211        let s = addr.to_string();
212        assert!(s.parse::<Multiaddr>().is_err());
213    }
214
215    #[tokio::test]
216    async fn test_metrics_layer_successful() {
217        #[derive(Clone)]
218        struct Metrics {
219            /// a flag to figure out whether the
220            /// on_request method has been called.
221            metrics_called: Arc<Mutex<bool>>,
222        }
223
224        impl MetricsCallbackProvider for Metrics {
225            fn on_request(&self, path: String) {
226                assert_eq!(path, "/grpc.health.v1.Health/Check");
227            }
228
229            fn on_response(
230                &self,
231                path: String,
232                _latency: Duration,
233                status: u16,
234                grpc_status_code: Code,
235            ) {
236                assert_eq!(path, "/grpc.health.v1.Health/Check");
237                assert_eq!(status, 200);
238                assert_eq!(grpc_status_code, Code::Ok);
239                let mut m = self.metrics_called.lock().unwrap();
240                *m = true
241            }
242        }
243
244        let metrics = Metrics {
245            metrics_called: Arc::new(Mutex::new(false)),
246        };
247
248        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
249        let config = Config::new();
250
251        let server = config
252            .server_builder_with_metrics(metrics.clone())
253            .bind(&address, None)
254            .await
255            .unwrap();
256
257        let address = server.local_addr().to_owned();
258        let channel = config.connect(&address, None).await.unwrap();
259        let mut client = HealthClient::new(channel);
260
261        client
262            .check(HealthCheckRequest {
263                service: "".to_owned(),
264            })
265            .await
266            .unwrap();
267
268        server.server_handle.shutdown().await;
269
270        assert!(metrics.metrics_called.lock().unwrap().deref());
271    }
272
273    #[tokio::test]
274    async fn test_metrics_layer_error() {
275        #[derive(Clone)]
276        struct Metrics {
277            /// a flag to figure out whether the
278            /// on_request method has been called.
279            metrics_called: Arc<Mutex<bool>>,
280        }
281
282        impl MetricsCallbackProvider for Metrics {
283            fn on_request(&self, path: String) {
284                assert_eq!(path, "/grpc.health.v1.Health/Check");
285            }
286
287            fn on_response(
288                &self,
289                path: String,
290                _latency: Duration,
291                status: u16,
292                grpc_status_code: Code,
293            ) {
294                assert_eq!(path, "/grpc.health.v1.Health/Check");
295                assert_eq!(status, 200);
296                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
297                // code 5 is not_found , which is what we expect to get in this case
298                assert_eq!(grpc_status_code, Code::NotFound);
299                let mut m = self.metrics_called.lock().unwrap();
300                *m = true
301            }
302        }
303
304        let metrics = Metrics {
305            metrics_called: Arc::new(Mutex::new(false)),
306        };
307
308        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
309        let config = Config::new();
310
311        let server = config
312            .server_builder_with_metrics(metrics.clone())
313            .bind(&address, None)
314            .await
315            .unwrap();
316
317        let address = server.local_addr().to_owned();
318        let channel = config.connect(&address, None).await.unwrap();
319        let mut client = HealthClient::new(channel);
320
321        // Call the healthcheck for a service that doesn't exist
322        // that should give us back an error with code 5 (not_found)
323        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
324        let _ = client
325            .check(HealthCheckRequest {
326                service: "non-existing-service".to_owned(),
327            })
328            .await;
329
330        server.server_handle.shutdown().await;
331
332        assert!(metrics.metrics_called.lock().unwrap().deref());
333    }
334
335    async fn test_multiaddr(address: Multiaddr) {
336        let config = Config::new();
337        let server_handle = config.server_builder().bind(&address, None).await.unwrap();
338        let address = server_handle.local_addr().to_owned();
339        let channel = config.connect(&address, None).await.unwrap();
340        let mut client = HealthClient::new(channel);
341
342        client
343            .check(HealthCheckRequest {
344                service: "".to_owned(),
345            })
346            .await
347            .unwrap();
348
349        server_handle.server_handle.shutdown().await;
350    }
351
352    #[tokio::test]
353    async fn dns() {
354        let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
355        test_multiaddr(address).await;
356    }
357
358    #[tokio::test]
359    async fn ip4() {
360        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
361        test_multiaddr(address).await;
362    }
363
364    #[tokio::test]
365    async fn ip6() {
366        let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
367        test_multiaddr(address).await;
368    }
369}
370
371#[derive(Clone)]
372struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
373    metrics_provider: M,
374}
375
376impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
377    type Service = RequestLifetime<M, S>;
378
379    fn layer(&self, inner: S) -> Self::Service {
380        RequestLifetime {
381            inner,
382            metrics_provider: self.metrics_provider.clone(),
383            path: None,
384        }
385    }
386}
387
388#[derive(Clone)]
389struct RequestLifetime<M: MetricsCallbackProvider, S> {
390    inner: S,
391    metrics_provider: M,
392    path: Option<String>,
393}
394
395impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
396    for RequestLifetime<M, S>
397where
398    S: Service<Request<RequestBody>>,
399{
400    type Response = S::Response;
401    type Error = S::Error;
402    type Future = S::Future;
403
404    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405        self.inner.poll_ready(cx)
406    }
407
408    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
409        if self.path.is_none() {
410            let path = request.uri().path().to_string();
411            self.metrics_provider.on_start(&path);
412            self.path = Some(path);
413        }
414        self.inner.call(request)
415    }
416}
417
418impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
419    fn drop(&mut self) {
420        if let Some(path) = &self.path {
421            self.metrics_provider.on_drop(path)
422        }
423    }
424}