1use std::{
6 convert::Infallible,
7 net::SocketAddr,
8 task::{Context, Poll},
9};
10
11use eyre::{Result, eyre};
12use futures::FutureExt;
13use tokio::net::{TcpListener, ToSocketAddrs};
14use tokio_stream::wrappers::TcpListenerStream;
15use tonic::{
16 body::Body,
17 codegen::{
18 BoxFuture,
19 http::{HeaderValue, Request, Response},
20 },
21 server::NamedService,
22 transport::server::Router,
23};
24use tower::{
25 Layer, Service, ServiceBuilder,
26 layer::util::{Identity, Stack},
27 limit::GlobalConcurrencyLimitLayer,
28 load_shed::LoadShedLayer,
29 util::Either,
30};
31use tower_http::{
32 classify::{GrpcErrorsAsFailures, SharedClassifier},
33 propagate_header::PropagateHeaderLayer,
34 set_header::SetRequestHeaderLayer,
35 trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer},
36};
37
38use crate::{
39 config::Config,
40 metrics::{
41 DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
42 MetricsHandler,
43 },
44 multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
45};
46
47pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
48 router: Router<WrapperService<M>>,
49 health_reporter: tonic_health::server::HealthReporter,
50}
51
52type AddPathToHeaderFunction = fn(&Request<Body>) -> Option<HeaderValue>;
53
54type WrapperService<M> = Stack<
55 Stack<
56 PropagateHeaderLayer,
57 Stack<
58 TraceLayer<
59 SharedClassifier<GrpcErrorsAsFailures>,
60 DefaultMakeSpan,
61 MetricsHandler<M>,
62 MetricsHandler<M>,
63 DefaultOnBodyChunk,
64 DefaultOnEos,
65 MetricsHandler<M>,
66 >,
67 Stack<
68 SetRequestHeaderLayer<AddPathToHeaderFunction>,
69 Stack<
70 RequestLifetimeLayer<M>,
71 Stack<
72 Either<LoadShedLayer, Identity>,
73 Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
74 >,
75 >,
76 >,
77 >,
78 >,
79 Identity,
80>;
81
82impl<M: MetricsCallbackProvider> ServerBuilder<M> {
83 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
84 let mut builder = tonic::transport::server::Server::builder();
85
86 if let Some(limit) = config.concurrency_limit_per_connection {
87 builder = builder.concurrency_limit_per_connection(limit);
88 }
89
90 if let Some(timeout) = config.request_timeout {
91 builder = builder.timeout(timeout);
92 }
93
94 if let Some(tcp_nodelay) = config.tcp_nodelay {
95 builder = builder.tcp_nodelay(tcp_nodelay);
96 }
97
98 let load_shed = config
99 .load_shed
100 .unwrap_or_default()
101 .then_some(tower::load_shed::LoadShedLayer::new());
102
103 let metrics = MetricsHandler::new(metrics_provider.clone());
104
105 let request_metrics = TraceLayer::new_for_grpc()
106 .on_request(metrics.clone())
107 .on_response(metrics.clone())
108 .on_failure(metrics);
109
110 let global_concurrency_limit = config
111 .global_concurrency_limit
112 .map(tower::limit::GlobalConcurrencyLimitLayer::new);
113
114 fn add_path_to_request_header(request: &Request<Body>) -> Option<HeaderValue> {
115 let path = request.uri().path();
116 Some(HeaderValue::from_str(path).unwrap())
117 }
118
119 let layer = ServiceBuilder::new()
120 .option_layer(global_concurrency_limit)
121 .option_layer(load_shed)
122 .layer(RequestLifetimeLayer { metrics_provider })
123 .layer(SetRequestHeaderLayer::overriding(
124 GRPC_ENDPOINT_PATH_HEADER.clone(),
125 add_path_to_request_header as AddPathToHeaderFunction,
126 ))
127 .layer(request_metrics)
128 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
129 .into_inner();
130
131 let (health_reporter, health_service) = tonic_health::server::health_reporter();
132 let router = builder
133 .initial_stream_window_size(config.http2_initial_stream_window_size)
134 .initial_connection_window_size(config.http2_initial_connection_window_size)
135 .http2_keepalive_interval(config.http2_keepalive_interval)
136 .http2_keepalive_timeout(config.http2_keepalive_timeout)
137 .max_concurrent_streams(config.http2_max_concurrent_streams)
138 .tcp_keepalive(config.tcp_keepalive)
139 .layer(layer)
140 .add_service(health_service);
141
142 Self {
143 router,
144 health_reporter,
145 }
146 }
147
148 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
149 self.health_reporter.clone()
150 }
151
152 pub fn add_service<S>(mut self, svc: S) -> Self
154 where
155 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
156 + NamedService
157 + Clone
158 + Send
159 + Sync
160 + 'static,
161 S::Future: Send + 'static,
162 {
163 self.router = self.router.add_service(svc);
164 self
165 }
166
167 pub async fn bind(self, addr: &Multiaddr) -> Result<Server> {
168 let mut iter = addr.iter();
169
170 let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
171 let rx_cancellation = rx_cancellation.map(|_| ());
172 let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) =
173 match iter.next().ok_or_else(|| eyre!("malformed addr"))? {
174 Protocol::Dns(_) => {
175 let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
176 let (local_addr, incoming) =
177 tcp_listener_and_update_multiaddr(addr, (dns_name.as_ref(), tcp_port))
178 .await?;
179 let server = Box::pin(
180 self.router
181 .serve_with_incoming_shutdown(incoming, rx_cancellation),
182 );
183 (local_addr, server)
184 }
185 Protocol::Ip4(_) => {
186 let (socket_addr, _http_or_https) = parse_ip4(addr)?;
187 let (local_addr, incoming) =
188 tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
189 let server = Box::pin(
190 self.router
191 .serve_with_incoming_shutdown(incoming, rx_cancellation),
192 );
193 (local_addr, server)
194 }
195 Protocol::Ip6(_) => {
196 let (socket_addr, _http_or_https) = parse_ip6(addr)?;
197 let (local_addr, incoming) =
198 tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
199 let server = Box::pin(
200 self.router
201 .serve_with_incoming_shutdown(incoming, rx_cancellation),
202 );
203 (local_addr, server)
204 }
205 unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
206 };
207
208 Ok(Server {
209 server,
210 cancel_handle: Some(tx_cancellation),
211 local_addr,
212 health_reporter: self.health_reporter,
213 })
214 }
215}
216
217async fn tcp_listener_and_update_multiaddr<T: ToSocketAddrs>(
218 address: &Multiaddr,
219 socket_addr: T,
220) -> Result<(Multiaddr, TcpListenerStream)> {
221 let (local_addr, incoming) = tcp_listener(socket_addr).await?;
222 let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
223 Ok((local_addr, incoming))
224}
225
226async fn tcp_listener<T: ToSocketAddrs>(address: T) -> Result<(SocketAddr, TcpListenerStream)> {
227 let listener = TcpListener::bind(address).await?;
228 let local_addr = listener.local_addr()?;
229 let incoming = TcpListenerStream::new(listener);
230 Ok((local_addr, incoming))
231}
232
233pub struct Server {
234 server: BoxFuture<(), tonic::transport::Error>,
235 cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
236 local_addr: Multiaddr,
237 health_reporter: tonic_health::server::HealthReporter,
238}
239
240impl Server {
241 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
242 self.server.await
243 }
244
245 pub fn local_addr(&self) -> &Multiaddr {
246 &self.local_addr
247 }
248
249 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
250 self.health_reporter.clone()
251 }
252
253 pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
254 self.cancel_handle.take()
255 }
256}
257
258fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
259 addr.replace(1, |protocol| {
260 if let Protocol::Tcp(_) = protocol {
261 Some(Protocol::Tcp(port))
262 } else {
263 panic!("expected tcp protocol at index 1");
264 }
265 })
266 .expect("tcp protocol at index 1")
267}
268
269#[cfg(test)]
270mod test {
271 use std::{
272 ops::Deref,
273 sync::{Arc, Mutex},
274 time::Duration,
275 };
276
277 use tonic::Code;
278 use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
279
280 use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
281
282 #[test]
283 fn document_multiaddr_limitation_for_unix_protocol() {
284 let path = "/tmp/foo";
286 let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
287
288 let s = addr.to_string();
290 assert!(s.parse::<Multiaddr>().is_err());
291 }
292
293 #[tokio::test]
294 async fn test_metrics_layer_successful() {
295 #[derive(Clone)]
296 struct Metrics {
297 metrics_called: Arc<Mutex<bool>>,
300 }
301
302 impl MetricsCallbackProvider for Metrics {
303 fn on_request(&self, path: String) {
304 assert_eq!(path, "/grpc.health.v1.Health/Check");
305 }
306
307 fn on_response(
308 &self,
309 path: String,
310 _latency: Duration,
311 status: u16,
312 grpc_status_code: Code,
313 ) {
314 assert_eq!(path, "/grpc.health.v1.Health/Check");
315 assert_eq!(status, 200);
316 assert_eq!(grpc_status_code, Code::Ok);
317 let mut m = self.metrics_called.lock().unwrap();
318 *m = true
319 }
320 }
321
322 let metrics = Metrics {
323 metrics_called: Arc::new(Mutex::new(false)),
324 };
325
326 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
327 let config = Config::new();
328
329 let mut server = config
330 .server_builder_with_metrics(metrics.clone())
331 .bind(&address)
332 .await
333 .unwrap();
334
335 let address = server.local_addr().to_owned();
336 let cancel_handle = server.take_cancel_handle().unwrap();
337 let server_handle = tokio::spawn(server.serve());
338 let channel = config.connect(&address).await.unwrap();
339 let mut client = HealthClient::new(channel);
340
341 client
342 .check(HealthCheckRequest {
343 service: "".to_owned(),
344 })
345 .await
346 .unwrap();
347
348 cancel_handle.send(()).unwrap();
349 server_handle.await.unwrap().unwrap();
350
351 assert!(metrics.metrics_called.lock().unwrap().deref());
352 }
353
354 #[tokio::test]
355 async fn test_metrics_layer_error() {
356 #[derive(Clone)]
357 struct Metrics {
358 metrics_called: Arc<Mutex<bool>>,
361 }
362
363 impl MetricsCallbackProvider for Metrics {
364 fn on_request(&self, path: String) {
365 assert_eq!(path, "/grpc.health.v1.Health/Check");
366 }
367
368 fn on_response(
369 &self,
370 path: String,
371 _latency: Duration,
372 status: u16,
373 grpc_status_code: Code,
374 ) {
375 assert_eq!(path, "/grpc.health.v1.Health/Check");
376 assert_eq!(status, 200);
377 assert_eq!(grpc_status_code, Code::NotFound);
380 let mut m = self.metrics_called.lock().unwrap();
381 *m = true
382 }
383 }
384
385 let metrics = Metrics {
386 metrics_called: Arc::new(Mutex::new(false)),
387 };
388
389 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
390 let config = Config::new();
391
392 let mut server = config
393 .server_builder_with_metrics(metrics.clone())
394 .bind(&address)
395 .await
396 .unwrap();
397
398 let address = server.local_addr().to_owned();
399 let cancel_handle = server.take_cancel_handle().unwrap();
400 let server_handle = tokio::spawn(server.serve());
401 let channel = config.connect(&address).await.unwrap();
402 let mut client = HealthClient::new(channel);
403
404 let _ = client
408 .check(HealthCheckRequest {
409 service: "non-existing-service".to_owned(),
410 })
411 .await;
412
413 cancel_handle.send(()).unwrap();
414 server_handle.await.unwrap().unwrap();
415
416 assert!(metrics.metrics_called.lock().unwrap().deref());
417 }
418
419 async fn test_multiaddr(address: Multiaddr) {
420 let config = Config::new();
421 let mut server = config.server_builder().bind(&address).await.unwrap();
422 let address = server.local_addr().to_owned();
423 let cancel_handle = server.take_cancel_handle().unwrap();
424 let server_handle = tokio::spawn(server.serve());
425 let channel = config.connect(&address).await.unwrap();
426 let mut client = HealthClient::new(channel);
427
428 client
429 .check(HealthCheckRequest {
430 service: "".to_owned(),
431 })
432 .await
433 .unwrap();
434
435 cancel_handle.send(()).unwrap();
436 server_handle.await.unwrap().unwrap();
437 }
438
439 #[tokio::test]
440 async fn dns() {
441 let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
442 test_multiaddr(address).await;
443 }
444
445 #[tokio::test]
446 async fn ip4() {
447 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
448 test_multiaddr(address).await;
449 }
450
451 #[tokio::test]
452 async fn ip6() {
453 let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
454 test_multiaddr(address).await;
455 }
456}
457
458#[derive(Clone)]
459struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
460 metrics_provider: M,
461}
462
463impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
464 type Service = RequestLifetime<M, S>;
465
466 fn layer(&self, inner: S) -> Self::Service {
467 RequestLifetime {
468 inner,
469 metrics_provider: self.metrics_provider.clone(),
470 path: None,
471 }
472 }
473}
474
475#[derive(Clone)]
476struct RequestLifetime<M: MetricsCallbackProvider, S> {
477 inner: S,
478 metrics_provider: M,
479 path: Option<String>,
480}
481
482impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
483 for RequestLifetime<M, S>
484where
485 S: Service<Request<RequestBody>>,
486{
487 type Response = S::Response;
488 type Error = S::Error;
489 type Future = S::Future;
490
491 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
492 self.inner.poll_ready(cx)
493 }
494
495 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
496 if self.path.is_none() {
497 let path = request.uri().path().to_string();
498 self.metrics_provider.on_start(&path);
499 self.path = Some(path);
500 }
501 self.inner.call(request)
502 }
503}
504
505impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
506 fn drop(&mut self) {
507 if let Some(path) = &self.path {
508 self.metrics_provider.on_drop(path)
509 }
510 }
511}