1use std::{
6 convert::Infallible,
7 net::SocketAddr,
8 task::{Context, Poll},
9};
10
11use eyre::{Result, eyre};
12use futures::FutureExt;
13use tokio::net::{TcpListener, ToSocketAddrs};
14use tokio_stream::wrappers::TcpListenerStream;
15use tonic::{
16 body::BoxBody,
17 codegen::{
18 BoxFuture,
19 http::{HeaderValue, Request, Response},
20 },
21 server::NamedService,
22 transport::server::Router,
23};
24use tower::{
25 Layer, Service, ServiceBuilder,
26 layer::util::{Identity, Stack},
27 limit::GlobalConcurrencyLimitLayer,
28 load_shed::LoadShedLayer,
29 util::Either,
30};
31use tower_http::{
32 classify::{GrpcErrorsAsFailures, SharedClassifier},
33 propagate_header::PropagateHeaderLayer,
34 set_header::SetRequestHeaderLayer,
35 trace::{DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, TraceLayer},
36};
37
38use crate::{
39 config::Config,
40 metrics::{
41 DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
42 MetricsHandler,
43 },
44 multiaddr::{Multiaddr, Protocol, parse_dns, parse_ip4, parse_ip6},
45};
46
47pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
48 router: Router<WrapperService<M>>,
49 health_reporter: tonic_health::server::HealthReporter,
50}
51
52type AddPathToHeaderFunction = fn(&Request<BoxBody>) -> Option<HeaderValue>;
53
54type WrapperService<M> = Stack<
55 Stack<
56 PropagateHeaderLayer,
57 Stack<
58 TraceLayer<
59 SharedClassifier<GrpcErrorsAsFailures>,
60 DefaultMakeSpan,
61 MetricsHandler<M>,
62 MetricsHandler<M>,
63 DefaultOnBodyChunk,
64 DefaultOnEos,
65 MetricsHandler<M>,
66 >,
67 Stack<
68 SetRequestHeaderLayer<AddPathToHeaderFunction>,
69 Stack<
70 RequestLifetimeLayer<M>,
71 Stack<
72 Either<LoadShedLayer, Identity>,
73 Stack<Either<GlobalConcurrencyLimitLayer, Identity>, Identity>,
74 >,
75 >,
76 >,
77 >,
78 >,
79 Identity,
80>;
81
82impl<M: MetricsCallbackProvider> ServerBuilder<M> {
83 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
84 let mut builder = tonic::transport::server::Server::builder();
85
86 if let Some(limit) = config.concurrency_limit_per_connection {
87 builder = builder.concurrency_limit_per_connection(limit);
88 }
89
90 if let Some(timeout) = config.request_timeout {
91 builder = builder.timeout(timeout);
92 }
93
94 if let Some(tcp_nodelay) = config.tcp_nodelay {
95 builder = builder.tcp_nodelay(tcp_nodelay);
96 }
97
98 let load_shed = config
99 .load_shed
100 .unwrap_or_default()
101 .then_some(tower::load_shed::LoadShedLayer::new());
102
103 let metrics = MetricsHandler::new(metrics_provider.clone());
104
105 let request_metrics = TraceLayer::new_for_grpc()
106 .on_request(metrics.clone())
107 .on_response(metrics.clone())
108 .on_failure(metrics);
109
110 let global_concurrency_limit = config
111 .global_concurrency_limit
112 .map(tower::limit::GlobalConcurrencyLimitLayer::new);
113
114 fn add_path_to_request_header(request: &Request<BoxBody>) -> Option<HeaderValue> {
115 let path = request.uri().path();
116 Some(HeaderValue::from_str(path).unwrap())
117 }
118
119 let layer = ServiceBuilder::new()
120 .option_layer(global_concurrency_limit)
121 .option_layer(load_shed)
122 .layer(RequestLifetimeLayer { metrics_provider })
123 .layer(SetRequestHeaderLayer::overriding(
124 GRPC_ENDPOINT_PATH_HEADER.clone(),
125 add_path_to_request_header as AddPathToHeaderFunction,
126 ))
127 .layer(request_metrics)
128 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
129 .into_inner();
130
131 let (health_reporter, health_service) = tonic_health::server::health_reporter();
132 let router = builder
133 .initial_stream_window_size(config.http2_initial_stream_window_size)
134 .initial_connection_window_size(config.http2_initial_connection_window_size)
135 .http2_keepalive_interval(config.http2_keepalive_interval)
136 .http2_keepalive_timeout(config.http2_keepalive_timeout)
137 .max_concurrent_streams(config.http2_max_concurrent_streams)
138 .tcp_keepalive(config.tcp_keepalive)
139 .layer(layer)
140 .add_service(health_service);
141
142 Self {
143 router,
144 health_reporter,
145 }
146 }
147
148 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
149 self.health_reporter.clone()
150 }
151
152 pub fn add_service<S>(mut self, svc: S) -> Self
154 where
155 S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
156 + NamedService
157 + Clone
158 + Send
159 + 'static,
160 S::Future: Send + 'static,
161 {
162 self.router = self.router.add_service(svc);
163 self
164 }
165
166 pub async fn bind(self, addr: &Multiaddr) -> Result<Server> {
167 let mut iter = addr.iter();
168
169 let (tx_cancellation, rx_cancellation) = tokio::sync::oneshot::channel();
170 let rx_cancellation = rx_cancellation.map(|_| ());
171 let (local_addr, server): (Multiaddr, BoxFuture<(), tonic::transport::Error>) =
172 match iter.next().ok_or_else(|| eyre!("malformed addr"))? {
173 Protocol::Dns(_) => {
174 let (dns_name, tcp_port, _http_or_https) = parse_dns(addr)?;
175 let (local_addr, incoming) =
176 tcp_listener_and_update_multiaddr(addr, (dns_name.as_ref(), tcp_port))
177 .await?;
178 let server = Box::pin(
179 self.router
180 .serve_with_incoming_shutdown(incoming, rx_cancellation),
181 );
182 (local_addr, server)
183 }
184 Protocol::Ip4(_) => {
185 let (socket_addr, _http_or_https) = parse_ip4(addr)?;
186 let (local_addr, incoming) =
187 tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
188 let server = Box::pin(
189 self.router
190 .serve_with_incoming_shutdown(incoming, rx_cancellation),
191 );
192 (local_addr, server)
193 }
194 Protocol::Ip6(_) => {
195 let (socket_addr, _http_or_https) = parse_ip6(addr)?;
196 let (local_addr, incoming) =
197 tcp_listener_and_update_multiaddr(addr, socket_addr).await?;
198 let server = Box::pin(
199 self.router
200 .serve_with_incoming_shutdown(incoming, rx_cancellation),
201 );
202 (local_addr, server)
203 }
204 unsupported => return Err(eyre!("unsupported protocol {unsupported}")),
205 };
206
207 Ok(Server {
208 server,
209 cancel_handle: Some(tx_cancellation),
210 local_addr,
211 health_reporter: self.health_reporter,
212 })
213 }
214}
215
216async fn tcp_listener_and_update_multiaddr<T: ToSocketAddrs>(
217 address: &Multiaddr,
218 socket_addr: T,
219) -> Result<(Multiaddr, TcpListenerStream)> {
220 let (local_addr, incoming) = tcp_listener(socket_addr).await?;
221 let local_addr = update_tcp_port_in_multiaddr(address, local_addr.port());
222 Ok((local_addr, incoming))
223}
224
225async fn tcp_listener<T: ToSocketAddrs>(address: T) -> Result<(SocketAddr, TcpListenerStream)> {
226 let listener = TcpListener::bind(address).await?;
227 let local_addr = listener.local_addr()?;
228 let incoming = TcpListenerStream::new(listener);
229 Ok((local_addr, incoming))
230}
231
232pub struct Server {
233 server: BoxFuture<(), tonic::transport::Error>,
234 cancel_handle: Option<tokio::sync::oneshot::Sender<()>>,
235 local_addr: Multiaddr,
236 health_reporter: tonic_health::server::HealthReporter,
237}
238
239impl Server {
240 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
241 self.server.await
242 }
243
244 pub fn local_addr(&self) -> &Multiaddr {
245 &self.local_addr
246 }
247
248 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
249 self.health_reporter.clone()
250 }
251
252 pub fn take_cancel_handle(&mut self) -> Option<tokio::sync::oneshot::Sender<()>> {
253 self.cancel_handle.take()
254 }
255}
256
257fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
258 addr.replace(1, |protocol| {
259 if let Protocol::Tcp(_) = protocol {
260 Some(Protocol::Tcp(port))
261 } else {
262 panic!("expected tcp protocol at index 1");
263 }
264 })
265 .expect("tcp protocol at index 1")
266}
267
268#[cfg(test)]
269mod test {
270 use std::{
271 ops::Deref,
272 sync::{Arc, Mutex},
273 time::Duration,
274 };
275
276 use tonic::Code;
277 use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
278
279 use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
280
281 #[test]
282 fn document_multiaddr_limitation_for_unix_protocol() {
283 let path = "/tmp/foo";
285 let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
286
287 let s = addr.to_string();
289 assert!(s.parse::<Multiaddr>().is_err());
290 }
291
292 #[tokio::test]
293 async fn test_metrics_layer_successful() {
294 #[derive(Clone)]
295 struct Metrics {
296 metrics_called: Arc<Mutex<bool>>,
299 }
300
301 impl MetricsCallbackProvider for Metrics {
302 fn on_request(&self, path: String) {
303 assert_eq!(path, "/grpc.health.v1.Health/Check");
304 }
305
306 fn on_response(
307 &self,
308 path: String,
309 _latency: Duration,
310 status: u16,
311 grpc_status_code: Code,
312 ) {
313 assert_eq!(path, "/grpc.health.v1.Health/Check");
314 assert_eq!(status, 200);
315 assert_eq!(grpc_status_code, Code::Ok);
316 let mut m = self.metrics_called.lock().unwrap();
317 *m = true
318 }
319 }
320
321 let metrics = Metrics {
322 metrics_called: Arc::new(Mutex::new(false)),
323 };
324
325 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
326 let config = Config::new();
327
328 let mut server = config
329 .server_builder_with_metrics(metrics.clone())
330 .bind(&address)
331 .await
332 .unwrap();
333
334 let address = server.local_addr().to_owned();
335 let cancel_handle = server.take_cancel_handle().unwrap();
336 let server_handle = tokio::spawn(server.serve());
337 let channel = config.connect(&address).await.unwrap();
338 let mut client = HealthClient::new(channel);
339
340 client
341 .check(HealthCheckRequest {
342 service: "".to_owned(),
343 })
344 .await
345 .unwrap();
346
347 cancel_handle.send(()).unwrap();
348 server_handle.await.unwrap().unwrap();
349
350 assert!(metrics.metrics_called.lock().unwrap().deref());
351 }
352
353 #[tokio::test]
354 async fn test_metrics_layer_error() {
355 #[derive(Clone)]
356 struct Metrics {
357 metrics_called: Arc<Mutex<bool>>,
360 }
361
362 impl MetricsCallbackProvider for Metrics {
363 fn on_request(&self, path: String) {
364 assert_eq!(path, "/grpc.health.v1.Health/Check");
365 }
366
367 fn on_response(
368 &self,
369 path: String,
370 _latency: Duration,
371 status: u16,
372 grpc_status_code: Code,
373 ) {
374 assert_eq!(path, "/grpc.health.v1.Health/Check");
375 assert_eq!(status, 200);
376 assert_eq!(grpc_status_code, Code::NotFound);
379 let mut m = self.metrics_called.lock().unwrap();
380 *m = true
381 }
382 }
383
384 let metrics = Metrics {
385 metrics_called: Arc::new(Mutex::new(false)),
386 };
387
388 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
389 let config = Config::new();
390
391 let mut server = config
392 .server_builder_with_metrics(metrics.clone())
393 .bind(&address)
394 .await
395 .unwrap();
396
397 let address = server.local_addr().to_owned();
398 let cancel_handle = server.take_cancel_handle().unwrap();
399 let server_handle = tokio::spawn(server.serve());
400 let channel = config.connect(&address).await.unwrap();
401 let mut client = HealthClient::new(channel);
402
403 let _ = client
407 .check(HealthCheckRequest {
408 service: "non-existing-service".to_owned(),
409 })
410 .await;
411
412 cancel_handle.send(()).unwrap();
413 server_handle.await.unwrap().unwrap();
414
415 assert!(metrics.metrics_called.lock().unwrap().deref());
416 }
417
418 async fn test_multiaddr(address: Multiaddr) {
419 let config = Config::new();
420 let mut server = config.server_builder().bind(&address).await.unwrap();
421 let address = server.local_addr().to_owned();
422 let cancel_handle = server.take_cancel_handle().unwrap();
423 let server_handle = tokio::spawn(server.serve());
424 let channel = config.connect(&address).await.unwrap();
425 let mut client = HealthClient::new(channel);
426
427 client
428 .check(HealthCheckRequest {
429 service: "".to_owned(),
430 })
431 .await
432 .unwrap();
433
434 cancel_handle.send(()).unwrap();
435 server_handle.await.unwrap().unwrap();
436 }
437
438 #[tokio::test]
439 async fn dns() {
440 let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
441 test_multiaddr(address).await;
442 }
443
444 #[tokio::test]
445 async fn ip4() {
446 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
447 test_multiaddr(address).await;
448 }
449
450 #[tokio::test]
451 async fn ip6() {
452 let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
453 test_multiaddr(address).await;
454 }
455}
456
457#[derive(Clone)]
458struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
459 metrics_provider: M,
460}
461
462impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
463 type Service = RequestLifetime<M, S>;
464
465 fn layer(&self, inner: S) -> Self::Service {
466 RequestLifetime {
467 inner,
468 metrics_provider: self.metrics_provider.clone(),
469 path: None,
470 }
471 }
472}
473
474#[derive(Clone)]
475struct RequestLifetime<M: MetricsCallbackProvider, S> {
476 inner: S,
477 metrics_provider: M,
478 path: Option<String>,
479}
480
481impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
482 for RequestLifetime<M, S>
483where
484 S: Service<Request<RequestBody>>,
485{
486 type Response = S::Response;
487 type Error = S::Error;
488 type Future = S::Future;
489
490 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
491 self.inner.poll_ready(cx)
492 }
493
494 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
495 if self.path.is_none() {
496 let path = request.uri().path().to_string();
497 self.metrics_provider.on_start(&path);
498 self.path = Some(path);
499 }
500 self.inner.call(request)
501 }
502}
503
504impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
505 fn drop(&mut self) {
506 if let Some(path) = &self.path {
507 self.metrics_provider.on_drop(path)
508 }
509 }
510}