1use std::sync::Arc;
6
7use axum::{
8 body::{Body, Bytes},
9 extract::{Extension, FromRequest},
10 http::{Request, StatusCode},
11 middleware::Next,
12 response::Response,
13};
14use axum_extra::{
15 headers::{ContentLength, ContentType},
16 typed_header::TypedHeader,
17};
18use bytes::Buf;
19use hyper::header::CONTENT_ENCODING;
20use iota_tls::TlsConnectionInfo;
21use once_cell::sync::Lazy;
22use prometheus::{CounterVec, proto::MetricFamily, register_counter_vec};
23use tracing::error;
24
25use crate::{consumer::ProtobufDecoder, peers::IotaNodeProvider};
26
27static MIDDLEWARE_OPS: Lazy<CounterVec> = Lazy::new(|| {
28 register_counter_vec!(
29 "middleware_operations",
30 "Operations counters and status for axum middleware.",
31 &["operation", "status"]
32 )
33 .unwrap()
34});
35
36static MIDDLEWARE_HEADERS: Lazy<CounterVec> = Lazy::new(|| {
37 register_counter_vec!(
38 "middleware_headers",
39 "Operations counters and status for axum middleware.",
40 &["header", "value"]
41 )
42 .unwrap()
43});
44
45pub async fn expect_content_length(
47 TypedHeader(content_length): TypedHeader<ContentLength>,
48 request: Request<Body>,
49 next: Next,
50) -> Result<Response, (StatusCode, &'static str)> {
51 MIDDLEWARE_HEADERS.with_label_values(&["content-length", &format!("{}", content_length.0)]);
52 Ok(next.run(request).await)
53}
54
55pub async fn expect_iota_proxy_header(
57 TypedHeader(content_type): TypedHeader<ContentType>,
58 request: Request<Body>,
59 next: Next,
60) -> Result<Response, (StatusCode, &'static str)> {
61 match format!("{content_type}").as_str() {
62 prometheus::PROTOBUF_FORMAT => Ok(next.run(request).await),
63 ct => {
64 error!("invalid content-type; {ct}");
65 MIDDLEWARE_OPS
66 .with_label_values(&["expect_iota_proxy_header", "invalid-content-type"])
67 .inc();
68 Err((StatusCode::BAD_REQUEST, "invalid content-type header"))
69 }
70 }
71}
72
73pub async fn expect_valid_public_key(
76 Extension(allower): Extension<Arc<IotaNodeProvider>>,
77 Extension(tls_connect_info): Extension<TlsConnectionInfo>,
78 mut request: Request<Body>,
79 next: Next,
80) -> Result<Response, (StatusCode, &'static str)> {
81 let Some(public_key) = tls_connect_info.public_key() else {
82 error!("unable to obtain public key from connecting client");
83 MIDDLEWARE_OPS
84 .with_label_values(&["expect_valid_public_key", "missing-public-key"])
85 .inc();
86 return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
87 };
88 let Some(peer) = allower.get(public_key) else {
89 error!("node with unknown pub key tried to connect {}", public_key);
90 MIDDLEWARE_OPS
91 .with_label_values(&[
92 "expect_valid_public_key",
93 "unknown-validator-connection-attempt",
94 ])
95 .inc();
96 return Err((StatusCode::FORBIDDEN, "unknown clients are not allowed"));
97 };
98 request.extensions_mut().insert(peer);
99 Ok(next.run(request).await)
100}
101
102#[derive(Debug)]
104pub struct LenDelimProtobuf(pub Vec<MetricFamily>);
105
106impl<S> FromRequest<S> for LenDelimProtobuf
107where
108 S: Send + Sync,
109{
110 type Rejection = (StatusCode, String);
111
112 async fn from_request(
113 req: Request<axum::body::Body>,
114 state: &S,
115 ) -> Result<Self, Self::Rejection> {
116 let should_be_snappy = req
117 .headers()
118 .get(CONTENT_ENCODING)
119 .map(|v| v.as_bytes() == b"snappy")
120 .unwrap_or(false);
121
122 let body = Bytes::from_request(req, state).await.map_err(|e| {
123 let msg = format!("error extracting bytes; {e}");
124 error!(msg);
125 MIDDLEWARE_OPS
126 .with_label_values(&["LenDelimProtobuf_from_request", "unable-to-extract-bytes"])
127 .inc();
128 (e.status(), msg)
129 })?;
130
131 let intermediate = if should_be_snappy {
132 let mut s = snap::raw::Decoder::new();
133 let decompressed = s.decompress_vec(&body).map_err(|e| {
134 let msg = format!("unable to decode snappy encoded protobufs; {e}");
135 error!(msg);
136 MIDDLEWARE_OPS
137 .with_label_values(&[
138 "LenDelimProtobuf_decompress_vec",
139 "unable-to-decode-snappy",
140 ])
141 .inc();
142 (StatusCode::BAD_REQUEST, msg)
143 })?;
144 Bytes::from(decompressed).reader()
145 } else {
146 body.reader()
147 };
148
149 let mut decoder = ProtobufDecoder::new(intermediate);
150 let decoded = decoder.parse::<MetricFamily>().map_err(|e| {
151 let msg = format!("unable to decode len delimited protobufs; {e}");
152 error!(msg);
153 MIDDLEWARE_OPS
154 .with_label_values(&[
155 "LenDelimProtobuf_from_request",
156 "unable-to-decode-protobufs",
157 ])
158 .inc();
159 (StatusCode::BAD_REQUEST, msg)
160 })?;
161 Ok(Self(decoded))
162 }
163}