1use axum::{
6 extract::{Path, Request, State},
7 http::{HeaderName, HeaderValue, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11use axum_extra::headers;
12
13use crate::{
14 config::Version,
15 error::{code, graphql_error_response},
16};
17
18pub(crate) static VERSION_HEADER: HeaderName = HeaderName::from_static("x-iota-rpc-version");
19
20#[expect(unused)]
21pub(crate) struct IotaRpcVersion(Vec<u8>, Vec<Vec<u8>>);
22const NAMED_VERSIONS: [&str; 3] = ["beta", "legacy", "stable"];
23
24impl headers::Header for IotaRpcVersion {
25 fn name() -> &'static HeaderName {
26 &VERSION_HEADER
27 }
28
29 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
30 where
31 I: Iterator<Item = &'i HeaderValue>,
32 {
33 let mut values = values.map(|v| v.as_bytes().to_owned());
34 let Some(value) = values.next() else {
35 return Err(headers::Error::invalid());
37 };
38
39 Ok(IotaRpcVersion(value, values.collect()))
43 }
44
45 fn encode<E: Extend<HeaderValue>>(&self, _values: &mut E) {
46 unimplemented!()
47 }
48}
49
50pub(crate) async fn check_version_middleware(
56 version: Option<Path<String>>,
57 State(service_version): State<Version>,
58 request: Request,
59 next: Next,
60) -> Response {
61 let Some(Path(version)) = version else {
62 return next.run(request).await;
63 };
64
65 if NAMED_VERSIONS.contains(&version.as_str()) || version.is_empty() {
66 return next.run(request).await;
67 }
68 let Some((year, month)) = parse_version(&version) else {
69 return (
70 StatusCode::BAD_REQUEST,
71 graphql_error_response(
72 code::BAD_REQUEST,
73 format!(
74 "Failed to parse version path: {version}. Expected either a `beta | legacy | stable` \
75 version or <YEAR>.<MONTH> version.",
76 ),
77 ),
78 )
79 .into_response();
80 };
81
82 if year != service_version.year || month != service_version.month {
83 return (
84 StatusCode::MISDIRECTED_REQUEST,
85 graphql_error_response(
86 code::INTERNAL_SERVER_ERROR,
87 format!("Version '{version}' not supported."),
88 ),
89 )
90 .into_response();
91 }
92 next.run(request).await
93}
94
95pub(crate) async fn set_version_middleware(
98 State(version): State<Version>,
99 request: Request,
100 next: Next,
101) -> Response {
102 let mut response = next.run(request).await;
103 let headers = response.headers_mut();
104 headers.insert(
105 VERSION_HEADER.clone(),
106 HeaderValue::from_static(version.full),
107 );
108 response
109}
110
111fn parse_version(version: &str) -> Option<(&str, &str)> {
116 let mut parts = version.split('.');
117 let year = parts.next()?;
118 let month = parts.next()?;
119
120 if year.is_empty() || month.is_empty() {
121 return None;
122 }
123
124 (parts.next().is_none()
125 && year.chars().all(|c| c.is_ascii_digit())
126 && month.chars().all(|c| c.is_ascii_digit()))
127 .then_some((year, month))
128}
129
130#[cfg(test)]
131mod tests {
132 use std::net::SocketAddr;
133
134 use axum::{Router, body::Body, middleware, routing::get};
135 use expect_test::expect;
136 use http_body_util::BodyExt;
137 use iota_metrics;
138 use tokio_util::sync::CancellationToken;
139 use tower::ServiceExt;
140
141 use super::*;
142 use crate::{
143 config::{ConnectionConfig, ServiceConfig, Version},
144 metrics::Metrics,
145 server::builder::AppState,
146 };
147
148 fn metrics() -> Metrics {
149 let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
150 let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
151 Metrics::new(®istry)
152 }
153 fn service() -> Router {
154 let version = Version::for_testing();
155 let metrics = metrics();
156 let cancellation_token = CancellationToken::new();
157 let connection_config = ConnectionConfig::default();
158 let service_config = ServiceConfig::default();
159 let state = AppState::new(
160 connection_config.clone(),
161 service_config.clone(),
162 metrics.clone(),
163 cancellation_token.clone(),
164 version,
165 );
166
167 Router::new()
168 .route("/", get(|| async { "Hello, Versioning!" }))
169 .route("/:version", get(|| async { "Hello, Versioning!" }))
170 .route("/graphql", get(|| async { "Hello, Versioning!" }))
171 .route("/graphql/:version", get(|| async { "Hello, Versioning!" }))
172 .layer(middleware::from_fn_with_state(
173 state.version,
174 check_version_middleware,
175 ))
176 .layer(middleware::from_fn_with_state(
177 state.version,
178 set_version_middleware,
179 ))
180 }
181
182 fn graphql_request() -> Request<Body> {
183 Request::builder()
184 .uri("/graphql")
185 .body(Body::empty())
186 .unwrap()
187 }
188
189 fn plain_request() -> Request<Body> {
190 Request::builder().uri("/").body(Body::empty()).unwrap()
191 }
192
193 fn version_request(version: &str) -> Request<Body> {
194 if version.is_empty() {
195 return plain_request();
196 }
197 Request::builder()
198 .uri(format!("/graphql/{}", version))
199 .body(Body::empty())
200 .unwrap()
201 }
202
203 async fn response_body(response: Response) -> String {
204 let bytes = response.into_body().collect().await.unwrap();
205 let value: serde_json::Value = serde_json::from_slice(bytes.to_bytes().as_ref()).unwrap();
206 serde_json::to_string_pretty(&value).unwrap()
207 }
208
209 #[tokio::test]
210 async fn successful() {
211 let version = Version::for_testing();
212 let major_version = format!("{}.{}", version.year, version.month);
213 let service = service();
214 let response = service
215 .oneshot(version_request(&major_version))
216 .await
217 .unwrap();
218 assert_eq!(response.status(), StatusCode::OK);
219 assert_eq!(
220 response.headers().get(&VERSION_HEADER),
221 Some(&HeaderValue::from_static(version.full))
222 );
223 }
224
225 #[tokio::test]
226 async fn default_graphql_route() {
227 let version = Version::for_testing();
228 let service = service();
229 let response = service.oneshot(graphql_request()).await.unwrap();
230 assert_eq!(response.status(), StatusCode::OK);
231 assert_eq!(
232 response.headers().get(&VERSION_HEADER),
233 Some(&HeaderValue::from_static(version.full))
234 );
235 }
236
237 #[tokio::test]
238 async fn named_version() {
239 let version = Version::for_testing();
240 let service = service();
241 for named_version in NAMED_VERSIONS {
242 let response = service
243 .clone()
244 .oneshot(version_request(named_version))
245 .await
246 .unwrap();
247 assert_eq!(response.status(), StatusCode::OK);
248 assert_eq!(
249 response.headers().get(&VERSION_HEADER),
250 Some(&HeaderValue::from_static(version.full))
251 );
252 }
253 }
254
255 #[tokio::test]
256 async fn default_version() {
257 let version = Version::for_testing();
258 let service = service();
259 let response = service.oneshot(plain_request()).await.unwrap();
260 assert_eq!(response.status(), StatusCode::OK);
261 assert_eq!(
262 response.headers().get(&VERSION_HEADER),
263 Some(&HeaderValue::from_static(version.full))
264 );
265 }
266
267 #[tokio::test]
268 async fn wrong_path() {
269 let version = Version::for_testing();
270 let service = service();
271 let response = service.oneshot(version_request("")).await.unwrap();
272 assert_eq!(response.status(), StatusCode::OK);
273 assert_eq!(
274 response.headers().get(&VERSION_HEADER),
275 Some(&HeaderValue::from_static(version.full))
276 );
277 }
278
279 #[tokio::test]
280 async fn incompatible_version() {
281 let version = Version::for_testing();
282 let service = service();
283 let response = service.oneshot(version_request("0.0")).await.unwrap();
284
285 assert_eq!(response.status(), StatusCode::MISDIRECTED_REQUEST);
286 assert_eq!(
287 response.headers().get(&VERSION_HEADER),
288 Some(&HeaderValue::from_static(version.full))
289 );
290
291 let expect = expect![[r#"
292 {
293 "data": null,
294 "errors": [
295 {
296 "message": "Version '0.0' not supported.",
297 "extensions": {
298 "code": "INTERNAL_SERVER_ERROR"
299 }
300 }
301 ]
302 }"#]];
303 expect.assert_eq(&response_body(response).await);
304 }
305
306 #[tokio::test]
307 async fn not_a_version() {
308 let version = Version::for_testing();
309 let service = service();
310 let response = service
311 .oneshot(version_request("not-a-version"))
312 .await
313 .unwrap();
314 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
315 assert_eq!(
316 response.headers().get(&VERSION_HEADER),
317 Some(&HeaderValue::from_static(version.full))
318 );
319
320 let expect = expect![[r#"
321 {
322 "data": null,
323 "errors": [
324 {
325 "message": "Failed to parse version path: not-a-version. Expected either a `beta | legacy | stable` version or <YEAR>.<MONTH> version.",
326 "extensions": {
327 "code": "BAD_REQUEST"
328 }
329 }
330 ]
331 }"#]];
332 expect.assert_eq(&response_body(response).await);
333 }
334}