iota_graphql_rpc/extensions/
feature_gate.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::sync::Arc;
6
7use async_graphql::{
8    ServerError, ServerResult, Value,
9    extensions::{Extension, ExtensionContext, ExtensionFactory, NextResolve, ResolveInfo},
10};
11use async_trait::async_trait;
12
13use crate::{
14    config::ServiceConfig,
15    error::{code, graphql_error},
16    functional_group::functional_group,
17};
18
19pub(crate) struct FeatureGate;
20
21impl ExtensionFactory for FeatureGate {
22    fn create(&self) -> Arc<dyn Extension> {
23        Arc::new(FeatureGate)
24    }
25}
26
27#[async_trait]
28impl Extension for FeatureGate {
29    async fn resolve(
30        &self,
31        ctx: &ExtensionContext<'_>,
32        info: ResolveInfo<'_>,
33        next: NextResolve<'_>,
34    ) -> ServerResult<Option<Value>> {
35        let ResolveInfo {
36            parent_type,
37            name,
38            is_for_introspection,
39            ..
40        } = &info;
41
42        let ServiceConfig {
43            disabled_features, ..
44        } = ctx.data().map_err(|_| {
45            graphql_error(
46                code::INTERNAL_SERVER_ERROR,
47                "Unable to fetch service configuration",
48            )
49        })?;
50
51        // TODO: Is there a way to set `is_visible` on `MetaField` and `MetaType` in a
52        // generic way after building the schema? (to a function which reads the
53        // `ServiceConfig` from the `Context`). This is (probably) required to
54        // hide disabled types and interfaces in the schema.
55
56        if let Some(group) = functional_group(parent_type, name) {
57            if disabled_features.contains(&group) {
58                return if *is_for_introspection {
59                    Ok(None)
60                } else {
61                    Err(ServerError::new(
62                        format!(
63                            "Cannot query field \"{name}\" on type \"{parent_type}\". \
64                             Feature {} is disabled.",
65                            group.name(),
66                        ),
67                        // TODO: Fork `async-graphl` to add field position information to
68                        // `ResolveInfo`, so the error can take advantage of it.  Similarly for
69                        // utilising the `path_node` to set the error path.
70                        None,
71                    ))
72                };
73            }
74        }
75
76        next.run(ctx, info).await
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use std::collections::BTreeSet;
83
84    use async_graphql::{EmptySubscription, Schema};
85    use expect_test::expect;
86
87    use super::*;
88    use crate::{functional_group::FunctionalGroup, mutation::Mutation, types::query::Query};
89
90    #[tokio::test]
91    #[should_panic] // because it tries to access the data provider, which isn't there
92    async fn test_accessing_an_enabled_field() {
93        Schema::build(Query, Mutation, EmptySubscription)
94            .data(ServiceConfig::default())
95            .extension(FeatureGate)
96            .finish()
97            .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
98            .await;
99    }
100
101    #[tokio::test]
102    async fn test_accessing_a_disabled_field() {
103        let errs: Vec<_> = Schema::build(Query, Mutation, EmptySubscription)
104            .data(ServiceConfig {
105                disabled_features: BTreeSet::from_iter([FunctionalGroup::SystemState]),
106                ..Default::default()
107            })
108            .extension(FeatureGate)
109            .finish()
110            .execute("{ protocolConfig(protocolVersion: 1) { protocolVersion } }")
111            .await
112            .into_result()
113            .unwrap_err()
114            .into_iter()
115            .map(|e| e.message)
116            .collect();
117
118        let expect = expect![[r#"
119            [
120                "Cannot query field \"protocolConfig\" on type \"Query\". Feature \"system-state\" is disabled.",
121            ]"#]];
122        expect.assert_eq(&format!("{errs:#?}"));
123    }
124}