iota_graphql_rpc/extensions/
query_limits_checker.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    mem,
8    net::SocketAddr,
9    sync::{Arc, Mutex},
10    time::Instant,
11};
12
13use async_graphql::{
14    Name, Positioned, Response, ServerError, ServerResult, Variables,
15    extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery, NextRequest},
16    parser::types::{
17        DocumentOperations, ExecutableDocument, Field, FragmentDefinition, OperationDefinition,
18        Selection,
19    },
20    value,
21};
22use async_graphql_value::{ConstValue, Value};
23use async_trait::async_trait;
24use axum::http::HeaderName;
25use iota_graphql_rpc_headers::LIMITS_HEADER;
26use serde::Serialize;
27use tracing::info;
28use uuid::Uuid;
29
30use crate::{
31    config::{Limits, ServiceConfig},
32    error::{code, graphql_error, graphql_error_at_pos},
33    metrics::Metrics,
34};
35
36pub(crate) const CONNECTION_FIELDS: [&str; 2] = ["edges", "nodes"];
37
38/// Extension factory for adding checks that the query is within configurable
39/// limits.
40pub(crate) struct QueryLimitsChecker;
41
42#[derive(Debug, Default)]
43struct QueryLimitsCheckerExt {
44    usage: Mutex<Option<Usage>>,
45}
46
47/// Only display usage information if this header was in the request.
48pub(crate) struct ShowUsage;
49
50/// State for traversing a document to check for limits. Holds on to
51/// environments for looking up variables and fragments, limits, and the
52/// remainder of the limit that can be used.
53struct LimitsTraversal<'a> {
54    // Environments for resolving lookups in the document
55    fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
56    variables: &'a Variables,
57
58    // Relevant limits from the service configuration
59    default_page_size: u32,
60    max_input_nodes: u32,
61    max_output_nodes: u32,
62    max_depth: u32,
63
64    // Remaining budget for the traversal
65    input_budget: u32,
66    output_budget: u32,
67    depth_seen: u32,
68}
69
70#[derive(Clone, Debug, Default, Serialize)]
71#[serde(rename_all = "camelCase")]
72struct Usage {
73    input_nodes: u32,
74    output_nodes: u32,
75    depth: u32,
76    variables: u32,
77    fragments: u32,
78    query_payload: u32,
79}
80
81impl ShowUsage {
82    pub(crate) fn name() -> &'static HeaderName {
83        &LIMITS_HEADER
84    }
85}
86
87impl<'a> LimitsTraversal<'a> {
88    fn new(
89        limits: &Limits,
90        fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
91        variables: &'a Variables,
92    ) -> Self {
93        Self {
94            fragments,
95            variables,
96            default_page_size: limits.default_page_size,
97            max_input_nodes: limits.max_query_nodes,
98            max_output_nodes: limits.max_output_nodes,
99            max_depth: limits.max_query_depth,
100            input_budget: limits.max_query_nodes,
101            output_budget: limits.max_output_nodes,
102            depth_seen: 0,
103        }
104    }
105
106    /// Main entrypoint for checking all limits.
107    fn check_document(&mut self, doc: &ExecutableDocument) -> ServerResult<()> {
108        for (_name, op) in doc.operations.iter() {
109            self.check_input_limits(op)?;
110            self.check_output_limits(op)?;
111        }
112        Ok(())
113    }
114
115    /// Test that the operation meets input limits (number of nodes and depth).
116    fn check_input_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
117        let mut next_level = vec![];
118        let mut curr_level = vec![];
119        let mut depth_budget = self.max_depth;
120
121        next_level.extend(&op.node.selection_set.node.items);
122        while let Some(next) = next_level.first() {
123            if depth_budget == 0 {
124                return Err(graphql_error_at_pos(
125                    code::BAD_USER_INPUT,
126                    format!("Query nesting is over {}", self.max_depth),
127                    next.pos,
128                ));
129            } else {
130                depth_budget -= 1;
131            }
132
133            mem::swap(&mut next_level, &mut curr_level);
134
135            for selection in curr_level.drain(..) {
136                if self.input_budget == 0 {
137                    return Err(graphql_error_at_pos(
138                        code::BAD_USER_INPUT,
139                        format!("Query has over {} nodes", self.max_input_nodes),
140                        selection.pos,
141                    ));
142                } else {
143                    self.input_budget -= 1;
144                }
145
146                match &selection.node {
147                    Selection::Field(f) => {
148                        next_level.extend(&f.node.selection_set.node.items);
149                    }
150
151                    Selection::InlineFragment(f) => {
152                        next_level.extend(&f.node.selection_set.node.items);
153                    }
154
155                    Selection::FragmentSpread(fs) => {
156                        let name = &fs.node.fragment_name.node;
157                        let def = self.fragments.get(name).ok_or_else(|| {
158                            graphql_error_at_pos(
159                                code::INTERNAL_SERVER_ERROR,
160                                format!("Fragment {name} referred to but not found in document"),
161                                fs.pos,
162                            )
163                        })?;
164
165                        next_level.extend(&def.node.selection_set.node.items);
166                    }
167                }
168            }
169        }
170
171        self.depth_seen = self.depth_seen.max(self.max_depth - depth_budget);
172        Ok(())
173    }
174
175    /// Check that the operation's output node estimate will not exceed the
176    /// service's limit.
177    ///
178    /// This check must be done after the input limit check, because it relies
179    /// on the query depth being bounded to protect it from recursing too
180    /// deeply.
181    fn check_output_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
182        for selection in &op.node.selection_set.node.items {
183            self.traverse_selection_for_output(selection, 1, None)?;
184        }
185        Ok(())
186    }
187
188    /// Account for the estimated output size of this selection and its
189    /// children.
190    ///
191    /// `multiplicity` is the number of times this selection will be output, on
192    /// account of being nested within paginated ancestors.
193    ///
194    /// If this field is inside a connection, but not inside one of its fields,
195    /// `page_size` is the size of the connection's page.
196    fn traverse_selection_for_output(
197        &mut self,
198        selection: &Positioned<Selection>,
199        multiplicity: u32,
200        page_size: Option<u32>,
201    ) -> ServerResult<()> {
202        match &selection.node {
203            Selection::Field(f) => {
204                if multiplicity > self.output_budget {
205                    return Err(self.output_node_error());
206                } else {
207                    self.output_budget -= multiplicity;
208                }
209
210                // If the field being traversed is a connection field, increase multiplicity by
211                // a factor of page size. This operation can fail due to
212                // overflow, which will be treated as a limits check failure,
213                // even if the resulting value does not get used for anything.
214                let name = &f.node.name.node;
215                let multiplicity = 'm: {
216                    if !CONNECTION_FIELDS.contains(&name.as_str()) {
217                        break 'm multiplicity;
218                    }
219
220                    let Some(page_size) = page_size else {
221                        break 'm multiplicity;
222                    };
223
224                    multiplicity
225                        .checked_mul(page_size)
226                        .ok_or_else(|| self.output_node_error())?
227                };
228
229                let page_size = self.connection_page_size(f)?;
230                for selection in &f.node.selection_set.node.items {
231                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
232                }
233            }
234
235            // Just recurse through fragments, because they are inlined into their "call site".
236            Selection::InlineFragment(f) => {
237                for selection in f.node.selection_set.node.items.iter() {
238                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
239                }
240            }
241
242            Selection::FragmentSpread(fs) => {
243                let name = &fs.node.fragment_name.node;
244                let def = self.fragments.get(name).ok_or_else(|| {
245                    graphql_error_at_pos(
246                        code::INTERNAL_SERVER_ERROR,
247                        format!("Fragment {name} referred to but not found in document"),
248                        fs.pos,
249                    )
250                })?;
251
252                for selection in def.node.selection_set.node.items.iter() {
253                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
254                }
255            }
256        }
257
258        Ok(())
259    }
260
261    /// If the field `f` is a connection, extract its page size, otherwise
262    /// return `None`. Returns an error if the page size cannot be
263    /// represented as a `u32`.
264    fn connection_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
265        if !self.is_connection(f) {
266            return Ok(None);
267        }
268
269        let first = f.node.get_argument("first");
270        let last = f.node.get_argument("last");
271
272        let page_size = match (self.resolve_u64(first), self.resolve_u64(last)) {
273            (Some(f), Some(l)) => f.max(l),
274            (Some(p), _) | (_, Some(p)) => p,
275            (None, None) => self.default_page_size as u64,
276        };
277
278        Ok(Some(
279            page_size.try_into().map_err(|_| self.output_node_error())?,
280        ))
281    }
282
283    /// Checks if the given field corresponds to a connection based on whether
284    /// it contains a selection for `edges` or `nodes`. That selection could
285    /// be immediately in that field's selection set, or nested within a
286    /// fragment or inline fragment spread.
287    fn is_connection(&self, f: &Positioned<Field>) -> bool {
288        f.node
289            .selection_set
290            .node
291            .items
292            .iter()
293            .any(|s| self.has_connection_fields(s))
294    }
295
296    /// Look for fields that suggest the container for this selection is a
297    /// connection. Recurses through fragment and inline fragment
298    /// applications, but does not look recursively through fields, as only
299    /// the fields requested from the immediate parent are relevant.
300    fn has_connection_fields(&self, s: &Positioned<Selection>) -> bool {
301        match &s.node {
302            Selection::Field(f) => {
303                let name = &f.node.name.node;
304                CONNECTION_FIELDS.contains(&name.as_str())
305            }
306
307            Selection::InlineFragment(f) => f
308                .node
309                .selection_set
310                .node
311                .items
312                .iter()
313                .any(|s| self.has_connection_fields(s)),
314
315            Selection::FragmentSpread(fs) => {
316                let name = &fs.node.fragment_name.node;
317                let Some(def) = self.fragments.get(name) else {
318                    return false;
319                };
320
321                def.node
322                    .selection_set
323                    .node
324                    .items
325                    .iter()
326                    .any(|s| self.has_connection_fields(s))
327            }
328        }
329    }
330
331    /// Translate a GraphQL value into a u64, if possible, resolving variables
332    /// if necessary.
333    fn resolve_u64(&self, value: Option<&Positioned<Value>>) -> Option<u64> {
334        match &value?.node {
335            Value::Number(num) => num,
336
337            Value::Variable(var) => {
338                if let ConstValue::Number(num) = self.variables.get(var)? {
339                    num
340                } else {
341                    return None;
342                }
343            }
344
345            _ => return None,
346        }
347        .as_u64()
348    }
349
350    /// Error returned if output node estimate exceeds limit. Also sets the
351    /// output budget to zero, to indicate that it has been spent (This is
352    /// done because unlike other budgets, the output budget is not
353    /// decremented one unit at a time, so we can have hit the limit previously
354    /// but still have budget left over).
355    fn output_node_error(&mut self) -> ServerError {
356        self.output_budget = 0;
357        graphql_error(
358            code::BAD_USER_INPUT,
359            format!("Estimated output nodes exceeds {}", self.max_output_nodes),
360        )
361    }
362
363    /// Finish the traversal and report its usage.
364    fn finish(self, query_payload: u32) -> Usage {
365        Usage {
366            input_nodes: self.max_input_nodes - self.input_budget,
367            output_nodes: self.max_output_nodes - self.output_budget,
368            depth: self.depth_seen,
369            variables: self.variables.len() as u32,
370            fragments: self.fragments.len() as u32,
371            query_payload,
372        }
373    }
374}
375
376impl Usage {
377    fn report(&self, metrics: &Metrics) {
378        metrics
379            .request_metrics
380            .input_nodes
381            .observe(self.input_nodes as f64);
382        metrics
383            .request_metrics
384            .output_nodes
385            .observe(self.output_nodes as f64);
386        metrics
387            .request_metrics
388            .query_depth
389            .observe(self.depth as f64);
390        metrics
391            .request_metrics
392            .query_payload_size
393            .observe(self.query_payload as f64);
394    }
395}
396
397impl ExtensionFactory for QueryLimitsChecker {
398    fn create(&self) -> Arc<dyn Extension> {
399        Arc::new(QueryLimitsCheckerExt {
400            usage: Mutex::new(None),
401        })
402    }
403}
404
405#[async_trait]
406impl Extension for QueryLimitsCheckerExt {
407    async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
408        let resp = next.run(ctx).await;
409        let usage = self.usage.lock().unwrap().take();
410        if let Some(usage) = usage {
411            resp.extension("usage", value!(usage))
412        } else {
413            resp
414        }
415    }
416
417    /// Validates the query against the limits set in the service config
418    /// If the limits are hit, the operation terminates early
419    async fn parse_query(
420        &self,
421        ctx: &ExtensionContext<'_>,
422        query: &str,
423        variables: &Variables,
424        next: NextParseQuery<'_>,
425    ) -> ServerResult<ExecutableDocument> {
426        let query_id: &Uuid = ctx.data_unchecked();
427        let session_id: &SocketAddr = ctx.data_unchecked();
428        let metrics: &Metrics = ctx.data_unchecked();
429        let cfg: &ServiceConfig = ctx.data_unchecked();
430        let instant = Instant::now();
431
432        if query.len() > cfg.limits.max_query_payload_size as usize {
433            metrics
434                .request_metrics
435                .query_payload_too_large_size
436                .observe(query.len() as f64);
437            info!(
438                query_id = %query_id,
439                session_id = %session_id,
440                error_code = code::BAD_USER_INPUT,
441                "Query payload is too large: {}",
442                query.len()
443            );
444
445            return Err(graphql_error(
446                code::BAD_USER_INPUT,
447                format!(
448                    "Query payload is too large. The maximum allowed is {} bytes",
449                    cfg.limits.max_query_payload_size
450                ),
451            ));
452        }
453
454        // Document layout of the query
455        let doc = next.run(ctx, query, variables).await?;
456
457        // If the query is pure introspection, we don't need to check the limits. Pure
458        // introspection queries are queries that only have one operation with
459        // one field and that field is a `__schema` query
460        if let DocumentOperations::Single(op) = &doc.operations {
461            if let [field] = &op.node.selection_set.node.items[..] {
462                if let Selection::Field(f) = &field.node {
463                    if f.node.name.node == "__schema" {
464                        return Ok(doc);
465                    }
466                }
467            }
468        }
469
470        let mut traversal = LimitsTraversal::new(&cfg.limits, &doc.fragments, variables);
471        let res = traversal.check_document(&doc);
472        let usage = traversal.finish(query.len() as u32);
473        metrics.query_validation_latency(instant.elapsed());
474        usage.report(metrics);
475
476        res.map(|()| {
477            if ctx.data_opt::<ShowUsage>().is_some() {
478                *self.usage.lock().unwrap() = Some(usage);
479            }
480
481            doc
482        })
483    }
484}