iota_network_stack/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    convert::Infallible,
7    task::{Context, Poll},
8};
9
10use eyre::{Result, eyre};
11use tokio_rustls::rustls::ServerConfig;
12use tonic::{
13    body::Body,
14    codegen::http::{HeaderValue, Request, Response},
15    server::NamedService,
16};
17use tower::{Layer, Service, ServiceBuilder};
18use tower_http::{
19    propagate_header::PropagateHeaderLayer, set_header::SetRequestHeaderLayer, trace::TraceLayer,
20};
21
22use crate::{
23    config::Config,
24    metrics::{
25        DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
26        MetricsHandler,
27    },
28    multiaddr::{Multiaddr, Protocol},
29};
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32    config: Config,
33    metrics_provider: M,
34    router: tonic::service::Routes,
35    health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40        let (health_reporter, health_service) = tonic_health::server::health_reporter();
41        let router = tonic::service::Routes::new(health_service);
42
43        Self {
44            config: config.to_owned(),
45            metrics_provider,
46            router,
47            health_reporter,
48        }
49    }
50
51    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52        self.health_reporter.clone()
53    }
54
55    /// Add a new service to this Server.
56    pub fn add_service<S>(mut self, svc: S) -> Self
57    where
58        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59            + NamedService
60            + Clone
61            + Send
62            + Sync
63            + 'static,
64        S::Future: Send + 'static,
65    {
66        self.router = self.router.add_service(svc);
67        self
68    }
69
70    pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71        let http_config = self
72            .config
73            .http_config()
74            // Temporarily continue allowing clients to connection without TLS even when the server
75            // is configured with a tls_config
76            .allow_insecure(true);
77
78        let request_timeout = self.config.request_timeout;
79        let metrics_provider = self.metrics_provider;
80        let metrics = MetricsHandler::new(metrics_provider.clone());
81        let request_metrics = TraceLayer::new_for_grpc()
82            .on_request(metrics.clone())
83            .on_response(metrics.clone())
84            .on_failure(metrics);
85
86        fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
87            let path = request.uri().path();
88            Some(HeaderValue::from_str(path).unwrap())
89        }
90
91        let limiting_layers = ServiceBuilder::new()
92            .option_layer(
93                self.config
94                    .load_shed
95                    .unwrap_or_default()
96                    .then_some(tower::load_shed::LoadShedLayer::new()),
97            )
98            .option_layer(
99                self.config
100                    .global_concurrency_limit
101                    .map(tower::limit::GlobalConcurrencyLimitLayer::new),
102            );
103
104        let route_layers = ServiceBuilder::new()
105            .map_request(|mut request: http::Request<_>| {
106                if let Some(connect_info) = request.extensions().get::<iota_http::ConnectInfo>() {
107                    let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
108                        local_addr: Some(connect_info.local_addr),
109                        remote_addr: Some(connect_info.remote_addr),
110                    };
111                    request.extensions_mut().insert(tonic_connect_info);
112                }
113                request
114            })
115            .layer(RequestLifetimeLayer { metrics_provider })
116            .layer(SetRequestHeaderLayer::overriding(
117                GRPC_ENDPOINT_PATH_HEADER.clone(),
118                add_path_to_request_header,
119            ))
120            .layer(request_metrics)
121            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
122            .layer_fn(move |service| {
123                crate::grpc_timeout::GrpcTimeout::new(service, request_timeout)
124            });
125
126        let mut builder = iota_http::Builder::new().config(http_config);
127
128        let has_tls = tls_config.is_some();
129        if let Some(tls_config) = tls_config {
130            builder = builder.tls_config(tls_config);
131        }
132
133        let server_handle = builder
134            .serve(
135                addr,
136                limiting_layers.service(self.router.into_axum_router().layer(route_layers)),
137            )
138            .map_err(|e| eyre!(e))?;
139
140        let mut local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
141        if has_tls {
142            local_addr = local_addr.rewrite_http_to_https();
143        }
144        Ok(Server {
145            server_handle,
146            local_addr,
147            health_reporter: self.health_reporter,
148        })
149    }
150}
151
152/// TLS server name to use for the public IOTA validator interface.
153pub const IOTA_TLS_SERVER_NAME: &str = "iota";
154
155pub struct Server {
156    server_handle: iota_http::ServerHandle,
157    local_addr: Multiaddr,
158    health_reporter: tonic_health::server::HealthReporter,
159}
160
161impl Server {
162    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
163        self.server_handle.wait_for_shutdown().await;
164        Ok(())
165    }
166
167    pub fn trigger_shutdown(&self) {
168        self.server_handle.trigger_shutdown();
169    }
170
171    pub fn local_addr(&self) -> &Multiaddr {
172        &self.local_addr
173    }
174
175    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
176        self.health_reporter.clone()
177    }
178
179    pub fn handle(&self) -> &iota_http::ServerHandle {
180        &self.server_handle
181    }
182}
183
184fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
185    addr.replace(1, |protocol| {
186        if let Protocol::Tcp(_) = protocol {
187            Some(Protocol::Tcp(port))
188        } else {
189            panic!("expected tcp protocol at index 1");
190        }
191    })
192    .expect("tcp protocol at index 1")
193}
194
195#[cfg(test)]
196mod test {
197    use std::{
198        ops::Deref,
199        sync::{Arc, Mutex},
200        time::Duration,
201    };
202
203    use tonic::Code;
204    use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
205
206    use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
207
208    #[test]
209    fn document_multiaddr_limitation_for_unix_protocol() {
210        // You can construct a multiaddr by hand (ie binary format) just fine
211        let path = "/tmp/foo";
212        let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
213
214        // But it doesn't round-trip in the human readable format
215        let s = addr.to_string();
216        assert!(s.parse::<Multiaddr>().is_err());
217    }
218
219    #[tokio::test]
220    async fn test_metrics_layer_successful() {
221        #[derive(Clone)]
222        struct Metrics {
223            /// a flag to figure out whether the
224            /// on_request method has been called.
225            metrics_called: Arc<Mutex<bool>>,
226        }
227
228        impl MetricsCallbackProvider for Metrics {
229            fn on_request(&self, path: String) {
230                assert_eq!(path, "/grpc.health.v1.Health/Check");
231            }
232
233            fn on_response(
234                &self,
235                path: String,
236                _latency: Duration,
237                status: u16,
238                grpc_status_code: Code,
239            ) {
240                assert_eq!(path, "/grpc.health.v1.Health/Check");
241                assert_eq!(status, 200);
242                assert_eq!(grpc_status_code, Code::Ok);
243                let mut m = self.metrics_called.lock().unwrap();
244                *m = true
245            }
246        }
247
248        let metrics = Metrics {
249            metrics_called: Arc::new(Mutex::new(false)),
250        };
251
252        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
253        let config = Config::new();
254
255        let server = config
256            .server_builder_with_metrics(metrics.clone())
257            .bind(&address, None)
258            .await
259            .unwrap();
260
261        let address = server.local_addr().to_owned();
262        let channel = config.connect(&address, None).await.unwrap();
263        let mut client = HealthClient::new(channel);
264
265        client
266            .check(HealthCheckRequest {
267                service: "".to_owned(),
268            })
269            .await
270            .unwrap();
271
272        server.server_handle.shutdown().await;
273
274        assert!(metrics.metrics_called.lock().unwrap().deref());
275    }
276
277    #[tokio::test]
278    async fn test_metrics_layer_error() {
279        #[derive(Clone)]
280        struct Metrics {
281            /// a flag to figure out whether the
282            /// on_request method has been called.
283            metrics_called: Arc<Mutex<bool>>,
284        }
285
286        impl MetricsCallbackProvider for Metrics {
287            fn on_request(&self, path: String) {
288                assert_eq!(path, "/grpc.health.v1.Health/Check");
289            }
290
291            fn on_response(
292                &self,
293                path: String,
294                _latency: Duration,
295                status: u16,
296                grpc_status_code: Code,
297            ) {
298                assert_eq!(path, "/grpc.health.v1.Health/Check");
299                assert_eq!(status, 200);
300                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
301                // code 5 is not_found , which is what we expect to get in this case
302                assert_eq!(grpc_status_code, Code::NotFound);
303                let mut m = self.metrics_called.lock().unwrap();
304                *m = true
305            }
306        }
307
308        let metrics = Metrics {
309            metrics_called: Arc::new(Mutex::new(false)),
310        };
311
312        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
313        let config = Config::new();
314
315        let server = config
316            .server_builder_with_metrics(metrics.clone())
317            .bind(&address, None)
318            .await
319            .unwrap();
320
321        let address = server.local_addr().to_owned();
322        let channel = config.connect(&address, None).await.unwrap();
323        let mut client = HealthClient::new(channel);
324
325        // Call the healthcheck for a service that doesn't exist
326        // that should give us back an error with code 5 (not_found)
327        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
328        let _ = client
329            .check(HealthCheckRequest {
330                service: "non-existing-service".to_owned(),
331            })
332            .await;
333
334        server.server_handle.shutdown().await;
335
336        assert!(metrics.metrics_called.lock().unwrap().deref());
337    }
338
339    async fn test_multiaddr(address: Multiaddr) {
340        let config = Config::new();
341        let server_handle = config.server_builder().bind(&address, None).await.unwrap();
342        let address = server_handle.local_addr().to_owned();
343        let channel = config.connect(&address, None).await.unwrap();
344        let mut client = HealthClient::new(channel);
345
346        client
347            .check(HealthCheckRequest {
348                service: "".to_owned(),
349            })
350            .await
351            .unwrap();
352
353        server_handle.server_handle.shutdown().await;
354    }
355
356    #[tokio::test]
357    async fn dns() {
358        let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
359        test_multiaddr(address).await;
360    }
361
362    #[tokio::test]
363    async fn ip4() {
364        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
365        test_multiaddr(address).await;
366    }
367
368    #[tokio::test]
369    async fn ip6() {
370        let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
371        test_multiaddr(address).await;
372    }
373}
374
375#[derive(Clone)]
376struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
377    metrics_provider: M,
378}
379
380impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
381    type Service = RequestLifetime<M, S>;
382
383    fn layer(&self, inner: S) -> Self::Service {
384        RequestLifetime {
385            inner,
386            metrics_provider: self.metrics_provider.clone(),
387            path: None,
388        }
389    }
390}
391
392#[derive(Clone)]
393struct RequestLifetime<M: MetricsCallbackProvider, S> {
394    inner: S,
395    metrics_provider: M,
396    path: Option<String>,
397}
398
399impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
400    for RequestLifetime<M, S>
401where
402    S: Service<Request<RequestBody>>,
403{
404    type Response = S::Response;
405    type Error = S::Error;
406    type Future = S::Future;
407
408    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
409        self.inner.poll_ready(cx)
410    }
411
412    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
413        if self.path.is_none() {
414            let path = request.uri().path().to_string();
415            self.metrics_provider.on_start(&path);
416            self.path = Some(path);
417        }
418        self.inner.call(request)
419    }
420}
421
422impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
423    fn drop(&mut self) {
424        if let Some(path) = &self.path {
425            self.metrics_provider.on_drop(path)
426        }
427    }
428}