iota_json_rpc/
axum_router.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    net::{IpAddr, SocketAddr},
7    sync::Arc,
8    time::SystemTime,
9};
10
11use axum::{
12    extract::{ConnectInfo, Json, State},
13    response::Response,
14};
15use hyper::{HeaderMap, header::HeaderValue};
16use iota_core::traffic_controller::{
17    TrafficController, metrics::TrafficControllerMetrics, policies::TrafficTally,
18};
19use iota_json_rpc_api::{
20    CLIENT_TARGET_API_VERSION_HEADER, TRANSACTION_EXECUTION_CLIENT_ERROR_CODE,
21};
22use iota_types::traffic_control::{ClientIdSource, PolicyConfig, RemoteFirewallConfig, Weight};
23use jsonrpsee::{
24    BoundedSubscriptions, ConnectionId, Extensions, MethodCallback, MethodKind, MethodResponse,
25    core::server::{Methods, helpers::MethodSink},
26    server::RandomIntegerIdProvider,
27    types::{
28        ErrorObject, Id, InvalidRequest, Params, Request,
29        error::{BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, ErrorCode},
30    },
31};
32use serde_json::value::RawValue;
33use tracing::error;
34
35use crate::{
36    logger::{Logger, TransportProtocol},
37    routing_layer::RpcRouter,
38};
39
40pub const MAX_RESPONSE_SIZE: u32 = 2 << 30;
41const TOO_MANY_REQUESTS_MSG: &str = "Too many requests";
42
43#[derive(Clone, Debug)]
44pub struct JsonRpcService<L> {
45    logger: L,
46
47    id_provider: Arc<RandomIntegerIdProvider>,
48
49    /// Registered server methods.
50    methods: Methods,
51    extensions: Extensions,
52    rpc_router: RpcRouter,
53    traffic_controller: Option<Arc<TrafficController>>,
54    client_id_source: Option<ClientIdSource>,
55}
56
57impl<L> JsonRpcService<L> {
58    pub fn new(
59        methods: Methods,
60        rpc_router: RpcRouter,
61        logger: L,
62        remote_fw_config: Option<RemoteFirewallConfig>,
63        policy_config: Option<PolicyConfig>,
64        traffic_controller_metrics: TrafficControllerMetrics,
65        extensions: Extensions,
66    ) -> Self {
67        Self {
68            methods,
69            rpc_router,
70            logger,
71            extensions,
72            id_provider: Arc::new(RandomIntegerIdProvider),
73            traffic_controller: policy_config.clone().map(|policy| {
74                Arc::new(TrafficController::spawn(
75                    policy,
76                    traffic_controller_metrics,
77                    remote_fw_config,
78                ))
79            }),
80            client_id_source: policy_config.map(|policy| policy.client_id_source),
81        }
82    }
83}
84
85impl<L: Logger> JsonRpcService<L> {
86    fn call_data(&self) -> CallData<'_, L> {
87        CallData {
88            logger: &self.logger,
89            methods: &self.methods,
90            rpc_router: &self.rpc_router,
91            extensions: &self.extensions,
92            max_response_body_size: MAX_RESPONSE_SIZE,
93            request_start: self.logger.on_request(TransportProtocol::Http),
94        }
95    }
96
97    fn ws_call_data<'c, 'a: 'c, 'b: 'c>(
98        &'a self,
99        bounded_subscriptions: BoundedSubscriptions,
100        sink: &'b MethodSink,
101    ) -> ws::WsCallData<'c, L> {
102        ws::WsCallData {
103            logger: &self.logger,
104            methods: &self.methods,
105            extensions: &self.extensions,
106            max_response_body_size: MAX_RESPONSE_SIZE,
107            request_start: self.logger.on_request(TransportProtocol::Http),
108            bounded_subscriptions,
109            id_provider: &*self.id_provider,
110            sink,
111        }
112    }
113}
114
115/// Create a response body.
116fn from_template<S: Into<axum::body::Body>>(
117    status: hyper::StatusCode,
118    body: S,
119    content_type: &'static str,
120) -> Response {
121    Response::builder()
122        .status(status)
123        .header(
124            "content-type",
125            hyper::header::HeaderValue::from_static(content_type),
126        )
127        .body(body.into())
128        // Parsing `StatusCode` and `HeaderValue` is infalliable but
129        // parsing body content is not.
130        .expect("Unable to parse response body for type conversion")
131}
132
133/// Create a valid JSON response.
134pub(crate) fn ok_response(body: String) -> Response {
135    const JSON: &str = "application/json; charset=utf-8";
136    from_template(hyper::StatusCode::OK, body, JSON)
137}
138
139pub async fn json_rpc_handler<L: Logger>(
140    ConnectInfo(client_addr): ConnectInfo<SocketAddr>,
141    State(service): State<JsonRpcService<L>>,
142    headers: HeaderMap,
143    Json(raw_request): Json<Box<RawValue>>,
144) -> impl axum::response::IntoResponse {
145    let headers_clone = headers.clone();
146    // Get version from header.
147    let api_version = headers
148        .get(CLIENT_TARGET_API_VERSION_HEADER)
149        .and_then(|h| h.to_str().ok());
150    let response = process_raw_request(
151        &service,
152        api_version,
153        raw_request.get(),
154        client_addr,
155        headers_clone,
156    )
157    .await;
158
159    ok_response(response.into_result())
160}
161
162async fn process_raw_request<L: Logger>(
163    service: &JsonRpcService<L>,
164    api_version: Option<&str>,
165    raw_request: &str,
166    client_addr: SocketAddr,
167    headers: HeaderMap,
168) -> MethodResponse {
169    let client = match service.client_id_source {
170        Some(ClientIdSource::SocketAddr) => Some(client_addr.ip()),
171        Some(ClientIdSource::XForwardedFor(num_hops)) => {
172            let do_header_parse = |header: &HeaderValue| match header.to_str() {
173                Ok(header_val) => {
174                    let header_contents = header_val.split(',').map(str::trim).collect::<Vec<_>>();
175                    if num_hops == 0 {
176                        error!(
177                            "x-forwarded-for: 0 specified. x-forwarded-for contents: {:?}. Please assign nonzero value for \
178                                number of hops here, or use `socket-addr` client-id-source type if requests are not being proxied \
179                                to this node. Skipping traffic controller request handling.",
180                            header_contents,
181                        );
182                        return None;
183                    }
184                    let contents_len = header_contents.len();
185                    let Some(client_ip) = header_contents.get(contents_len - num_hops) else {
186                        error!(
187                            "x-forwarded-for header value of {:?} contains {} values, but {} hops were specified. \
188                                Expected {} values. Skipping traffic controller request handling.",
189                            header_contents,
190                            contents_len,
191                            num_hops,
192                            num_hops + 1,
193                        );
194                        return None;
195                    };
196                    client_ip.parse::<IpAddr>().ok().or_else(|| {
197                        client_ip.parse::<SocketAddr>().ok().map(|socket_addr| socket_addr.ip()).or_else(|| {
198                                error!(
199                                    "Failed to parse x-forwarded-for header value of {:?} to ip address or socket. \
200                                    Please ensure that your proxy is configured to resolve client domains to an \
201                                    IP address before writing header",
202                                    client_ip,
203                                );
204                                None
205                            })
206                        })
207                }
208                Err(e) => {
209                    error!("Invalid UTF-8 in x-forwarded-for header: {:?}", e);
210                    None
211                }
212            };
213            if let Some(header) = headers.get("x-forwarded-for") {
214                do_header_parse(header)
215            } else if let Some(header) = headers.get("X-Forwarded-For") {
216                do_header_parse(header)
217            } else {
218                error!(
219                    "x-forwarded-for header not present for request despite node configuring x-forwarded-for tracking type"
220                );
221                None
222            }
223        }
224        None => None,
225    };
226    if let Ok(request) = serde_json::from_str::<Request>(raw_request) {
227        // check if either IP is blocked, in which case return early
228        if let Some(traffic_controller) = &service.traffic_controller {
229            if let Err(blocked_response) =
230                handle_traffic_req(traffic_controller.clone(), &client).await
231            {
232                return blocked_response;
233            }
234        }
235
236        // handle response tallying
237        let response = process_request(request, api_version, service.call_data()).await;
238        if let Some(traffic_controller) = &service.traffic_controller {
239            handle_traffic_resp(traffic_controller.clone(), client, &response);
240        }
241
242        response
243    } else if let Ok(_batch) = serde_json::from_str::<Vec<&RawValue>>(raw_request) {
244        MethodResponse::error(
245            Id::Null,
246            ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None),
247        )
248    } else {
249        let (id, code) = prepare_error(raw_request);
250        MethodResponse::error(id, ErrorObject::from(code))
251    }
252}
253
254async fn handle_traffic_req(
255    traffic_controller: Arc<TrafficController>,
256    client: &Option<IpAddr>,
257) -> Result<(), MethodResponse> {
258    if !traffic_controller.check(client, &None).await {
259        // Entity in blocklist
260        let err_obj =
261            ErrorObject::borrowed(ErrorCode::ServerIsBusy.code(), TOO_MANY_REQUESTS_MSG, None);
262        Err(MethodResponse::error(Id::Null, err_obj))
263    } else {
264        Ok(())
265    }
266}
267
268fn handle_traffic_resp(
269    traffic_controller: Arc<TrafficController>,
270    client: Option<IpAddr>,
271    response: &MethodResponse,
272) {
273    let error = response.as_error_code().map(ErrorCode::from);
274    traffic_controller.tally(TrafficTally {
275        direct: client,
276        through_fullnode: None,
277        error_weight: error.map(normalize).unwrap_or(Weight::zero()),
278        // For now, count everything as spam with equal weight
279        // on the rpc node side, including gas-charging endpoints
280        // such as `iota_executeTransactionBlock`, as this can enable
281        // node operators who wish to rate limit their transaction
282        // traffic and incentivize high volume clients to choose a
283        // suitable rpc provider (or run their own). Later we may want
284        // to provide a weight distribution based on the method being called.
285        spam_weight: Weight::one(),
286        timestamp: SystemTime::now(),
287    });
288}
289
290// TODO: refine error matching here
291fn normalize(err: ErrorCode) -> Weight {
292    match err {
293        ErrorCode::InvalidRequest | ErrorCode::InvalidParams => Weight::one(),
294        // e.g. invalid client signature
295        ErrorCode::ServerError(i) if i == TRANSACTION_EXECUTION_CLIENT_ERROR_CODE => Weight::one(),
296        _ => Weight::zero(),
297    }
298}
299
300async fn process_request<L: Logger>(
301    req: Request<'_>,
302    api_version: Option<&str>,
303    call: CallData<'_, L>,
304) -> MethodResponse {
305    let CallData {
306        methods,
307        rpc_router,
308        logger,
309        extensions,
310        max_response_body_size,
311        request_start,
312    } = call;
313    let conn_id = ConnectionId(0); // unused
314
315    let name = rpc_router.route(&req.method, api_version);
316    let params = Params::new(req.params.as_ref().map(|params| params.get()));
317
318    let id = req.id;
319
320    let response = match methods.method_with_name(name) {
321        None => {
322            logger.on_call(
323                name,
324                params.clone(),
325                MethodKind::NotFound,
326                TransportProtocol::Http,
327            );
328            MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound))
329        }
330        Some((name, method)) => match method {
331            MethodCallback::Sync(callback) => {
332                logger.on_call(
333                    name,
334                    params.clone(),
335                    MethodKind::MethodCall,
336                    TransportProtocol::Http,
337                );
338                (callback)(
339                    id,
340                    params,
341                    max_response_body_size as usize,
342                    extensions.clone(),
343                )
344            }
345            MethodCallback::Async(callback) => {
346                logger.on_call(
347                    name,
348                    params.clone(),
349                    MethodKind::MethodCall,
350                    TransportProtocol::Http,
351                );
352
353                let id = id.into_owned();
354                let params = params.into_owned();
355
356                (callback)(
357                    id,
358                    params,
359                    conn_id,
360                    max_response_body_size as usize,
361                    extensions.clone(),
362                )
363                .await
364            }
365            MethodCallback::Subscription(_) | MethodCallback::Unsubscription(_) => {
366                logger.on_call(
367                    name,
368                    params.clone(),
369                    MethodKind::NotFound,
370                    TransportProtocol::Http,
371                );
372                // Subscriptions not supported on HTTP
373                MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError))
374            }
375        },
376    };
377
378    logger.on_result(
379        name,
380        response.is_success(),
381        response.as_error_code(),
382        request_start,
383        TransportProtocol::Http,
384    );
385    response
386}
387
388/// Figure out if this is a sufficiently complete request that we can extract an
389/// [`Id`] out of, or just plain unparsable garbage.
390pub fn prepare_error(data: &str) -> (Id<'_>, ErrorCode) {
391    match serde_json::from_str::<InvalidRequest>(data) {
392        Ok(InvalidRequest { id }) => (id, ErrorCode::InvalidRequest),
393        Err(_) => (Id::Null, ErrorCode::ParseError),
394    }
395}
396
397#[derive(Debug, Clone)]
398pub(crate) struct CallData<'a, L: Logger> {
399    logger: &'a L,
400    methods: &'a Methods,
401    rpc_router: &'a RpcRouter,
402    extensions: &'a Extensions,
403    max_response_body_size: u32,
404    request_start: L::Instant,
405}
406
407pub mod ws {
408    use axum::{
409        extract::{
410            WebSocketUpgrade,
411            ws::{Message, WebSocket},
412        },
413        response::Response,
414    };
415    use jsonrpsee::{
416        SubscriptionState, core::server::helpers::MethodSink, server::IdProvider,
417        types::error::reject_too_many_subscriptions,
418    };
419    use tokio::sync::mpsc;
420
421    use super::*;
422
423    const MAX_WS_MESSAGE_BUFFER: usize = 100;
424
425    #[derive(Debug, Clone)]
426    pub(crate) struct WsCallData<'a, L: Logger> {
427        pub bounded_subscriptions: BoundedSubscriptions,
428        pub id_provider: &'a dyn IdProvider,
429        pub methods: &'a Methods,
430        pub extensions: &'a Extensions,
431        pub max_response_body_size: u32,
432        pub sink: &'a MethodSink,
433        pub logger: &'a L,
434        pub request_start: L::Instant,
435    }
436
437    // A WebSocket handler that echos any message it receives.
438    //
439    // This one we'll be integration testing so it can be written in the regular
440    // way.
441    pub async fn ws_json_rpc_upgrade<L: Logger>(
442        ws: WebSocketUpgrade,
443        State(service): State<JsonRpcService<L>>,
444    ) -> Response {
445        ws.on_upgrade(|ws| ws_json_rpc_handler(ws, service))
446    }
447
448    async fn ws_json_rpc_handler<L: Logger>(mut socket: WebSocket, service: JsonRpcService<L>) {
449        let (tx, mut rx) = mpsc::channel::<String>(MAX_WS_MESSAGE_BUFFER);
450        let sink = MethodSink::new_with_limit(tx, MAX_RESPONSE_SIZE);
451        let bounded_subscriptions = BoundedSubscriptions::new(100);
452
453        loop {
454            tokio::select! {
455                maybe_message = socket.recv() => {
456                    if let Some(Ok(message)) = maybe_message {
457                        if let Message::Text(msg) = message {
458                            let response =
459                                process_raw_request(&service, &msg, bounded_subscriptions.clone(), &sink).await;
460                            if let Some(response) = response {
461                                sink.send(response.into_result()).await.ok();
462                            }
463                        }
464                    } else {
465                        break;
466                    }
467                },
468                Some(response) = rx.recv() => {
469                    if socket.send(Message::Text(response)).await.is_err() {
470                        break;
471                    }
472                },
473            }
474        }
475    }
476
477    async fn process_raw_request<L: Logger>(
478        service: &JsonRpcService<L>,
479        raw_request: &str,
480        bounded_subscriptions: BoundedSubscriptions,
481        sink: &MethodSink,
482    ) -> Option<MethodResponse> {
483        if let Ok(request) = serde_json::from_str::<Request>(raw_request) {
484            process_request(request, service.ws_call_data(bounded_subscriptions, sink)).await
485        } else if let Ok(_batch) = serde_json::from_str::<Vec<&RawValue>>(raw_request) {
486            Some(MethodResponse::error(
487                Id::Null,
488                ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None),
489            ))
490        } else {
491            let (id, code) = prepare_error(raw_request);
492            Some(MethodResponse::error(id, ErrorObject::from(code)))
493        }
494    }
495
496    async fn process_request<L: Logger>(
497        req: Request<'_>,
498        call: WsCallData<'_, L>,
499    ) -> Option<MethodResponse> {
500        let WsCallData {
501            methods,
502            logger,
503            extensions,
504            max_response_body_size,
505            request_start,
506            bounded_subscriptions,
507            id_provider,
508            sink,
509        } = call;
510        let conn_id = ConnectionId(0); // unused
511
512        let params = Params::new(req.params.as_ref().map(|params| params.get()));
513        let name = &req.method;
514        let id = req.id;
515
516        let response = match methods.method_with_name(name) {
517            None => {
518                logger.on_call(
519                    name,
520                    params.clone(),
521                    MethodKind::NotFound,
522                    TransportProtocol::Http,
523                );
524                Some(MethodResponse::error(
525                    id,
526                    ErrorObject::from(ErrorCode::MethodNotFound),
527                ))
528            }
529            Some((name, method)) => match method {
530                MethodCallback::Sync(callback) => {
531                    logger.on_call(
532                        name,
533                        params.clone(),
534                        MethodKind::MethodCall,
535                        TransportProtocol::Http,
536                    );
537                    tracing::info!("calling {name} sync");
538                    Some((callback)(
539                        id,
540                        params,
541                        max_response_body_size as usize,
542                        extensions.clone(),
543                    ))
544                }
545                MethodCallback::Async(callback) => {
546                    logger.on_call(
547                        name,
548                        params.clone(),
549                        MethodKind::MethodCall,
550                        TransportProtocol::Http,
551                    );
552
553                    let id = id.into_owned();
554                    let params = params.into_owned();
555
556                    tracing::info!("calling {name} async");
557                    Some(
558                        (callback)(
559                            id,
560                            params,
561                            conn_id,
562                            max_response_body_size as usize,
563                            extensions.clone(),
564                        )
565                        .await,
566                    )
567                }
568
569                MethodCallback::Subscription(callback) => {
570                    logger.on_call(
571                        name,
572                        params.clone(),
573                        MethodKind::Subscription,
574                        TransportProtocol::WebSocket,
575                    );
576                    if let Some(subscription_permit) = bounded_subscriptions.acquire() {
577                        let conn_state = SubscriptionState {
578                            conn_id,
579                            subscription_permit,
580                            id_provider,
581                        };
582                        (callback)(
583                            id.clone(),
584                            params,
585                            sink.clone(),
586                            conn_state,
587                            extensions.clone(),
588                        )
589                        .await;
590                        None
591                    } else {
592                        Some(MethodResponse::error(
593                            id,
594                            reject_too_many_subscriptions(bounded_subscriptions.max()),
595                        ))
596                    }
597                }
598
599                MethodCallback::Unsubscription(callback) => {
600                    logger.on_call(
601                        name,
602                        params.clone(),
603                        MethodKind::Unsubscription,
604                        TransportProtocol::WebSocket,
605                    );
606
607                    Some(callback(
608                        id,
609                        params,
610                        conn_id,
611                        max_response_body_size as usize,
612                        extensions.clone(),
613                    ))
614                }
615            },
616        };
617
618        if let Some(response) = &response {
619            logger.on_result(
620                name,
621                response.is_success(),
622                response.as_error_code(),
623                request_start,
624                TransportProtocol::WebSocket,
625            );
626        }
627        response
628    }
629}