iota_graphql_rpc/extensions/
feature_gate.rs1use std::sync::Arc;
6
7use async_graphql::{
8 ServerError, ServerResult, Value,
9 extensions::{Extension, ExtensionContext, ExtensionFactory, NextResolve, ResolveInfo},
10};
11use async_trait::async_trait;
12
13use crate::{
14 config::ServiceConfig,
15 error::{code, graphql_error},
16 functional_group::functional_group,
17};
18
19pub(crate) struct FeatureGate;
20
21impl ExtensionFactory for FeatureGate {
22 fn create(&self) -> Arc<dyn Extension> {
23 Arc::new(FeatureGate)
24 }
25}
26
27#[async_trait]
28impl Extension for FeatureGate {
29 async fn resolve(
30 &self,
31 ctx: &ExtensionContext<'_>,
32 info: ResolveInfo<'_>,
33 next: NextResolve<'_>,
34 ) -> ServerResult<Option<Value>> {
35 let ResolveInfo {
36 parent_type,
37 name,
38 is_for_introspection,
39 ..
40 } = &info;
41
42 let ServiceConfig {
43 disabled_features, ..
44 } = ctx.data().map_err(|_| {
45 graphql_error(
46 code::INTERNAL_SERVER_ERROR,
47 "Unable to fetch service configuration",
48 )
49 })?;
50
51 if let Some(group) = functional_group(parent_type, name) {
57 if disabled_features.contains(&group) {
58 return if *is_for_introspection {
59 Ok(None)
60 } else {
61 Err(ServerError::new(
62 format!(
63 "Cannot query field \"{name}\" on type \"{parent_type}\". \
64 Feature {} is disabled.",
65 group.name(),
66 ),
67 None,
71 ))
72 };
73 }
74 }
75
76 next.run(ctx, info).await
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use std::collections::BTreeSet;
83
84 use async_graphql::{EmptySubscription, Schema};
85 use expect_test::expect;
86
87 use super::*;
88 use crate::{functional_group::FunctionalGroup, mutation::Mutation, types::query::Query};
89
90 #[tokio::test]
91 #[should_panic] async fn test_accessing_an_enabled_field() {
93 Schema::build(Query, Mutation, EmptySubscription)
94 .data(ServiceConfig::default())
95 .extension(FeatureGate)
96 .finish()
97 .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
98 .await;
99 }
100
101 #[tokio::test]
102 async fn test_accessing_a_disabled_field() {
103 let errs: Vec<_> = Schema::build(Query, Mutation, EmptySubscription)
104 .data(ServiceConfig {
105 disabled_features: BTreeSet::from_iter([FunctionalGroup::SystemState]),
106 ..Default::default()
107 })
108 .extension(FeatureGate)
109 .finish()
110 .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
111 .await
112 .into_result()
113 .unwrap_err()
114 .into_iter()
115 .map(|e| e.message)
116 .collect();
117
118 let expect = expect![[r#"
119 [
120 "Cannot query field \"protocolConfig\" on type \"Query\". Feature \"system-state\" is disabled.",
121 ]"#]];
122 expect.assert_eq(&format!("{errs:#?}"));
123 }
124}