iota_network_stack/
grpc_timeout.rs1use std::{
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll, ready},
12 time::Duration,
13};
14
15use http::{HeaderMap, HeaderValue, Request, Response};
16use pin_project_lite::pin_project;
17use tokio::time::Sleep;
18use tonic::Status;
19use tower::Service;
20
21const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";
22
23#[derive(Debug, Clone)]
24pub struct GrpcTimeout<S> {
25 inner: S,
26 server_timeout: Option<Duration>,
27}
28
29impl<S> GrpcTimeout<S> {
30 pub fn new(inner: S, server_timeout: Option<Duration>) -> Self {
31 Self {
32 inner,
33 server_timeout,
34 }
35 }
36}
37
38impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for GrpcTimeout<S>
39where
40 S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
41{
42 type Response = Response<MaybeEmptyBody<ResponseBody>>;
43 type Error = S::Error;
44 type Future = ResponseFuture<S::Future>;
45
46 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
47 self.inner.poll_ready(cx).map_err(Into::into)
48 }
49
50 fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
51 let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
52 tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
53 None
54 });
55
56 let timeout_duration = match (client_timeout, self.server_timeout) {
58 (None, None) => None,
59 (Some(dur), None) => Some(dur),
60 (None, Some(dur)) => Some(dur),
61 (Some(header), Some(server)) => {
62 let shorter_duration = std::cmp::min(header, server);
63 Some(shorter_duration)
64 }
65 };
66
67 ResponseFuture {
68 inner: self.inner.call(req),
69 sleep: timeout_duration.map(tokio::time::sleep),
70 }
71 }
72}
73
74pin_project! {
75 pub struct ResponseFuture<F> {
76 #[pin]
77 inner: F,
78 #[pin]
79 sleep: Option<Sleep>,
80 }
81}
82
83impl<F, ResponseBody, E> Future for ResponseFuture<F>
84where
85 F: Future<Output = Result<Response<ResponseBody>, E>>,
86{
87 type Output = Result<Response<MaybeEmptyBody<ResponseBody>>, E>;
88
89 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
90 let this = self.project();
91
92 if let Poll::Ready(result) = this.inner.poll(cx) {
93 return Poll::Ready(result.map(|response| response.map(MaybeEmptyBody::full)));
94 }
95
96 if let Some(sleep) = this.sleep.as_pin_mut() {
97 ready!(sleep.poll(cx));
98 let response = Status::deadline_exceeded("Timeout expired")
99 .into_http()
100 .map(|_| MaybeEmptyBody::empty());
101 return Poll::Ready(Ok(response));
102 }
103
104 Poll::Pending
105 }
106}
107
108pin_project! {
109 pub struct MaybeEmptyBody<B> {
110 #[pin]
111 inner: Option<B>,
112 }
113}
114
115impl<B> MaybeEmptyBody<B> {
116 fn full(inner: B) -> Self {
117 Self { inner: Some(inner) }
118 }
119
120 fn empty() -> Self {
121 Self { inner: None }
122 }
123}
124
125impl<B> http_body::Body for MaybeEmptyBody<B>
126where
127 B: http_body::Body + Send,
128{
129 type Data = B::Data;
130 type Error = B::Error;
131
132 fn poll_frame(
133 self: Pin<&mut Self>,
134 cx: &mut Context<'_>,
135 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
136 match self.project().inner.as_pin_mut() {
137 Some(b) => b.poll_frame(cx),
138 None => Poll::Ready(None),
139 }
140 }
141
142 fn is_end_stream(&self) -> bool {
143 match &self.inner {
144 Some(b) => b.is_end_stream(),
145 None => true,
146 }
147 }
148
149 fn size_hint(&self) -> http_body::SizeHint {
150 match &self.inner {
151 Some(body) => body.size_hint(),
152 None => http_body::SizeHint::with_exact(0),
153 }
154 }
155}
156
157const SECONDS_IN_HOUR: u64 = 60 * 60;
158const SECONDS_IN_MINUTE: u64 = 60;
159
160fn try_parse_grpc_timeout(
165 headers: &HeaderMap<HeaderValue>,
166) -> Result<Option<Duration>, &HeaderValue> {
167 let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
168 return Ok(None);
169 };
170
171 let (timeout_value, timeout_unit) = val
172 .to_str()
173 .map_err(|_| val)
174 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
175 .split_at(val.len() - 1);
181
182 if timeout_value.len() > 8 {
185 return Err(val);
186 }
187
188 let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
189
190 let duration = match timeout_unit {
191 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
193 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
195 "S" => Duration::from_secs(timeout_value),
197 "m" => Duration::from_millis(timeout_value),
199 "u" => Duration::from_micros(timeout_value),
201 "n" => Duration::from_nanos(timeout_value),
203 _ => return Err(val),
204 };
205
206 Ok(Some(duration))
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
215 let mut hm = HeaderMap::new();
216 if let Some(v) = val {
217 let hv = HeaderValue::from_str(v).unwrap();
218 hm.insert(GRPC_TIMEOUT_HEADER, hv);
219 };
220
221 try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
222 }
223
224 #[test]
225 fn test_hours() {
226 let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
227 assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
228 }
229
230 #[test]
231 fn test_minutes() {
232 let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
233 assert_eq!(Duration::from_secs(60), parsed_duration);
234 }
235
236 #[test]
237 fn test_seconds() {
238 let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
239 assert_eq!(Duration::from_secs(42), parsed_duration);
240 }
241
242 #[test]
243 fn test_milliseconds() {
244 let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
245 assert_eq!(Duration::from_millis(13), parsed_duration);
246 }
247
248 #[test]
249 fn test_microseconds() {
250 let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
251 assert_eq!(Duration::from_micros(2), parsed_duration);
252 }
253
254 #[test]
255 fn test_nanoseconds() {
256 let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
257 assert_eq!(Duration::from_nanos(82), parsed_duration);
258 }
259
260 #[test]
261 fn test_header_not_present() {
262 let parsed_duration = setup_map_try_parse(None).unwrap();
263 assert!(parsed_duration.is_none());
264 }
265
266 #[test]
267 #[should_panic(expected = "82f")]
268 fn test_invalid_unit() {
269 setup_map_try_parse(Some("82f")).unwrap().unwrap();
271 }
272
273 #[test]
274 #[should_panic(expected = "123456789H")]
275 fn test_too_many_digits() {
276 setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
278 }
279
280 #[test]
281 #[should_panic(expected = "oneH")]
282 fn test_invalid_digits() {
283 setup_map_try_parse(Some("oneH")).unwrap().unwrap();
285 }
286}