iota_graphql_rpc/extensions/
directive_checker.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::sync::Arc;
6
7use async_graphql::{
8    Positioned, ServerResult,
9    extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery},
10    parser::types::{Directive, ExecutableDocument, Selection},
11};
12use async_graphql_value::Variables;
13use async_trait::async_trait;
14
15use crate::error::{code, graphql_error_at_pos};
16
17const ALLOWED_DIRECTIVES: [&str; 2] = ["include", "skip"];
18
19/// Extension factory to add a check that all the directives used in the query
20/// are accepted and understood by the service.
21pub(crate) struct DirectiveChecker;
22
23struct DirectiveCheckerExt;
24
25impl ExtensionFactory for DirectiveChecker {
26    fn create(&self) -> Arc<dyn Extension> {
27        Arc::new(DirectiveCheckerExt)
28    }
29}
30
31#[async_trait]
32impl Extension for DirectiveCheckerExt {
33    async fn parse_query(
34        &self,
35        ctx: &ExtensionContext<'_>,
36        query: &str,
37        variables: &Variables,
38        next: NextParseQuery<'_>,
39    ) -> ServerResult<ExecutableDocument> {
40        let doc = next.run(ctx, query, variables).await?;
41
42        let mut selection_sets = vec![];
43        for fragment in doc.fragments.values() {
44            check_directives(&fragment.node.directives)?;
45            selection_sets.push(&fragment.node.selection_set);
46        }
47
48        for (_name, op) in doc.operations.iter() {
49            check_directives(&op.node.directives)?;
50
51            for var in &op.node.variable_definitions {
52                check_directives(&var.node.directives)?;
53            }
54
55            selection_sets.push(&op.node.selection_set);
56        }
57
58        while let Some(selection_set) = selection_sets.pop() {
59            for selection in &selection_set.node.items {
60                match &selection.node {
61                    Selection::Field(field) => {
62                        check_directives(&field.node.directives)?;
63                        selection_sets.push(&field.node.selection_set);
64                    }
65                    Selection::FragmentSpread(spread) => {
66                        check_directives(&spread.node.directives)?;
67                    }
68                    Selection::InlineFragment(fragment) => {
69                        check_directives(&fragment.node.directives)?;
70                        selection_sets.push(&fragment.node.selection_set);
71                    }
72                }
73            }
74        }
75
76        Ok(doc)
77    }
78}
79
80fn check_directives(directives: &[Positioned<Directive>]) -> ServerResult<()> {
81    for directive in directives {
82        let name = &directive.node.name.node;
83        if !ALLOWED_DIRECTIVES.contains(&name.as_str()) {
84            let supported: Vec<_> = ALLOWED_DIRECTIVES
85                .iter()
86                .map(|s| format!("`@{s}`"))
87                .collect();
88
89            return Err(graphql_error_at_pos(
90                code::BAD_USER_INPUT,
91                format!(
92                    "Directive `@{name}` is not supported. Supported directives are {}",
93                    supported.join(", "),
94                ),
95                directive.pos,
96            ));
97        }
98    }
99    Ok(())
100}