iota_graphql_rpc/extensions/
query_limits_checker.rs1use std::{
6 collections::HashMap,
7 mem,
8 net::SocketAddr,
9 sync::{Arc, Mutex},
10 time::Instant,
11};
12
13use async_graphql::{
14 Name, Positioned, Response, ServerError, ServerResult, Variables,
15 extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery, NextRequest},
16 parser::types::{
17 DocumentOperations, ExecutableDocument, Field, FragmentDefinition, OperationDefinition,
18 Selection,
19 },
20 value,
21};
22use async_graphql_value::{ConstValue, Value};
23use async_trait::async_trait;
24use axum::http::HeaderName;
25use iota_graphql_rpc_headers::LIMITS_HEADER;
26use serde::Serialize;
27use tracing::info;
28use uuid::Uuid;
29
30use crate::{
31 config::{Limits, ServiceConfig},
32 error::{code, graphql_error, graphql_error_at_pos},
33 metrics::Metrics,
34};
35
36pub(crate) const CONNECTION_FIELDS: [&str; 2] = ["edges", "nodes"];
37
38pub(crate) struct QueryLimitsChecker;
41
42#[derive(Debug, Default)]
43struct QueryLimitsCheckerExt {
44 usage: Mutex<Option<Usage>>,
45}
46
47pub(crate) struct ShowUsage;
49
50struct LimitsTraversal<'a> {
54 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
56 variables: &'a Variables,
57
58 default_page_size: u32,
60 max_input_nodes: u32,
61 max_output_nodes: u32,
62 max_depth: u32,
63
64 input_budget: u32,
66 output_budget: u32,
67 depth_seen: u32,
68}
69
70#[derive(Clone, Debug, Default, Serialize)]
71#[serde(rename_all = "camelCase")]
72struct Usage {
73 input_nodes: u32,
74 output_nodes: u32,
75 depth: u32,
76 variables: u32,
77 fragments: u32,
78 query_payload: u32,
79}
80
81impl ShowUsage {
82 pub(crate) fn name() -> &'static HeaderName {
83 &LIMITS_HEADER
84 }
85}
86
87impl<'a> LimitsTraversal<'a> {
88 fn new(
89 limits: &Limits,
90 fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
91 variables: &'a Variables,
92 ) -> Self {
93 Self {
94 fragments,
95 variables,
96 default_page_size: limits.default_page_size,
97 max_input_nodes: limits.max_query_nodes,
98 max_output_nodes: limits.max_output_nodes,
99 max_depth: limits.max_query_depth,
100 input_budget: limits.max_query_nodes,
101 output_budget: limits.max_output_nodes,
102 depth_seen: 0,
103 }
104 }
105
106 fn check_document(&mut self, doc: &ExecutableDocument) -> ServerResult<()> {
108 for (_name, op) in doc.operations.iter() {
109 self.check_input_limits(op)?;
110 self.check_output_limits(op)?;
111 }
112 Ok(())
113 }
114
115 fn check_input_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
117 let mut next_level = vec![];
118 let mut curr_level = vec![];
119 let mut depth_budget = self.max_depth;
120
121 next_level.extend(&op.node.selection_set.node.items);
122 while let Some(next) = next_level.first() {
123 if depth_budget == 0 {
124 return Err(graphql_error_at_pos(
125 code::BAD_USER_INPUT,
126 format!("Query nesting is over {}", self.max_depth),
127 next.pos,
128 ));
129 } else {
130 depth_budget -= 1;
131 }
132
133 mem::swap(&mut next_level, &mut curr_level);
134
135 for selection in curr_level.drain(..) {
136 if self.input_budget == 0 {
137 return Err(graphql_error_at_pos(
138 code::BAD_USER_INPUT,
139 format!("Query has over {} nodes", self.max_input_nodes),
140 selection.pos,
141 ));
142 } else {
143 self.input_budget -= 1;
144 }
145
146 match &selection.node {
147 Selection::Field(f) => {
148 next_level.extend(&f.node.selection_set.node.items);
149 }
150
151 Selection::InlineFragment(f) => {
152 next_level.extend(&f.node.selection_set.node.items);
153 }
154
155 Selection::FragmentSpread(fs) => {
156 let name = &fs.node.fragment_name.node;
157 let def = self.fragments.get(name).ok_or_else(|| {
158 graphql_error_at_pos(
159 code::INTERNAL_SERVER_ERROR,
160 format!("Fragment {name} referred to but not found in document"),
161 fs.pos,
162 )
163 })?;
164
165 next_level.extend(&def.node.selection_set.node.items);
166 }
167 }
168 }
169 }
170
171 self.depth_seen = self.depth_seen.max(self.max_depth - depth_budget);
172 Ok(())
173 }
174
175 fn check_output_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
182 for selection in &op.node.selection_set.node.items {
183 self.traverse_selection_for_output(selection, 1, None)?;
184 }
185 Ok(())
186 }
187
188 fn traverse_selection_for_output(
197 &mut self,
198 selection: &Positioned<Selection>,
199 multiplicity: u32,
200 page_size: Option<u32>,
201 ) -> ServerResult<()> {
202 match &selection.node {
203 Selection::Field(f) => {
204 if multiplicity > self.output_budget {
205 return Err(self.output_node_error());
206 } else {
207 self.output_budget -= multiplicity;
208 }
209
210 let name = &f.node.name.node;
215 let multiplicity = 'm: {
216 if !CONNECTION_FIELDS.contains(&name.as_str()) {
217 break 'm multiplicity;
218 }
219
220 let Some(page_size) = page_size else {
221 break 'm multiplicity;
222 };
223
224 multiplicity
225 .checked_mul(page_size)
226 .ok_or_else(|| self.output_node_error())?
227 };
228
229 let page_size = self.connection_page_size(f)?;
230 for selection in &f.node.selection_set.node.items {
231 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
232 }
233 }
234
235 Selection::InlineFragment(f) => {
237 for selection in f.node.selection_set.node.items.iter() {
238 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
239 }
240 }
241
242 Selection::FragmentSpread(fs) => {
243 let name = &fs.node.fragment_name.node;
244 let def = self.fragments.get(name).ok_or_else(|| {
245 graphql_error_at_pos(
246 code::INTERNAL_SERVER_ERROR,
247 format!("Fragment {name} referred to but not found in document"),
248 fs.pos,
249 )
250 })?;
251
252 for selection in def.node.selection_set.node.items.iter() {
253 self.traverse_selection_for_output(selection, multiplicity, page_size)?;
254 }
255 }
256 }
257
258 Ok(())
259 }
260
261 fn connection_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
265 if !self.is_connection(f) {
266 return Ok(None);
267 }
268
269 let first = f.node.get_argument("first");
270 let last = f.node.get_argument("last");
271
272 let page_size = match (self.resolve_u64(first), self.resolve_u64(last)) {
273 (Some(f), Some(l)) => f.max(l),
274 (Some(p), _) | (_, Some(p)) => p,
275 (None, None) => self.default_page_size as u64,
276 };
277
278 Ok(Some(
279 page_size.try_into().map_err(|_| self.output_node_error())?,
280 ))
281 }
282
283 fn is_connection(&self, f: &Positioned<Field>) -> bool {
288 f.node
289 .selection_set
290 .node
291 .items
292 .iter()
293 .any(|s| self.has_connection_fields(s))
294 }
295
296 fn has_connection_fields(&self, s: &Positioned<Selection>) -> bool {
301 match &s.node {
302 Selection::Field(f) => {
303 let name = &f.node.name.node;
304 CONNECTION_FIELDS.contains(&name.as_str())
305 }
306
307 Selection::InlineFragment(f) => f
308 .node
309 .selection_set
310 .node
311 .items
312 .iter()
313 .any(|s| self.has_connection_fields(s)),
314
315 Selection::FragmentSpread(fs) => {
316 let name = &fs.node.fragment_name.node;
317 let Some(def) = self.fragments.get(name) else {
318 return false;
319 };
320
321 def.node
322 .selection_set
323 .node
324 .items
325 .iter()
326 .any(|s| self.has_connection_fields(s))
327 }
328 }
329 }
330
331 fn resolve_u64(&self, value: Option<&Positioned<Value>>) -> Option<u64> {
334 match &value?.node {
335 Value::Number(num) => num,
336
337 Value::Variable(var) => {
338 if let ConstValue::Number(num) = self.variables.get(var)? {
339 num
340 } else {
341 return None;
342 }
343 }
344
345 _ => return None,
346 }
347 .as_u64()
348 }
349
350 fn output_node_error(&mut self) -> ServerError {
356 self.output_budget = 0;
357 graphql_error(
358 code::BAD_USER_INPUT,
359 format!("Estimated output nodes exceeds {}", self.max_output_nodes),
360 )
361 }
362
363 fn finish(self, query_payload: u32) -> Usage {
365 Usage {
366 input_nodes: self.max_input_nodes - self.input_budget,
367 output_nodes: self.max_output_nodes - self.output_budget,
368 depth: self.depth_seen,
369 variables: self.variables.len() as u32,
370 fragments: self.fragments.len() as u32,
371 query_payload,
372 }
373 }
374}
375
376impl Usage {
377 fn report(&self, metrics: &Metrics) {
378 metrics
379 .request_metrics
380 .input_nodes
381 .observe(self.input_nodes as f64);
382 metrics
383 .request_metrics
384 .output_nodes
385 .observe(self.output_nodes as f64);
386 metrics
387 .request_metrics
388 .query_depth
389 .observe(self.depth as f64);
390 metrics
391 .request_metrics
392 .query_payload_size
393 .observe(self.query_payload as f64);
394 }
395}
396
397impl ExtensionFactory for QueryLimitsChecker {
398 fn create(&self) -> Arc<dyn Extension> {
399 Arc::new(QueryLimitsCheckerExt {
400 usage: Mutex::new(None),
401 })
402 }
403}
404
405#[async_trait]
406impl Extension for QueryLimitsCheckerExt {
407 async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
408 let resp = next.run(ctx).await;
409 let usage = self.usage.lock().unwrap().take();
410 if let Some(usage) = usage {
411 resp.extension("usage", value!(usage))
412 } else {
413 resp
414 }
415 }
416
417 async fn parse_query(
420 &self,
421 ctx: &ExtensionContext<'_>,
422 query: &str,
423 variables: &Variables,
424 next: NextParseQuery<'_>,
425 ) -> ServerResult<ExecutableDocument> {
426 let query_id: &Uuid = ctx.data_unchecked();
427 let session_id: &SocketAddr = ctx.data_unchecked();
428 let metrics: &Metrics = ctx.data_unchecked();
429 let cfg: &ServiceConfig = ctx.data_unchecked();
430 let instant = Instant::now();
431
432 if query.len() > cfg.limits.max_query_payload_size as usize {
433 metrics
434 .request_metrics
435 .query_payload_too_large_size
436 .observe(query.len() as f64);
437 info!(
438 query_id = %query_id,
439 session_id = %session_id,
440 error_code = code::BAD_USER_INPUT,
441 "Query payload is too large: {}",
442 query.len()
443 );
444
445 return Err(graphql_error(
446 code::BAD_USER_INPUT,
447 format!(
448 "Query payload is too large. The maximum allowed is {} bytes",
449 cfg.limits.max_query_payload_size
450 ),
451 ));
452 }
453
454 let doc = next.run(ctx, query, variables).await?;
456
457 if let DocumentOperations::Single(op) = &doc.operations {
461 if let [field] = &op.node.selection_set.node.items[..] {
462 if let Selection::Field(f) = &field.node {
463 if f.node.name.node == "__schema" {
464 return Ok(doc);
465 }
466 }
467 }
468 }
469
470 let mut traversal = LimitsTraversal::new(&cfg.limits, &doc.fragments, variables);
471 let res = traversal.check_document(&doc);
472 let usage = traversal.finish(query.len() as u32);
473 metrics.query_validation_latency(instant.elapsed());
474 usage.report(metrics);
475
476 res.map(|()| {
477 if ctx.data_opt::<ShowUsage>().is_some() {
478 *self.usage.lock().unwrap() = Some(usage);
479 }
480
481 doc
482 })
483 }
484}