iota_network_stack/
grpc_timeout.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2025 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4//
5// Ported from `tonic` crate
6// SPDX-License-Identifier: MIT
7
8use std::{
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll, ready},
12    time::Duration,
13};
14
15use http::{HeaderMap, HeaderValue, Request, Response};
16use pin_project_lite::pin_project;
17use tokio::time::Sleep;
18use tonic::Status;
19use tower::Service;
20
21const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";
22
23#[derive(Debug, Clone)]
24pub struct GrpcTimeout<S> {
25    inner: S,
26    server_timeout: Option<Duration>,
27}
28
29impl<S> GrpcTimeout<S> {
30    pub fn new(inner: S, server_timeout: Option<Duration>) -> Self {
31        Self {
32            inner,
33            server_timeout,
34        }
35    }
36}
37
38impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for GrpcTimeout<S>
39where
40    S: Service<Request<RequestBody>, Response = Response<ResponseBody>>,
41{
42    type Response = Response<MaybeEmptyBody<ResponseBody>>;
43    type Error = S::Error;
44    type Future = ResponseFuture<S::Future>;
45
46    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
47        self.inner.poll_ready(cx).map_err(Into::into)
48    }
49
50    fn call(&mut self, req: Request<RequestBody>) -> Self::Future {
51        let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
52            tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
53            None
54        });
55
56        // Use the shorter of the two durations, if either are set
57        let timeout_duration = match (client_timeout, self.server_timeout) {
58            (None, None) => None,
59            (Some(dur), None) => Some(dur),
60            (None, Some(dur)) => Some(dur),
61            (Some(header), Some(server)) => {
62                let shorter_duration = std::cmp::min(header, server);
63                Some(shorter_duration)
64            }
65        };
66
67        ResponseFuture {
68            inner: self.inner.call(req),
69            sleep: timeout_duration.map(tokio::time::sleep),
70        }
71    }
72}
73
74pin_project! {
75    pub struct ResponseFuture<F> {
76        #[pin]
77        inner: F,
78        #[pin]
79        sleep: Option<Sleep>,
80    }
81}
82
83impl<F, ResponseBody, E> Future for ResponseFuture<F>
84where
85    F: Future<Output = Result<Response<ResponseBody>, E>>,
86{
87    type Output = Result<Response<MaybeEmptyBody<ResponseBody>>, E>;
88
89    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
90        let this = self.project();
91
92        if let Poll::Ready(result) = this.inner.poll(cx) {
93            return Poll::Ready(result.map(|response| response.map(MaybeEmptyBody::full)));
94        }
95
96        if let Some(sleep) = this.sleep.as_pin_mut() {
97            ready!(sleep.poll(cx));
98            let response = Status::deadline_exceeded("Timeout expired")
99                .into_http()
100                .map(|_| MaybeEmptyBody::empty());
101            return Poll::Ready(Ok(response));
102        }
103
104        Poll::Pending
105    }
106}
107
108pin_project! {
109    pub struct MaybeEmptyBody<B> {
110        #[pin]
111        inner: Option<B>,
112    }
113}
114
115impl<B> MaybeEmptyBody<B> {
116    fn full(inner: B) -> Self {
117        Self { inner: Some(inner) }
118    }
119
120    fn empty() -> Self {
121        Self { inner: None }
122    }
123}
124
125impl<B> http_body::Body for MaybeEmptyBody<B>
126where
127    B: http_body::Body + Send,
128{
129    type Data = B::Data;
130    type Error = B::Error;
131
132    fn poll_frame(
133        self: Pin<&mut Self>,
134        cx: &mut Context<'_>,
135    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
136        match self.project().inner.as_pin_mut() {
137            Some(b) => b.poll_frame(cx),
138            None => Poll::Ready(None),
139        }
140    }
141
142    fn is_end_stream(&self) -> bool {
143        match &self.inner {
144            Some(b) => b.is_end_stream(),
145            None => true,
146        }
147    }
148
149    fn size_hint(&self) -> http_body::SizeHint {
150        match &self.inner {
151            Some(body) => body.size_hint(),
152            None => http_body::SizeHint::with_exact(0),
153        }
154    }
155}
156
157const SECONDS_IN_HOUR: u64 = 60 * 60;
158const SECONDS_IN_MINUTE: u64 = 60;
159
160/// Tries to parse the `grpc-timeout` header if it is present. If we fail to
161/// parse, returns the value we attempted to parse.
162///
163/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
164fn try_parse_grpc_timeout(
165    headers: &HeaderMap<HeaderValue>,
166) -> Result<Option<Duration>, &HeaderValue> {
167    let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
168        return Ok(None);
169    };
170
171    let (timeout_value, timeout_unit) = val
172        .to_str()
173        .map_err(|_| val)
174        .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
175        // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
176        // `split_at` will never panic from trying to split in the middle of a character.
177        // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str
178        //
179        // `len - 1` also wont panic since we just checked `s.is_empty`.
180        .split_at(val.len() - 1);
181
182    // gRPC spec specifies `TimeoutValue` will be at most 8 digits
183    // Caping this at 8 digits also prevents integer overflow from ever occurring
184    if timeout_value.len() > 8 {
185        return Err(val);
186    }
187
188    let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
189
190    let duration = match timeout_unit {
191        // Hours
192        "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
193        // Minutes
194        "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
195        // Seconds
196        "S" => Duration::from_secs(timeout_value),
197        // Milliseconds
198        "m" => Duration::from_millis(timeout_value),
199        // Microseconds
200        "u" => Duration::from_micros(timeout_value),
201        // Nanoseconds
202        "n" => Duration::from_nanos(timeout_value),
203        _ => return Err(val),
204    };
205
206    Ok(Some(duration))
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    // Helper function to reduce the boiler plate of our test cases
214    fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
215        let mut hm = HeaderMap::new();
216        if let Some(v) = val {
217            let hv = HeaderValue::from_str(v).unwrap();
218            hm.insert(GRPC_TIMEOUT_HEADER, hv);
219        };
220
221        try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
222    }
223
224    #[test]
225    fn test_hours() {
226        let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
227        assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
228    }
229
230    #[test]
231    fn test_minutes() {
232        let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
233        assert_eq!(Duration::from_secs(60), parsed_duration);
234    }
235
236    #[test]
237    fn test_seconds() {
238        let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
239        assert_eq!(Duration::from_secs(42), parsed_duration);
240    }
241
242    #[test]
243    fn test_milliseconds() {
244        let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
245        assert_eq!(Duration::from_millis(13), parsed_duration);
246    }
247
248    #[test]
249    fn test_microseconds() {
250        let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
251        assert_eq!(Duration::from_micros(2), parsed_duration);
252    }
253
254    #[test]
255    fn test_nanoseconds() {
256        let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
257        assert_eq!(Duration::from_nanos(82), parsed_duration);
258    }
259
260    #[test]
261    fn test_header_not_present() {
262        let parsed_duration = setup_map_try_parse(None).unwrap();
263        assert!(parsed_duration.is_none());
264    }
265
266    #[test]
267    #[should_panic(expected = "82f")]
268    fn test_invalid_unit() {
269        // "f" is not a valid TimeoutUnit
270        setup_map_try_parse(Some("82f")).unwrap().unwrap();
271    }
272
273    #[test]
274    #[should_panic(expected = "123456789H")]
275    fn test_too_many_digits() {
276        // gRPC spec states TimeoutValue will be at most 8 digits
277        setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
278    }
279
280    #[test]
281    #[should_panic(expected = "oneH")]
282    fn test_invalid_digits() {
283        // gRPC spec states TimeoutValue will be at most 8 digits
284        setup_map_try_parse(Some("oneH")).unwrap().unwrap();
285    }
286}