1use std::{
6 convert::Infallible,
7 task::{Context, Poll},
8};
9
10use eyre::{Result, eyre};
11use tokio_rustls::rustls::ServerConfig;
12use tonic::{
13 body::Body,
14 codegen::http::{HeaderValue, Request, Response},
15 server::NamedService,
16};
17use tower::{Layer, Service, ServiceBuilder};
18use tower_http::{
19 propagate_header::PropagateHeaderLayer, set_header::SetRequestHeaderLayer, trace::TraceLayer,
20};
21
22use crate::{
23 config::Config,
24 metrics::{
25 DefaultMetricsCallbackProvider, GRPC_ENDPOINT_PATH_HEADER, MetricsCallbackProvider,
26 MetricsHandler,
27 },
28 multiaddr::{Multiaddr, Protocol},
29};
30
31pub struct ServerBuilder<M: MetricsCallbackProvider = DefaultMetricsCallbackProvider> {
32 config: Config,
33 metrics_provider: M,
34 router: tonic::service::Routes,
35 health_reporter: tonic_health::server::HealthReporter,
36}
37
38impl<M: MetricsCallbackProvider> ServerBuilder<M> {
39 pub fn from_config(config: &Config, metrics_provider: M) -> Self {
40 let (health_reporter, health_service) = tonic_health::server::health_reporter();
41 let router = tonic::service::Routes::new(health_service);
42
43 Self {
44 config: config.to_owned(),
45 metrics_provider,
46 router,
47 health_reporter,
48 }
49 }
50
51 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
52 self.health_reporter.clone()
53 }
54
55 pub fn add_service<S>(mut self, svc: S) -> Self
57 where
58 S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
59 + NamedService
60 + Clone
61 + Send
62 + Sync
63 + 'static,
64 S::Future: Send + 'static,
65 {
66 self.router = self.router.add_service(svc);
67 self
68 }
69
70 pub async fn bind(self, addr: &Multiaddr, tls_config: Option<ServerConfig>) -> Result<Server> {
71 let http_config = self
72 .config
73 .http_config()
74 .allow_insecure(true);
77
78 let request_timeout = self.config.request_timeout;
79 let metrics_provider = self.metrics_provider;
80 let metrics = MetricsHandler::new(metrics_provider.clone());
81 let request_metrics = TraceLayer::new_for_grpc()
82 .on_request(metrics.clone())
83 .on_response(metrics.clone())
84 .on_failure(metrics);
85
86 fn add_path_to_request_header<T>(request: &Request<T>) -> Option<HeaderValue> {
87 let path = request.uri().path();
88 Some(HeaderValue::from_str(path).unwrap())
89 }
90
91 let limiting_layers = ServiceBuilder::new()
92 .option_layer(
93 self.config
94 .load_shed
95 .unwrap_or_default()
96 .then_some(tower::load_shed::LoadShedLayer::new()),
97 )
98 .option_layer(
99 self.config
100 .global_concurrency_limit
101 .map(tower::limit::GlobalConcurrencyLimitLayer::new),
102 );
103
104 let route_layers = ServiceBuilder::new()
105 .map_request(|mut request: http::Request<_>| {
106 if let Some(connect_info) = request.extensions().get::<iota_http::ConnectInfo>() {
107 let tonic_connect_info = tonic::transport::server::TcpConnectInfo {
108 local_addr: Some(connect_info.local_addr),
109 remote_addr: Some(connect_info.remote_addr),
110 };
111 request.extensions_mut().insert(tonic_connect_info);
112 }
113 request
114 })
115 .layer(RequestLifetimeLayer { metrics_provider })
116 .layer(SetRequestHeaderLayer::overriding(
117 GRPC_ENDPOINT_PATH_HEADER.clone(),
118 add_path_to_request_header,
119 ))
120 .layer(request_metrics)
121 .layer(PropagateHeaderLayer::new(GRPC_ENDPOINT_PATH_HEADER.clone()))
122 .layer_fn(move |service| {
123 crate::grpc_timeout::GrpcTimeout::new(service, request_timeout)
124 });
125
126 let mut builder = iota_http::Builder::new().config(http_config);
127
128 let has_tls = tls_config.is_some();
129 if let Some(tls_config) = tls_config {
130 builder = builder.tls_config(tls_config);
131 }
132
133 let server_handle = builder
134 .serve(
135 addr,
136 limiting_layers.service(self.router.into_axum_router().layer(route_layers)),
137 )
138 .map_err(|e| eyre!(e))?;
139
140 let mut local_addr = update_tcp_port_in_multiaddr(addr, server_handle.local_addr().port());
141 if has_tls {
142 local_addr = local_addr.rewrite_http_to_https();
143 }
144 Ok(Server {
145 server_handle,
146 local_addr,
147 health_reporter: self.health_reporter,
148 })
149 }
150}
151
152pub const IOTA_TLS_SERVER_NAME: &str = "iota";
154
155pub struct Server {
156 server_handle: iota_http::ServerHandle,
157 local_addr: Multiaddr,
158 health_reporter: tonic_health::server::HealthReporter,
159}
160
161impl Server {
162 pub async fn serve(self) -> Result<(), tonic::transport::Error> {
163 self.server_handle.wait_for_shutdown().await;
164 Ok(())
165 }
166
167 pub fn trigger_shutdown(&self) {
168 self.server_handle.trigger_shutdown();
169 }
170
171 pub fn local_addr(&self) -> &Multiaddr {
172 &self.local_addr
173 }
174
175 pub fn health_reporter(&self) -> tonic_health::server::HealthReporter {
176 self.health_reporter.clone()
177 }
178
179 pub fn handle(&self) -> &iota_http::ServerHandle {
180 &self.server_handle
181 }
182}
183
184fn update_tcp_port_in_multiaddr(addr: &Multiaddr, port: u16) -> Multiaddr {
185 addr.replace(1, |protocol| {
186 if let Protocol::Tcp(_) = protocol {
187 Some(Protocol::Tcp(port))
188 } else {
189 panic!("expected tcp protocol at index 1");
190 }
191 })
192 .expect("tcp protocol at index 1")
193}
194
195#[cfg(test)]
196mod test {
197 use std::{
198 ops::Deref,
199 sync::{Arc, Mutex},
200 time::Duration,
201 };
202
203 use tonic::Code;
204 use tonic_health::pb::{HealthCheckRequest, health_client::HealthClient};
205
206 use crate::{Multiaddr, config::Config, metrics::MetricsCallbackProvider};
207
208 #[test]
209 fn document_multiaddr_limitation_for_unix_protocol() {
210 let path = "/tmp/foo";
212 let addr = Multiaddr::new_internal(multiaddr::multiaddr!(Unix(path), Http));
213
214 let s = addr.to_string();
216 assert!(s.parse::<Multiaddr>().is_err());
217 }
218
219 #[tokio::test]
220 async fn test_metrics_layer_successful() {
221 #[derive(Clone)]
222 struct Metrics {
223 metrics_called: Arc<Mutex<bool>>,
226 }
227
228 impl MetricsCallbackProvider for Metrics {
229 fn on_request(&self, path: String) {
230 assert_eq!(path, "/grpc.health.v1.Health/Check");
231 }
232
233 fn on_response(
234 &self,
235 path: String,
236 _latency: Duration,
237 status: u16,
238 grpc_status_code: Code,
239 ) {
240 assert_eq!(path, "/grpc.health.v1.Health/Check");
241 assert_eq!(status, 200);
242 assert_eq!(grpc_status_code, Code::Ok);
243 let mut m = self.metrics_called.lock().unwrap();
244 *m = true
245 }
246 }
247
248 let metrics = Metrics {
249 metrics_called: Arc::new(Mutex::new(false)),
250 };
251
252 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
253 let config = Config::new();
254
255 let server = config
256 .server_builder_with_metrics(metrics.clone())
257 .bind(&address, None)
258 .await
259 .unwrap();
260
261 let address = server.local_addr().to_owned();
262 let channel = config.connect(&address, None).await.unwrap();
263 let mut client = HealthClient::new(channel);
264
265 client
266 .check(HealthCheckRequest {
267 service: "".to_owned(),
268 })
269 .await
270 .unwrap();
271
272 server.server_handle.shutdown().await;
273
274 assert!(metrics.metrics_called.lock().unwrap().deref());
275 }
276
277 #[tokio::test]
278 async fn test_metrics_layer_error() {
279 #[derive(Clone)]
280 struct Metrics {
281 metrics_called: Arc<Mutex<bool>>,
284 }
285
286 impl MetricsCallbackProvider for Metrics {
287 fn on_request(&self, path: String) {
288 assert_eq!(path, "/grpc.health.v1.Health/Check");
289 }
290
291 fn on_response(
292 &self,
293 path: String,
294 _latency: Duration,
295 status: u16,
296 grpc_status_code: Code,
297 ) {
298 assert_eq!(path, "/grpc.health.v1.Health/Check");
299 assert_eq!(status, 200);
300 assert_eq!(grpc_status_code, Code::NotFound);
303 let mut m = self.metrics_called.lock().unwrap();
304 *m = true
305 }
306 }
307
308 let metrics = Metrics {
309 metrics_called: Arc::new(Mutex::new(false)),
310 };
311
312 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
313 let config = Config::new();
314
315 let server = config
316 .server_builder_with_metrics(metrics.clone())
317 .bind(&address, None)
318 .await
319 .unwrap();
320
321 let address = server.local_addr().to_owned();
322 let channel = config.connect(&address, None).await.unwrap();
323 let mut client = HealthClient::new(channel);
324
325 let _ = client
329 .check(HealthCheckRequest {
330 service: "non-existing-service".to_owned(),
331 })
332 .await;
333
334 server.server_handle.shutdown().await;
335
336 assert!(metrics.metrics_called.lock().unwrap().deref());
337 }
338
339 async fn test_multiaddr(address: Multiaddr) {
340 let config = Config::new();
341 let server_handle = config.server_builder().bind(&address, None).await.unwrap();
342 let address = server_handle.local_addr().to_owned();
343 let channel = config.connect(&address, None).await.unwrap();
344 let mut client = HealthClient::new(channel);
345
346 client
347 .check(HealthCheckRequest {
348 service: "".to_owned(),
349 })
350 .await
351 .unwrap();
352
353 server_handle.server_handle.shutdown().await;
354 }
355
356 #[tokio::test]
357 async fn dns() {
358 let address: Multiaddr = "/dns/localhost/tcp/0/http".parse().unwrap();
359 test_multiaddr(address).await;
360 }
361
362 #[tokio::test]
363 async fn ip4() {
364 let address: Multiaddr = "/ip4/127.0.0.1/tcp/0/http".parse().unwrap();
365 test_multiaddr(address).await;
366 }
367
368 #[tokio::test]
369 async fn ip6() {
370 let address: Multiaddr = "/ip6/::1/tcp/0/http".parse().unwrap();
371 test_multiaddr(address).await;
372 }
373}
374
375#[derive(Clone)]
376struct RequestLifetimeLayer<M: MetricsCallbackProvider> {
377 metrics_provider: M,
378}
379
380impl<M: MetricsCallbackProvider, S> Layer<S> for RequestLifetimeLayer<M> {
381 type Service = RequestLifetime<M, S>;
382
383 fn layer(&self, inner: S) -> Self::Service {
384 RequestLifetime {
385 inner,
386 metrics_provider: self.metrics_provider.clone(),
387 path: None,
388 }
389 }
390}
391
392#[derive(Clone)]
393struct RequestLifetime<M: MetricsCallbackProvider, S> {
394 inner: S,
395 metrics_provider: M,
396 path: Option<String>,
397}
398
399impl<M: MetricsCallbackProvider, S, RequestBody> Service<Request<RequestBody>>
400 for RequestLifetime<M, S>
401where
402 S: Service<Request<RequestBody>>,
403{
404 type Response = S::Response;
405 type Error = S::Error;
406 type Future = S::Future;
407
408 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
409 self.inner.poll_ready(cx)
410 }
411
412 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
413 if self.path.is_none() {
414 let path = request.uri().path().to_string();
415 self.metrics_provider.on_start(&path);
416 self.path = Some(path);
417 }
418 self.inner.call(request)
419 }
420}
421
422impl<M: MetricsCallbackProvider, S> Drop for RequestLifetime<M, S> {
423 fn drop(&mut self) {
424 if let Some(path) = &self.path {
425 self.metrics_provider.on_drop(path)
426 }
427 }
428}