1use std::{
6 net::{IpAddr, SocketAddr},
7 sync::Arc,
8 time::SystemTime,
9};
10
11use axum::{
12 extract::{ConnectInfo, Json, State},
13 response::Response,
14};
15use hyper::{HeaderMap, header::HeaderValue};
16use iota_core::traffic_controller::{
17 TrafficController, metrics::TrafficControllerMetrics, policies::TrafficTally,
18};
19use iota_json_rpc_api::{
20 CLIENT_TARGET_API_VERSION_HEADER, TRANSACTION_EXECUTION_CLIENT_ERROR_CODE,
21};
22use iota_types::traffic_control::{ClientIdSource, PolicyConfig, RemoteFirewallConfig, Weight};
23use jsonrpsee::{
24 BoundedSubscriptions, ConnectionId, Extensions, MethodCallback, MethodKind, MethodResponse,
25 core::server::{Methods, helpers::MethodSink},
26 server::RandomIntegerIdProvider,
27 types::{
28 ErrorObject, Id, InvalidRequest, Params, Request,
29 error::{BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, ErrorCode},
30 },
31};
32use serde_json::value::RawValue;
33use tracing::error;
34
35use crate::{
36 logger::{Logger, TransportProtocol},
37 routing_layer::RpcRouter,
38};
39
40pub const MAX_RESPONSE_SIZE: u32 = 2 << 30;
41const TOO_MANY_REQUESTS_MSG: &str = "Too many requests";
42
43#[derive(Clone, Debug)]
44pub struct JsonRpcService<L> {
45 logger: L,
46
47 id_provider: Arc<RandomIntegerIdProvider>,
48
49 methods: Methods,
51 extensions: Extensions,
52 rpc_router: RpcRouter,
53 traffic_controller: Option<Arc<TrafficController>>,
54 client_id_source: Option<ClientIdSource>,
55}
56
57impl<L> JsonRpcService<L> {
58 pub fn new(
59 methods: Methods,
60 rpc_router: RpcRouter,
61 logger: L,
62 remote_fw_config: Option<RemoteFirewallConfig>,
63 policy_config: Option<PolicyConfig>,
64 traffic_controller_metrics: TrafficControllerMetrics,
65 extensions: Extensions,
66 ) -> Self {
67 Self {
68 methods,
69 rpc_router,
70 logger,
71 extensions,
72 id_provider: Arc::new(RandomIntegerIdProvider),
73 traffic_controller: policy_config.clone().map(|policy| {
74 Arc::new(TrafficController::spawn(
75 policy,
76 traffic_controller_metrics,
77 remote_fw_config,
78 ))
79 }),
80 client_id_source: policy_config.map(|policy| policy.client_id_source),
81 }
82 }
83}
84
85impl<L: Logger> JsonRpcService<L> {
86 fn call_data(&self) -> CallData<'_, L> {
87 CallData {
88 logger: &self.logger,
89 methods: &self.methods,
90 rpc_router: &self.rpc_router,
91 extensions: &self.extensions,
92 max_response_body_size: MAX_RESPONSE_SIZE,
93 request_start: self.logger.on_request(TransportProtocol::Http),
94 }
95 }
96
97 fn ws_call_data<'c, 'a: 'c, 'b: 'c>(
98 &'a self,
99 bounded_subscriptions: BoundedSubscriptions,
100 sink: &'b MethodSink,
101 ) -> ws::WsCallData<'c, L> {
102 ws::WsCallData {
103 logger: &self.logger,
104 methods: &self.methods,
105 extensions: &self.extensions,
106 max_response_body_size: MAX_RESPONSE_SIZE,
107 request_start: self.logger.on_request(TransportProtocol::Http),
108 bounded_subscriptions,
109 id_provider: &*self.id_provider,
110 sink,
111 }
112 }
113}
114
115fn from_template<S: Into<axum::body::Body>>(
117 status: hyper::StatusCode,
118 body: S,
119 content_type: &'static str,
120) -> Response {
121 Response::builder()
122 .status(status)
123 .header(
124 "content-type",
125 hyper::header::HeaderValue::from_static(content_type),
126 )
127 .body(body.into())
128 .expect("Unable to parse response body for type conversion")
131}
132
133pub(crate) fn ok_response(body: String) -> Response {
135 const JSON: &str = "application/json; charset=utf-8";
136 from_template(hyper::StatusCode::OK, body, JSON)
137}
138
139pub async fn json_rpc_handler<L: Logger>(
140 ConnectInfo(client_addr): ConnectInfo<SocketAddr>,
141 State(service): State<JsonRpcService<L>>,
142 headers: HeaderMap,
143 Json(raw_request): Json<Box<RawValue>>,
144) -> impl axum::response::IntoResponse {
145 let headers_clone = headers.clone();
146 let api_version = headers
148 .get(CLIENT_TARGET_API_VERSION_HEADER)
149 .and_then(|h| h.to_str().ok());
150 let response = process_raw_request(
151 &service,
152 api_version,
153 raw_request.get(),
154 client_addr,
155 headers_clone,
156 )
157 .await;
158
159 ok_response(response.into_result())
160}
161
162async fn process_raw_request<L: Logger>(
163 service: &JsonRpcService<L>,
164 api_version: Option<&str>,
165 raw_request: &str,
166 client_addr: SocketAddr,
167 headers: HeaderMap,
168) -> MethodResponse {
169 let client = match service.client_id_source {
170 Some(ClientIdSource::SocketAddr) => Some(client_addr.ip()),
171 Some(ClientIdSource::XForwardedFor(num_hops)) => {
172 let do_header_parse = |header: &HeaderValue| match header.to_str() {
173 Ok(header_val) => {
174 let header_contents = header_val.split(',').map(str::trim).collect::<Vec<_>>();
175 if num_hops == 0 {
176 error!(
177 "x-forwarded-for: 0 specified. x-forwarded-for contents: {:?}. Please assign nonzero value for \
178 number of hops here, or use `socket-addr` client-id-source type if requests are not being proxied \
179 to this node. Skipping traffic controller request handling.",
180 header_contents,
181 );
182 return None;
183 }
184 let contents_len = header_contents.len();
185 let Some(client_ip) = header_contents.get(contents_len - num_hops) else {
186 error!(
187 "x-forwarded-for header value of {:?} contains {} values, but {} hops were specified. \
188 Expected {} values. Skipping traffic controller request handling.",
189 header_contents,
190 contents_len,
191 num_hops,
192 num_hops + 1,
193 );
194 return None;
195 };
196 client_ip.parse::<IpAddr>().ok().or_else(|| {
197 client_ip.parse::<SocketAddr>().ok().map(|socket_addr| socket_addr.ip()).or_else(|| {
198 error!(
199 "Failed to parse x-forwarded-for header value of {:?} to ip address or socket. \
200 Please ensure that your proxy is configured to resolve client domains to an \
201 IP address before writing header",
202 client_ip,
203 );
204 None
205 })
206 })
207 }
208 Err(e) => {
209 error!("Invalid UTF-8 in x-forwarded-for header: {:?}", e);
210 None
211 }
212 };
213 if let Some(header) = headers.get("x-forwarded-for") {
214 do_header_parse(header)
215 } else if let Some(header) = headers.get("X-Forwarded-For") {
216 do_header_parse(header)
217 } else {
218 error!(
219 "x-forwarded-for header not present for request despite node configuring x-forwarded-for tracking type"
220 );
221 None
222 }
223 }
224 None => None,
225 };
226 if let Ok(request) = serde_json::from_str::<Request>(raw_request) {
227 if let Some(traffic_controller) = &service.traffic_controller {
229 if let Err(blocked_response) =
230 handle_traffic_req(traffic_controller.clone(), &client).await
231 {
232 return blocked_response;
233 }
234 }
235
236 let response = process_request(request, api_version, service.call_data()).await;
238 if let Some(traffic_controller) = &service.traffic_controller {
239 handle_traffic_resp(traffic_controller.clone(), client, &response);
240 }
241
242 response
243 } else if let Ok(_batch) = serde_json::from_str::<Vec<&RawValue>>(raw_request) {
244 MethodResponse::error(
245 Id::Null,
246 ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None),
247 )
248 } else {
249 let (id, code) = prepare_error(raw_request);
250 MethodResponse::error(id, ErrorObject::from(code))
251 }
252}
253
254async fn handle_traffic_req(
255 traffic_controller: Arc<TrafficController>,
256 client: &Option<IpAddr>,
257) -> Result<(), MethodResponse> {
258 if !traffic_controller.check(client, &None).await {
259 let err_obj =
261 ErrorObject::borrowed(ErrorCode::ServerIsBusy.code(), TOO_MANY_REQUESTS_MSG, None);
262 Err(MethodResponse::error(Id::Null, err_obj))
263 } else {
264 Ok(())
265 }
266}
267
268fn handle_traffic_resp(
269 traffic_controller: Arc<TrafficController>,
270 client: Option<IpAddr>,
271 response: &MethodResponse,
272) {
273 let error = response.as_error_code().map(ErrorCode::from);
274 traffic_controller.tally(TrafficTally {
275 direct: client,
276 through_fullnode: None,
277 error_weight: error.map(normalize).unwrap_or(Weight::zero()),
278 spam_weight: Weight::one(),
286 timestamp: SystemTime::now(),
287 });
288}
289
290fn normalize(err: ErrorCode) -> Weight {
292 match err {
293 ErrorCode::InvalidRequest | ErrorCode::InvalidParams => Weight::one(),
294 ErrorCode::ServerError(i) if i == TRANSACTION_EXECUTION_CLIENT_ERROR_CODE => Weight::one(),
296 _ => Weight::zero(),
297 }
298}
299
300async fn process_request<L: Logger>(
301 req: Request<'_>,
302 api_version: Option<&str>,
303 call: CallData<'_, L>,
304) -> MethodResponse {
305 let CallData {
306 methods,
307 rpc_router,
308 logger,
309 extensions,
310 max_response_body_size,
311 request_start,
312 } = call;
313 let conn_id = ConnectionId(0); let name = rpc_router.route(&req.method, api_version);
316 let params = Params::new(req.params.as_ref().map(|params| params.get()));
317
318 let id = req.id;
319
320 let response = match methods.method_with_name(name) {
321 None => {
322 logger.on_call(
323 name,
324 params.clone(),
325 MethodKind::NotFound,
326 TransportProtocol::Http,
327 );
328 MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound))
329 }
330 Some((name, method)) => match method {
331 MethodCallback::Sync(callback) => {
332 logger.on_call(
333 name,
334 params.clone(),
335 MethodKind::MethodCall,
336 TransportProtocol::Http,
337 );
338 (callback)(
339 id,
340 params,
341 max_response_body_size as usize,
342 extensions.clone(),
343 )
344 }
345 MethodCallback::Async(callback) => {
346 logger.on_call(
347 name,
348 params.clone(),
349 MethodKind::MethodCall,
350 TransportProtocol::Http,
351 );
352
353 let id = id.into_owned();
354 let params = params.into_owned();
355
356 (callback)(
357 id,
358 params,
359 conn_id,
360 max_response_body_size as usize,
361 extensions.clone(),
362 )
363 .await
364 }
365 MethodCallback::Subscription(_) | MethodCallback::Unsubscription(_) => {
366 logger.on_call(
367 name,
368 params.clone(),
369 MethodKind::NotFound,
370 TransportProtocol::Http,
371 );
372 MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError))
374 }
375 },
376 };
377
378 logger.on_result(
379 name,
380 response.is_success(),
381 response.as_error_code(),
382 request_start,
383 TransportProtocol::Http,
384 );
385 response
386}
387
388pub fn prepare_error(data: &str) -> (Id<'_>, ErrorCode) {
391 match serde_json::from_str::<InvalidRequest>(data) {
392 Ok(InvalidRequest { id }) => (id, ErrorCode::InvalidRequest),
393 Err(_) => (Id::Null, ErrorCode::ParseError),
394 }
395}
396
397#[derive(Debug, Clone)]
398pub(crate) struct CallData<'a, L: Logger> {
399 logger: &'a L,
400 methods: &'a Methods,
401 rpc_router: &'a RpcRouter,
402 extensions: &'a Extensions,
403 max_response_body_size: u32,
404 request_start: L::Instant,
405}
406
407pub mod ws {
408 use axum::{
409 extract::{
410 WebSocketUpgrade,
411 ws::{Message, WebSocket},
412 },
413 response::Response,
414 };
415 use jsonrpsee::{
416 SubscriptionState, core::server::helpers::MethodSink, server::IdProvider,
417 types::error::reject_too_many_subscriptions,
418 };
419 use tokio::sync::mpsc;
420
421 use super::*;
422
423 const MAX_WS_MESSAGE_BUFFER: usize = 100;
424
425 #[derive(Debug, Clone)]
426 pub(crate) struct WsCallData<'a, L: Logger> {
427 pub bounded_subscriptions: BoundedSubscriptions,
428 pub id_provider: &'a dyn IdProvider,
429 pub methods: &'a Methods,
430 pub extensions: &'a Extensions,
431 pub max_response_body_size: u32,
432 pub sink: &'a MethodSink,
433 pub logger: &'a L,
434 pub request_start: L::Instant,
435 }
436
437 pub async fn ws_json_rpc_upgrade<L: Logger>(
442 ws: WebSocketUpgrade,
443 State(service): State<JsonRpcService<L>>,
444 ) -> Response {
445 ws.on_upgrade(|ws| ws_json_rpc_handler(ws, service))
446 }
447
448 async fn ws_json_rpc_handler<L: Logger>(mut socket: WebSocket, service: JsonRpcService<L>) {
449 let (tx, mut rx) = mpsc::channel::<String>(MAX_WS_MESSAGE_BUFFER);
450 let sink = MethodSink::new_with_limit(tx, MAX_RESPONSE_SIZE);
451 let bounded_subscriptions = BoundedSubscriptions::new(100);
452
453 loop {
454 tokio::select! {
455 maybe_message = socket.recv() => {
456 if let Some(Ok(message)) = maybe_message {
457 if let Message::Text(msg) = message {
458 let response =
459 process_raw_request(&service, &msg, bounded_subscriptions.clone(), &sink).await;
460 if let Some(response) = response {
461 sink.send(response.into_result()).await.ok();
462 }
463 }
464 } else {
465 break;
466 }
467 },
468 Some(response) = rx.recv() => {
469 if socket.send(Message::Text(response)).await.is_err() {
470 break;
471 }
472 },
473 }
474 }
475 }
476
477 async fn process_raw_request<L: Logger>(
478 service: &JsonRpcService<L>,
479 raw_request: &str,
480 bounded_subscriptions: BoundedSubscriptions,
481 sink: &MethodSink,
482 ) -> Option<MethodResponse> {
483 if let Ok(request) = serde_json::from_str::<Request>(raw_request) {
484 process_request(request, service.ws_call_data(bounded_subscriptions, sink)).await
485 } else if let Ok(_batch) = serde_json::from_str::<Vec<&RawValue>>(raw_request) {
486 Some(MethodResponse::error(
487 Id::Null,
488 ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None),
489 ))
490 } else {
491 let (id, code) = prepare_error(raw_request);
492 Some(MethodResponse::error(id, ErrorObject::from(code)))
493 }
494 }
495
496 async fn process_request<L: Logger>(
497 req: Request<'_>,
498 call: WsCallData<'_, L>,
499 ) -> Option<MethodResponse> {
500 let WsCallData {
501 methods,
502 logger,
503 extensions,
504 max_response_body_size,
505 request_start,
506 bounded_subscriptions,
507 id_provider,
508 sink,
509 } = call;
510 let conn_id = ConnectionId(0); let params = Params::new(req.params.as_ref().map(|params| params.get()));
513 let name = &req.method;
514 let id = req.id;
515
516 let response = match methods.method_with_name(name) {
517 None => {
518 logger.on_call(
519 name,
520 params.clone(),
521 MethodKind::NotFound,
522 TransportProtocol::Http,
523 );
524 Some(MethodResponse::error(
525 id,
526 ErrorObject::from(ErrorCode::MethodNotFound),
527 ))
528 }
529 Some((name, method)) => match method {
530 MethodCallback::Sync(callback) => {
531 logger.on_call(
532 name,
533 params.clone(),
534 MethodKind::MethodCall,
535 TransportProtocol::Http,
536 );
537 tracing::info!("calling {name} sync");
538 Some((callback)(
539 id,
540 params,
541 max_response_body_size as usize,
542 extensions.clone(),
543 ))
544 }
545 MethodCallback::Async(callback) => {
546 logger.on_call(
547 name,
548 params.clone(),
549 MethodKind::MethodCall,
550 TransportProtocol::Http,
551 );
552
553 let id = id.into_owned();
554 let params = params.into_owned();
555
556 tracing::info!("calling {name} async");
557 Some(
558 (callback)(
559 id,
560 params,
561 conn_id,
562 max_response_body_size as usize,
563 extensions.clone(),
564 )
565 .await,
566 )
567 }
568
569 MethodCallback::Subscription(callback) => {
570 logger.on_call(
571 name,
572 params.clone(),
573 MethodKind::Subscription,
574 TransportProtocol::WebSocket,
575 );
576 if let Some(subscription_permit) = bounded_subscriptions.acquire() {
577 let conn_state = SubscriptionState {
578 conn_id,
579 subscription_permit,
580 id_provider,
581 };
582 (callback)(
583 id.clone(),
584 params,
585 sink.clone(),
586 conn_state,
587 extensions.clone(),
588 )
589 .await;
590 None
591 } else {
592 Some(MethodResponse::error(
593 id,
594 reject_too_many_subscriptions(bounded_subscriptions.max()),
595 ))
596 }
597 }
598
599 MethodCallback::Unsubscription(callback) => {
600 logger.on_call(
601 name,
602 params.clone(),
603 MethodKind::Unsubscription,
604 TransportProtocol::WebSocket,
605 );
606
607 Some(callback(
608 id,
609 params,
610 conn_id,
611 max_response_body_size as usize,
612 extensions.clone(),
613 ))
614 }
615 },
616 };
617
618 if let Some(response) = &response {
619 logger.on_result(
620 name,
621 response.is_success(),
622 response.as_error_code(),
623 request_start,
624 TransportProtocol::WebSocket,
625 );
626 }
627 response
628 }
629}