iota_rest_api/
accept.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use axum::http::{self, HeaderMap, header};
6use mime::Mime;
7
8// TODO look into utilizing the following way to signal the expected types since
9// bcs doesn't include type information
10// "application/x.iota.<type>+bcs"
11pub const APPLICATION_BCS: &str = "application/bcs";
12
13/// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2)
14#[derive(Debug, Clone)]
15pub struct Accept(pub Vec<Mime>);
16
17fn parse_accept(headers: &HeaderMap) -> Vec<Mime> {
18    let mut items = headers
19        .get_all(header::ACCEPT)
20        .iter()
21        .filter_map(|hval| hval.to_str().ok())
22        .flat_map(|s| s.split(',').map(str::trim))
23        .filter_map(|item| {
24            let mime: Mime = item.parse().ok()?;
25            let q = mime
26                .get_param("q")
27                .and_then(|value| Some((value.as_str().parse::<f32>().ok()? * 1000.0) as i32))
28                .unwrap_or(1000);
29            Some((mime, q))
30        })
31        .collect::<Vec<_>>();
32    items.sort_by(|(_, qa), (_, qb)| qb.cmp(qa));
33    items.into_iter().map(|(mime, _)| mime).collect()
34}
35
36#[axum::async_trait]
37impl<S> axum::extract::FromRequestParts<S> for Accept
38where
39    S: Send + Sync,
40{
41    type Rejection = std::convert::Infallible;
42
43    async fn from_request_parts(
44        parts: &mut http::request::Parts,
45        _: &S,
46    ) -> Result<Self, Self::Rejection> {
47        Ok(Self(parse_accept(&parts.headers)))
48    }
49}
50
51#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
52pub enum AcceptFormat {
53    Json,
54    Bcs,
55}
56
57#[axum::async_trait]
58impl<S> axum::extract::FromRequestParts<S> for AcceptFormat
59where
60    S: Send + Sync,
61{
62    type Rejection = std::convert::Infallible;
63
64    async fn from_request_parts(
65        parts: &mut http::request::Parts,
66        s: &S,
67    ) -> Result<Self, Self::Rejection> {
68        let accept = Accept::from_request_parts(parts, s).await?;
69
70        for mime in accept.0 {
71            let essence = mime.essence_str();
72
73            if essence == mime::APPLICATION_JSON.essence_str() {
74                return Ok(Self::Json);
75            } else if essence == APPLICATION_BCS {
76                return Ok(Self::Bcs);
77            }
78        }
79
80        Ok(Self::Json)
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use std::str::FromStr;
87
88    use axum::{extract::FromRequest, http::Request};
89    use http::header;
90
91    use super::*;
92
93    #[tokio::test]
94    async fn test_accept() {
95        let req = Request::builder()
96            .header(
97                header::ACCEPT,
98                "text/html, text/yaml;q=0.5, application/xhtml+xml, application/xml;q=0.9, */*;q=0.1",
99            )
100            .body(axum::body::Body::empty())
101            .unwrap();
102        let accept = Accept::from_request(req, &()).await.unwrap();
103        assert_eq!(
104            accept.0,
105            &[
106                Mime::from_str("text/html").unwrap(),
107                Mime::from_str("application/xhtml+xml").unwrap(),
108                Mime::from_str("application/xml;q=0.9").unwrap(),
109                Mime::from_str("text/yaml;q=0.5").unwrap(),
110                Mime::from_str("*/*;q=0.1").unwrap()
111            ]
112        );
113    }
114
115    #[tokio::test]
116    async fn test_accept_format() {
117        let req = Request::builder()
118            .header(header::ACCEPT, "*/*, application/bcs")
119            .body(axum::body::Body::empty())
120            .unwrap();
121        let accept = AcceptFormat::from_request(req, &()).await.unwrap();
122        assert_eq!(accept, AcceptFormat::Bcs);
123
124        let req = Request::builder()
125            .header(header::ACCEPT, "*/*")
126            .body(axum::body::Body::empty())
127            .unwrap();
128        let accept = AcceptFormat::from_request(req, &()).await.unwrap();
129        assert_eq!(accept, AcceptFormat::Json);
130    }
131}