1use std::{
6 convert::Infallible,
7 task::{Context, Poll},
8};
9
10use eyre::{Result, eyre};
11use tokio_rustls::rustls::ServerConfig;
12use tonic::{
13 body::Body,
14 codegen::http::{HeaderValue, Request, Response},
15 server::NamedService,
16};
17use tower::{Layer, Service, ServiceBuilder};
18use tower_http::{
19 propagate_header::PropagateHeaderLayer, set_header::SetRequestHeaderLayer, trace::TraceLayer,
20};
21
22use crate::{
23 config::Config,
24 metrics::{
25 DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
26 MetricsHandler,
27 },
28 multiaddr::{Multiaddr, Protocol},
29};
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32 config: Config,
33 metrics_provider: M,
34 router: tonic::service::Routes,
35 health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40 let (health_reporter, health_service) = tonic_health::server::health_reporter();
41 let router = tonic::service::Routes::new(health_service);
42
43 Self {
44 config: config.to_owned(),
45 metrics_provider,
46 router,
47 health_reporter,
48 }
49 }
50
51 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52 self.health_reporter.clone()
53 }
54
55 pub fn add_service<S>(mut self, svc: S) -> Self
57 where
58 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59 + NamedService
60 + Clone
61 + Send
62 + Sync
63 + 'static,
64 S::Future: Send + 'static,
65 {
66 self.router = self.router.add_service(svc);
67 self
68 }
69
70 pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71 let http_config = self
72 .config
73 .http_config()
74 .allow_insecure(true);
77
78 let request_timeout = self.config.request_timeout;
79 let metrics_provider = self.metrics_provider;
80 let metrics = MetricsHandler::new(metrics_provider.clone());
81 let request_metrics = TraceLayer::new_for_grpc()
82 .on_request(metrics.clone())
83 .on_response(metrics.clone())
84 .on_failure(metrics);
85
86 fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
87 let path = request.uri().path();
88 Some(HeaderValue::from_str(path).unwrap())
89 }
90
91 let limiting_layers = ServiceBuilder::new()
92 .option_layer(
93 self.config
94 .load_shed
95 .unwrap_or_default()
96 .then_some(tower::load_shed::LoadShedLayer::new()),
97 )
98 .option_layer(
99 self.config
100 .global_concurrency_limit
101 .map(tower::limit::GlobalConcurrencyLimitLayer::new),
102 );
103
104 let route_layers = ServiceBuilder::new()
105 .map_request(|mut request: http::Request<_>| {
106 if let Some(connect_info) = request.extensions().get::<iota_http::ConnectInfo>() {
107 let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
108 local_addr: Some(connect_info.local_addr),
109 remote_addr: Some(connect_info.remote_addr),
110 };
111 request.extensions_mut().insert(tonic_connect_info);
112 }
113 request
114 })
115 .layer(RequestLifetimeLayer { metrics_provider })
116 .layer(SetRequestHeaderLayer::overriding(
117 GRPC_ENDPOINT_PATH_HEADER.clone(),
118 add_path_to_request_header,
119 ))
120 .layer(request_metrics)
121 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
122 .layer_fn(move |service| {
123 crate::grpc_timeout::GrpcTimeout::new(service, request_timeout)
124 });
125
126 let mut builder = iota_http::Builder::new().config(http_config);
127
128 if let Some(tls_config) = tls_config {
129 builder = builder.tls_config(tls_config);
130 }
131
132 let server_handle = builder
133 .serve(
134 addr,
135 limiting_layers.service(self.router.into_axum_router().layer(route_layers)),
136 )
137 .map_err(|e| eyre!(e))?;
138
139 let local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
140 Ok(Server {
141 server_handle,
142 local_addr,
143 health_reporter: self.health_reporter,
144 })
145 }
146}
147
148pub const IOTA_TLS_SERVER_NAME: &str = "iota";
150
151pub struct Server {
152 server_handle: iota_http::ServerHandle,
153 local_addr: Multiaddr,
154 health_reporter: tonic_health::server::HealthReporter,
155}
156
157impl Server {
158 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
159 self.server_handle.wait_for_shutdown().await;
160 Ok(())
161 }
162
163 pub fn trigger_shutdown(&self) {
164 self.server_handle.trigger_shutdown();
165 }
166
167 pub fn local_addr(&self) -> &Multiaddr {
168 &self.local_addr
169 }
170
171 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
172 self.health_reporter.clone()
173 }
174
175 pub fn handle(&self) -> &iota_http::ServerHandle {
176 &self.server_handle
177 }
178}
179
180fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
181 addr.replace(1, |protocol| {
182 if let Protocol::Tcp(_) = protocol {
183 Some(Protocol::Tcp(port))
184 } else {
185 panic!("expected tcp protocol at index 1");
186 }
187 })
188 .expect("tcp protocol at index 1")
189}
190
191#[cfg(test)]
192mod test {
193 use std::{
194 ops::Deref,
195 sync::{Arc, Mutex},
196 time::Duration,
197 };
198
199 use tonic::Code;
200 use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
201
202 use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
203
204 #[test]
205 fn document_multiaddr_limitation_for_unix_protocol() {
206 let path = "/tmp/foo";
208 let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
209
210 let s = addr.to_string();
212 assert!(s.parse::<Multiaddr>().is_err());
213 }
214
215 #[tokio::test]
216 async fn test_metrics_layer_successful() {
217 #[derive(Clone)]
218 struct Metrics {
219 metrics_called: Arc<Mutex<bool>>,
222 }
223
224 impl MetricsCallbackProvider for Metrics {
225 fn on_request(&self, path: String) {
226 assert_eq!(path, "/grpc.health.v1.Health/Check");
227 }
228
229 fn on_response(
230 &self,
231 path: String,
232 _latency: Duration,
233 status: u16,
234 grpc_status_code: Code,
235 ) {
236 assert_eq!(path, "/grpc.health.v1.Health/Check");
237 assert_eq!(status, 200);
238 assert_eq!(grpc_status_code, Code::Ok);
239 let mut m = self.metrics_called.lock().unwrap();
240 *m = true
241 }
242 }
243
244 let metrics = Metrics {
245 metrics_called: Arc::new(Mutex::new(false)),
246 };
247
248 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
249 let config = Config::new();
250
251 let server = config
252 .server_builder_with_metrics(metrics.clone())
253 .bind(&address, None)
254 .await
255 .unwrap();
256
257 let address = server.local_addr().to_owned();
258 let channel = config.connect(&address, None).await.unwrap();
259 let mut client = HealthClient::new(channel);
260
261 client
262 .check(HealthCheckRequest {
263 service: "".to_owned(),
264 })
265 .await
266 .unwrap();
267
268 server.server_handle.shutdown().await;
269
270 assert!(metrics.metrics_called.lock().unwrap().deref());
271 }
272
273 #[tokio::test]
274 async fn test_metrics_layer_error() {
275 #[derive(Clone)]
276 struct Metrics {
277 metrics_called: Arc<Mutex<bool>>,
280 }
281
282 impl MetricsCallbackProvider for Metrics {
283 fn on_request(&self, path: String) {
284 assert_eq!(path, "/grpc.health.v1.Health/Check");
285 }
286
287 fn on_response(
288 &self,
289 path: String,
290 _latency: Duration,
291 status: u16,
292 grpc_status_code: Code,
293 ) {
294 assert_eq!(path, "/grpc.health.v1.Health/Check");
295 assert_eq!(status, 200);
296 assert_eq!(grpc_status_code, Code::NotFound);
299 let mut m = self.metrics_called.lock().unwrap();
300 *m = true
301 }
302 }
303
304 let metrics = Metrics {
305 metrics_called: Arc::new(Mutex::new(false)),
306 };
307
308 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
309 let config = Config::new();
310
311 let server = config
312 .server_builder_with_metrics(metrics.clone())
313 .bind(&address, None)
314 .await
315 .unwrap();
316
317 let address = server.local_addr().to_owned();
318 let channel = config.connect(&address, None).await.unwrap();
319 let mut client = HealthClient::new(channel);
320
321 let _ = client
325 .check(HealthCheckRequest {
326 service: "non-existing-service".to_owned(),
327 })
328 .await;
329
330 server.server_handle.shutdown().await;
331
332 assert!(metrics.metrics_called.lock().unwrap().deref());
333 }
334
335 async fn test_multiaddr(address: Multiaddr) {
336 let config = Config::new();
337 let server_handle = config.server_builder().bind(&address, None).await.unwrap();
338 let address = server_handle.local_addr().to_owned();
339 let channel = config.connect(&address, None).await.unwrap();
340 let mut client = HealthClient::new(channel);
341
342 client
343 .check(HealthCheckRequest {
344 service: "".to_owned(),
345 })
346 .await
347 .unwrap();
348
349 server_handle.server_handle.shutdown().await;
350 }
351
352 #[tokio::test]
353 async fn dns() {
354 let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
355 test_multiaddr(address).await;
356 }
357
358 #[tokio::test]
359 async fn ip4() {
360 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
361 test_multiaddr(address).await;
362 }
363
364 #[tokio::test]
365 async fn ip6() {
366 let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
367 test_multiaddr(address).await;
368 }
369}
370
371#[derive(Clone)]
372struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
373 metrics_provider: M,
374}
375
376impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
377 type Service = RequestLifetime<M, S>;
378
379 fn layer(&self, inner: S) -> Self::Service {
380 RequestLifetime {
381 inner,
382 metrics_provider: self.metrics_provider.clone(),
383 path: None,
384 }
385 }
386}
387
388#[derive(Clone)]
389struct RequestLifetime<M: MetricsCallbackProvider, S> {
390 inner: S,
391 metrics_provider: M,
392 path: Option<String>,
393}
394
395impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
396 for RequestLifetime<M, S>
397where
398 S: Service<Request<RequestBody>>,
399{
400 type Response = S::Response;
401 type Error = S::Error;
402 type Future = S::Future;
403
404 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405 self.inner.poll_ready(cx)
406 }
407
408 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
409 if self.path.is_none() {
410 let path = request.uri().path().to_string();
411 self.metrics_provider.on_start(&path);
412 self.path = Some(path);
413 }
414 self.inner.call(request)
415 }
416}
417
418impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
419 fn drop(&mut self) {
420 if let Some(path) = &self.path {
421 self.metrics_provider.on_drop(path)
422 }
423 }
424}