1use std::sync::Arc;
6
7use axum::{
8 async_trait,
9 body::{Body, Bytes},
10 extract::{Extension, FromRequest},
11 http::{Request, StatusCode},
12 middleware::Next,
13 response::Response,
14};
15use axum_extra::{
16 headers::{ContentLength, ContentType},
17 typed_header::TypedHeader,
18};
19use bytes::Buf;
20use hyper::header::CONTENT_ENCODING;
21use iota_tls::TlsConnectionInfo;
22use once_cell::sync::Lazy;
23use prometheus::{CounterVec, proto::MetricFamily, register_counter_vec};
24use tracing::error;
25
26use crate::{consumer::ProtobufDecoder, peers::IotaNodeProvider};
27
28static MIDDLEWARE_OPS: Lazy<CounterVec> = Lazy::new(|| {
29 register_counter_vec!(
30 "middleware_operations",
31 "Operations counters and status for axum middleware.",
32 &["operation", "status"]
33 )
34 .unwrap()
35});
36
37static MIDDLEWARE_HEADERS: Lazy<CounterVec> = Lazy::new(|| {
38 register_counter_vec!(
39 "middleware_headers",
40 "Operations counters and status for axum middleware.",
41 &["header", "value"]
42 )
43 .unwrap()
44});
45
46pub async fn expect_content_length(
48 TypedHeader(content_length): TypedHeader<ContentLength>,
49 request: Request<Body>,
50 next: Next,
51) -> Result<Response, (StatusCode, &'static str)> {
52 MIDDLEWARE_HEADERS.with_label_values(&["content-length", &format!("{}", content_length.0)]);
53 Ok(next.run(request).await)
54}
55
56pub async fn expect_iota_proxy_header(
58 TypedHeader(content_type): TypedHeader<ContentType>,
59 request: Request<Body>,
60 next: Next,
61) -> Result<Response, (StatusCode, &'static str)> {
62 match format!("{content_type}").as_str() {
63 prometheus::PROTOBUF_FORMAT => Ok(next.run(request).await),
64 ct => {
65 error!("invalid content-type; {ct}");
66 MIDDLEWARE_OPS
67 .with_label_values(&["expect_iota_proxy_header", "invalid-content-type"])
68 .inc();
69 Err((StatusCode::BAD_REQUEST, "invalid content-type header"))
70 }
71 }
72}
73
74pub async fn expect_valid_public_key(
77 Extension(allower): Extension<Arc<IotaNodeProvider>>,
78 Extension(tls_connect_info): Extension<TlsConnectionInfo>,
79 mut request: Request<Body>,
80 next: Next,
81) -> Result<Response, (StatusCode, &'static str)> {
82 let Some(public_key) = tls_connect_info.public_key() else {
83 error!("unable to obtain public key from connecting client");
84 MIDDLEWARE_OPS
85 .with_label_values(&["expect_valid_public_key", "missing-public-key"])
86 .inc();
87 return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
88 };
89 let Some(peer) = allower.get(public_key) else {
90 error!("node with unknown pub key tried to connect {}", public_key);
91 MIDDLEWARE_OPS
92 .with_label_values(&[
93 "expect_valid_public_key",
94 "unknown-validator-connection-attempt",
95 ])
96 .inc();
97 return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
98 };
99 request.extensions_mut().insert(peer);
100 Ok(next.run(request).await)
101}
102
103#[derive(Debug)]
105pub struct LenDelimProtobuf(pub Vec<MetricFamily>);
106
107#[async_trait]
108impl<S> FromRequest<S> for LenDelimProtobuf
109where
110 S: Send + Sync,
111{
112 type Rejection = (StatusCode, String);
113
114 async fn from_request(
115 req: Request<axum::body::Body>,
116 state: &S,
117 ) -> Result<Self, Self::Rejection> {
118 let should_be_snappy = req
119 .headers()
120 .get(CONTENT_ENCODING)
121 .map(|v| v.as_bytes() == b"snappy")
122 .unwrap_or(false);
123
124 let body = Bytes::from_request(req, state).await.map_err(|e| {
125 let msg = format!("error extracting bytes; {e}");
126 error!(msg);
127 MIDDLEWARE_OPS
128 .with_label_values(&["LenDelimProtobuf_from_request", "unable-to-extract-bytes"])
129 .inc();
130 (e.status(), msg)
131 })?;
132
133 let intermediate = if should_be_snappy {
134 let mut s = snap::raw::Decoder::new();
135 let decompressed = s.decompress_vec(&body).map_err(|e| {
136 let msg = format!("unable to decode snappy encoded protobufs; {e}");
137 error!(msg);
138 MIDDLEWARE_OPS
139 .with_label_values(&[
140 "LenDelimProtobuf_decompress_vec",
141 "unable-to-decode-snappy",
142 ])
143 .inc();
144 (StatusCode::BAD_REQUEST, msg)
145 })?;
146 Bytes::from(decompressed).reader()
147 } else {
148 body.reader()
149 };
150
151 let mut decoder = ProtobufDecoder::new(intermediate);
152 let decoded = decoder.parse::<MetricFamily>().map_err(|e| {
153 let msg = format!("unable to decode len delimited protobufs; {e}");
154 error!(msg);
155 MIDDLEWARE_OPS
156 .with_label_values(&[
157 "LenDelimProtobuf_from_request",
158 "unable-to-decode-protobufs",
159 ])
160 .inc();
161 (StatusCode::BAD_REQUEST, msg)
162 })?;
163 Ok(Self(decoded))
164 }
165}