iota_proxy/
middleware.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::sync::Arc;
6
7use axum::{
8    async_trait,
9    body::{Body, Bytes},
10    extract::{Extension, FromRequest},
11    http::{Request, StatusCode},
12    middleware::Next,
13    response::Response,
14};
15use axum_extra::{
16    headers::{ContentLength, ContentType},
17    typed_header::TypedHeader,
18};
19use bytes::Buf;
20use hyper::header::CONTENT_ENCODING;
21use iota_tls::TlsConnectionInfo;
22use once_cell::sync::Lazy;
23use prometheus::{CounterVec, proto::MetricFamily, register_counter_vec};
24use tracing::error;
25
26use crate::{consumer::ProtobufDecoder, peers::IotaNodeProvider};
27
28static MIDDLEWARE_OPS: Lazy<CounterVec> = Lazy::new(|| {
29    register_counter_vec!(
30        "middleware_operations",
31        "Operations counters and status for axum middleware.",
32        &["operation", "status"]
33    )
34    .unwrap()
35});
36
37static MIDDLEWARE_HEADERS: Lazy<CounterVec> = Lazy::new(|| {
38    register_counter_vec!(
39        "middleware_headers",
40        "Operations counters and status for axum middleware.",
41        &["header", "value"]
42    )
43    .unwrap()
44});
45
46/// we expect iota-node to send us an http header content-length encoding.
47pub async fn expect_content_length(
48    TypedHeader(content_length): TypedHeader<ContentLength>,
49    request: Request<Body>,
50    next: Next,
51) -> Result<Response, (StatusCode, &'static str)> {
52    MIDDLEWARE_HEADERS.with_label_values(&["content-length", &format!("{}", content_length.0)]);
53    Ok(next.run(request).await)
54}
55
56/// we expect iota-node to send us an http header content-type encoding.
57pub async fn expect_iota_proxy_header(
58    TypedHeader(content_type): TypedHeader<ContentType>,
59    request: Request<Body>,
60    next: Next,
61) -> Result<Response, (StatusCode, &'static str)> {
62    match format!("{content_type}").as_str() {
63        prometheus::PROTOBUF_FORMAT => Ok(next.run(request).await),
64        ct => {
65            error!("invalid content-type; {ct}");
66            MIDDLEWARE_OPS
67                .with_label_values(&["expect_iota_proxy_header", "invalid-content-type"])
68                .inc();
69            Err((StatusCode::BAD_REQUEST, "invalid content-type header"))
70        }
71    }
72}
73
74/// we expect that calling iota-nodes are known on the blockchain and we enforce
75/// their pub key tls creds here
76pub async fn expect_valid_public_key(
77    Extension(allower): Extension<Arc<IotaNodeProvider>>,
78    Extension(tls_connect_info): Extension<TlsConnectionInfo>,
79    mut request: Request<Body>,
80    next: Next,
81) -> Result<Response, (StatusCode, &'static str)> {
82    let Some(public_key) = tls_connect_info.public_key() else {
83        error!("unable to obtain public key from connecting client");
84        MIDDLEWARE_OPS
85            .with_label_values(&["expect_valid_public_key", "missing-public-key"])
86            .inc();
87        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
88    };
89    let Some(peer) = allower.get(public_key) else {
90        error!("node with unknown pub key tried to connect {}", public_key);
91        MIDDLEWARE_OPS
92            .with_label_values(&[
93                "expect_valid_public_key",
94                "unknown-validator-connection-attempt",
95            ])
96            .inc();
97        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
98    };
99    request.extensions_mut().insert(peer);
100    Ok(next.run(request).await)
101}
102
103// extractor that shows how to consume the request body upfront
104#[derive(Debug)]
105pub struct LenDelimProtobuf(pub Vec<MetricFamily>);
106
107#[async_trait]
108impl<S> FromRequest<S> for LenDelimProtobuf
109where
110    S: Send + Sync,
111{
112    type Rejection = (StatusCode, String);
113
114    async fn from_request(
115        req: Request<axum::body::Body>,
116        state: &S,
117    ) -> Result<Self, Self::Rejection> {
118        let should_be_snappy = req
119            .headers()
120            .get(CONTENT_ENCODING)
121            .map(|v| v.as_bytes() == b"snappy")
122            .unwrap_or(false);
123
124        let body = Bytes::from_request(req, state).await.map_err(|e| {
125            let msg = format!("error extracting bytes; {e}");
126            error!(msg);
127            MIDDLEWARE_OPS
128                .with_label_values(&["LenDelimProtobuf_from_request", "unable-to-extract-bytes"])
129                .inc();
130            (e.status(), msg)
131        })?;
132
133        let intermediate = if should_be_snappy {
134            let mut s = snap::raw::Decoder::new();
135            let decompressed = s.decompress_vec(&body).map_err(|e| {
136                let msg = format!("unable to decode snappy encoded protobufs; {e}");
137                error!(msg);
138                MIDDLEWARE_OPS
139                    .with_label_values(&[
140                        "LenDelimProtobuf_decompress_vec",
141                        "unable-to-decode-snappy",
142                    ])
143                    .inc();
144                (StatusCode::BAD_REQUEST, msg)
145            })?;
146            Bytes::from(decompressed).reader()
147        } else {
148            body.reader()
149        };
150
151        let mut decoder = ProtobufDecoder::new(intermediate);
152        let decoded = decoder.parse::<MetricFamily>().map_err(|e| {
153            let msg = format!("unable to decode len delimited protobufs; {e}");
154            error!(msg);
155            MIDDLEWARE_OPS
156                .with_label_values(&[
157                    "LenDelimProtobuf_from_request",
158                    "unable-to-decode-protobufs",
159                ])
160                .inc();
161            (StatusCode::BAD_REQUEST, msg)
162        })?;
163        Ok(Self(decoded))
164    }
165}