iota_graphql_rpc/server/
builder.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    any::Any,
7    convert::Infallible,
8    net::{SocketAddr, TcpStream},
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use async_graphql::{
14    EmptySubscription, Schema, SchemaBuilder,
15    extensions::{ApolloTracing, ExtensionFactory, Tracing},
16};
17use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
18use axum::{
19    Extension, Router,
20    body::Body,
21    extract::{ConnectInfo, FromRef, Query as AxumQuery, State},
22    http::{HeaderMap, StatusCode},
23    middleware::{self},
24    response::IntoResponse,
25    routing::{MethodRouter, Route, get, post},
26};
27use chrono::Utc;
28use http::{HeaderValue, Method, Request};
29use iota_graphql_rpc_headers::LIMITS_HEADER;
30use iota_indexer::db::{get_pool_connection, setup_postgres::check_db_migration_consistency};
31use iota_metrics::spawn_monitored_task;
32use iota_network_stack::callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler};
33use iota_package_resolver::{PackageStoreWithLruCache, Resolver};
34use iota_sdk::IotaClientBuilder;
35use tokio::{join, net::TcpListener, sync::OnceCell};
36use tokio_util::sync::CancellationToken;
37use tower::{Layer, Service};
38use tower_http::cors::{AllowOrigin, CorsLayer};
39use tracing::{info, warn};
40use uuid::Uuid;
41
42use crate::{
43    config::{
44        ConnectionConfig, MAX_CONCURRENT_REQUESTS, RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD,
45        ServerConfig, ServiceConfig, Version,
46    },
47    context_data::db_data_provider::PgManager,
48    data::{
49        DataLoader, Db,
50        package_resolver::{DbPackageStore, PackageResolver},
51    },
52    error::Error,
53    extensions::{
54        directive_checker::DirectiveChecker,
55        feature_gate::FeatureGate,
56        logger::Logger,
57        query_limits_checker::{QueryLimitsChecker, ShowUsage},
58        timeout::Timeout,
59    },
60    metrics::Metrics,
61    mutation::Mutation,
62    server::{
63        exchange_rates_task::TriggerExchangeRatesTask,
64        system_package_task::SystemPackageTask,
65        version::{check_version_middleware, set_version_middleware},
66        watermark_task::{Watermark, WatermarkLock, WatermarkTask},
67    },
68    types::{
69        datatype::IMoveDatatype,
70        move_object::IMoveObject,
71        object::IObject,
72        owner::IOwner,
73        query::{IotaGraphQLSchema, Query},
74    },
75};
76
77/// The default allowed maximum lag between the current timestamp and the
78/// checkpoint timestamp.
79const DEFAULT_MAX_CHECKPOINT_LAG: Duration = Duration::from_secs(300);
80
81pub(crate) struct Server {
82    router: Router,
83    address: SocketAddr,
84    watermark_task: WatermarkTask,
85    system_package_task: SystemPackageTask,
86    trigger_exchange_rates_task: TriggerExchangeRatesTask,
87    state: AppState,
88    db_reader: Db,
89}
90
91impl Server {
92    /// Start the GraphQL service and any background tasks it is dependent on.
93    /// When a cancellation signal is received, the method waits for all
94    /// tasks to complete before returning.
95    pub async fn run(mut self) -> Result<(), Error> {
96        get_or_init_server_start_time().await;
97
98        {
99            // Compatibility check
100            info!("Starting compatibility check");
101            let mut connection = get_pool_connection(&self.db_reader.inner.get_pool())?;
102            check_db_migration_consistency(&mut connection)?;
103            info!("Compatibility check passed");
104        }
105
106        // A handle that spawns a background task to periodically update the
107        // `Watermark`, which consists of the checkpoint upper bound and current
108        // epoch.
109        let watermark_task = {
110            info!("Starting watermark update task");
111            spawn_monitored_task!(async move {
112                self.watermark_task.run().await;
113            })
114        };
115
116        // A handle that spawns a background task to evict system packages on epoch
117        // changes.
118        let system_package_task = {
119            info!("Starting system package task");
120            spawn_monitored_task!(async move {
121                self.system_package_task.run().await;
122            })
123        };
124
125        let trigger_exchange_rates_task = {
126            info!("Starting trigger exchange rates task");
127            spawn_monitored_task!(async move {
128                self.trigger_exchange_rates_task.run().await;
129            })
130        };
131
132        let server_task = {
133            info!("Starting graphql service");
134            let cancellation_token = self.state.cancellation_token.clone();
135            let address = self.address;
136            let router = self.router;
137            spawn_monitored_task!(async move {
138                axum::serve(
139                    TcpListener::bind(address)
140                        .await
141                        .map_err(|e| Error::Internal(format!("listener bind failed: {e}")))?,
142                    router.into_make_service_with_connect_info::<SocketAddr>(),
143                )
144                .with_graceful_shutdown(async move {
145                    cancellation_token.cancelled().await;
146                    info!("Shutdown signal received, terminating graphql service");
147                })
148                .await
149                .map_err(|e| Error::Internal(format!("Server run failed: {e}")))
150            })
151        };
152
153        // Wait for all tasks to complete. This ensures that the service doesn't fully
154        // shut down until all tasks and the server have completed their
155        // shutdown processes.
156        let _ = join!(
157            watermark_task,
158            system_package_task,
159            trigger_exchange_rates_task,
160            server_task
161        );
162
163        Ok(())
164    }
165}
166
167pub(crate) struct ServerBuilder {
168    state: AppState,
169    schema: SchemaBuilder<Query, Mutation, EmptySubscription>,
170    router: Option<Router>,
171    db_reader: Option<Db>,
172    resolver: Option<PackageResolver>,
173}
174
175#[derive(Clone)]
176pub(crate) struct AppState {
177    connection: ConnectionConfig,
178    service: ServiceConfig,
179    metrics: Metrics,
180    cancellation_token: CancellationToken,
181    pub version: Version,
182}
183
184impl AppState {
185    pub(crate) fn new(
186        connection: ConnectionConfig,
187        service: ServiceConfig,
188        metrics: Metrics,
189        cancellation_token: CancellationToken,
190        version: Version,
191    ) -> Self {
192        Self {
193            connection,
194            service,
195            metrics,
196            cancellation_token,
197            version,
198        }
199    }
200}
201
202impl FromRef<AppState> for ConnectionConfig {
203    fn from_ref(app_state: &AppState) -> ConnectionConfig {
204        app_state.connection.clone()
205    }
206}
207
208impl FromRef<AppState> for Metrics {
209    fn from_ref(app_state: &AppState) -> Metrics {
210        app_state.metrics.clone()
211    }
212}
213
214impl ServerBuilder {
215    pub fn new(state: AppState) -> Self {
216        Self {
217            state,
218            schema: schema_builder(),
219            router: None,
220            db_reader: None,
221            resolver: None,
222        }
223    }
224
225    pub fn address(&self) -> String {
226        format!(
227            "{}:{}",
228            self.state.connection.host, self.state.connection.port
229        )
230    }
231
232    pub fn context_data(mut self, context_data: impl Any + Send + Sync) -> Self {
233        self.schema = self.schema.data(context_data);
234        self
235    }
236
237    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
238        self.schema = self.schema.extension(extension);
239        self
240    }
241
242    fn build_schema(self) -> Schema<Query, Mutation, EmptySubscription> {
243        self.schema.finish()
244    }
245
246    /// Prepares the components of the server to be run. Finalizes the graphql
247    /// schema, and expects the `Db` and `Router` to have been initialized.
248    fn build_components(
249        self,
250    ) -> (
251        String,
252        Schema<Query, Mutation, EmptySubscription>,
253        Db,
254        PackageResolver,
255        Router,
256    ) {
257        let address = self.address();
258        let ServerBuilder {
259            state: _,
260            schema,
261            db_reader,
262            resolver,
263            router,
264        } = self;
265        (
266            address,
267            schema.finish(),
268            db_reader.expect("DB reader not initialized"),
269            resolver.expect("Package resolver not initialized"),
270            router.expect("Router not initialized"),
271        )
272    }
273
274    fn init_router(&mut self) {
275        if self.router.is_none() {
276            let router: Router = Router::new()
277                .route("/", post(graphql_handler))
278                .route("/{version}", post(graphql_handler))
279                .route("/graphql", post(graphql_handler))
280                .route("/graphql/{version}", post(graphql_handler))
281                .route("/health", get(health_check))
282                .route("/graphql/health", get(health_check))
283                .route("/graphql/{version}/health", get(health_check))
284                .with_state(self.state.clone())
285                .route_layer(CallbackLayer::new(MetricsMakeCallbackHandler {
286                    metrics: self.state.metrics.clone(),
287                }));
288            self.router = Some(router);
289        }
290    }
291
292    pub fn route(mut self, path: &str, method_handler: MethodRouter) -> Self {
293        self.init_router();
294        self.router = self.router.map(|router| router.route(path, method_handler));
295        self
296    }
297
298    pub fn layer<L>(mut self, layer: L) -> Self
299    where
300        L: Layer<Route> + Clone + Send + Sync + 'static,
301        L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
302        <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
303        <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
304        <L::Service as Service<Request<Body>>>::Future: Send + 'static,
305    {
306        self.init_router();
307        self.router = self.router.map(|router| router.layer(layer));
308        self
309    }
310
311    fn cors() -> Result<CorsLayer, Error> {
312        let acl = match std::env::var("ACCESS_CONTROL_ALLOW_ORIGIN") {
313            Ok(value) => {
314                let allow_hosts = value
315                    .split(',')
316                    .map(HeaderValue::from_str)
317                    .collect::<Result<Vec<_>, _>>()
318                    .map_err(|_| {
319                        Error::Internal(
320                            "Cannot resolve access control origin env variable".to_string(),
321                        )
322                    })?;
323                AllowOrigin::list(allow_hosts)
324            }
325            _ => AllowOrigin::any(),
326        };
327        info!("Access control allow origin set to: {acl:?}");
328
329        let cors = CorsLayer::new()
330            // Allow `POST` when accessing the resource
331            .allow_methods([Method::POST])
332            // Allow requests from any origin
333            .allow_origin(acl)
334            .allow_headers([hyper::header::CONTENT_TYPE, LIMITS_HEADER.clone()]);
335        Ok(cors)
336    }
337
338    /// Consumes the `ServerBuilder` to create a `Server` that can be run.
339    pub fn build(self) -> Result<Server, Error> {
340        let state = self.state.clone();
341        let (address, schema, db_reader, resolver, router) = self.build_components();
342
343        // Initialize the watermark background task struct.
344        let watermark_task = WatermarkTask::new(
345            db_reader.clone(),
346            state.metrics.clone(),
347            std::time::Duration::from_millis(state.service.background_tasks.watermark_update_ms),
348            state.cancellation_token.clone(),
349        );
350
351        let system_package_task = SystemPackageTask::new(
352            resolver,
353            watermark_task.epoch_receiver(),
354            state.cancellation_token.clone(),
355        );
356
357        let trigger_exchange_rates_task = TriggerExchangeRatesTask::new(
358            db_reader.clone(),
359            watermark_task.epoch_receiver(),
360            state.cancellation_token.clone(),
361        );
362
363        let router = router
364            .route_layer(middleware::from_fn_with_state(
365                state.version,
366                set_version_middleware,
367            ))
368            .route_layer(middleware::from_fn_with_state(
369                state.version,
370                check_version_middleware,
371            ))
372            .layer(axum::extract::Extension(schema))
373            .layer(axum::extract::Extension(watermark_task.lock()))
374            .layer(Self::cors()?);
375
376        Ok(Server {
377            router,
378            address: address
379                .parse()
380                .map_err(|_| Error::Internal(format!("Failed to parse address {address}")))?,
381            watermark_task,
382            system_package_task,
383            trigger_exchange_rates_task,
384            state,
385            db_reader,
386        })
387    }
388
389    /// Instantiate a `ServerBuilder` from a `ServerConfig`, typically called
390    /// when building the graphql service for production usage.
391    pub async fn from_config(
392        config: &ServerConfig,
393        version: &Version,
394        cancellation_token: CancellationToken,
395    ) -> Result<Self, Error> {
396        // PROMETHEUS
397        let prom_addr: SocketAddr = format!(
398            "{}:{}",
399            config.connection.prom_url, config.connection.prom_port
400        )
401        .parse()
402        .map_err(|_| {
403            Error::Internal(format!(
404                "Failed to parse url {}, port {} into socket address",
405                config.connection.prom_url, config.connection.prom_port
406            ))
407        })?;
408
409        let registry_service = iota_metrics::start_prometheus_server(prom_addr);
410        info!("Starting Prometheus HTTP endpoint at {}", prom_addr);
411        let registry = registry_service.default_registry();
412        registry
413            .register(iota_metrics::uptime_metric(
414                "graphql",
415                version.full,
416                "unknown",
417            ))
418            .unwrap();
419
420        // METRICS
421        let metrics = Metrics::new(&registry);
422        let state = AppState::new(
423            config.connection.clone(),
424            config.service.clone(),
425            metrics.clone(),
426            cancellation_token,
427            *version,
428        );
429        let mut builder = ServerBuilder::new(state);
430
431        let iota_names_config = config.service.iota_names.clone();
432        let zklogin_config = config.service.zklogin.clone();
433        let reader = PgManager::reader_with_config(
434            config.connection.db_url.clone(),
435            config.connection.db_pool_size,
436            // Bound each statement in a request with the overall request timeout, to bound DB
437            // utilisation (in the worst case we will use 2x the request timeout time in DB wall
438            // time).
439            config.service.limits.request_timeout_ms.into(),
440        )
441        .map_err(|e| Error::Internal(format!("Failed to create pg connection pool: {e}")))?;
442
443        // DB
444        let db = Db::new(
445            reader.clone(),
446            config.service.limits.clone(),
447            metrics.clone(),
448        );
449        let loader = DataLoader::new(db.clone());
450        let pg_conn_pool = PgManager::new(reader.clone());
451        let package_store = DbPackageStore::new(loader.clone());
452        let resolver = Arc::new(Resolver::new_with_limits(
453            PackageStoreWithLruCache::new(package_store),
454            config.service.limits.package_resolver_limits(),
455        ));
456
457        builder.db_reader = Some(db.clone());
458        builder.resolver = Some(resolver.clone());
459
460        // SDK for talking to fullnode. Used for executing transactions only
461        // TODO: fail fast if no url, once we enable mutations fully
462        let iota_sdk_client = if let Some(url) = &config.tx_exec_full_node.node_rpc_url {
463            Some(
464                IotaClientBuilder::default()
465                    .request_timeout(RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD)
466                    .max_concurrent_requests(MAX_CONCURRENT_REQUESTS)
467                    .build(url)
468                    .await
469                    .map_err(|e| {
470                        Error::Internal(format!(
471                            "Failed to connect to fullnode {e}. Is the node server running?"
472                        ))
473                    })?,
474            )
475        } else {
476            warn!(
477                "No fullnode url found in config. `dryRunTransactionBlock` and `executeTransactionBlock` will not work"
478            );
479            None
480        };
481
482        builder = builder
483            .context_data(config.service.clone())
484            .context_data(loader)
485            .context_data(db)
486            .context_data(pg_conn_pool)
487            .context_data(resolver)
488            .context_data(iota_sdk_client)
489            .context_data(iota_names_config)
490            .context_data(zklogin_config)
491            .context_data(metrics.clone())
492            .context_data(config.clone());
493
494        if config.internal_features.feature_gate {
495            builder = builder.extension(FeatureGate);
496        }
497
498        if config.internal_features.logger {
499            builder = builder.extension(Logger::default());
500        }
501
502        if config.internal_features.query_limits_checker {
503            builder = builder.extension(QueryLimitsChecker);
504        }
505
506        if config.internal_features.directive_checker {
507            builder = builder.extension(DirectiveChecker);
508        }
509
510        if config.internal_features.query_timeout {
511            builder = builder.extension(Timeout);
512        }
513
514        if config.internal_features.tracing {
515            builder = builder.extension(Tracing);
516        }
517
518        if config.internal_features.apollo_tracing {
519            builder = builder.extension(ApolloTracing);
520        }
521
522        // TODO: uncomment once impl
523        // if config.internal_features.open_telemetry { }
524
525        Ok(builder)
526    }
527}
528
529fn schema_builder() -> SchemaBuilder<Query, Mutation, EmptySubscription> {
530    async_graphql::Schema::build(Query, Mutation, EmptySubscription)
531        .register_output_type::<IMoveObject>()
532        .register_output_type::<IObject>()
533        .register_output_type::<IOwner>()
534        .register_output_type::<IMoveDatatype>()
535}
536
537/// Return the string representation of the schema used by this server.
538pub fn export_schema() -> String {
539    schema_builder().finish().sdl()
540}
541
542/// Entry point for graphql requests. Each request is stamped with a unique ID,
543/// a `ShowUsage` flag if set in the request headers, and the watermark as set
544/// by the background task.
545async fn graphql_handler(
546    ConnectInfo(addr): ConnectInfo<SocketAddr>,
547    schema: Extension<IotaGraphQLSchema>,
548    Extension(watermark_lock): Extension<WatermarkLock>,
549    headers: HeaderMap,
550    req: GraphQLRequest,
551) -> (axum::http::Extensions, GraphQLResponse) {
552    let mut req = req.into_inner();
553    req.data.insert(Uuid::new_v4());
554    if headers.contains_key(ShowUsage::name()) {
555        req.data.insert(ShowUsage)
556    }
557    // Capture the IP address of the client
558    // Note: if a load balancer is used it must be configured to forward the client
559    // IP address
560    req.data.insert(addr);
561
562    req.data.insert(Watermark::new(watermark_lock).await);
563
564    let result = schema.execute(req).await;
565
566    // If there are errors, insert them as an extension so that the Metrics callback
567    // handler can pull it out later.
568    let mut extensions = axum::http::Extensions::new();
569    if result.is_err() {
570        extensions.insert(GraphqlErrors(std::sync::Arc::new(result.errors.clone())));
571    };
572    (extensions, result.into())
573}
574
575#[derive(Clone)]
576struct MetricsMakeCallbackHandler {
577    metrics: Metrics,
578}
579
580impl MakeCallbackHandler for MetricsMakeCallbackHandler {
581    type Handler = MetricsCallbackHandler;
582
583    fn make_handler(&self, _request: &http::request::Parts) -> Self::Handler {
584        let start = Instant::now();
585        let metrics = self.metrics.clone();
586
587        metrics.request_metrics.inflight_requests.inc();
588        metrics.inc_num_queries();
589
590        MetricsCallbackHandler { metrics, start }
591    }
592}
593
594struct MetricsCallbackHandler {
595    metrics: Metrics,
596    start: Instant,
597}
598
599impl ResponseHandler for MetricsCallbackHandler {
600    fn on_response(self, response: &http::response::Parts) {
601        if let Some(errors) = response.extensions.get::<GraphqlErrors>() {
602            self.metrics.inc_errors(&errors.0);
603        }
604    }
605
606    fn on_error<E>(self, _error: &E) {
607        // Do nothing if the whole service errored
608        //
609        // in Axum this isn't possible since all services are required to have
610        // an error type of Infallible
611    }
612}
613
614impl Drop for MetricsCallbackHandler {
615    fn drop(&mut self) {
616        self.metrics.query_latency(self.start.elapsed());
617        self.metrics.request_metrics.inflight_requests.dec();
618    }
619}
620
621#[derive(Debug, Clone)]
622struct GraphqlErrors(std::sync::Arc<Vec<async_graphql::ServerError>>);
623
624/// Connect via a TCPStream to the DB to check if it is alive
625async fn db_health_check(State(connection): State<ConnectionConfig>) -> StatusCode {
626    let Ok(url) = reqwest::Url::parse(connection.db_url.as_str()) else {
627        return StatusCode::INTERNAL_SERVER_ERROR;
628    };
629
630    let Some(host) = url.host_str() else {
631        return StatusCode::INTERNAL_SERVER_ERROR;
632    };
633
634    let tcp_url = if let Some(port) = url.port() {
635        format!("{host}:{port}")
636    } else {
637        host.to_string()
638    };
639
640    if TcpStream::connect(tcp_url).is_err() {
641        StatusCode::INTERNAL_SERVER_ERROR
642    } else {
643        StatusCode::OK
644    }
645}
646
647#[derive(serde::Deserialize)]
648struct HealthParam {
649    max_checkpoint_lag_ms: Option<u64>,
650}
651
652/// Endpoint for querying the health of the service.
653/// It returns 500 for any internal error, including not connecting to the DB,
654/// and 504 if the checkpoint timestamp is too far behind the current timestamp
655/// as per the max checkpoint timestamp lag query parameter, or the default
656/// value if not provided.
657async fn health_check(
658    State(connection): State<ConnectionConfig>,
659    Extension(watermark_lock): Extension<WatermarkLock>,
660    AxumQuery(query_params): AxumQuery<HealthParam>,
661) -> StatusCode {
662    let db_health_check = db_health_check(axum::extract::State(connection)).await;
663    if db_health_check != StatusCode::OK {
664        return db_health_check;
665    }
666
667    let max_checkpoint_lag_ms = query_params
668        .max_checkpoint_lag_ms
669        .map(Duration::from_millis)
670        .unwrap_or_else(|| DEFAULT_MAX_CHECKPOINT_LAG);
671
672    let checkpoint_timestamp =
673        Duration::from_millis(watermark_lock.read().await.checkpoint_timestamp_ms);
674
675    let now_millis = Utc::now().timestamp_millis();
676
677    // Check for negative timestamp or conversion failure
678    let now: Duration = match u64::try_from(now_millis) {
679        Ok(val) => Duration::from_millis(val),
680        Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
681    };
682
683    if (now - checkpoint_timestamp) > max_checkpoint_lag_ms {
684        return StatusCode::GATEWAY_TIMEOUT;
685    }
686
687    db_health_check
688}
689
690// One server per proc, so this is okay
691async fn get_or_init_server_start_time() -> &'static Instant {
692    static ONCE: OnceCell<Instant> = OnceCell::const_new();
693    ONCE.get_or_init(|| async move { Instant::now() }).await
694}
695
696pub mod tests {
697    use std::{sync::Arc, time::Duration};
698
699    use async_graphql::{
700        Response,
701        extensions::{Extension, ExtensionContext, NextExecute},
702    };
703    use iota_sdk::{IotaClient, wallet_context::WalletContext};
704    use iota_types::transaction::TransactionData;
705    use uuid::Uuid;
706
707    use super::*;
708    use crate::{
709        config::{ConnectionConfig, Limits, ServiceConfig, Version},
710        context_data::db_data_provider::PgManager,
711        extensions::{query_limits_checker::QueryLimitsChecker, timeout::Timeout},
712    };
713
714    /// Prepares a schema for tests dealing with extensions. Returns a
715    /// `ServerBuilder` that can be further extended with `context_data` and
716    /// `extension` for testing.
717    fn prep_schema(
718        connection_config: Option<ConnectionConfig>,
719        service_config: Option<ServiceConfig>,
720    ) -> ServerBuilder {
721        let connection_config = connection_config.unwrap_or_default();
722        let service_config = service_config.unwrap_or_default();
723
724        let reader = PgManager::reader_with_config(
725            connection_config.db_url.clone(),
726            connection_config.db_pool_size,
727            service_config.limits.request_timeout_ms.into(),
728        )
729        .expect("Failed to create pg connection pool");
730
731        let version = Version::for_testing();
732        let metrics = metrics();
733        let db = Db::new(
734            reader.clone(),
735            service_config.limits.clone(),
736            metrics.clone(),
737        );
738        let pg_conn_pool = PgManager::new(reader);
739        let cancellation_token = CancellationToken::new();
740        let watermark = Watermark {
741            checkpoint: 1,
742            checkpoint_timestamp_ms: 1,
743            epoch: 0,
744        };
745        let state = AppState::new(
746            connection_config.clone(),
747            service_config.clone(),
748            metrics.clone(),
749            cancellation_token.clone(),
750            version,
751        );
752        ServerBuilder::new(state)
753            .context_data(db)
754            .context_data(pg_conn_pool)
755            .context_data(service_config)
756            .context_data(query_id())
757            .context_data(ip_address())
758            .context_data(watermark)
759            .context_data(metrics)
760    }
761
762    fn metrics() -> Metrics {
763        let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
764        let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
765        Metrics::new(&registry)
766    }
767
768    fn ip_address() -> SocketAddr {
769        let binding_address: SocketAddr = "0.0.0.0:51515".parse().unwrap();
770        binding_address
771    }
772
773    fn query_id() -> Uuid {
774        Uuid::new_v4()
775    }
776
777    pub async fn test_timeout_impl(wallet: &WalletContext) {
778        struct TimedExecuteExt {
779            pub min_req_delay: Duration,
780        }
781
782        impl ExtensionFactory for TimedExecuteExt {
783            fn create(&self) -> Arc<dyn Extension> {
784                Arc::new(TimedExecuteExt {
785                    min_req_delay: self.min_req_delay,
786                })
787            }
788        }
789
790        #[async_trait::async_trait]
791        impl Extension for TimedExecuteExt {
792            async fn execute(
793                &self,
794                ctx: &ExtensionContext<'_>,
795                operation_name: Option<&str>,
796                next: NextExecute<'_>,
797            ) -> Response {
798                tokio::time::sleep(self.min_req_delay).await;
799                next.run(ctx, operation_name).await
800            }
801        }
802
803        async fn test_timeout(
804            delay: Duration,
805            timeout: Duration,
806            query: &str,
807            iota_client: &IotaClient,
808        ) -> Response {
809            let mut cfg = ServiceConfig::default();
810            cfg.limits.request_timeout_ms = timeout.as_millis() as u32;
811            cfg.limits.mutation_timeout_ms = timeout.as_millis() as u32;
812
813            let schema = prep_schema(None, Some(cfg))
814                .context_data(Some(iota_client.clone()))
815                .extension(Timeout)
816                .extension(TimedExecuteExt {
817                    min_req_delay: delay,
818                })
819                .build_schema();
820
821            schema.execute(query).await
822        }
823
824        let query = "{ chainIdentifier }";
825        let timeout = Duration::from_millis(1000);
826        let delay = Duration::from_millis(100);
827        let iota_client = wallet.get_client().await.unwrap();
828
829        test_timeout(delay, timeout, query, &iota_client)
830            .await
831            .into_result()
832            .expect("Should complete successfully");
833
834        // Should timeout
835        let errs: Vec<_> = test_timeout(delay, delay, query, &iota_client)
836            .await
837            .into_result()
838            .unwrap_err()
839            .into_iter()
840            .map(|e| e.message)
841            .collect();
842        let exp = format!("Query request timed out. Limit: {}s", delay.as_secs_f32());
843        assert_eq!(errs, vec![exp]);
844
845        // Should timeout for mutation
846        // Create a transaction and sign it, and use the tx_bytes + signatures for the
847        // GraphQL executeTransactionBlock mutation call.
848        let addresses = wallet.get_addresses();
849        let gas = wallet
850            .get_one_gas_object_owned_by_address(addresses[0])
851            .await
852            .unwrap();
853        let tx_data = TransactionData::new_transfer_iota(
854            addresses[1],
855            addresses[0],
856            Some(1000),
857            gas.unwrap(),
858            1_000_000,
859            wallet.get_reference_gas_price().await.unwrap(),
860        );
861
862        let tx = wallet.sign_transaction(&tx_data);
863        let (tx_bytes, signatures) = tx.to_tx_bytes_and_signatures();
864
865        let signature_base64 = &signatures[0];
866        let query = format!(
867            r#"
868            mutation {{
869              executeTransactionBlock(txBytes: "{}", signatures: "{}") {{
870                effects {{
871                  status
872                }}
873              }}
874            }}"#,
875            tx_bytes.encoded(),
876            signature_base64.encoded()
877        );
878        let errs: Vec<_> = test_timeout(delay, delay, &query, &iota_client)
879            .await
880            .into_result()
881            .unwrap_err()
882            .into_iter()
883            .map(|e| e.message)
884            .collect();
885        let exp = format!(
886            "Mutation request timed out. Limit: {}s",
887            delay.as_secs_f32()
888        );
889        assert_eq!(errs, vec![exp]);
890    }
891
892    pub async fn test_query_depth_limit_impl() {
893        async fn exec_query_depth_limit(depth: u32, query: &str) -> Response {
894            let service_config = ServiceConfig {
895                limits: Limits {
896                    max_query_depth: depth,
897                    ..Default::default()
898                },
899                ..Default::default()
900            };
901
902            let schema = prep_schema(None, Some(service_config))
903                .extension(QueryLimitsChecker)
904                .build_schema();
905            schema.execute(query).await
906        }
907
908        exec_query_depth_limit(1, "{ chainIdentifier }")
909            .await
910            .into_result()
911            .expect("Should complete successfully");
912
913        exec_query_depth_limit(
914            5,
915            "{ chainIdentifier protocolConfig { configs { value key }} }",
916        )
917        .await
918        .into_result()
919        .expect("Should complete successfully");
920
921        // Should fail
922        let errs: Vec<_> = exec_query_depth_limit(0, "{ chainIdentifier }")
923            .await
924            .into_result()
925            .unwrap_err()
926            .into_iter()
927            .map(|e| e.message)
928            .collect();
929
930        assert_eq!(errs, vec!["Query nesting is over 0".to_string()]);
931        let errs: Vec<_> = exec_query_depth_limit(
932            2,
933            "{ chainIdentifier protocolConfig { configs { value key }} }",
934        )
935        .await
936        .into_result()
937        .unwrap_err()
938        .into_iter()
939        .map(|e| e.message)
940        .collect();
941        assert_eq!(errs, vec!["Query nesting is over 2".to_string()]);
942    }
943
944    pub async fn test_query_node_limit_impl() {
945        async fn exec_query_node_limit(nodes: u32, query: &str) -> Response {
946            let service_config = ServiceConfig {
947                limits: Limits {
948                    max_query_nodes: nodes,
949                    ..Default::default()
950                },
951                ..Default::default()
952            };
953
954            let schema = prep_schema(None, Some(service_config))
955                .extension(QueryLimitsChecker)
956                .build_schema();
957            schema.execute(query).await
958        }
959
960        exec_query_node_limit(1, "{ chainIdentifier }")
961            .await
962            .into_result()
963            .expect("Should complete successfully");
964
965        exec_query_node_limit(
966            5,
967            "{ chainIdentifier protocolConfig { configs { value key }} }",
968        )
969        .await
970        .into_result()
971        .expect("Should complete successfully");
972
973        // Should fail
974        let err: Vec<_> = exec_query_node_limit(0, "{ chainIdentifier }")
975            .await
976            .into_result()
977            .unwrap_err()
978            .into_iter()
979            .map(|e| e.message)
980            .collect();
981        assert_eq!(err, vec!["Query has over 0 nodes".to_string()]);
982
983        let err: Vec<_> = exec_query_node_limit(
984            4,
985            "{ chainIdentifier protocolConfig { configs { value key }} }",
986        )
987        .await
988        .into_result()
989        .unwrap_err()
990        .into_iter()
991        .map(|e| e.message)
992        .collect();
993        assert_eq!(err, vec!["Query has over 4 nodes".to_string()]);
994    }
995
996    pub async fn test_query_default_page_limit_impl(connection_config: ConnectionConfig) {
997        let service_config = ServiceConfig {
998            limits: Limits {
999                default_page_size: 1,
1000                ..Default::default()
1001            },
1002            ..Default::default()
1003        };
1004        let schema = prep_schema(Some(connection_config), Some(service_config)).build_schema();
1005
1006        let resp = schema
1007            .execute("{ checkpoints { nodes { sequenceNumber } } }")
1008            .await;
1009        let data = resp.data.clone().into_json().unwrap();
1010        let checkpoints = data
1011            .get("checkpoints")
1012            .unwrap()
1013            .get("nodes")
1014            .unwrap()
1015            .as_array()
1016            .unwrap();
1017        assert_eq!(
1018            checkpoints.len(),
1019            1,
1020            "Checkpoints should have exactly one element"
1021        );
1022
1023        let resp = schema
1024            .execute("{ checkpoints(first: 2) { nodes { sequenceNumber } } }")
1025            .await;
1026        let data = resp.data.clone().into_json().unwrap();
1027        let checkpoints = data
1028            .get("checkpoints")
1029            .unwrap()
1030            .get("nodes")
1031            .unwrap()
1032            .as_array()
1033            .unwrap();
1034        assert_eq!(
1035            checkpoints.len(),
1036            2,
1037            "Checkpoints should return two elements"
1038        );
1039    }
1040
1041    pub async fn test_query_max_page_limit_impl() {
1042        let schema = prep_schema(None, None).build_schema();
1043
1044        schema
1045            .execute("{ objects(first: 1) { nodes { version } } }")
1046            .await
1047            .into_result()
1048            .expect("Should complete successfully");
1049
1050        // Should fail
1051        let err: Vec<_> = schema
1052            .execute("{ objects(first: 51) { nodes { version } } }")
1053            .await
1054            .into_result()
1055            .unwrap_err()
1056            .into_iter()
1057            .map(|e| e.message)
1058            .collect();
1059        assert_eq!(
1060            err,
1061            vec!["Connection's page size of 51 exceeds max of 50".to_string()]
1062        );
1063    }
1064
1065    pub async fn test_query_complexity_metrics_impl() {
1066        let server_builder = prep_schema(None, None);
1067        let metrics = server_builder.state.metrics.clone();
1068        let schema = server_builder
1069            .extension(QueryLimitsChecker) // QueryLimitsChecker is where we actually set the metrics
1070            .build_schema();
1071
1072        schema
1073            .execute("{ chainIdentifier }")
1074            .await
1075            .into_result()
1076            .expect("Should complete successfully");
1077
1078        let req_metrics = metrics.request_metrics;
1079        assert_eq!(req_metrics.input_nodes.get_sample_count(), 1);
1080        assert_eq!(req_metrics.output_nodes.get_sample_count(), 1);
1081        assert_eq!(req_metrics.query_depth.get_sample_count(), 1);
1082        assert_eq!(req_metrics.input_nodes.get_sample_sum(), 1.);
1083        assert_eq!(req_metrics.output_nodes.get_sample_sum(), 1.);
1084        assert_eq!(req_metrics.query_depth.get_sample_sum(), 1.);
1085
1086        schema
1087            .execute("{ chainIdentifier protocolConfig { configs { value key }} }")
1088            .await
1089            .into_result()
1090            .expect("Should complete successfully");
1091
1092        assert_eq!(req_metrics.input_nodes.get_sample_count(), 2);
1093        assert_eq!(req_metrics.output_nodes.get_sample_count(), 2);
1094        assert_eq!(req_metrics.query_depth.get_sample_count(), 2);
1095        assert_eq!(req_metrics.input_nodes.get_sample_sum(), 2. + 4.);
1096        assert_eq!(req_metrics.output_nodes.get_sample_sum(), 2. + 4.);
1097        assert_eq!(req_metrics.query_depth.get_sample_sum(), 1. + 3.);
1098    }
1099
1100    pub async fn test_health_check_impl() {
1101        let server_builder = prep_schema(None, None);
1102        let url = format!(
1103            "http://{}:{}/health",
1104            server_builder.state.connection.host, server_builder.state.connection.port
1105        );
1106        server_builder.build_schema();
1107
1108        let resp = reqwest::get(&url).await.unwrap();
1109        assert_eq!(resp.status(), reqwest::StatusCode::OK);
1110
1111        let url_with_param = format!("{url}?max_checkpoint_lag_ms=1");
1112        let resp = reqwest::get(&url_with_param).await.unwrap();
1113        assert_eq!(resp.status(), reqwest::StatusCode::GATEWAY_TIMEOUT);
1114    }
1115}