iota_graphql_rpc/extensions/
directive_checker.rs1use std::sync::Arc;
6
7use async_graphql::{
8 Positioned, ServerResult,
9 extensions::{Extension, ExtensionContext, ExtensionFactory, NextParseQuery},
10 parser::types::{Directive, ExecutableDocument, Selection},
11};
12use async_graphql_value::Variables;
13use async_trait::async_trait;
14
15use crate::error::{code, graphql_error_at_pos};
16
17const ALLOWED_DIRECTIVES: [&str; 2] = ["include", "skip"];
18
19pub(crate) struct DirectiveChecker;
22
23struct DirectiveCheckerExt;
24
25impl ExtensionFactory for DirectiveChecker {
26 fn create(&self) -> Arc<dyn Extension> {
27 Arc::new(DirectiveCheckerExt)
28 }
29}
30
31#[async_trait]
32impl Extension for DirectiveCheckerExt {
33 async fn parse_query(
34 &self,
35 ctx: &ExtensionContext<'_>,
36 query: &str,
37 variables: &Variables,
38 next: NextParseQuery<'_>,
39 ) -> ServerResult<ExecutableDocument> {
40 let doc = next.run(ctx, query, variables).await?;
41
42 let mut selection_sets = vec![];
43 for fragment in doc.fragments.values() {
44 check_directives(&fragment.node.directives)?;
45 selection_sets.push(&fragment.node.selection_set);
46 }
47
48 for (_name, op) in doc.operations.iter() {
49 check_directives(&op.node.directives)?;
50
51 for var in &op.node.variable_definitions {
52 check_directives(&var.node.directives)?;
53 }
54
55 selection_sets.push(&op.node.selection_set);
56 }
57
58 while let Some(selection_set) = selection_sets.pop() {
59 for selection in &selection_set.node.items {
60 match &selection.node {
61 Selection::Field(field) => {
62 check_directives(&field.node.directives)?;
63 selection_sets.push(&field.node.selection_set);
64 }
65 Selection::FragmentSpread(spread) => {
66 check_directives(&spread.node.directives)?;
67 }
68 Selection::InlineFragment(fragment) => {
69 check_directives(&fragment.node.directives)?;
70 selection_sets.push(&fragment.node.selection_set);
71 }
72 }
73 }
74 }
75
76 Ok(doc)
77 }
78}
79
80fn check_directives(directives: &[Positioned<Directive>]) -> ServerResult<()> {
81 for directive in directives {
82 let name = &directive.node.name.node;
83 if !ALLOWED_DIRECTIVES.contains(&name.as_str()) {
84 let supported: Vec<_> = ALLOWED_DIRECTIVES
85 .iter()
86 .map(|s| format!("`@{s}`"))
87 .collect();
88
89 return Err(graphql_error_at_pos(
90 code::BAD_USER_INPUT,
91 format!(
92 "Directive `@{name}` is not supported. Supported directives are {}",
93 supported.join(", "),
94 ),
95 directive.pos,
96 ));
97 }
98 }
99 Ok(())
100}