iota_proxy/
middleware.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::sync::Arc;
6
7use axum::{
8    body::{Body, Bytes},
9    extract::{Extension, FromRequest},
10    http::{Request, StatusCode},
11    middleware::Next,
12    response::Response,
13};
14use axum_extra::{
15    headers::{ContentLength, ContentType},
16    typed_header::TypedHeader,
17};
18use bytes::Buf;
19use hyper::header::CONTENT_ENCODING;
20use iota_tls::TlsConnectionInfo;
21use once_cell::sync::Lazy;
22use prometheus::{CounterVec, proto::MetricFamily, register_counter_vec};
23use tracing::error;
24
25use crate::{consumer::ProtobufDecoder, peers::IotaNodeProvider};
26
27static MIDDLEWARE_OPS: Lazy<CounterVec> = Lazy::new(|| {
28    register_counter_vec!(
29        "middleware_operations",
30        "Operations counters and status for axum middleware.",
31        &["operation", "status"]
32    )
33    .unwrap()
34});
35
36static MIDDLEWARE_HEADERS: Lazy<CounterVec> = Lazy::new(|| {
37    register_counter_vec!(
38        "middleware_headers",
39        "Operations counters and status for axum middleware.",
40        &["header", "value"]
41    )
42    .unwrap()
43});
44
45/// we expect iota-node to send us an http header content-length encoding.
46pub async fn expect_content_length(
47    TypedHeader(content_length): TypedHeader<ContentLength>,
48    request: Request<Body>,
49    next: Next,
50) -> Result<Response, (StatusCode, &'static str)> {
51    MIDDLEWARE_HEADERS.with_label_values(&["content-length", &format!("{}", content_length.0)]);
52    Ok(next.run(request).await)
53}
54
55/// we expect iota-node to send us an http header content-type encoding.
56pub async fn expect_iota_proxy_header(
57    TypedHeader(content_type): TypedHeader<ContentType>,
58    request: Request<Body>,
59    next: Next,
60) -> Result<Response, (StatusCode, &'static str)> {
61    match format!("{content_type}").as_str() {
62        prometheus::PROTOBUF_FORMAT => Ok(next.run(request).await),
63        ct => {
64            error!("invalid content-type; {ct}");
65            MIDDLEWARE_OPS
66                .with_label_values(&["expect_iota_proxy_header", "invalid-content-type"])
67                .inc();
68            Err((StatusCode::BAD_REQUEST, "invalid content-type header"))
69        }
70    }
71}
72
73/// we expect that calling iota-nodes are known on the blockchain and we enforce
74/// their pub key tls creds here
75pub async fn expect_valid_public_key(
76    Extension(allower): Extension<Arc<IotaNodeProvider>>,
77    Extension(tls_connect_info): Extension<TlsConnectionInfo>,
78    mut request: Request<Body>,
79    next: Next,
80) -> Result<Response, (StatusCode, &'static str)> {
81    let Some(public_key) = tls_connect_info.public_key() else {
82        error!("unable to obtain public key from connecting client");
83        MIDDLEWARE_OPS
84            .with_label_values(&["expect_valid_public_key", "missing-public-key"])
85            .inc();
86        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
87    };
88    let Some(peer) = allower.get(public_key) else {
89        error!("node with unknown pub key tried to connect {}", public_key);
90        MIDDLEWARE_OPS
91            .with_label_values(&[
92                "expect_valid_public_key",
93                "unknown-validator-connection-attempt",
94            ])
95            .inc();
96        return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
97    };
98    request.extensions_mut().insert(peer);
99    Ok(next.run(request).await)
100}
101
102// extractor that shows how to consume the request body upfront
103#[derive(Debug)]
104pub struct LenDelimProtobuf(pub Vec<MetricFamily>);
105
106impl<S> FromRequest<S> for LenDelimProtobuf
107where
108    S: Send + Sync,
109{
110    type Rejection = (StatusCode, String);
111
112    async fn from_request(
113        req: Request<axum::body::Body>,
114        state: &S,
115    ) -> Result<Self, Self::Rejection> {
116        let should_be_snappy = req
117            .headers()
118            .get(CONTENT_ENCODING)
119            .map(|v| v.as_bytes() == b"snappy")
120            .unwrap_or(false);
121
122        let body = Bytes::from_request(req, state).await.map_err(|e| {
123            let msg = format!("error extracting bytes; {e}");
124            error!(msg);
125            MIDDLEWARE_OPS
126                .with_label_values(&["LenDelimProtobuf_from_request", "unable-to-extract-bytes"])
127                .inc();
128            (e.status(), msg)
129        })?;
130
131        let intermediate = if should_be_snappy {
132            let mut s = snap::raw::Decoder::new();
133            let decompressed = s.decompress_vec(&body).map_err(|e| {
134                let msg = format!("unable to decode snappy encoded protobufs; {e}");
135                error!(msg);
136                MIDDLEWARE_OPS
137                    .with_label_values(&[
138                        "LenDelimProtobuf_decompress_vec",
139                        "unable-to-decode-snappy",
140                    ])
141                    .inc();
142                (StatusCode::BAD_REQUEST, msg)
143            })?;
144            Bytes::from(decompressed).reader()
145        } else {
146            body.reader()
147        };
148
149        let mut decoder = ProtobufDecoder::new(intermediate);
150        let decoded = decoder.parse::<MetricFamily>().map_err(|e| {
151            let msg = format!("unable to decode len delimited protobufs; {e}");
152            error!(msg);
153            MIDDLEWARE_OPS
154                .with_label_values(&[
155                    "LenDelimProtobuf_from_request",
156                    "unable-to-decode-protobufs",
157                ])
158                .inc();
159            (StatusCode::BAD_REQUEST, msg)
160        })?;
161        Ok(Self(decoded))
162    }
163}