iota_rest_api/
accept.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use axum::http::{self, HeaderMap, header};
6use mime::Mime;
7
8// TODO look into utilizing the following way to signal the expected types since
9// bcs doesn't include type information
10// "application/x.iota.<type>+bcs"
11pub const APPLICATION_BCS: &str = "application/bcs";
12
13/// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2)
14#[derive(Debug, Clone)]
15pub struct Accept(pub Vec<Mime>);
16
17fn parse_accept(headers: &HeaderMap) -> Vec<Mime> {
18    let mut items = headers
19        .get_all(header::ACCEPT)
20        .iter()
21        .filter_map(|hval| hval.to_str().ok())
22        .flat_map(|s| s.split(',').map(str::trim))
23        .filter_map(|item| {
24            let mime: Mime = item.parse().ok()?;
25            let q = mime
26                .get_param("q")
27                .and_then(|value| Some((value.as_str().parse::<f32>().ok()? * 1000.0) as i32))
28                .unwrap_or(1000);
29            Some((mime, q))
30        })
31        .collect::<Vec<_>>();
32    items.sort_by(|(_, qa), (_, qb)| qb.cmp(qa));
33    items.into_iter().map(|(mime, _)| mime).collect()
34}
35
36impl<S> axum::extract::FromRequestParts<S> for Accept
37where
38    S: Send + Sync,
39{
40    type Rejection = std::convert::Infallible;
41
42    async fn from_request_parts(
43        parts: &mut http::request::Parts,
44        _: &S,
45    ) -> Result<Self, Self::Rejection> {
46        Ok(Self(parse_accept(&parts.headers)))
47    }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
51pub enum AcceptFormat {
52    Json,
53    Bcs,
54}
55
56impl<S> axum::extract::FromRequestParts<S> for AcceptFormat
57where
58    S: Send + Sync,
59{
60    type Rejection = std::convert::Infallible;
61
62    async fn from_request_parts(
63        parts: &mut http::request::Parts,
64        s: &S,
65    ) -> Result<Self, Self::Rejection> {
66        let accept = Accept::from_request_parts(parts, s).await?;
67
68        for mime in accept.0 {
69            let essence = mime.essence_str();
70
71            if essence == mime::APPLICATION_JSON.essence_str() {
72                return Ok(Self::Json);
73            } else if essence == APPLICATION_BCS {
74                return Ok(Self::Bcs);
75            }
76        }
77
78        Ok(Self::Json)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use std::str::FromStr;
85
86    use axum::{extract::FromRequest, http::Request};
87    use http::header;
88
89    use super::*;
90
91    #[tokio::test]
92    async fn test_accept() {
93        let req = Request::builder()
94            .header(
95                header::ACCEPT,
96                "text/html, text/yaml;q=0.5, application/xhtml+xml, application/xml;q=0.9, */*;q=0.1",
97            )
98            .body(axum::body::Body::empty())
99            .unwrap();
100        let accept = Accept::from_request(req, &()).await.unwrap();
101        assert_eq!(
102            accept.0,
103            &[
104                Mime::from_str("text/html").unwrap(),
105                Mime::from_str("application/xhtml+xml").unwrap(),
106                Mime::from_str("application/xml;q=0.9").unwrap(),
107                Mime::from_str("text/yaml;q=0.5").unwrap(),
108                Mime::from_str("*/*;q=0.1").unwrap()
109            ]
110        );
111    }
112
113    #[tokio::test]
114    async fn test_accept_format() {
115        let req = Request::builder()
116            .header(header::ACCEPT, "*/*, application/bcs")
117            .body(axum::body::Body::empty())
118            .unwrap();
119        let accept = AcceptFormat::from_request(req, &()).await.unwrap();
120        assert_eq!(accept, AcceptFormat::Bcs);
121
122        let req = Request::builder()
123            .header(header::ACCEPT, "*/*")
124            .body(axum::body::Body::empty())
125            .unwrap();
126        let accept = AcceptFormat::from_request(req, &()).await.unwrap();
127        assert_eq!(accept, AcceptFormat::Json);
128    }
129}