iota_network_stack/
server.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    convert::Infallible,
7    net::SocketAddr,
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11};
12
13use eyre::{Result, eyre};
14use futures::{FutureExt, Stream, StreamExt, stream::FuturesUnordered};
15use tokio::{
16    io::{AsyncRead, AsyncWrite},
17    net::{TcpListener, TcpStream, ToSocketAddrs},
18};
19use tokio_rustls::{TlsAcceptor, rustls::ServerConfig, server::TlsStream};
20use tonic::{
21    body::Body,
22    codegen::{
23        BoxFuture,
24        http::{HeaderValue, Request, Response},
25    },
26    server::NamedService,
27    transport::server::Router,
28};
29use tower::{
30    Layer, Service, ServiceBuilder,
31    layer::util::{Identity, Stack},
32    limit::GlobalConcurrencyLimitLayer,
33    load_shed::LoadShedLayer,
34    util::Either,
35};
36use tower_http::{
37    classify::{GrpcErrorsAsFailures, SharedClassifier},
38    propagate_header::PropagateHeaderLayer,
39    set_header::SetRequestHeaderLayer,
40    trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer},
41};
42use tracing::debug;
43
44use crate::{
45    config::Config,
46    metrics::{
47        DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
48        MetricsHandler,
49    },
50    multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
51};
52
53pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
54    router: Router<WrapperService<M>>,
55    health_reporter: tonic_health::server::HealthReporter,
56}
57
58type AddPathToHeaderFunction = fn(&Request<Body>) -> Option<HeaderValue>;
59
60type WrapperService<M> = Stack<
61    Stack<
62        PropagateHeaderLayer,
63        Stack<
64            TraceLayer<
65                SharedClassifier<GrpcErrorsAsFailures>,
66                DefaultMakeSpan,
67                MetricsHandler<M>,
68                MetricsHandler<M>,
69                DefaultOnBodyChunk,
70                DefaultOnEos,
71                MetricsHandler<M>,
72            >,
73            Stack<
74                SetRequestHeaderLayer<AddPathToHeaderFunction>,
75                Stack<
76                    RequestLifetimeLayer<M>,
77                    Stack<
78                        Either<LoadShedLayer, Identity>,
79                        Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
80                    >,
81                >,
82            >,
83        >,
84    >,
85    Identity,
86>;
87
88impl<M: MetricsCallbackProvider> ServerBuilder<M> {
89    pub fn from_config(config: &Config, metrics_provider: M) -> Self {
90        let mut builder = tonic::transport::server::Server::builder();
91
92        if let Some(limit) = config.concurrency_limit_per_connection {
93            builder = builder.concurrency_limit_per_connection(limit);
94        }
95
96        if let Some(timeout) = config.request_timeout {
97            builder = builder.timeout(timeout);
98        }
99
100        if let Some(tcp_nodelay) = config.tcp_nodelay {
101            builder = builder.tcp_nodelay(tcp_nodelay);
102        }
103
104        let load_shed = config
105            .load_shed
106            .unwrap_or_default()
107            .then_some(tower::load_shed::LoadShedLayer::new());
108
109        let metrics = MetricsHandler::new(metrics_provider.clone());
110
111        let request_metrics = TraceLayer::new_for_grpc()
112            .on_request(metrics.clone())
113            .on_response(metrics.clone())
114            .on_failure(metrics);
115
116        let global_concurrency_limit = config
117            .global_concurrency_limit
118            .map(tower::limit::GlobalConcurrencyLimitLayer::new);
119
120        fn add_path_to_request_header(request: &Request<Body>) -> Option<HeaderValue> {
121            let path = request.uri().path();
122            Some(HeaderValue::from_str(path).unwrap())
123        }
124
125        let layer = ServiceBuilder::new()
126            .option_layer(global_concurrency_limit)
127            .option_layer(load_shed)
128            .layer(RequestLifetimeLayer { metrics_provider })
129            .layer(SetRequestHeaderLayer::overriding(
130                GRPC_ENDPOINT_PATH_HEADER.clone(),
131                add_path_to_request_header as AddPathToHeaderFunction,
132            ))
133            .layer(request_metrics)
134            .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
135            .into_inner();
136
137        let (health_reporter, health_service) = tonic_health::server::health_reporter();
138        let router = builder
139            .initial_stream_window_size(config.http2_initial_stream_window_size)
140            .initial_connection_window_size(config.http2_initial_connection_window_size)
141            .http2_keepalive_interval(config.http2_keepalive_interval)
142            .http2_keepalive_timeout(config.http2_keepalive_timeout)
143            .max_concurrent_streams(config.http2_max_concurrent_streams)
144            .tcp_keepalive(config.tcp_keepalive)
145            .layer(layer)
146            .add_service(health_service);
147
148        Self {
149            router,
150            health_reporter,
151        }
152    }
153
154    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
155        self.health_reporter.clone()
156    }
157
158    /// Add a new service to this Server.
159    pub fn add_service<S>(mut self, svc: S) -> Self
160    where
161        S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
162            + NamedService
163            + Clone
164            + Send
165            + Sync
166            + 'static,
167        S::Future: Send + 'static,
168    {
169        self.router = self.router.add_service(svc);
170        self
171    }
172
173    pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
174        let mut iter = addr.iter();
175
176        let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
177        let rx_cancellation = rx_cancellation.map(|_| ());
178        let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) = match iter
179            .next()
180            .ok_or_else(|| eyre!("malformed addr"))?
181        {
182            Protocol::Dns(_) => {
183                let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
184                let (local_addr, incoming) =
185                    listen_and_update_multiaddr(addr, (dns_name.to_string(), tcp_port), tls_config)
186                        .await?;
187                let server = Box::pin(
188                    self.router
189                        .serve_with_incoming_shutdown(incoming, rx_cancellation),
190                );
191                (local_addr, server)
192            }
193            Protocol::Ip4(_) => {
194                let (socket_addr, _http_or_https) = parse_ip4(addr)?;
195                let (local_addr, incoming) =
196                    listen_and_update_multiaddr(addr, socket_addr, tls_config).await?;
197                let server = Box::pin(
198                    self.router
199                        .serve_with_incoming_shutdown(incoming, rx_cancellation),
200                );
201                (local_addr, server)
202            }
203            Protocol::Ip6(_) => {
204                let (socket_addr, _http_or_https) = parse_ip6(addr)?;
205                let (local_addr, incoming) =
206                    listen_and_update_multiaddr(addr, socket_addr, tls_config).await?;
207                let server = Box::pin(
208                    self.router
209                        .serve_with_incoming_shutdown(incoming, rx_cancellation),
210                );
211                (local_addr, server)
212            }
213            unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
214        };
215
216        Ok(Server {
217            server,
218            cancel_handle: Some(tx_cancellation),
219            local_addr,
220            health_reporter: self.health_reporter,
221        })
222    }
223}
224
225async fn listen_and_update_multiaddr<T: ToSocketAddrs>(
226    address: &Multiaddr,
227    socket_addr: T,
228    tls_config: Option<ServerConfig>,
229) -> Result<(
230    Multiaddr,
231    impl Stream<Item = std::io::Result<TcpOrTlsStream>>,
232)> {
233    let listener = TcpListener::bind(socket_addr).await?;
234    let local_addr = listener.local_addr()?;
235    let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
236
237    let tls_acceptor = tls_config.map(|tls_config| TlsAcceptor::from(Arc::new(tls_config)));
238    let incoming = TcpOrTlsListener::new(listener, tls_acceptor);
239    let stream = async_stream::stream! {
240        let mut new_connections = FuturesUnordered::new();
241        loop {
242            tokio::select! {
243                result = incoming.accept_raw() => {
244                    match result {
245                        Ok((stream, addr)) => {
246                            new_connections.push(incoming.maybe_upgrade(stream, addr));
247                        }
248                        Err(e) => yield Err(e),
249                    }
250                }
251                Some(result) = new_connections.next() => {
252                    yield result;
253                }
254            }
255        }
256    };
257
258    Ok((local_addr, stream))
259}
260
261pub struct TcpOrTlsListener {
262    listener: TcpListener,
263    tls_acceptor: Option<TlsAcceptor>,
264}
265
266impl TcpOrTlsListener {
267    fn new(listener: TcpListener, tls_acceptor: Option<TlsAcceptor>) -> Self {
268        Self {
269            listener,
270            tls_acceptor,
271        }
272    }
273
274    async fn accept_raw(&self) -> std::io::Result<(TcpStream, SocketAddr)> {
275        self.listener.accept().await
276    }
277
278    async fn maybe_upgrade(
279        &self,
280        stream: TcpStream,
281        addr: SocketAddr,
282    ) -> std::io::Result<TcpOrTlsStream> {
283        if self.tls_acceptor.is_none() {
284            return Ok(TcpOrTlsStream::Tcp(stream, addr));
285        }
286
287        // Determine whether new connection is TLS.
288        let mut buf = [0; 1];
289        // `peek` blocks until at least some data is available, so if there is no error
290        // then it must return the one byte we are requesting.
291        stream.peek(&mut buf).await?;
292        if buf[0] == 0x16 {
293            // First byte of a TLS handshake is 0x16.
294            debug!("accepting TLS connection from {addr:?}");
295            let stream = self.tls_acceptor.as_ref().unwrap().accept(stream).await?;
296            Ok(TcpOrTlsStream::Tls(Box::new(stream), addr))
297        } else {
298            debug!("accepting TCP connection from {addr:?}");
299            Ok(TcpOrTlsStream::Tcp(stream, addr))
300        }
301    }
302}
303
304pub enum TcpOrTlsStream {
305    Tcp(TcpStream, SocketAddr),
306    Tls(Box<TlsStream<TcpStream>>, SocketAddr),
307}
308
309impl AsyncRead for TcpOrTlsStream {
310    fn poll_read(
311        self: Pin<&mut Self>,
312        cx: &mut Context<'_>,
313        buf: &mut tokio::io::ReadBuf,
314    ) -> Poll<std::io::Result<()>> {
315        match self.get_mut() {
316            TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_read(cx, buf),
317            TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_read(cx, buf),
318        }
319    }
320}
321
322impl AsyncWrite for TcpOrTlsStream {
323    fn poll_write(
324        self: Pin<&mut Self>,
325        cx: &mut Context<'_>,
326        buf: &[u8],
327    ) -> Poll<std::result::Result<usize, std::io::Error>> {
328        match self.get_mut() {
329            TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_write(cx, buf),
330            TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_write(cx, buf),
331        }
332    }
333
334    fn poll_flush(
335        self: Pin<&mut Self>,
336        cx: &mut Context<'_>,
337    ) -> Poll<std::result::Result<(), std::io::Error>> {
338        match self.get_mut() {
339            TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_flush(cx),
340            TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_flush(cx),
341        }
342    }
343
344    fn poll_shutdown(
345        self: Pin<&mut Self>,
346        cx: &mut Context<'_>,
347    ) -> Poll<std::result::Result<(), std::io::Error>> {
348        match self.get_mut() {
349            TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_shutdown(cx),
350            TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_shutdown(cx),
351        }
352    }
353}
354
355impl tonic::transport::server::Connected for TcpOrTlsStream {
356    type ConnectInfo = tonic::transport::server::TcpConnectInfo;
357
358    fn connect_info(&self) -> Self::ConnectInfo {
359        match self {
360            TcpOrTlsStream::Tcp(stream, addr) => Self::ConnectInfo {
361                local_addr: stream.local_addr().ok(),
362                remote_addr: Some(*addr),
363            },
364            TcpOrTlsStream::Tls(stream, addr) => Self::ConnectInfo {
365                local_addr: stream.get_ref().0.local_addr().ok(),
366                remote_addr: Some(*addr),
367            },
368        }
369    }
370}
371
372/// TLS server name to use for the public IOTA validator interface.
373pub const IOTA_TLS_SERVER_NAME: &str = "iota";
374
375pub struct Server {
376    server: BoxFuture<(), tonic::transport::Error>,
377    cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
378    local_addr: Multiaddr,
379    health_reporter: tonic_health::server::HealthReporter,
380}
381
382impl Server {
383    pub async fn serve(self) -> Result<(), tonic::transport::Error> {
384        self.server.await
385    }
386
387    pub fn local_addr(&self) -> &Multiaddr {
388        &self.local_addr
389    }
390
391    pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
392        self.health_reporter.clone()
393    }
394
395    pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
396        self.cancel_handle.take()
397    }
398}
399
400fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
401    addr.replace(1, |protocol| {
402        if let Protocol::Tcp(_) = protocol {
403            Some(Protocol::Tcp(port))
404        } else {
405            panic!("expected tcp protocol at index 1");
406        }
407    })
408    .expect("tcp protocol at index 1")
409}
410
411#[cfg(test)]
412mod test {
413    use std::{
414        ops::Deref,
415        sync::{Arc, Mutex},
416        time::Duration,
417    };
418
419    use tonic::Code;
420    use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
421
422    use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
423
424    #[test]
425    fn document_multiaddr_limitation_for_unix_protocol() {
426        // You can construct a multiaddr by hand (ie binary format) just fine
427        let path = "/tmp/foo";
428        let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
429
430        // But it doesn't round-trip in the human readable format
431        let s = addr.to_string();
432        assert!(s.parse::<Multiaddr>().is_err());
433    }
434
435    #[tokio::test]
436    async fn test_metrics_layer_successful() {
437        #[derive(Clone)]
438        struct Metrics {
439            /// a flag to figure out whether the
440            /// on_request method has been called.
441            metrics_called: Arc<Mutex<bool>>,
442        }
443
444        impl MetricsCallbackProvider for Metrics {
445            fn on_request(&self, path: String) {
446                assert_eq!(path, "/grpc.health.v1.Health/Check");
447            }
448
449            fn on_response(
450                &self,
451                path: String,
452                _latency: Duration,
453                status: u16,
454                grpc_status_code: Code,
455            ) {
456                assert_eq!(path, "/grpc.health.v1.Health/Check");
457                assert_eq!(status, 200);
458                assert_eq!(grpc_status_code, Code::Ok);
459                let mut m = self.metrics_called.lock().unwrap();
460                *m = true
461            }
462        }
463
464        let metrics = Metrics {
465            metrics_called: Arc::new(Mutex::new(false)),
466        };
467
468        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
469        let config = Config::new();
470
471        let mut server = config
472            .server_builder_with_metrics(metrics.clone())
473            .bind(&address, None)
474            .await
475            .unwrap();
476
477        let address = server.local_addr().to_owned();
478        let cancel_handle = server.take_cancel_handle().unwrap();
479        let server_handle = tokio::spawn(server.serve());
480        let channel = config.connect(&address, None).await.unwrap();
481        let mut client = HealthClient::new(channel);
482
483        client
484            .check(HealthCheckRequest {
485                service: "".to_owned(),
486            })
487            .await
488            .unwrap();
489
490        cancel_handle.send(()).unwrap();
491        server_handle.await.unwrap().unwrap();
492
493        assert!(metrics.metrics_called.lock().unwrap().deref());
494    }
495
496    #[tokio::test]
497    async fn test_metrics_layer_error() {
498        #[derive(Clone)]
499        struct Metrics {
500            /// a flag to figure out whether the
501            /// on_request method has been called.
502            metrics_called: Arc<Mutex<bool>>,
503        }
504
505        impl MetricsCallbackProvider for Metrics {
506            fn on_request(&self, path: String) {
507                assert_eq!(path, "/grpc.health.v1.Health/Check");
508            }
509
510            fn on_response(
511                &self,
512                path: String,
513                _latency: Duration,
514                status: u16,
515                grpc_status_code: Code,
516            ) {
517                assert_eq!(path, "/grpc.health.v1.Health/Check");
518                assert_eq!(status, 200);
519                // According to https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
520                // code 5 is not_found , which is what we expect to get in this case
521                assert_eq!(grpc_status_code, Code::NotFound);
522                let mut m = self.metrics_called.lock().unwrap();
523                *m = true
524            }
525        }
526
527        let metrics = Metrics {
528            metrics_called: Arc::new(Mutex::new(false)),
529        };
530
531        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
532        let config = Config::new();
533
534        let mut server = config
535            .server_builder_with_metrics(metrics.clone())
536            .bind(&address, None)
537            .await
538            .unwrap();
539
540        let address = server.local_addr().to_owned();
541        let cancel_handle = server.take_cancel_handle().unwrap();
542        let server_handle = tokio::spawn(server.serve());
543        let channel = config.connect(&address, None).await.unwrap();
544        let mut client = HealthClient::new(channel);
545
546        // Call the healthcheck for a service that doesn't exist
547        // that should give us back an error with code 5 (not_found)
548        // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
549        let _ = client
550            .check(HealthCheckRequest {
551                service: "non-existing-service".to_owned(),
552            })
553            .await;
554
555        cancel_handle.send(()).unwrap();
556        server_handle.await.unwrap().unwrap();
557
558        assert!(metrics.metrics_called.lock().unwrap().deref());
559    }
560
561    async fn test_multiaddr(address: Multiaddr) {
562        let config = Config::new();
563        let mut server = config.server_builder().bind(&address, None).await.unwrap();
564        let address = server.local_addr().to_owned();
565        let cancel_handle = server.take_cancel_handle().unwrap();
566        let server_handle = tokio::spawn(server.serve());
567        let channel = config.connect(&address, None).await.unwrap();
568        let mut client = HealthClient::new(channel);
569
570        client
571            .check(HealthCheckRequest {
572                service: "".to_owned(),
573            })
574            .await
575            .unwrap();
576
577        cancel_handle.send(()).unwrap();
578        server_handle.await.unwrap().unwrap();
579    }
580
581    #[tokio::test]
582    async fn dns() {
583        let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
584        test_multiaddr(address).await;
585    }
586
587    #[tokio::test]
588    async fn ip4() {
589        let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
590        test_multiaddr(address).await;
591    }
592
593    #[tokio::test]
594    async fn ip6() {
595        let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
596        test_multiaddr(address).await;
597    }
598}
599
600#[derive(Clone)]
601struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
602    metrics_provider: M,
603}
604
605impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
606    type Service = RequestLifetime<M, S>;
607
608    fn layer(&self, inner: S) -> Self::Service {
609        RequestLifetime {
610            inner,
611            metrics_provider: self.metrics_provider.clone(),
612            path: None,
613        }
614    }
615}
616
617#[derive(Clone)]
618struct RequestLifetime<M: MetricsCallbackProvider, S> {
619    inner: S,
620    metrics_provider: M,
621    path: Option<String>,
622}
623
624impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
625    for RequestLifetime<M, S>
626where
627    S: Service<Request<RequestBody>>,
628{
629    type Response = S::Response;
630    type Error = S::Error;
631    type Future = S::Future;
632
633    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
634        self.inner.poll_ready(cx)
635    }
636
637    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
638        if self.path.is_none() {
639            let path = request.uri().path().to_string();
640            self.metrics_provider.on_start(&path);
641            self.path = Some(path);
642        }
643        self.inner.call(request)
644    }
645}
646
647impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
648    fn drop(&mut self) {
649        if let Some(path) = &self.path {
650            self.metrics_provider.on_drop(path)
651        }
652    }
653}