1use std::{
6 convert::Infallible,
7 net::SocketAddr,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12
13use eyre::{Result, eyre};
14use futures::{FutureExt, Stream, StreamExt, stream::FuturesUnordered};
15use tokio::{
16 io::{AsyncRead, AsyncWrite},
17 net::{TcpListener, TcpStream, ToSocketAddrs},
18};
19use tokio_rustls::{TlsAcceptor, rustls::ServerConfig, server::TlsStream};
20use tonic::{
21 body::Body,
22 codegen::{
23 BoxFuture,
24 http::{HeaderValue, Request, Response},
25 },
26 server::NamedService,
27 transport::server::Router,
28};
29use tower::{
30 Layer, Service, ServiceBuilder,
31 layer::util::{Identity, Stack},
32 limit::GlobalConcurrencyLimitLayer,
33 load_shed::LoadShedLayer,
34 util::Either,
35};
36use tower_http::{
37 classify::{GrpcErrorsAsFailures, SharedClassifier},
38 propagate_header::PropagateHeaderLayer,
39 set_header::SetRequestHeaderLayer,
40 trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer},
41};
42use tracing::debug;
43
44use crate::{
45 config::Config,
46 metrics::{
47 DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
48 MetricsHandler,
49 },
50 multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
51};
52
53pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
54 router: Router<WrapperService<M>>,
55 health_reporter: tonic_health::server::HealthReporter,
56}
57
58type AddPathToHeaderFunction = fn(&Request<Body>) -> Option<HeaderValue>;
59
60type WrapperService<M> = Stack<
61 Stack<
62 PropagateHeaderLayer,
63 Stack<
64 TraceLayer<
65 SharedClassifier<GrpcErrorsAsFailures>,
66 DefaultMakeSpan,
67 MetricsHandler<M>,
68 MetricsHandler<M>,
69 DefaultOnBodyChunk,
70 DefaultOnEos,
71 MetricsHandler<M>,
72 >,
73 Stack<
74 SetRequestHeaderLayer<AddPathToHeaderFunction>,
75 Stack<
76 RequestLifetimeLayer<M>,
77 Stack<
78 Either<LoadShedLayer, Identity>,
79 Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
80 >,
81 >,
82 >,
83 >,
84 >,
85 Identity,
86>;
87
88impl<M: MetricsCallbackProvider> ServerBuilder<M> {
89 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
90 let mut builder = tonic::transport::server::Server::builder();
91
92 if let Some(limit) = config.concurrency_limit_per_connection {
93 builder = builder.concurrency_limit_per_connection(limit);
94 }
95
96 if let Some(timeout) = config.request_timeout {
97 builder = builder.timeout(timeout);
98 }
99
100 if let Some(tcp_nodelay) = config.tcp_nodelay {
101 builder = builder.tcp_nodelay(tcp_nodelay);
102 }
103
104 let load_shed = config
105 .load_shed
106 .unwrap_or_default()
107 .then_some(tower::load_shed::LoadShedLayer::new());
108
109 let metrics = MetricsHandler::new(metrics_provider.clone());
110
111 let request_metrics = TraceLayer::new_for_grpc()
112 .on_request(metrics.clone())
113 .on_response(metrics.clone())
114 .on_failure(metrics);
115
116 let global_concurrency_limit = config
117 .global_concurrency_limit
118 .map(tower::limit::GlobalConcurrencyLimitLayer::new);
119
120 fn add_path_to_request_header(request: &Request<Body>) -> Option<HeaderValue> {
121 let path = request.uri().path();
122 Some(HeaderValue::from_str(path).unwrap())
123 }
124
125 let layer = ServiceBuilder::new()
126 .option_layer(global_concurrency_limit)
127 .option_layer(load_shed)
128 .layer(RequestLifetimeLayer { metrics_provider })
129 .layer(SetRequestHeaderLayer::overriding(
130 GRPC_ENDPOINT_PATH_HEADER.clone(),
131 add_path_to_request_header as AddPathToHeaderFunction,
132 ))
133 .layer(request_metrics)
134 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
135 .into_inner();
136
137 let (health_reporter, health_service) = tonic_health::server::health_reporter();
138 let router = builder
139 .initial_stream_window_size(config.http2_initial_stream_window_size)
140 .initial_connection_window_size(config.http2_initial_connection_window_size)
141 .http2_keepalive_interval(config.http2_keepalive_interval)
142 .http2_keepalive_timeout(config.http2_keepalive_timeout)
143 .max_concurrent_streams(config.http2_max_concurrent_streams)
144 .tcp_keepalive(config.tcp_keepalive)
145 .layer(layer)
146 .add_service(health_service);
147
148 Self {
149 router,
150 health_reporter,
151 }
152 }
153
154 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
155 self.health_reporter.clone()
156 }
157
158 pub fn add_service<S>(mut self, svc: S) -> Self
160 where
161 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
162 + NamedService
163 + Clone
164 + Send
165 + Sync
166 + 'static,
167 S::Future: Send + 'static,
168 {
169 self.router = self.router.add_service(svc);
170 self
171 }
172
173 pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
174 let mut iter = addr.iter();
175
176 let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
177 let rx_cancellation = rx_cancellation.map(|_| ());
178 let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) = match iter
179 .next()
180 .ok_or_else(|| eyre!("malformed addr"))?
181 {
182 Protocol::Dns(_) => {
183 let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
184 let (local_addr, incoming) =
185 listen_and_update_multiaddr(addr, (dns_name.to_string(), tcp_port), tls_config)
186 .await?;
187 let server = Box::pin(
188 self.router
189 .serve_with_incoming_shutdown(incoming, rx_cancellation),
190 );
191 (local_addr, server)
192 }
193 Protocol::Ip4(_) => {
194 let (socket_addr, _http_or_https) = parse_ip4(addr)?;
195 let (local_addr, incoming) =
196 listen_and_update_multiaddr(addr, socket_addr, tls_config).await?;
197 let server = Box::pin(
198 self.router
199 .serve_with_incoming_shutdown(incoming, rx_cancellation),
200 );
201 (local_addr, server)
202 }
203 Protocol::Ip6(_) => {
204 let (socket_addr, _http_or_https) = parse_ip6(addr)?;
205 let (local_addr, incoming) =
206 listen_and_update_multiaddr(addr, socket_addr, tls_config).await?;
207 let server = Box::pin(
208 self.router
209 .serve_with_incoming_shutdown(incoming, rx_cancellation),
210 );
211 (local_addr, server)
212 }
213 unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
214 };
215
216 Ok(Server {
217 server,
218 cancel_handle: Some(tx_cancellation),
219 local_addr,
220 health_reporter: self.health_reporter,
221 })
222 }
223}
224
225async fn listen_and_update_multiaddr<T: ToSocketAddrs>(
226 address: &Multiaddr,
227 socket_addr: T,
228 tls_config: Option<ServerConfig>,
229) -> Result<(
230 Multiaddr,
231 impl Stream<Item = std::io::Result<TcpOrTlsStream>>,
232)> {
233 let listener = TcpListener::bind(socket_addr).await?;
234 let local_addr = listener.local_addr()?;
235 let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
236
237 let tls_acceptor = tls_config.map(|tls_config| TlsAcceptor::from(Arc::new(tls_config)));
238 let incoming = TcpOrTlsListener::new(listener, tls_acceptor);
239 let stream = async_stream::stream! {
240 let mut new_connections = FuturesUnordered::new();
241 loop {
242 tokio::select! {
243 result = incoming.accept_raw() => {
244 match result {
245 Ok((stream, addr)) => {
246 new_connections.push(incoming.maybe_upgrade(stream, addr));
247 }
248 Err(e) => yield Err(e),
249 }
250 }
251 Some(result) = new_connections.next() => {
252 yield result;
253 }
254 }
255 }
256 };
257
258 Ok((local_addr, stream))
259}
260
261pub struct TcpOrTlsListener {
262 listener: TcpListener,
263 tls_acceptor: Option<TlsAcceptor>,
264}
265
266impl TcpOrTlsListener {
267 fn new(listener: TcpListener, tls_acceptor: Option<TlsAcceptor>) -> Self {
268 Self {
269 listener,
270 tls_acceptor,
271 }
272 }
273
274 async fn accept_raw(&self) -> std::io::Result<(TcpStream, SocketAddr)> {
275 self.listener.accept().await
276 }
277
278 async fn maybe_upgrade(
279 &self,
280 stream: TcpStream,
281 addr: SocketAddr,
282 ) -> std::io::Result<TcpOrTlsStream> {
283 if self.tls_acceptor.is_none() {
284 return Ok(TcpOrTlsStream::Tcp(stream, addr));
285 }
286
287 let mut buf = [0; 1];
289 stream.peek(&mut buf).await?;
292 if buf[0] == 0x16 {
293 debug!("accepting TLS connection from {addr:?}");
295 let stream = self.tls_acceptor.as_ref().unwrap().accept(stream).await?;
296 Ok(TcpOrTlsStream::Tls(Box::new(stream), addr))
297 } else {
298 debug!("accepting TCP connection from {addr:?}");
299 Ok(TcpOrTlsStream::Tcp(stream, addr))
300 }
301 }
302}
303
304pub enum TcpOrTlsStream {
305 Tcp(TcpStream, SocketAddr),
306 Tls(Box<TlsStream<TcpStream>>, SocketAddr),
307}
308
309impl AsyncRead for TcpOrTlsStream {
310 fn poll_read(
311 self: Pin<&mut Self>,
312 cx: &mut Context<'_>,
313 buf: &mut tokio::io::ReadBuf,
314 ) -> Poll<std::io::Result<()>> {
315 match self.get_mut() {
316 TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_read(cx, buf),
317 TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_read(cx, buf),
318 }
319 }
320}
321
322impl AsyncWrite for TcpOrTlsStream {
323 fn poll_write(
324 self: Pin<&mut Self>,
325 cx: &mut Context<'_>,
326 buf: &[u8],
327 ) -> Poll<std::result::Result<usize, std::io::Error>> {
328 match self.get_mut() {
329 TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_write(cx, buf),
330 TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_write(cx, buf),
331 }
332 }
333
334 fn poll_flush(
335 self: Pin<&mut Self>,
336 cx: &mut Context<'_>,
337 ) -> Poll<std::result::Result<(), std::io::Error>> {
338 match self.get_mut() {
339 TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_flush(cx),
340 TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_flush(cx),
341 }
342 }
343
344 fn poll_shutdown(
345 self: Pin<&mut Self>,
346 cx: &mut Context<'_>,
347 ) -> Poll<std::result::Result<(), std::io::Error>> {
348 match self.get_mut() {
349 TcpOrTlsStream::Tcp(stream, _) => Pin::new(stream).poll_shutdown(cx),
350 TcpOrTlsStream::Tls(stream, _) => Pin::new(stream).poll_shutdown(cx),
351 }
352 }
353}
354
355impl tonic::transport::server::Connected for TcpOrTlsStream {
356 type ConnectInfo = tonic::transport::server::TcpConnectInfo;
357
358 fn connect_info(&self) -> Self::ConnectInfo {
359 match self {
360 TcpOrTlsStream::Tcp(stream, addr) => Self::ConnectInfo {
361 local_addr: stream.local_addr().ok(),
362 remote_addr: Some(*addr),
363 },
364 TcpOrTlsStream::Tls(stream, addr) => Self::ConnectInfo {
365 local_addr: stream.get_ref().0.local_addr().ok(),
366 remote_addr: Some(*addr),
367 },
368 }
369 }
370}
371
372pub const IOTA_TLS_SERVER_NAME: &str = "iota";
374
375pub struct Server {
376 server: BoxFuture<(), tonic::transport::Error>,
377 cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
378 local_addr: Multiaddr,
379 health_reporter: tonic_health::server::HealthReporter,
380}
381
382impl Server {
383 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
384 self.server.await
385 }
386
387 pub fn local_addr(&self) -> &Multiaddr {
388 &self.local_addr
389 }
390
391 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
392 self.health_reporter.clone()
393 }
394
395 pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
396 self.cancel_handle.take()
397 }
398}
399
400fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
401 addr.replace(1, |protocol| {
402 if let Protocol::Tcp(_) = protocol {
403 Some(Protocol::Tcp(port))
404 } else {
405 panic!("expected tcp protocol at index 1");
406 }
407 })
408 .expect("tcp protocol at index 1")
409}
410
411#[cfg(test)]
412mod test {
413 use std::{
414 ops::Deref,
415 sync::{Arc, Mutex},
416 time::Duration,
417 };
418
419 use tonic::Code;
420 use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
421
422 use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
423
424 #[test]
425 fn document_multiaddr_limitation_for_unix_protocol() {
426 let path = "/tmp/foo";
428 let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
429
430 let s = addr.to_string();
432 assert!(s.parse::<Multiaddr>().is_err());
433 }
434
435 #[tokio::test]
436 async fn test_metrics_layer_successful() {
437 #[derive(Clone)]
438 struct Metrics {
439 metrics_called: Arc<Mutex<bool>>,
442 }
443
444 impl MetricsCallbackProvider for Metrics {
445 fn on_request(&self, path: String) {
446 assert_eq!(path, "/grpc.health.v1.Health/Check");
447 }
448
449 fn on_response(
450 &self,
451 path: String,
452 _latency: Duration,
453 status: u16,
454 grpc_status_code: Code,
455 ) {
456 assert_eq!(path, "/grpc.health.v1.Health/Check");
457 assert_eq!(status, 200);
458 assert_eq!(grpc_status_code, Code::Ok);
459 let mut m = self.metrics_called.lock().unwrap();
460 *m = true
461 }
462 }
463
464 let metrics = Metrics {
465 metrics_called: Arc::new(Mutex::new(false)),
466 };
467
468 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
469 let config = Config::new();
470
471 let mut server = config
472 .server_builder_with_metrics(metrics.clone())
473 .bind(&address, None)
474 .await
475 .unwrap();
476
477 let address = server.local_addr().to_owned();
478 let cancel_handle = server.take_cancel_handle().unwrap();
479 let server_handle = tokio::spawn(server.serve());
480 let channel = config.connect(&address, None).await.unwrap();
481 let mut client = HealthClient::new(channel);
482
483 client
484 .check(HealthCheckRequest {
485 service: "".to_owned(),
486 })
487 .await
488 .unwrap();
489
490 cancel_handle.send(()).unwrap();
491 server_handle.await.unwrap().unwrap();
492
493 assert!(metrics.metrics_called.lock().unwrap().deref());
494 }
495
496 #[tokio::test]
497 async fn test_metrics_layer_error() {
498 #[derive(Clone)]
499 struct Metrics {
500 metrics_called: Arc<Mutex<bool>>,
503 }
504
505 impl MetricsCallbackProvider for Metrics {
506 fn on_request(&self, path: String) {
507 assert_eq!(path, "/grpc.health.v1.Health/Check");
508 }
509
510 fn on_response(
511 &self,
512 path: String,
513 _latency: Duration,
514 status: u16,
515 grpc_status_code: Code,
516 ) {
517 assert_eq!(path, "/grpc.health.v1.Health/Check");
518 assert_eq!(status, 200);
519 assert_eq!(grpc_status_code, Code::NotFound);
522 let mut m = self.metrics_called.lock().unwrap();
523 *m = true
524 }
525 }
526
527 let metrics = Metrics {
528 metrics_called: Arc::new(Mutex::new(false)),
529 };
530
531 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
532 let config = Config::new();
533
534 let mut server = config
535 .server_builder_with_metrics(metrics.clone())
536 .bind(&address, None)
537 .await
538 .unwrap();
539
540 let address = server.local_addr().to_owned();
541 let cancel_handle = server.take_cancel_handle().unwrap();
542 let server_handle = tokio::spawn(server.serve());
543 let channel = config.connect(&address, None).await.unwrap();
544 let mut client = HealthClient::new(channel);
545
546 let _ = client
550 .check(HealthCheckRequest {
551 service: "non-existing-service".to_owned(),
552 })
553 .await;
554
555 cancel_handle.send(()).unwrap();
556 server_handle.await.unwrap().unwrap();
557
558 assert!(metrics.metrics_called.lock().unwrap().deref());
559 }
560
561 async fn test_multiaddr(address: Multiaddr) {
562 let config = Config::new();
563 let mut server = config.server_builder().bind(&address, None).await.unwrap();
564 let address = server.local_addr().to_owned();
565 let cancel_handle = server.take_cancel_handle().unwrap();
566 let server_handle = tokio::spawn(server.serve());
567 let channel = config.connect(&address, None).await.unwrap();
568 let mut client = HealthClient::new(channel);
569
570 client
571 .check(HealthCheckRequest {
572 service: "".to_owned(),
573 })
574 .await
575 .unwrap();
576
577 cancel_handle.send(()).unwrap();
578 server_handle.await.unwrap().unwrap();
579 }
580
581 #[tokio::test]
582 async fn dns() {
583 let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
584 test_multiaddr(address).await;
585 }
586
587 #[tokio::test]
588 async fn ip4() {
589 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
590 test_multiaddr(address).await;
591 }
592
593 #[tokio::test]
594 async fn ip6() {
595 let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
596 test_multiaddr(address).await;
597 }
598}
599
600#[derive(Clone)]
601struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
602 metrics_provider: M,
603}
604
605impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
606 type Service = RequestLifetime<M, S>;
607
608 fn layer(&self, inner: S) -> Self::Service {
609 RequestLifetime {
610 inner,
611 metrics_provider: self.metrics_provider.clone(),
612 path: None,
613 }
614 }
615}
616
617#[derive(Clone)]
618struct RequestLifetime<M: MetricsCallbackProvider, S> {
619 inner: S,
620 metrics_provider: M,
621 path: Option<String>,
622}
623
624impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
625 for RequestLifetime<M, S>
626where
627 S: Service<Request<RequestBody>>,
628{
629 type Response = S::Response;
630 type Error = S::Error;
631 type Future = S::Future;
632
633 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
634 self.inner.poll_ready(cx)
635 }
636
637 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
638 if self.path.is_none() {
639 let path = request.uri().path().to_string();
640 self.metrics_provider.on_start(&path);
641 self.path = Some(path);
642 }
643 self.inner.call(request)
644 }
645}
646
647impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
648 fn drop(&mut self) {
649 if let Some(path) = &self.path {
650 self.metrics_provider.on_drop(path)
651 }
652 }
653}