iota_graphql_rpc/extensions/
query_limits_checker.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::{HashMap, HashSet, VecDeque},
7    mem,
8    net::SocketAddr,
9    sync::{Arc, Mutex},
10    time::Instant,
11};
12
13use async_graphql::{
14    Name, Pos, Positioned, Response, ServerError, ServerResult, Variables,
15    extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery, NextRequest},
16    parser::types::{
17        DocumentOperations, ExecutableDocument, Field, FragmentDefinition, OperationDefinition,
18        Selection,
19    },
20    value,
21};
22use async_graphql_value::{ConstValue, Value};
23use async_trait::async_trait;
24use axum::http::HeaderName;
25use iota_graphql_rpc_headers::LIMITS_HEADER;
26use serde::Serialize;
27use tracing::{error, info};
28use uuid::Uuid;
29
30use crate::{
31    config::{Limits, ServiceConfig},
32    error::{code, graphql_error, graphql_error_at_pos},
33    metrics::Metrics,
34};
35
36const DRY_RUN_TX_BLOCK: &str = "dryRunTransactionBlock";
37const EXECUTE_TX_BLOCK: &str = "executeTransactionBlock";
38
39/// The size of the query payload in bytes.
40///
41/// This is resolved from the `Content-Length` request header:
42#[derive(Clone, Copy, Debug)]
43pub(crate) struct PayloadSize(pub u64);
44
45pub(crate) const CONNECTION_FIELDS: [&str; 2] = ["edges", "nodes"];
46
47/// Extension factory for adding checks that the query is within configurable
48/// limits.
49pub(crate) struct QueryLimitsChecker;
50
51#[derive(Debug, Default)]
52struct QueryLimitsCheckerExt {
53    usage: Mutex<Option<Usage>>,
54}
55
56/// Only display usage information if this header was in the request.
57pub(crate) struct ShowUsage;
58
59/// Builds error messages and reports them to tracing.
60struct Reporter<'a> {
61    limits: &'a Limits,
62    query_id: &'a Uuid,
63    session_id: &'a SocketAddr,
64}
65
66impl<'a> Reporter<'a> {
67    fn new(ctx: &'a ExtensionContext<'a>) -> Self {
68        let cfg: &ServiceConfig = ctx.data_unchecked();
69        Self {
70            limits: &cfg.limits,
71            query_id: ctx.data_unchecked(),
72            session_id: ctx.data_unchecked(),
73        }
74    }
75
76    /// Error returned if a fragment is referred to but not found in the
77    /// document.
78    fn fragment_not_found_error(&self, name: &Name, pos: Pos) -> ServerError {
79        self.graphql_error_at_pos(
80            code::BAD_USER_INPUT,
81            format!("Fragment {name} referred to but not found in document"),
82            pos,
83        )
84    }
85
86    /// Error returned if output node estimate exceeds limit.
87    fn output_node_error(&self) -> ServerError {
88        self.graphql_error(
89            code::BAD_USER_INPUT,
90            format!(
91                "Estimated output nodes exceeds {}",
92                self.limits.max_output_nodes
93            ),
94        )
95    }
96
97    /// Error returned if the payload size exceeds the limit.
98    fn payload_size_error(&self, message: &str) -> ServerError {
99        self.graphql_error(
100            code::BAD_USER_INPUT,
101            format!(
102                "{message}. Requests are limited to {max_tx_payload} bytes or fewer on \
103                 transaction payloads (all inputs to executeTransactionBlock or \
104                 dryRunTransactionBlock) and the rest of the request (the query part) must be \
105                 {max_query_payload} bytes or fewer.",
106                max_tx_payload = self.limits.max_tx_payload_size,
107                max_query_payload = self.limits.max_query_payload_size,
108            ),
109        )
110    }
111
112    /// Build a GraphQL Server Error and also log it.
113    fn graphql_error(&self, code: &str, message: String) -> ServerError {
114        self.log_error(code, &message);
115        graphql_error(code, message)
116    }
117
118    /// Like `graphql_error` but for an error at a specific position in the
119    /// query.
120    fn graphql_error_at_pos(&self, code: &str, message: String, pos: Pos) -> ServerError {
121        self.log_error(code, &message);
122        graphql_error_at_pos(code, message, pos)
123    }
124
125    /// Log an error (used before returning an error response.
126    fn log_error(&self, error_code: &str, message: &str) {
127        if error_code == code::INTERNAL_SERVER_ERROR {
128            error!(
129                query_id = %self.query_id,
130                session_id = %self.session_id,
131                error_code,
132                "Internal error while checking limits: {message}",
133            );
134        } else {
135            info!(
136                query_id = %self.query_id,
137                session_id = %self.session_id,
138                error_code,
139                "Limits error: {message}",
140            );
141        }
142    }
143}
144
145/// State for traversing a document to check for limits. Holds on to
146/// environments for looking up variables and fragments, limits, and the
147/// remainder of the limit that can be used.
148struct LimitsTraversal<'a> {
149    /// Environments for resolving lookups in the document
150    fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
151    variables: &'a Variables,
152
153    /// Creates and traces errors
154    reporter: &'a Reporter<'a>,
155    /// Raw size of the request
156    payload_size: u64,
157    /// Variables that are used in transaction execution and dry-runs.
158    ///
159    /// If these variables are used multiple times, the size of their
160    /// contents should not be double-counted.
161    tx_variables_used: HashSet<&'a Name>,
162
163    /// Remaining budget for the traversal
164    tx_payload_budget: u32,
165    input_budget: u32,
166    output_budget: u32,
167    depth_seen: u32,
168}
169
170#[derive(Clone, Debug, Default, Serialize)]
171#[serde(rename_all = "camelCase")]
172struct Usage {
173    input_nodes: u32,
174    output_nodes: u32,
175    depth: u32,
176    variables: u32,
177    fragments: u32,
178    query_payload: u32,
179}
180
181impl ShowUsage {
182    pub(crate) fn name() -> &'static HeaderName {
183        &LIMITS_HEADER
184    }
185}
186
187impl<'a> LimitsTraversal<'a> {
188    fn new(
189        PayloadSize(payload_size): PayloadSize,
190        reporter: &'a Reporter<'a>,
191        fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
192        variables: &'a Variables,
193    ) -> Self {
194        Self {
195            fragments,
196            variables,
197            payload_size,
198            reporter,
199            input_budget: reporter.limits.max_query_nodes,
200            output_budget: reporter.limits.max_output_nodes,
201            tx_payload_budget: reporter.limits.max_tx_payload_size,
202            tx_variables_used: HashSet::new(),
203            depth_seen: 0,
204        }
205    }
206
207    /// Main entrypoint for checking all limits.
208    fn check_document(&mut self, doc: &'a ExecutableDocument) -> ServerResult<()> {
209        // First, check the size of the query inputs. This is done using a non-recursive
210        // algorithm in case the input has too many nodes or is too deep. This
211        // allows subsequent checks to be implemented recursively.
212        for (_name, op) in doc.operations.iter() {
213            self.check_input_limits(op)?;
214        }
215        // Then gather inputs to transaction execution and dry-run nodes, and make sure
216        // these are within budget, cumulatively.
217        for (_name, op) in doc.operations.iter() {
218            self.check_tx_payload(op)?;
219        }
220        // Next, with the transaction payloads accounted for, ensure the remaining query
221        // is within the size limit.
222        let limits = self.reporter.limits;
223        let tx_payload_size = (limits.max_tx_payload_size - self.tx_payload_budget) as u64;
224        let query_payload_size = self.payload_size - tx_payload_size;
225        if query_payload_size > limits.max_query_payload_size as u64 {
226            let message = format!("Query part too large: {query_payload_size} bytes");
227            return Err(self.reporter.payload_size_error(&message));
228        }
229        // Finally, run output node estimation, to check that the output won't contain
230        // too many nodes, in the worst case.
231        for (_name, op) in doc.operations.iter() {
232            self.check_output_limits(op)?;
233        }
234
235        Ok(())
236    }
237
238    /// Test that the operation meets input limits (number of nodes and depth).
239    fn check_input_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
240        let limits = self.reporter.limits;
241
242        let mut next_level = vec![];
243        let mut curr_level = vec![];
244        let mut depth_budget = limits.max_query_depth;
245
246        next_level.extend(&op.node.selection_set.node.items);
247        while let Some(next) = next_level.first() {
248            if depth_budget == 0 {
249                return Err(self.reporter.graphql_error_at_pos(
250                    code::BAD_USER_INPUT,
251                    format!("Query nesting is over {}", limits.max_query_depth),
252                    next.pos,
253                ));
254            } else {
255                depth_budget -= 1;
256            }
257
258            mem::swap(&mut next_level, &mut curr_level);
259
260            for selection in curr_level.drain(..) {
261                if self.input_budget == 0 {
262                    return Err(self.reporter.graphql_error_at_pos(
263                        code::BAD_USER_INPUT,
264                        format!("Query has over {} nodes", limits.max_query_nodes),
265                        selection.pos,
266                    ));
267                } else {
268                    self.input_budget -= 1;
269                }
270
271                match &selection.node {
272                    Selection::Field(f) => {
273                        next_level.extend(&f.node.selection_set.node.items);
274                    }
275
276                    Selection::InlineFragment(f) => {
277                        next_level.extend(&f.node.selection_set.node.items);
278                    }
279
280                    Selection::FragmentSpread(fs) => {
281                        let name = &fs.node.fragment_name.node;
282                        let def = self
283                            .fragments
284                            .get(name)
285                            .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
286
287                        next_level.extend(&def.node.selection_set.node.items);
288                    }
289                }
290            }
291        }
292
293        self.depth_seen = self.depth_seen.max(limits.max_query_depth - depth_budget);
294        Ok(())
295    }
296
297    /// Test that inputs to `executeTransactionBlock` and
298    /// `dryRunTransactionBlock` take up less space than the service's
299    /// transaction payload limit, cumulatively.
300    ///
301    /// This check must be done after the input limit check, because it relies
302    /// on the query depth being bounded to protect it from recursing too
303    /// deeply.
304    fn check_tx_payload(&mut self, op: &'a Positioned<OperationDefinition>) -> ServerResult<()> {
305        let tx_arg_values = op
306            .node
307            .selection_set
308            .node
309            .items
310            .iter()
311            .flat_map(|selection| {
312                TxArgValueIter::new(&selection.node, self.fragments, self.reporter)
313            });
314        for value in tx_arg_values {
315            let cost = value?.cost(self.variables, &mut self.tx_variables_used);
316            if cost > self.tx_payload_budget as usize {
317                // Ensure that the caller is aware that the budget is spent.
318                self.tx_payload_budget = 0;
319                return Err(self.tx_payload_size_error());
320            } else {
321                self.tx_payload_budget -= cost as u32;
322            }
323        }
324        Ok(())
325    }
326
327    /// Check that the operation's output node estimate will not exceed the
328    /// service's limit.
329    ///
330    /// This check must be done after the input limit check, because it relies
331    /// on the query depth being bounded to protect it from recursing too
332    /// deeply.
333    fn check_output_limits(&mut self, op: &Positioned<OperationDefinition>) -> ServerResult<()> {
334        for selection in &op.node.selection_set.node.items {
335            self.traverse_selection_for_output(selection, 1, None)?;
336        }
337        Ok(())
338    }
339
340    /// Account for the estimated output size of this selection and its
341    /// children.
342    ///
343    /// `multiplicity` is the number of times this selection will be output, on
344    /// account of being nested within paginated ancestors.
345    ///
346    /// If this field is inside a connection, but not inside one of its fields,
347    /// `page_size` is the size of the connection's page.
348    fn traverse_selection_for_output(
349        &mut self,
350        selection: &Positioned<Selection>,
351        multiplicity: u32,
352        page_size: Option<u32>,
353    ) -> ServerResult<()> {
354        match &selection.node {
355            Selection::Field(f) => {
356                if multiplicity > self.output_budget {
357                    return Err(self.output_node_error());
358                } else {
359                    self.output_budget -= multiplicity;
360                }
361
362                // If the field being traversed is a connection field, increase multiplicity by
363                // a factor of page size. This operation can fail due to
364                // overflow, which will be treated as a limits check failure,
365                // even if the resulting value does not get used for anything.
366                let name = &f.node.name.node;
367                let multiplicity = 'm: {
368                    if !CONNECTION_FIELDS.contains(&name.as_str()) {
369                        break 'm multiplicity;
370                    }
371
372                    let Some(page_size) = page_size else {
373                        break 'm multiplicity;
374                    };
375
376                    multiplicity
377                        .checked_mul(page_size)
378                        .ok_or_else(|| self.output_node_error())?
379                };
380
381                let page_size = self.connection_page_size(f)?;
382                for selection in &f.node.selection_set.node.items {
383                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
384                }
385            }
386
387            // Just recurse through fragments, because they are inlined into their "call site".
388            Selection::InlineFragment(f) => {
389                for selection in &f.node.selection_set.node.items {
390                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
391                }
392            }
393
394            Selection::FragmentSpread(fs) => {
395                let name = &fs.node.fragment_name.node;
396                let def = self
397                    .fragments
398                    .get(name)
399                    .ok_or_else(|| self.reporter.fragment_not_found_error(name, fs.pos))?;
400
401                for selection in &def.node.selection_set.node.items {
402                    self.traverse_selection_for_output(selection, multiplicity, page_size)?;
403                }
404            }
405        }
406
407        Ok(())
408    }
409
410    /// Error returned if transaction payloads exceed limit.
411    fn tx_payload_size_error(&mut self) -> ServerError {
412        self.reporter
413            .payload_size_error("Transaction payload too large")
414    }
415
416    /// If the field `f` is a connection, extract its page size, otherwise
417    /// return `None`. Returns an error if the page size cannot be
418    /// represented as a `u32`.
419    fn connection_page_size(&mut self, f: &Positioned<Field>) -> ServerResult<Option<u32>> {
420        if !self.is_connection(f) {
421            return Ok(None);
422        }
423
424        let first = f.node.get_argument("first");
425        let last = f.node.get_argument("last");
426
427        let page_size = match (self.resolve_u64(first), self.resolve_u64(last)) {
428            (Some(f), Some(l)) => f.max(l),
429            (Some(p), _) | (_, Some(p)) => p,
430            (None, None) => self.reporter.limits.default_page_size as u64,
431        };
432
433        Ok(Some(
434            page_size.try_into().map_err(|_| self.output_node_error())?,
435        ))
436    }
437
438    /// Checks if the given field corresponds to a connection based on whether
439    /// it contains a selection for `edges` or `nodes`. That selection could
440    /// be immediately in that field's selection set, or nested within a
441    /// fragment or inline fragment spread.
442    fn is_connection(&self, f: &Positioned<Field>) -> bool {
443        f.node
444            .selection_set
445            .node
446            .items
447            .iter()
448            .any(|s| self.has_connection_fields(s))
449    }
450
451    /// Look for fields that suggest the container for this selection is a
452    /// connection. Recurses through fragment and inline fragment
453    /// applications, but does not look recursively through fields, as only
454    /// the fields requested from the immediate parent are relevant.
455    fn has_connection_fields(&self, s: &Positioned<Selection>) -> bool {
456        match &s.node {
457            Selection::Field(f) => {
458                let name = &f.node.name.node;
459                CONNECTION_FIELDS.contains(&name.as_str())
460            }
461
462            Selection::InlineFragment(f) => f
463                .node
464                .selection_set
465                .node
466                .items
467                .iter()
468                .any(|s| self.has_connection_fields(s)),
469
470            Selection::FragmentSpread(fs) => {
471                let name = &fs.node.fragment_name.node;
472                let Some(def) = self.fragments.get(name) else {
473                    return false;
474                };
475
476                def.node
477                    .selection_set
478                    .node
479                    .items
480                    .iter()
481                    .any(|s| self.has_connection_fields(s))
482            }
483        }
484    }
485
486    /// Translate a GraphQL value into a u64, if possible, resolving variables
487    /// if necessary.
488    fn resolve_u64(&self, value: Option<&Positioned<Value>>) -> Option<u64> {
489        match &value?.node {
490            Value::Number(num) => num,
491
492            Value::Variable(var) => {
493                if let ConstValue::Number(num) = self.variables.get(var)? {
494                    num
495                } else {
496                    return None;
497                }
498            }
499
500            _ => return None,
501        }
502        .as_u64()
503    }
504
505    /// Error returned if output node estimate exceeds limit. Also sets the
506    /// output budget to zero, to indicate that it has been spent (This is
507    /// done because unlike other budgets, the output budget is not
508    /// decremented one unit at a time, so we can have hit the limit previously
509    /// but still have budget left over).
510    fn output_node_error(&mut self) -> ServerError {
511        self.output_budget = 0;
512        self.reporter.output_node_error()
513    }
514
515    /// Finish the traversal and report its usage.
516    fn finish(self, query_payload: u32) -> Usage {
517        let limits = self.reporter.limits;
518        Usage {
519            input_nodes: limits.max_query_nodes - self.input_budget,
520            output_nodes: limits.max_output_nodes - self.output_budget,
521            depth: self.depth_seen,
522            variables: self.variables.len() as u32,
523            fragments: self.fragments.len() as u32,
524            query_payload,
525        }
526    }
527}
528
529#[derive(Debug)]
530enum ParsedValue<'a> {
531    /// An unresolved GraphQL [`Value`].
532    GraphQL(&'a Value),
533    /// A resolved value in a GraphQL query.
534    Resolved(&'a ConstValue),
535}
536
537struct VariableUsage {
538    /// Cost of a single use of the variable.
539    cost: usize,
540    /// Number of times a variable is used in a parsed value.
541    count: usize,
542}
543
544impl VariableUsage {
545    fn new(cost: usize) -> Self {
546        Self { cost, count: 1 }
547    }
548
549    fn increase_count(&mut self) {
550        self.count += 1;
551    }
552}
553
554/// Cost report for a parsed value.
555#[derive(Default)]
556struct ValueCostReport<'a> {
557    /// Total cost before deducting reused variables.
558    gross: usize,
559    /// Variable usage while parsing the value.
560    variables_used: HashMap<&'a Name, VariableUsage>,
561}
562
563impl<'a> ValueCostReport<'a> {
564    fn new(cost: usize) -> Self {
565        Self {
566            gross: cost,
567            variables_used: HashMap::new(),
568        }
569    }
570
571    fn add_variable(&mut self, name: &'a Name, cost: usize) {
572        self.variables_used
573            .entry(name)
574            .and_modify(|usage| usage.increase_count())
575            .or_insert_with(|| VariableUsage::new(cost));
576    }
577
578    fn merge_report(&mut self, other: ValueCostReport<'a>) {
579        self.gross += other.gross;
580        for (name, usage) in other.variables_used {
581            self.variables_used
582                .entry(name)
583                .and_modify(|existing| {
584                    existing.count += usage.count;
585                })
586                .or_insert(usage);
587        }
588    }
589}
590
591/// A parsed transaction argument value.
592struct TxArgValue<'a>(ParsedValue<'a>);
593
594impl<'a> TxArgValue<'a> {
595    /// Cost of the value after deducting reused variables
596    fn cost(
597        self,
598        document_variables: &'a Variables,
599        used_variables: &mut HashSet<&'a Name>,
600    ) -> usize {
601        let cost_report = self.cost_report(document_variables);
602        let mut net_cost = cost_report.gross;
603        for (name, VariableUsage { cost, count }) in &cost_report.variables_used {
604            if used_variables.insert(name) {
605                // variable not used before, deduct only repeated uses
606                net_cost -= (count - 1) * cost
607            } else {
608                // variable already used, deduct all uses
609                net_cost -= count * cost
610            }
611        }
612        net_cost
613    }
614
615    /// Evaluate the cost report for transaction argument values.
616    fn cost_report(self, variables: &'a Variables) -> ValueCostReport<'a> {
617        use ConstValue as CV;
618        use ParsedValue::{GraphQL, Resolved};
619        use Value as V;
620
621        match self.0 {
622            GraphQL(V::String(s)) | Resolved(CV::String(s)) => {
623                // Pay for the string, plus the quotes around it.
624                ValueCostReport::new(s.len() + 2)
625            }
626            GraphQL(V::List(vs)) => {
627                // Pay for the opening and closing brackets and every comma up-front so that
628                // deeply nested lists are not free.
629                let mut cost = ValueCostReport::new(vs.len().saturating_sub(1) + 2);
630                for value in vs {
631                    cost.merge_report(Self(ParsedValue::GraphQL(value)).cost_report(variables));
632                }
633                cost
634            }
635            Resolved(CV::List(vs)) => {
636                // Follows the `GraphQL` list evaluation.
637                let mut cost = ValueCostReport::new(vs.len().saturating_sub(1) + 2);
638                for value in vs {
639                    cost.merge_report(Self(ParsedValue::Resolved(value)).cost_report(variables));
640                }
641                cost
642            }
643            GraphQL(V::Variable(name)) => {
644                if let Some(value) = variables.get(name) {
645                    let mut cost = Self(ParsedValue::Resolved(value)).cost_report(variables);
646                    cost.add_variable(name, cost.gross);
647                    return cost;
648                }
649                Default::default()
650            }
651            _ => {
652                // Transaction payloads cannot be any of these types.
653                //
654                // From a limits perspective, it is safe to ignore these
655                // values here, because they will still
656                // be counted as part of the query payload (and so are still
657                // subject to a limit).
658                Default::default()
659            }
660        }
661    }
662}
663
664/// Iterator over transaction argument values.
665///
666/// Traverses a selection on the `dryRunTransactionBlock` or
667/// `executeTransactionBlock` query paths to find all argument values.
668struct TxArgValueIter<'a> {
669    selections: VecDeque<&'a Selection>,
670    arguments: VecDeque<TxArgValue<'a>>,
671    fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
672    reporter: &'a Reporter<'a>,
673}
674
675impl<'a> TxArgValueIter<'a> {
676    fn new(
677        selection: &'a Selection,
678        fragments: &'a HashMap<Name, Positioned<FragmentDefinition>>,
679        reporter: &'a Reporter<'a>,
680    ) -> Self {
681        Self {
682            selections: VecDeque::from([selection]),
683            arguments: VecDeque::new(),
684            fragments,
685            reporter,
686        }
687    }
688}
689
690impl<'a> Iterator for TxArgValueIter<'a> {
691    type Item = Result<TxArgValue<'a>, ServerError>;
692
693    fn next(&mut self) -> Option<Self::Item> {
694        if let Some(value) = self.arguments.pop_front() {
695            return Some(Ok(value));
696        }
697        let selection = self.selections.pop_front()?;
698        match selection {
699            Selection::Field(f) => {
700                let name = &f.node.name.node;
701                if name == DRY_RUN_TX_BLOCK || name == EXECUTE_TX_BLOCK {
702                    for (_name, value) in &f.node.arguments {
703                        self.arguments
704                            .push_back(TxArgValue(ParsedValue::GraphQL(&value.node)));
705                    }
706                }
707            }
708            Selection::InlineFragment(f) => {
709                self.selections
710                    .extend(f.node.selection_set.node.items.iter().map(|s| &s.node));
711            }
712
713            Selection::FragmentSpread(fs) => {
714                let name = &fs.node.fragment_name.node;
715                let Some(def) = self.fragments.get(name) else {
716                    return Some(Err(self.reporter.fragment_not_found_error(name, fs.pos)));
717                };
718                self.selections
719                    .extend(def.node.selection_set.node.items.iter().map(|s| &s.node));
720            }
721        }
722        self.next()
723    }
724}
725
726impl Usage {
727    fn report(&self, metrics: &Metrics) {
728        metrics
729            .request_metrics
730            .input_nodes
731            .observe(self.input_nodes as f64);
732        metrics
733            .request_metrics
734            .output_nodes
735            .observe(self.output_nodes as f64);
736        metrics
737            .request_metrics
738            .query_depth
739            .observe(self.depth as f64);
740        metrics
741            .request_metrics
742            .query_payload_size
743            .observe(self.query_payload as f64);
744    }
745}
746
747impl ExtensionFactory for QueryLimitsChecker {
748    fn create(&self) -> Arc<dyn Extension> {
749        Arc::new(QueryLimitsCheckerExt {
750            usage: Mutex::new(None),
751        })
752    }
753}
754
755#[async_trait]
756impl Extension for QueryLimitsCheckerExt {
757    async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
758        let resp = next.run(ctx).await;
759        let usage = self.usage.lock().unwrap().take();
760        if let Some(usage) = usage {
761            resp.extension("usage", value!(usage))
762        } else {
763            resp
764        }
765    }
766
767    /// Validates the query against the limits set in the service config
768    /// If the limits are hit, the operation terminates early
769    async fn parse_query(
770        &self,
771        ctx: &ExtensionContext<'_>,
772        query: &str,
773        variables: &Variables,
774        next: NextParseQuery<'_>,
775    ) -> ServerResult<ExecutableDocument> {
776        let metrics: &Metrics = ctx.data_unchecked();
777        let payload_size: &PayloadSize = ctx.data_unchecked();
778        let reporter = Reporter::new(ctx);
779        let instant = Instant::now();
780
781        // Make sure the request meets a basic size limit before trying to parse it.
782        let max_payload_size = reporter.limits.max_query_payload_size as u64
783            + reporter.limits.max_tx_payload_size as u64;
784
785        if payload_size.0 > max_payload_size {
786            let message = format!("Overall request too large: {} bytes", payload_size.0);
787            return Err(reporter.payload_size_error(&message));
788        }
789
790        // Document layout of the query
791        let doc = next.run(ctx, query, variables).await?;
792
793        // If the query is pure introspection, we don't need to check the limits. Pure
794        // introspection queries are queries that only have one operation with
795        // one field and that field is a `__schema` query
796        if let DocumentOperations::Single(op) = &doc.operations {
797            if let [field] = &op.node.selection_set.node.items[..] {
798                if let Selection::Field(f) = &field.node {
799                    if f.node.name.node == "__schema" {
800                        return Ok(doc);
801                    }
802                }
803            }
804        }
805
806        let mut traversal =
807            LimitsTraversal::new(*payload_size, &reporter, &doc.fragments, variables);
808        let res = traversal.check_document(&doc);
809        let usage = traversal.finish(query.len() as u32);
810        metrics.query_validation_latency(instant.elapsed());
811        usage.report(metrics);
812
813        res.map(|()| {
814            if ctx.data_opt::<ShowUsage>().is_some() {
815                *self.usage.lock().unwrap() = Some(usage);
816            }
817
818            doc
819        })
820    }
821}