1use std::{
6 collections::{HashMap, HashSet, VecDeque},
7 mem,
8 net::SocketAddr,
9 sync::{Arc, Mutex},
10 time::Instant,
11};
12
13use async_graphql::{
14 Name, Pos, Positioned, Response, ServerError, ServerResult, Variables,
15 extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery, NextRequest},
16 parser::types::{
17 DocumentOperations, ExecutableDocument, Field, FragmentDefinition, OperationDefinition,
18 Selection,
19 },
20 value,
21};
22use async_graphql_value::{ConstValue, Value};
23use async_trait::async_trait;
24use axum::http::HeaderName;
25use iota_graphql_rpc_headers::LIMITS_HEADER;
26use serde::Serialize;
27use tracing::{error, info};
28use uuid::Uuid;
29
30use crate::{
31 config::{Limits, ServiceConfig},
32 error::{code, graphql_error, graphql_error_at_pos},
33 metrics::Metrics,
34};
35
36const DRY_RUN_TX_BLOCK: &str = "dryRunTransactionBlock";
37const EXECUTE_TX_BLOCK: &str = "executeTransactionBlock";
38
39#[derive(Clone, Copy, Debug)]
43pub(crate) struct PayloadSize(pub u64);
44
45pub(crate) const CONNECTION_FIELDS: [&str; 2] = ["edges", "nodes"];
46
47pub(crate) struct QueryLimitsChecker;
50
51#[derive(Debug, Default)]
52struct QueryLimitsCheckerExt {
53 usage: Mutex<Option<Usage>>,
54}
55
56pub(crate) struct ShowUsage;
58
59struct Reporter<'a> {
61 limits: &'a Limits,
62 query_id: &'a Uuid,
63 session_id: &'a SocketAddr,
64}
65
66impl<'a> Reporter<'a> {
67 fn new(ctx: &'a ExtensionContext<'a>) -> Self {
68 let cfg: &ServiceConfig = ctx.data_unchecked();
69 Self {
70 limits: &cfg.limits,
71 query_id: ctx.data_unchecked(),
72 session_id: ctx.data_unchecked(),
73 }
74 }
75
76 fn fragment_not_found_error(&self, name: &Name, pos: Pos) -> ServerError {
79 self.graphql_error_at_pos(
80 code::BAD_USER_INPUT,
81 format!("Fragment {name} referred to but not found in document"),
82 pos,
83 )
84 }
85
86 fn output_node_error(&self) -> ServerError {
88 self.graphql_error(
89 code::BAD_USER_INPUT,
90 format!(
91 "Estimated output nodes exceeds {}",
92 self.limits.max_output_nodes
93 ),
94 )
95 }
96
97 fn payload_size_error(&self, message: &str) -> ServerError {
99 self.graphql_error(
100 code::BAD_USER_INPUT,
101 format!(
102 "{message}. Requests are limited to {max_tx_payload} bytes or fewer on \
103 transaction payloads (all inputs to executeTransactionBlock or \
104 dryRunTransactionBlock) and the rest of the request (the query part) must be \
105 {max_query_payload} bytes or fewer.",
106 max_tx_payload = self.limits.max_tx_payload_size,
107 max_query_payload = self.limits.max_query_payload_size,
108 ),
109 )
110 }
111
112 fn graphql_error(&self, code: &str, message: String) -> ServerError {
114 self.log_error(code, &message);
115 graphql_error(code, message)
116 }
117
118 fn graphql_error_at_pos(&self, code: &str, message: String, pos: Pos) -> ServerError {
121 self.log_error(code, &message);
122 graphql_error_at_pos(code, message, pos)
123 }
124
125 fn log_error(&self, error_code: &str, message: &str) {
127 if error_code == code::INTERNAL_SERVER_ERROR {
128 error!(
129 query_id = %self.query_id,
130 session_id = %self.session_id,
131 error_code,
132 "Internal error while checking limits: {message}",
133 );
134 } else {
135 info!(
136 query_id = %self.query_id,
137 session_id = %self.session_id,
138 error_code,
139 "Limits error: {message}",
140 );
141 }
142 }
143}
144
145struct LimitsTraversal<'a> {
149 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
151 variables: &'a Variables,
152
153 reporter: &'a Reporter<'a>,
155 payload_size: u64,
157 tx_variables_used: HashSet<&'a Name>,
162
163 tx_payload_budget: u32,
165 input_budget: u32,
166 output_budget: u32,
167 depth_seen: u32,
168}
169
170#[derive(Clone, Debug, Default, Serialize)]
171#[serde(rename_all = "camelCase")]
172struct Usage {
173 input_nodes: u32,
174 output_nodes: u32,
175 depth: u32,
176 variables: u32,
177 fragments: u32,
178 query_payload: u32,
179}
180
181impl ShowUsage {
182 pub(crate) fn name() -> &'static HeaderName {
183 &LIMITS_HEADER
184 }
185}
186
187impl<'a> LimitsTraversal<'a> {
188 fn new(
189 PayloadSize(payload_size): PayloadSize,
190 reporter: &'a Reporter<'a>,
191 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
192 variables: &'a Variables,
193 ) -> Self {
194 Self {
195 fragments,
196 variables,
197 payload_size,
198 reporter,
199 input_budget: reporter.limits.max_query_nodes,
200 output_budget: reporter.limits.max_output_nodes,
201 tx_payload_budget: reporter.limits.max_tx_payload_size,
202 tx_variables_used: HashSet::new(),
203 depth_seen: 0,
204 }
205 }
206
207 fn check_document(&mut self, doc: &'a ExecutableDocument) -> ServerResult<()> {
209 for (_name, op) in doc.operations.iter() {
213 self.check_input_limits(op)?;
214 }
215 for (_name, op) in doc.operations.iter() {
218 self.check_tx_payload(op)?;
219 }
220 let limits = self.reporter.limits;
223 let tx_payload_size = (limits.max_tx_payload_size - self.tx_payload_budget) as u64;
224 let query_payload_size = self.payload_size - tx_payload_size;
225 if query_payload_size > limits.max_query_payload_size as u64 {
226 let message = format!("Query part too large: {query_payload_size} bytes");
227 return Err(self.reporter.payload_size_error(&message));
228 }
229 for (_name, op) in doc.operations.iter() {
232 self.check_output_limits(op)?;
233 }
234
235 Ok(())
236 }
237
238 fn check_input_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
240 let limits = self.reporter.limits;
241
242 let mut next_level = vec![];
243 let mut curr_level = vec![];
244 let mut depth_budget = limits.max_query_depth;
245
246 next_level.extend(&op.node.selection_set.node.items);
247 while let Some(next) = next_level.first() {
248 if depth_budget == 0 {
249 return Err(self.reporter.graphql_error_at_pos(
250 code::BAD_USER_INPUT,
251 format!("Query nesting is over {}", limits.max_query_depth),
252 next.pos,
253 ));
254 } else {
255 depth_budget -= 1;
256 }
257
258 mem::swap(&mut next_level, &mut curr_level);
259
260 for selection in curr_level.drain(..) {
261 if self.input_budget == 0 {
262 return Err(self.reporter.graphql_error_at_pos(
263 code::BAD_USER_INPUT,
264 format!("Query has over {} nodes", limits.max_query_nodes),
265 selection.pos,
266 ));
267 } else {
268 self.input_budget -= 1;
269 }
270
271 match &selection.node {
272 Selection::Field(f) => {
273 next_level.extend(&f.node.selection_set.node.items);
274 }
275
276 Selection::InlineFragment(f) => {
277 next_level.extend(&f.node.selection_set.node.items);
278 }
279
280 Selection::FragmentSpread(fs) => {
281 let name = &fs.node.fragment_name.node;
282 let def = self
283 .fragments
284 .get(name)
285 .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
286
287 next_level.extend(&def.node.selection_set.node.items);
288 }
289 }
290 }
291 }
292
293 self.depth_seen = self.depth_seen.max(limits.max_query_depth - depth_budget);
294 Ok(())
295 }
296
297 fn check_tx_payload(&mut self, op: &'a Positioned<OperationDefinition>) -> ServerResult<()> {
305 let tx_arg_values = op
306 .node
307 .selection_set
308 .node
309 .items
310 .iter()
311 .flat_map(|selection| {
312 TxArgValueIter::new(&selection.node, self.fragments, self.reporter)
313 });
314 for value in tx_arg_values {
315 let cost = value?.cost(self.variables, &mut self.tx_variables_used);
316 if cost > self.tx_payload_budget as usize {
317 self.tx_payload_budget = 0;
319 return Err(self.tx_payload_size_error());
320 } else {
321 self.tx_payload_budget -= cost as u32;
322 }
323 }
324 Ok(())
325 }
326
327 fn check_output_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
334 for selection in &op.node.selection_set.node.items {
335 self.traverse_selection_for_output(selection, 1, None)?;
336 }
337 Ok(())
338 }
339
340 fn traverse_selection_for_output(
349 &mut self,
350 selection: &Positioned<Selection>,
351 multiplicity: u32,
352 page_size: Option<u32>,
353 ) -> ServerResult<()> {
354 match &selection.node {
355 Selection::Field(f) => {
356 if multiplicity > self.output_budget {
357 return Err(self.output_node_error());
358 } else {
359 self.output_budget -= multiplicity;
360 }
361
362 let name = &f.node.name.node;
367 let multiplicity = 'm: {
368 if !CONNECTION_FIELDS.contains(&name.as_str()) {
369 break 'm multiplicity;
370 }
371
372 let Some(page_size) = page_size else {
373 break 'm multiplicity;
374 };
375
376 multiplicity
377 .checked_mul(page_size)
378 .ok_or_else(|| self.output_node_error())?
379 };
380
381 let page_size = self.connection_page_size(f)?;
382 for selection in &f.node.selection_set.node.items {
383 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
384 }
385 }
386
387 Selection::InlineFragment(f) => {
389 for selection in &f.node.selection_set.node.items {
390 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
391 }
392 }
393
394 Selection::FragmentSpread(fs) => {
395 let name = &fs.node.fragment_name.node;
396 let def = self
397 .fragments
398 .get(name)
399 .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
400
401 for selection in &def.node.selection_set.node.items {
402 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
403 }
404 }
405 }
406
407 Ok(())
408 }
409
410 fn tx_payload_size_error(&mut self) -> ServerError {
412 self.reporter
413 .payload_size_error("Transaction payload too large")
414 }
415
416 fn connection_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
420 if !self.is_connection(f) {
421 return Ok(None);
422 }
423
424 let first = f.node.get_argument("first");
425 let last = f.node.get_argument("last");
426
427 let page_size = match (self.resolve_u64(first), self.resolve_u64(last)) {
428 (Some(f), Some(l)) => f.max(l),
429 (Some(p), _) | (_, Some(p)) => p,
430 (None, None) => self.reporter.limits.default_page_size as u64,
431 };
432
433 Ok(Some(
434 page_size.try_into().map_err(|_| self.output_node_error())?,
435 ))
436 }
437
438 fn is_connection(&self, f: &Positioned<Field>) -> bool {
443 f.node
444 .selection_set
445 .node
446 .items
447 .iter()
448 .any(|s| self.has_connection_fields(s))
449 }
450
451 fn has_connection_fields(&self, s: &Positioned<Selection>) -> bool {
456 match &s.node {
457 Selection::Field(f) => {
458 let name = &f.node.name.node;
459 CONNECTION_FIELDS.contains(&name.as_str())
460 }
461
462 Selection::InlineFragment(f) => f
463 .node
464 .selection_set
465 .node
466 .items
467 .iter()
468 .any(|s| self.has_connection_fields(s)),
469
470 Selection::FragmentSpread(fs) => {
471 let name = &fs.node.fragment_name.node;
472 let Some(def) = self.fragments.get(name) else {
473 return false;
474 };
475
476 def.node
477 .selection_set
478 .node
479 .items
480 .iter()
481 .any(|s| self.has_connection_fields(s))
482 }
483 }
484 }
485
486 fn resolve_u64(&self, value: Option<&Positioned<Value>>) -> Option<u64> {
489 match &value?.node {
490 Value::Number(num) => num,
491
492 Value::Variable(var) => {
493 if let ConstValue::Number(num) = self.variables.get(var)? {
494 num
495 } else {
496 return None;
497 }
498 }
499
500 _ => return None,
501 }
502 .as_u64()
503 }
504
505 fn output_node_error(&mut self) -> ServerError {
511 self.output_budget = 0;
512 self.reporter.output_node_error()
513 }
514
515 fn finish(self, query_payload: u32) -> Usage {
517 let limits = self.reporter.limits;
518 Usage {
519 input_nodes: limits.max_query_nodes - self.input_budget,
520 output_nodes: limits.max_output_nodes - self.output_budget,
521 depth: self.depth_seen,
522 variables: self.variables.len() as u32,
523 fragments: self.fragments.len() as u32,
524 query_payload,
525 }
526 }
527}
528
529#[derive(Debug)]
530enum ParsedValue<'a> {
531 GraphQL(&'a Value),
533 Resolved(&'a ConstValue),
535}
536
537struct VariableUsage {
538 cost: usize,
540 count: usize,
542}
543
544impl VariableUsage {
545 fn new(cost: usize) -> Self {
546 Self { cost, count: 1 }
547 }
548
549 fn increase_count(&mut self) {
550 self.count += 1;
551 }
552}
553
554#[derive(Default)]
556struct ValueCostReport<'a> {
557 gross: usize,
559 variables_used: HashMap<&'a Name, VariableUsage>,
561}
562
563impl<'a> ValueCostReport<'a> {
564 fn new(cost: usize) -> Self {
565 Self {
566 gross: cost,
567 variables_used: HashMap::new(),
568 }
569 }
570
571 fn add_variable(&mut self, name: &'a Name, cost: usize) {
572 self.variables_used
573 .entry(name)
574 .and_modify(|usage| usage.increase_count())
575 .or_insert_with(|| VariableUsage::new(cost));
576 }
577
578 fn merge_report(&mut self, other: ValueCostReport<'a>) {
579 self.gross += other.gross;
580 for (name, usage) in other.variables_used {
581 self.variables_used
582 .entry(name)
583 .and_modify(|existing| {
584 existing.count += usage.count;
585 })
586 .or_insert(usage);
587 }
588 }
589}
590
591struct TxArgValue<'a>(ParsedValue<'a>);
593
594impl<'a> TxArgValue<'a> {
595 fn cost(
597 self,
598 document_variables: &'a Variables,
599 used_variables: &mut HashSet<&'a Name>,
600 ) -> usize {
601 let cost_report = self.cost_report(document_variables);
602 let mut net_cost = cost_report.gross;
603 for (name, VariableUsage { cost, count }) in &cost_report.variables_used {
604 if used_variables.insert(name) {
605 net_cost -= (count - 1) * cost
607 } else {
608 net_cost -= count * cost
610 }
611 }
612 net_cost
613 }
614
615 fn cost_report(self, variables: &'a Variables) -> ValueCostReport<'a> {
617 use ConstValue as CV;
618 use ParsedValue::{GraphQL, Resolved};
619 use Value as V;
620
621 match self.0 {
622 GraphQL(V::String(s)) | Resolved(CV::String(s)) => {
623 ValueCostReport::new(s.len() + 2)
625 }
626 GraphQL(V::List(vs)) => {
627 let mut cost = ValueCostReport::new(vs.len().saturating_sub(1) + 2);
630 for value in vs {
631 cost.merge_report(Self(ParsedValue::GraphQL(value)).cost_report(variables));
632 }
633 cost
634 }
635 Resolved(CV::List(vs)) => {
636 let mut cost = ValueCostReport::new(vs.len().saturating_sub(1) + 2);
638 for value in vs {
639 cost.merge_report(Self(ParsedValue::Resolved(value)).cost_report(variables));
640 }
641 cost
642 }
643 GraphQL(V::Variable(name)) => {
644 if let Some(value) = variables.get(name) {
645 let mut cost = Self(ParsedValue::Resolved(value)).cost_report(variables);
646 cost.add_variable(name, cost.gross);
647 return cost;
648 }
649 Default::default()
650 }
651 _ => {
652 Default::default()
659 }
660 }
661 }
662}
663
664struct TxArgValueIter<'a> {
669 selections: VecDeque<&'a Selection>,
670 arguments: VecDeque<TxArgValue<'a>>,
671 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
672 reporter: &'a Reporter<'a>,
673}
674
675impl<'a> TxArgValueIter<'a> {
676 fn new(
677 selection: &'a Selection,
678 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
679 reporter: &'a Reporter<'a>,
680 ) -> Self {
681 Self {
682 selections: VecDeque::from([selection]),
683 arguments: VecDeque::new(),
684 fragments,
685 reporter,
686 }
687 }
688}
689
690impl<'a> Iterator for TxArgValueIter<'a> {
691 type Item = Result<TxArgValue<'a>, ServerError>;
692
693 fn next(&mut self) -> Option<Self::Item> {
694 if let Some(value) = self.arguments.pop_front() {
695 return Some(Ok(value));
696 }
697 let selection = self.selections.pop_front()?;
698 match selection {
699 Selection::Field(f) => {
700 let name = &f.node.name.node;
701 if name == DRY_RUN_TX_BLOCK || name == EXECUTE_TX_BLOCK {
702 for (_name, value) in &f.node.arguments {
703 self.arguments
704 .push_back(TxArgValue(ParsedValue::GraphQL(&value.node)));
705 }
706 }
707 }
708 Selection::InlineFragment(f) => {
709 self.selections
710 .extend(f.node.selection_set.node.items.iter().map(|s| &s.node));
711 }
712
713 Selection::FragmentSpread(fs) => {
714 let name = &fs.node.fragment_name.node;
715 let Some(def) = self.fragments.get(name) else {
716 return Some(Err(self.reporter.fragment_not_found_error(name, fs.pos)));
717 };
718 self.selections
719 .extend(def.node.selection_set.node.items.iter().map(|s| &s.node));
720 }
721 }
722 self.next()
723 }
724}
725
726impl Usage {
727 fn report(&self, metrics: &Metrics) {
728 metrics
729 .request_metrics
730 .input_nodes
731 .observe(self.input_nodes as f64);
732 metrics
733 .request_metrics
734 .output_nodes
735 .observe(self.output_nodes as f64);
736 metrics
737 .request_metrics
738 .query_depth
739 .observe(self.depth as f64);
740 metrics
741 .request_metrics
742 .query_payload_size
743 .observe(self.query_payload as f64);
744 }
745}
746
747impl ExtensionFactory for QueryLimitsChecker {
748 fn create(&self) -> Arc<dyn Extension> {
749 Arc::new(QueryLimitsCheckerExt {
750 usage: Mutex::new(None),
751 })
752 }
753}
754
755#[async_trait]
756impl Extension for QueryLimitsCheckerExt {
757 async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
758 let resp = next.run(ctx).await;
759 let usage = self.usage.lock().unwrap().take();
760 if let Some(usage) = usage {
761 resp.extension("usage", value!(usage))
762 } else {
763 resp
764 }
765 }
766
767 async fn parse_query(
770 &self,
771 ctx: &ExtensionContext<'_>,
772 query: &str,
773 variables: &Variables,
774 next: NextParseQuery<'_>,
775 ) -> ServerResult<ExecutableDocument> {
776 let metrics: &Metrics = ctx.data_unchecked();
777 let payload_size: &PayloadSize = ctx.data_unchecked();
778 let reporter = Reporter::new(ctx);
779 let instant = Instant::now();
780
781 let max_payload_size = reporter.limits.max_query_payload_size as u64
783 + reporter.limits.max_tx_payload_size as u64;
784
785 if payload_size.0 > max_payload_size {
786 let message = format!("Overall request too large: {} bytes", payload_size.0);
787 return Err(reporter.payload_size_error(&message));
788 }
789
790 let doc = next.run(ctx, query, variables).await?;
792
793 if let DocumentOperations::Single(op) = &doc.operations {
797 if let [field] = &op.node.selection_set.node.items[..] {
798 if let Selection::Field(f) = &field.node {
799 if f.node.name.node == "__schema" {
800 return Ok(doc);
801 }
802 }
803 }
804 }
805
806 let mut traversal =
807 LimitsTraversal::new(*payload_size, &reporter, &doc.fragments, variables);
808 let res = traversal.check_document(&doc);
809 let usage = traversal.finish(query.len() as u32);
810 metrics.query_validation_latency(instant.elapsed());
811 usage.report(metrics);
812
813 res.map(|()| {
814 if ctx.data_opt::<ShowUsage>().is_some() {
815 *self.usage.lock().unwrap() = Some(usage);
816 }
817
818 doc
819 })
820 }
821}