iota_graphql_rpc/server/
version.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use axum::{
6    extract::{Path, Request, State},
7    http::{HeaderName, HeaderValue, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11use axum_extra::headers;
12
13use crate::{
14    config::Version,
15    error::{code, graphql_error_response},
16};
17
18pub(crate) static VERSION_HEADER: HeaderName = HeaderName::from_static("x-iota-rpc-version");
19
20#[expect(unused)]
21pub(crate) struct IotaRpcVersion(Vec<u8>, Vec<Vec<u8>>);
22const NAMED_VERSIONS: [&str; 3] = ["beta", "legacy", "stable"];
23
24impl headers::Header for IotaRpcVersion {
25    fn name() -> &'static HeaderName {
26        &VERSION_HEADER
27    }
28
29    fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
30    where
31        I: Iterator<Item = &'i HeaderValue>,
32    {
33        let mut values = values.map(|v| v.as_bytes().to_owned());
34        let Some(value) = values.next() else {
35            // No values for this header -- it doesn't exist.
36            return Err(headers::Error::invalid());
37        };
38
39        // Extract the header values as bytes.  Distinguish the first value as we expect
40        // there to be just one under normal operation.  Do not attempt to parse
41        // the value, as a header parsing failure produces a generic error.
42        Ok(IotaRpcVersion(value, values.collect()))
43    }
44
45    fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
46        unimplemented!()
47    }
48}
49
50/// Middleware to check for the existence of a version constraint in the request
51/// header, and confirm that this instance of the RPC matches that version
52/// constraint.  Each RPC instance only supports one version of the RPC
53/// software, and it is the responsibility of the load balancer to make sure
54/// version constraints are met.
55pub(crate) async fn check_version_middleware(
56    version: Option<Path<String>>,
57    State(service_version): State<Version>,
58    request: Request,
59    next: Next,
60) -> Response {
61    let Some(Path(version)) = version else {
62        return next.run(request).await;
63    };
64
65    if NAMED_VERSIONS.contains(&version.as_str()) || version.is_empty() {
66        return next.run(request).await;
67    }
68    let Some((year, month)) = parse_version(&version) else {
69        return (
70                StatusCode::BAD_REQUEST,
71                graphql_error_response(
72                    code::BAD_REQUEST,
73                    format!(
74                        "Failed to parse version path: {version}. Expected either a `beta | legacy | stable` \
75                    version or <YEAR>.<MONTH> version.",
76                    ),
77                ),
78            )
79                .into_response();
80    };
81
82    if year != service_version.year || month != service_version.month {
83        return (
84            StatusCode::MISDIRECTED_REQUEST,
85            graphql_error_response(
86                code::INTERNAL_SERVER_ERROR,
87                format!("Version '{version}' not supported."),
88            ),
89        )
90            .into_response();
91    }
92    next.run(request).await
93}
94
95/// Mark every outgoing response with a header indicating the precise version of
96/// the RPC that was used (including the patch version and sha).
97pub(crate) async fn set_version_middleware(
98    State(version): State<Version>,
99    request: Request,
100    next: Next,
101) -> Response {
102    let mut response = next.run(request).await;
103    let headers = response.headers_mut();
104    headers.insert(
105        VERSION_HEADER.clone(),
106        HeaderValue::from_static(version.full),
107    );
108    response
109}
110
111/// Split a `version` string into two parts (year and month) separated by a ".".
112///
113/// Confirms that the version specifier contains exactly two components, and
114/// that both components are entirely comprised of digits.
115fn parse_version(version: &str) -> Option<(&str, &str)> {
116    let mut parts = version.split('.');
117    let year = parts.next()?;
118    let month = parts.next()?;
119
120    if year.is_empty() || month.is_empty() {
121        return None;
122    }
123
124    (parts.next().is_none()
125        && year.chars().all(|c| c.is_ascii_digit())
126        && month.chars().all(|c| c.is_ascii_digit()))
127    .then_some((year, month))
128}
129
130#[cfg(test)]
131mod tests {
132    use std::net::SocketAddr;
133
134    use axum::{Router, body::Body, middleware, routing::get};
135    use expect_test::expect;
136    use http_body_util::BodyExt;
137    use iota_metrics;
138    use tokio_util::sync::CancellationToken;
139    use tower::ServiceExt;
140
141    use super::*;
142    use crate::{
143        config::{ConnectionConfig, ServiceConfig, Version},
144        metrics::Metrics,
145        server::builder::AppState,
146    };
147
148    fn metrics() -> Metrics {
149        let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
150        let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
151        Metrics::new(&registry)
152    }
153    fn service() -> Router {
154        let version = Version::for_testing();
155        let metrics = metrics();
156        let cancellation_token = CancellationToken::new();
157        let connection_config = ConnectionConfig::default();
158        let service_config = ServiceConfig::default();
159        let state = AppState::new(
160            connection_config.clone(),
161            service_config.clone(),
162            metrics.clone(),
163            cancellation_token.clone(),
164            version,
165        );
166
167        Router::new()
168            .route("/", get(|| async { "Hello, Versioning!" }))
169            .route("/:version", get(|| async { "Hello, Versioning!" }))
170            .route("/graphql", get(|| async { "Hello, Versioning!" }))
171            .route("/graphql/:version", get(|| async { "Hello, Versioning!" }))
172            .layer(middleware::from_fn_with_state(
173                state.version,
174                check_version_middleware,
175            ))
176            .layer(middleware::from_fn_with_state(
177                state.version,
178                set_version_middleware,
179            ))
180    }
181
182    fn graphql_request() -> Request<Body> {
183        Request::builder()
184            .uri("/graphql")
185            .body(Body::empty())
186            .unwrap()
187    }
188
189    fn plain_request() -> Request<Body> {
190        Request::builder().uri("/").body(Body::empty()).unwrap()
191    }
192
193    fn version_request(version: &str) -> Request<Body> {
194        if version.is_empty() {
195            return plain_request();
196        }
197        Request::builder()
198            .uri(format!("/graphql/{}", version))
199            .body(Body::empty())
200            .unwrap()
201    }
202
203    async fn response_body(response: Response) -> String {
204        let bytes = response.into_body().collect().await.unwrap();
205        let value: serde_json::Value = serde_json::from_slice(bytes.to_bytes().as_ref()).unwrap();
206        serde_json::to_string_pretty(&value).unwrap()
207    }
208
209    #[tokio::test]
210    async fn successful() {
211        let version = Version::for_testing();
212        let major_version = format!("{}.{}", version.year, version.month);
213        let service = service();
214        let response = service
215            .oneshot(version_request(&major_version))
216            .await
217            .unwrap();
218        assert_eq!(response.status(), StatusCode::OK);
219        assert_eq!(
220            response.headers().get(&VERSION_HEADER),
221            Some(&HeaderValue::from_static(version.full))
222        );
223    }
224
225    #[tokio::test]
226    async fn default_graphql_route() {
227        let version = Version::for_testing();
228        let service = service();
229        let response = service.oneshot(graphql_request()).await.unwrap();
230        assert_eq!(response.status(), StatusCode::OK);
231        assert_eq!(
232            response.headers().get(&VERSION_HEADER),
233            Some(&HeaderValue::from_static(version.full))
234        );
235    }
236
237    #[tokio::test]
238    async fn named_version() {
239        let version = Version::for_testing();
240        let service = service();
241        for named_version in NAMED_VERSIONS {
242            let response = service
243                .clone()
244                .oneshot(version_request(named_version))
245                .await
246                .unwrap();
247            assert_eq!(response.status(), StatusCode::OK);
248            assert_eq!(
249                response.headers().get(&VERSION_HEADER),
250                Some(&HeaderValue::from_static(version.full))
251            );
252        }
253    }
254
255    #[tokio::test]
256    async fn default_version() {
257        let version = Version::for_testing();
258        let service = service();
259        let response = service.oneshot(plain_request()).await.unwrap();
260        assert_eq!(response.status(), StatusCode::OK);
261        assert_eq!(
262            response.headers().get(&VERSION_HEADER),
263            Some(&HeaderValue::from_static(version.full))
264        );
265    }
266
267    #[tokio::test]
268    async fn wrong_path() {
269        let version = Version::for_testing();
270        let service = service();
271        let response = service.oneshot(version_request("")).await.unwrap();
272        assert_eq!(response.status(), StatusCode::OK);
273        assert_eq!(
274            response.headers().get(&VERSION_HEADER),
275            Some(&HeaderValue::from_static(version.full))
276        );
277    }
278
279    #[tokio::test]
280    async fn incompatible_version() {
281        let version = Version::for_testing();
282        let service = service();
283        let response = service.oneshot(version_request("0.0")).await.unwrap();
284
285        assert_eq!(response.status(), StatusCode::MISDIRECTED_REQUEST);
286        assert_eq!(
287            response.headers().get(&VERSION_HEADER),
288            Some(&HeaderValue::from_static(version.full))
289        );
290
291        let expect = expect![[r#"
292            {
293              "data": null,
294              "errors": [
295                {
296                  "message": "Version '0.0' not supported.",
297                  "extensions": {
298                    "code": "INTERNAL_SERVER_ERROR"
299                  }
300                }
301              ]
302            }"#]];
303        expect.assert_eq(&response_body(response).await);
304    }
305
306    #[tokio::test]
307    async fn not_a_version() {
308        let version = Version::for_testing();
309        let service = service();
310        let response = service
311            .oneshot(version_request("not-a-version"))
312            .await
313            .unwrap();
314        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
315        assert_eq!(
316            response.headers().get(&VERSION_HEADER),
317            Some(&HeaderValue::from_static(version.full))
318        );
319
320        let expect = expect![[r#"
321            {
322              "data": null,
323              "errors": [
324                {
325                  "message": "Failed to parse version path: not-a-version. Expected either a `beta | legacy | stable` version or <YEAR>.<MONTH> version.",
326                  "extensions": {
327                    "code": "BAD_REQUEST"
328                  }
329                }
330              ]
331            }"#]];
332        expect.assert_eq(&response_body(response).await);
333    }
334}