iota_graphql_rpc/server/
builder.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    any::Any,
7    convert::Infallible,
8    net::{SocketAddr, TcpStream},
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use async_graphql::{
14    EmptySubscription, Schema, SchemaBuilder,
15    extensions::{ApolloTracing, ExtensionFactory, Tracing},
16};
17use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
18use axum::{
19    Extension, Router,
20    body::Body,
21    extract::{ConnectInfo, FromRef, Query as AxumQuery, State},
22    http::{HeaderMap, StatusCode},
23    middleware::{self},
24    response::IntoResponse,
25    routing::{MethodRouter, Route, get, post},
26};
27use chrono::Utc;
28use http::{HeaderValue, Method, Request};
29use iota_graphql_rpc_headers::LIMITS_HEADER;
30use iota_metrics::spawn_monitored_task;
31use iota_network_stack::callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler};
32use iota_package_resolver::{PackageStoreWithLruCache, Resolver};
33use iota_sdk::IotaClientBuilder;
34use tokio::{join, net::TcpListener, sync::OnceCell};
35use tokio_util::sync::CancellationToken;
36use tower::{Layer, Service};
37use tower_http::cors::{AllowOrigin, CorsLayer};
38use tracing::{info, warn};
39use uuid::Uuid;
40
41use crate::{
42    config::{
43        ConnectionConfig, MAX_CONCURRENT_REQUESTS, RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD,
44        ServerConfig, ServiceConfig, Version,
45    },
46    context_data::db_data_provider::PgManager,
47    data::{
48        DataLoader, Db,
49        package_resolver::{DbPackageStore, PackageResolver},
50    },
51    error::Error,
52    extensions::{
53        directive_checker::DirectiveChecker,
54        feature_gate::FeatureGate,
55        logger::Logger,
56        query_limits_checker::{QueryLimitsChecker, ShowUsage},
57        timeout::Timeout,
58    },
59    metrics::Metrics,
60    mutation::Mutation,
61    server::{
62        compatibility_check::check_all_tables,
63        exchange_rates_task::TriggerExchangeRatesTask,
64        system_package_task::SystemPackageTask,
65        version::{check_version_middleware, set_version_middleware},
66        watermark_task::{Watermark, WatermarkLock, WatermarkTask},
67    },
68    types::{
69        datatype::IMoveDatatype,
70        move_object::IMoveObject,
71        object::IObject,
72        owner::IOwner,
73        query::{IotaGraphQLSchema, Query},
74    },
75};
76
77/// The default allowed maximum lag between the current timestamp and the
78/// checkpoint timestamp.
79const DEFAULT_MAX_CHECKPOINT_LAG: Duration = Duration::from_secs(300);
80
81pub(crate) struct Server {
82    router: Router,
83    address: SocketAddr,
84    watermark_task: WatermarkTask,
85    system_package_task: SystemPackageTask,
86    trigger_exchange_rates_task: TriggerExchangeRatesTask,
87    state: AppState,
88    db_reader: Db,
89}
90
91impl Server {
92    /// Start the GraphQL service and any background tasks it is dependent on.
93    /// When a cancellation signal is received, the method waits for all
94    /// tasks to complete before returning.
95    pub async fn run(mut self) -> Result<(), Error> {
96        get_or_init_server_start_time().await;
97
98        // Compatibility check
99        info!("Starting compatibility check");
100        check_all_tables(&self.db_reader).await?;
101        info!("Compatibility check passed");
102
103        // A handle that spawns a background task to periodically update the
104        // `Watermark`, which consists of the checkpoint upper bound and current
105        // epoch.
106        let watermark_task = {
107            info!("Starting watermark update task");
108            spawn_monitored_task!(async move {
109                self.watermark_task.run().await;
110            })
111        };
112
113        // A handle that spawns a background task to evict system packages on epoch
114        // changes.
115        let system_package_task = {
116            info!("Starting system package task");
117            spawn_monitored_task!(async move {
118                self.system_package_task.run().await;
119            })
120        };
121
122        let trigger_exchange_rates_task = {
123            info!("Starting trigger exchange rates task");
124            spawn_monitored_task!(async move {
125                self.trigger_exchange_rates_task.run().await;
126            })
127        };
128
129        let server_task = {
130            info!("Starting graphql service");
131            let cancellation_token = self.state.cancellation_token.clone();
132            let address = self.address;
133            let router = self.router;
134            spawn_monitored_task!(async move {
135                axum::serve(
136                    TcpListener::bind(address)
137                        .await
138                        .map_err(|e| Error::Internal(format!("listener bind failed: {}", e)))?,
139                    router.into_make_service_with_connect_info::<SocketAddr>(),
140                )
141                .with_graceful_shutdown(async move {
142                    cancellation_token.cancelled().await;
143                    info!("Shutdown signal received, terminating graphql service");
144                })
145                .await
146                .map_err(|e| Error::Internal(format!("Server run failed: {}", e)))
147            })
148        };
149
150        // Wait for all tasks to complete. This ensures that the service doesn't fully
151        // shut down until all tasks and the server have completed their
152        // shutdown processes.
153        let _ = join!(
154            watermark_task,
155            system_package_task,
156            trigger_exchange_rates_task,
157            server_task
158        );
159
160        Ok(())
161    }
162}
163
164pub(crate) struct ServerBuilder {
165    state: AppState,
166    schema: SchemaBuilder<Query, Mutation, EmptySubscription>,
167    router: Option<Router>,
168    db_reader: Option<Db>,
169    resolver: Option<PackageResolver>,
170}
171
172#[derive(Clone)]
173pub(crate) struct AppState {
174    connection: ConnectionConfig,
175    service: ServiceConfig,
176    metrics: Metrics,
177    cancellation_token: CancellationToken,
178    pub version: Version,
179}
180
181impl AppState {
182    pub(crate) fn new(
183        connection: ConnectionConfig,
184        service: ServiceConfig,
185        metrics: Metrics,
186        cancellation_token: CancellationToken,
187        version: Version,
188    ) -> Self {
189        Self {
190            connection,
191            service,
192            metrics,
193            cancellation_token,
194            version,
195        }
196    }
197}
198
199impl FromRef<AppState> for ConnectionConfig {
200    fn from_ref(app_state: &AppState) -> ConnectionConfig {
201        app_state.connection.clone()
202    }
203}
204
205impl FromRef<AppState> for Metrics {
206    fn from_ref(app_state: &AppState) -> Metrics {
207        app_state.metrics.clone()
208    }
209}
210
211impl ServerBuilder {
212    pub fn new(state: AppState) -> Self {
213        Self {
214            state,
215            schema: schema_builder(),
216            router: None,
217            db_reader: None,
218            resolver: None,
219        }
220    }
221
222    pub fn address(&self) -> String {
223        format!(
224            "{}:{}",
225            self.state.connection.host, self.state.connection.port
226        )
227    }
228
229    pub fn context_data(mut self, context_data: impl Any + Send + Sync) -> Self {
230        self.schema = self.schema.data(context_data);
231        self
232    }
233
234    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
235        self.schema = self.schema.extension(extension);
236        self
237    }
238
239    fn build_schema(self) -> Schema<Query, Mutation, EmptySubscription> {
240        self.schema.finish()
241    }
242
243    /// Prepares the components of the server to be run. Finalizes the graphql
244    /// schema, and expects the `Db` and `Router` to have been initialized.
245    fn build_components(
246        self,
247    ) -> (
248        String,
249        Schema<Query, Mutation, EmptySubscription>,
250        Db,
251        PackageResolver,
252        Router,
253    ) {
254        let address = self.address();
255        let ServerBuilder {
256            state: _,
257            schema,
258            db_reader,
259            resolver,
260            router,
261        } = self;
262        (
263            address,
264            schema.finish(),
265            db_reader.expect("DB reader not initialized"),
266            resolver.expect("Package resolver not initialized"),
267            router.expect("Router not initialized"),
268        )
269    }
270
271    fn init_router(&mut self) {
272        if self.router.is_none() {
273            let router: Router = Router::new()
274                .route("/", post(graphql_handler))
275                .route("/:version", post(graphql_handler))
276                .route("/graphql", post(graphql_handler))
277                .route("/graphql/:version", post(graphql_handler))
278                .route("/health", get(health_check))
279                .route("/graphql/health", get(health_check))
280                .route("/graphql/:version/health", get(health_check))
281                .with_state(self.state.clone())
282                .route_layer(CallbackLayer::new(MetricsMakeCallbackHandler {
283                    metrics: self.state.metrics.clone(),
284                }));
285            self.router = Some(router);
286        }
287    }
288
289    pub fn route(mut self, path: &str, method_handler: MethodRouter) -> Self {
290        self.init_router();
291        self.router = self.router.map(|router| router.route(path, method_handler));
292        self
293    }
294
295    pub fn layer<L>(mut self, layer: L) -> Self
296    where
297        L: Layer<Route> + Clone + Send + 'static,
298        L::Service: Service<Request<Body>> + Clone + Send + 'static,
299        <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
300        <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
301        <L::Service as Service<Request<Body>>>::Future: Send + 'static,
302    {
303        self.init_router();
304        self.router = self.router.map(|router| router.layer(layer));
305        self
306    }
307
308    fn cors() -> Result<CorsLayer, Error> {
309        let acl = match std::env::var("ACCESS_CONTROL_ALLOW_ORIGIN") {
310            Ok(value) => {
311                let allow_hosts = value
312                    .split(',')
313                    .map(HeaderValue::from_str)
314                    .collect::<Result<Vec<_>, _>>()
315                    .map_err(|_| {
316                        Error::Internal(
317                            "Cannot resolve access control origin env variable".to_string(),
318                        )
319                    })?;
320                AllowOrigin::list(allow_hosts)
321            }
322            _ => AllowOrigin::any(),
323        };
324        info!("Access control allow origin set to: {acl:?}");
325
326        let cors = CorsLayer::new()
327            // Allow `POST` when accessing the resource
328            .allow_methods([Method::POST])
329            // Allow requests from any origin
330            .allow_origin(acl)
331            .allow_headers([hyper::header::CONTENT_TYPE, LIMITS_HEADER.clone()]);
332        Ok(cors)
333    }
334
335    /// Consumes the `ServerBuilder` to create a `Server` that can be run.
336    pub fn build(self) -> Result<Server, Error> {
337        let state = self.state.clone();
338        let (address, schema, db_reader, resolver, router) = self.build_components();
339
340        // Initialize the watermark background task struct.
341        let watermark_task = WatermarkTask::new(
342            db_reader.clone(),
343            state.metrics.clone(),
344            std::time::Duration::from_millis(state.service.background_tasks.watermark_update_ms),
345            state.cancellation_token.clone(),
346        );
347
348        let system_package_task = SystemPackageTask::new(
349            resolver,
350            watermark_task.epoch_receiver(),
351            state.cancellation_token.clone(),
352        );
353
354        let trigger_exchange_rates_task = TriggerExchangeRatesTask::new(
355            db_reader.clone(),
356            watermark_task.epoch_receiver(),
357            state.cancellation_token.clone(),
358        );
359
360        let router = router
361            .route_layer(middleware::from_fn_with_state(
362                state.version,
363                set_version_middleware,
364            ))
365            .route_layer(middleware::from_fn_with_state(
366                state.version,
367                check_version_middleware,
368            ))
369            .layer(axum::extract::Extension(schema))
370            .layer(axum::extract::Extension(watermark_task.lock()))
371            .layer(Self::cors()?);
372
373        Ok(Server {
374            router,
375            address: address
376                .parse()
377                .map_err(|_| Error::Internal(format!("Failed to parse address {}", address)))?,
378            watermark_task,
379            system_package_task,
380            trigger_exchange_rates_task,
381            state,
382            db_reader,
383        })
384    }
385
386    /// Instantiate a `ServerBuilder` from a `ServerConfig`, typically called
387    /// when building the graphql service for production usage.
388    pub async fn from_config(
389        config: &ServerConfig,
390        version: &Version,
391        cancellation_token: CancellationToken,
392    ) -> Result<Self, Error> {
393        // PROMETHEUS
394        let prom_addr: SocketAddr = format!(
395            "{}:{}",
396            config.connection.prom_url, config.connection.prom_port
397        )
398        .parse()
399        .map_err(|_| {
400            Error::Internal(format!(
401                "Failed to parse url {}, port {} into socket address",
402                config.connection.prom_url, config.connection.prom_port
403            ))
404        })?;
405
406        let registry_service = iota_metrics::start_prometheus_server(prom_addr);
407        info!("Starting Prometheus HTTP endpoint at {}", prom_addr);
408        let registry = registry_service.default_registry();
409        registry
410            .register(iota_metrics::uptime_metric(
411                "graphql",
412                version.full,
413                "unknown",
414            ))
415            .unwrap();
416
417        // METRICS
418        let metrics = Metrics::new(&registry);
419        let state = AppState::new(
420            config.connection.clone(),
421            config.service.clone(),
422            metrics.clone(),
423            cancellation_token,
424            *version,
425        );
426        let mut builder = ServerBuilder::new(state);
427
428        let iota_names_config = config.service.iota_names.clone();
429        let zklogin_config = config.service.zklogin.clone();
430        let reader = PgManager::reader_with_config(
431            config.connection.db_url.clone(),
432            config.connection.db_pool_size,
433            // Bound each statement in a request with the overall request timeout, to bound DB
434            // utilisation (in the worst case we will use 2x the request timeout time in DB wall
435            // time).
436            config.service.limits.request_timeout_ms.into(),
437        )
438        .map_err(|e| Error::Internal(format!("Failed to create pg connection pool: {}", e)))?;
439
440        // DB
441        let db = Db::new(
442            reader.clone(),
443            config.service.limits.clone(),
444            metrics.clone(),
445        );
446        let loader = DataLoader::new(db.clone());
447        let pg_conn_pool = PgManager::new(reader.clone());
448        let package_store = DbPackageStore::new(loader.clone());
449        let resolver = Arc::new(Resolver::new_with_limits(
450            PackageStoreWithLruCache::new(package_store),
451            config.service.limits.package_resolver_limits(),
452        ));
453
454        builder.db_reader = Some(db.clone());
455        builder.resolver = Some(resolver.clone());
456
457        // SDK for talking to fullnode. Used for executing transactions only
458        // TODO: fail fast if no url, once we enable mutations fully
459        let iota_sdk_client = if let Some(url) = &config.tx_exec_full_node.node_rpc_url {
460            Some(
461                IotaClientBuilder::default()
462                    .request_timeout(RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD)
463                    .max_concurrent_requests(MAX_CONCURRENT_REQUESTS)
464                    .build(url)
465                    .await
466                    .map_err(|e| {
467                        Error::Internal(format!(
468                            "Failed to connect to fullnode {e}. Is the node server running?"
469                        ))
470                    })?,
471            )
472        } else {
473            warn!(
474                "No fullnode url found in config. `dryRunTransactionBlock` and `executeTransactionBlock` will not work"
475            );
476            None
477        };
478
479        builder = builder
480            .context_data(config.service.clone())
481            .context_data(loader)
482            .context_data(db)
483            .context_data(pg_conn_pool)
484            .context_data(resolver)
485            .context_data(iota_sdk_client)
486            .context_data(iota_names_config)
487            .context_data(zklogin_config)
488            .context_data(metrics.clone())
489            .context_data(config.clone());
490
491        if config.internal_features.feature_gate {
492            builder = builder.extension(FeatureGate);
493        }
494
495        if config.internal_features.logger {
496            builder = builder.extension(Logger::default());
497        }
498
499        if config.internal_features.query_limits_checker {
500            builder = builder.extension(QueryLimitsChecker);
501        }
502
503        if config.internal_features.directive_checker {
504            builder = builder.extension(DirectiveChecker);
505        }
506
507        if config.internal_features.query_timeout {
508            builder = builder.extension(Timeout);
509        }
510
511        if config.internal_features.tracing {
512            builder = builder.extension(Tracing);
513        }
514
515        if config.internal_features.apollo_tracing {
516            builder = builder.extension(ApolloTracing);
517        }
518
519        // TODO: uncomment once impl
520        // if config.internal_features.open_telemetry { }
521
522        Ok(builder)
523    }
524}
525
526fn schema_builder() -> SchemaBuilder<Query, Mutation, EmptySubscription> {
527    async_graphql::Schema::build(Query, Mutation, EmptySubscription)
528        .register_output_type::<IMoveObject>()
529        .register_output_type::<IObject>()
530        .register_output_type::<IOwner>()
531        .register_output_type::<IMoveDatatype>()
532}
533
534/// Return the string representation of the schema used by this server.
535pub fn export_schema() -> String {
536    schema_builder().finish().sdl()
537}
538
539/// Entry point for graphql requests. Each request is stamped with a unique ID,
540/// a `ShowUsage` flag if set in the request headers, and the watermark as set
541/// by the background task.
542async fn graphql_handler(
543    ConnectInfo(addr): ConnectInfo<SocketAddr>,
544    schema: Extension<IotaGraphQLSchema>,
545    Extension(watermark_lock): Extension<WatermarkLock>,
546    headers: HeaderMap,
547    req: GraphQLRequest,
548) -> (axum::http::Extensions, GraphQLResponse) {
549    let mut req = req.into_inner();
550    req.data.insert(Uuid::new_v4());
551    if headers.contains_key(ShowUsage::name()) {
552        req.data.insert(ShowUsage)
553    }
554    // Capture the IP address of the client
555    // Note: if a load balancer is used it must be configured to forward the client
556    // IP address
557    req.data.insert(addr);
558
559    req.data.insert(Watermark::new(watermark_lock).await);
560
561    let result = schema.execute(req).await;
562
563    // If there are errors, insert them as an extension so that the Metrics callback
564    // handler can pull it out later.
565    let mut extensions = axum::http::Extensions::new();
566    if result.is_err() {
567        extensions.insert(GraphqlErrors(std::sync::Arc::new(result.errors.clone())));
568    };
569    (extensions, result.into())
570}
571
572#[derive(Clone)]
573struct MetricsMakeCallbackHandler {
574    metrics: Metrics,
575}
576
577impl MakeCallbackHandler for MetricsMakeCallbackHandler {
578    type Handler = MetricsCallbackHandler;
579
580    fn make_handler(&self, _request: &http::request::Parts) -> Self::Handler {
581        let start = Instant::now();
582        let metrics = self.metrics.clone();
583
584        metrics.request_metrics.inflight_requests.inc();
585        metrics.inc_num_queries();
586
587        MetricsCallbackHandler { metrics, start }
588    }
589}
590
591struct MetricsCallbackHandler {
592    metrics: Metrics,
593    start: Instant,
594}
595
596impl ResponseHandler for MetricsCallbackHandler {
597    fn on_response(self, response: &http::response::Parts) {
598        if let Some(errors) = response.extensions.get::<GraphqlErrors>() {
599            self.metrics.inc_errors(&errors.0);
600        }
601    }
602
603    fn on_error<E>(self, _error: &E) {
604        // Do nothing if the whole service errored
605        //
606        // in Axum this isn't possible since all services are required to have
607        // an error type of Infallible
608    }
609}
610
611impl Drop for MetricsCallbackHandler {
612    fn drop(&mut self) {
613        self.metrics.query_latency(self.start.elapsed());
614        self.metrics.request_metrics.inflight_requests.dec();
615    }
616}
617
618#[derive(Debug, Clone)]
619struct GraphqlErrors(std::sync::Arc<Vec<async_graphql::ServerError>>);
620
621/// Connect via a TCPStream to the DB to check if it is alive
622async fn db_health_check(State(connection): State<ConnectionConfig>) -> StatusCode {
623    let Ok(url) = reqwest::Url::parse(connection.db_url.as_str()) else {
624        return StatusCode::INTERNAL_SERVER_ERROR;
625    };
626
627    let Some(host) = url.host_str() else {
628        return StatusCode::INTERNAL_SERVER_ERROR;
629    };
630
631    let tcp_url = if let Some(port) = url.port() {
632        format!("{host}:{port}")
633    } else {
634        host.to_string()
635    };
636
637    if TcpStream::connect(tcp_url).is_err() {
638        StatusCode::INTERNAL_SERVER_ERROR
639    } else {
640        StatusCode::OK
641    }
642}
643
644#[derive(serde::Deserialize)]
645struct HealthParam {
646    max_checkpoint_lag_ms: Option<u64>,
647}
648
649/// Endpoint for querying the health of the service.
650/// It returns 500 for any internal error, including not connecting to the DB,
651/// and 504 if the checkpoint timestamp is too far behind the current timestamp
652/// as per the max checkpoint timestamp lag query parameter, or the default
653/// value if not provided.
654async fn health_check(
655    State(connection): State<ConnectionConfig>,
656    Extension(watermark_lock): Extension<WatermarkLock>,
657    AxumQuery(query_params): AxumQuery<HealthParam>,
658) -> StatusCode {
659    let db_health_check = db_health_check(axum::extract::State(connection)).await;
660    if db_health_check != StatusCode::OK {
661        return db_health_check;
662    }
663
664    let max_checkpoint_lag_ms = query_params
665        .max_checkpoint_lag_ms
666        .map(Duration::from_millis)
667        .unwrap_or_else(|| DEFAULT_MAX_CHECKPOINT_LAG);
668
669    let checkpoint_timestamp =
670        Duration::from_millis(watermark_lock.read().await.checkpoint_timestamp_ms);
671
672    let now_millis = Utc::now().timestamp_millis();
673
674    // Check for negative timestamp or conversion failure
675    let now: Duration = match u64::try_from(now_millis) {
676        Ok(val) => Duration::from_millis(val),
677        Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
678    };
679
680    if (now - checkpoint_timestamp) > max_checkpoint_lag_ms {
681        return StatusCode::GATEWAY_TIMEOUT;
682    }
683
684    db_health_check
685}
686
687// One server per proc, so this is okay
688async fn get_or_init_server_start_time() -> &'static Instant {
689    static ONCE: OnceCell<Instant> = OnceCell::const_new();
690    ONCE.get_or_init(|| async move { Instant::now() }).await
691}
692
693pub mod tests {
694    use std::{sync::Arc, time::Duration};
695
696    use async_graphql::{
697        Response,
698        extensions::{Extension, ExtensionContext, NextExecute},
699    };
700    use iota_sdk::{IotaClient, wallet_context::WalletContext};
701    use iota_types::transaction::TransactionData;
702    use uuid::Uuid;
703
704    use super::*;
705    use crate::{
706        config::{ConnectionConfig, Limits, ServiceConfig, Version},
707        context_data::db_data_provider::PgManager,
708        extensions::{query_limits_checker::QueryLimitsChecker, timeout::Timeout},
709    };
710
711    /// Prepares a schema for tests dealing with extensions. Returns a
712    /// `ServerBuilder` that can be further extended with `context_data` and
713    /// `extension` for testing.
714    fn prep_schema(
715        connection_config: Option<ConnectionConfig>,
716        service_config: Option<ServiceConfig>,
717    ) -> ServerBuilder {
718        let connection_config = connection_config.unwrap_or_default();
719        let service_config = service_config.unwrap_or_default();
720
721        let reader = PgManager::reader_with_config(
722            connection_config.db_url.clone(),
723            connection_config.db_pool_size,
724            service_config.limits.request_timeout_ms.into(),
725        )
726        .expect("Failed to create pg connection pool");
727
728        let version = Version::for_testing();
729        let metrics = metrics();
730        let db = Db::new(
731            reader.clone(),
732            service_config.limits.clone(),
733            metrics.clone(),
734        );
735        let pg_conn_pool = PgManager::new(reader);
736        let cancellation_token = CancellationToken::new();
737        let watermark = Watermark {
738            checkpoint: 1,
739            checkpoint_timestamp_ms: 1,
740            epoch: 0,
741        };
742        let state = AppState::new(
743            connection_config.clone(),
744            service_config.clone(),
745            metrics.clone(),
746            cancellation_token.clone(),
747            version,
748        );
749        ServerBuilder::new(state)
750            .context_data(db)
751            .context_data(pg_conn_pool)
752            .context_data(service_config)
753            .context_data(query_id())
754            .context_data(ip_address())
755            .context_data(watermark)
756            .context_data(metrics)
757    }
758
759    fn metrics() -> Metrics {
760        let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
761        let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
762        Metrics::new(&registry)
763    }
764
765    fn ip_address() -> SocketAddr {
766        let binding_address: SocketAddr = "0.0.0.0:51515".parse().unwrap();
767        binding_address
768    }
769
770    fn query_id() -> Uuid {
771        Uuid::new_v4()
772    }
773
774    pub async fn test_timeout_impl(wallet: &WalletContext) {
775        struct TimedExecuteExt {
776            pub min_req_delay: Duration,
777        }
778
779        impl ExtensionFactory for TimedExecuteExt {
780            fn create(&self) -> Arc<dyn Extension> {
781                Arc::new(TimedExecuteExt {
782                    min_req_delay: self.min_req_delay,
783                })
784            }
785        }
786
787        #[async_trait::async_trait]
788        impl Extension for TimedExecuteExt {
789            async fn execute(
790                &self,
791                ctx: &ExtensionContext<'_>,
792                operation_name: Option<&str>,
793                next: NextExecute<'_>,
794            ) -> Response {
795                tokio::time::sleep(self.min_req_delay).await;
796                next.run(ctx, operation_name).await
797            }
798        }
799
800        async fn test_timeout(
801            delay: Duration,
802            timeout: Duration,
803            query: &str,
804            iota_client: &IotaClient,
805        ) -> Response {
806            let mut cfg = ServiceConfig::default();
807            cfg.limits.request_timeout_ms = timeout.as_millis() as u32;
808            cfg.limits.mutation_timeout_ms = timeout.as_millis() as u32;
809
810            let schema = prep_schema(None, Some(cfg))
811                .context_data(Some(iota_client.clone()))
812                .extension(Timeout)
813                .extension(TimedExecuteExt {
814                    min_req_delay: delay,
815                })
816                .build_schema();
817
818            schema.execute(query).await
819        }
820
821        let query = "{ chainIdentifier }";
822        let timeout = Duration::from_millis(1000);
823        let delay = Duration::from_millis(100);
824        let iota_client = wallet.get_client().await.unwrap();
825
826        test_timeout(delay, timeout, query, &iota_client)
827            .await
828            .into_result()
829            .expect("Should complete successfully");
830
831        // Should timeout
832        let errs: Vec<_> = test_timeout(delay, delay, query, &iota_client)
833            .await
834            .into_result()
835            .unwrap_err()
836            .into_iter()
837            .map(|e| e.message)
838            .collect();
839        let exp = format!("Query request timed out. Limit: {}s", delay.as_secs_f32());
840        assert_eq!(errs, vec![exp]);
841
842        // Should timeout for mutation
843        // Create a transaction and sign it, and use the tx_bytes + signatures for the
844        // GraphQL executeTransactionBlock mutation call.
845        let addresses = wallet.get_addresses();
846        let gas = wallet
847            .get_one_gas_object_owned_by_address(addresses[0])
848            .await
849            .unwrap();
850        let tx_data = TransactionData::new_transfer_iota(
851            addresses[1],
852            addresses[0],
853            Some(1000),
854            gas.unwrap(),
855            1_000_000,
856            wallet.get_reference_gas_price().await.unwrap(),
857        );
858
859        let tx = wallet.sign_transaction(&tx_data);
860        let (tx_bytes, signatures) = tx.to_tx_bytes_and_signatures();
861
862        let signature_base64 = &signatures[0];
863        let query = format!(
864            r#"
865            mutation {{
866              executeTransactionBlock(txBytes: "{}", signatures: "{}") {{
867                effects {{
868                  status
869                }}
870              }}
871            }}"#,
872            tx_bytes.encoded(),
873            signature_base64.encoded()
874        );
875        let errs: Vec<_> = test_timeout(delay, delay, &query, &iota_client)
876            .await
877            .into_result()
878            .unwrap_err()
879            .into_iter()
880            .map(|e| e.message)
881            .collect();
882        let exp = format!(
883            "Mutation request timed out. Limit: {}s",
884            delay.as_secs_f32()
885        );
886        assert_eq!(errs, vec![exp]);
887    }
888
889    pub async fn test_query_depth_limit_impl() {
890        async fn exec_query_depth_limit(depth: u32, query: &str) -> Response {
891            let service_config = ServiceConfig {
892                limits: Limits {
893                    max_query_depth: depth,
894                    ..Default::default()
895                },
896                ..Default::default()
897            };
898
899            let schema = prep_schema(None, Some(service_config))
900                .extension(QueryLimitsChecker)
901                .build_schema();
902            schema.execute(query).await
903        }
904
905        exec_query_depth_limit(1, "{ chainIdentifier }")
906            .await
907            .into_result()
908            .expect("Should complete successfully");
909
910        exec_query_depth_limit(
911            5,
912            "{ chainIdentifier protocolConfig { configs { value key }} }",
913        )
914        .await
915        .into_result()
916        .expect("Should complete successfully");
917
918        // Should fail
919        let errs: Vec<_> = exec_query_depth_limit(0, "{ chainIdentifier }")
920            .await
921            .into_result()
922            .unwrap_err()
923            .into_iter()
924            .map(|e| e.message)
925            .collect();
926
927        assert_eq!(errs, vec!["Query nesting is over 0".to_string()]);
928        let errs: Vec<_> = exec_query_depth_limit(
929            2,
930            "{ chainIdentifier protocolConfig { configs { value key }} }",
931        )
932        .await
933        .into_result()
934        .unwrap_err()
935        .into_iter()
936        .map(|e| e.message)
937        .collect();
938        assert_eq!(errs, vec!["Query nesting is over 2".to_string()]);
939    }
940
941    pub async fn test_query_node_limit_impl() {
942        async fn exec_query_node_limit(nodes: u32, query: &str) -> Response {
943            let service_config = ServiceConfig {
944                limits: Limits {
945                    max_query_nodes: nodes,
946                    ..Default::default()
947                },
948                ..Default::default()
949            };
950
951            let schema = prep_schema(None, Some(service_config))
952                .extension(QueryLimitsChecker)
953                .build_schema();
954            schema.execute(query).await
955        }
956
957        exec_query_node_limit(1, "{ chainIdentifier }")
958            .await
959            .into_result()
960            .expect("Should complete successfully");
961
962        exec_query_node_limit(
963            5,
964            "{ chainIdentifier protocolConfig { configs { value key }} }",
965        )
966        .await
967        .into_result()
968        .expect("Should complete successfully");
969
970        // Should fail
971        let err: Vec<_> = exec_query_node_limit(0, "{ chainIdentifier }")
972            .await
973            .into_result()
974            .unwrap_err()
975            .into_iter()
976            .map(|e| e.message)
977            .collect();
978        assert_eq!(err, vec!["Query has over 0 nodes".to_string()]);
979
980        let err: Vec<_> = exec_query_node_limit(
981            4,
982            "{ chainIdentifier protocolConfig { configs { value key }} }",
983        )
984        .await
985        .into_result()
986        .unwrap_err()
987        .into_iter()
988        .map(|e| e.message)
989        .collect();
990        assert_eq!(err, vec!["Query has over 4 nodes".to_string()]);
991    }
992
993    pub async fn test_query_default_page_limit_impl(connection_config: ConnectionConfig) {
994        let service_config = ServiceConfig {
995            limits: Limits {
996                default_page_size: 1,
997                ..Default::default()
998            },
999            ..Default::default()
1000        };
1001        let schema = prep_schema(Some(connection_config), Some(service_config)).build_schema();
1002
1003        let resp = schema
1004            .execute("{ checkpoints { nodes { sequenceNumber } } }")
1005            .await;
1006        let data = resp.data.clone().into_json().unwrap();
1007        let checkpoints = data
1008            .get("checkpoints")
1009            .unwrap()
1010            .get("nodes")
1011            .unwrap()
1012            .as_array()
1013            .unwrap();
1014        assert_eq!(
1015            checkpoints.len(),
1016            1,
1017            "Checkpoints should have exactly one element"
1018        );
1019
1020        let resp = schema
1021            .execute("{ checkpoints(first: 2) { nodes { sequenceNumber } } }")
1022            .await;
1023        let data = resp.data.clone().into_json().unwrap();
1024        let checkpoints = data
1025            .get("checkpoints")
1026            .unwrap()
1027            .get("nodes")
1028            .unwrap()
1029            .as_array()
1030            .unwrap();
1031        assert_eq!(
1032            checkpoints.len(),
1033            2,
1034            "Checkpoints should return two elements"
1035        );
1036    }
1037
1038    pub async fn test_query_max_page_limit_impl() {
1039        let schema = prep_schema(None, None).build_schema();
1040
1041        schema
1042            .execute("{ objects(first: 1) { nodes { version } } }")
1043            .await
1044            .into_result()
1045            .expect("Should complete successfully");
1046
1047        // Should fail
1048        let err: Vec<_> = schema
1049            .execute("{ objects(first: 51) { nodes { version } } }")
1050            .await
1051            .into_result()
1052            .unwrap_err()
1053            .into_iter()
1054            .map(|e| e.message)
1055            .collect();
1056        assert_eq!(
1057            err,
1058            vec!["Connection's page size of 51 exceeds max of 50".to_string()]
1059        );
1060    }
1061
1062    pub async fn test_query_complexity_metrics_impl() {
1063        let server_builder = prep_schema(None, None);
1064        let metrics = server_builder.state.metrics.clone();
1065        let schema = server_builder
1066            .extension(QueryLimitsChecker) // QueryLimitsChecker is where we actually set the metrics
1067            .build_schema();
1068
1069        schema
1070            .execute("{ chainIdentifier }")
1071            .await
1072            .into_result()
1073            .expect("Should complete successfully");
1074
1075        let req_metrics = metrics.request_metrics;
1076        assert_eq!(req_metrics.input_nodes.get_sample_count(), 1);
1077        assert_eq!(req_metrics.output_nodes.get_sample_count(), 1);
1078        assert_eq!(req_metrics.query_depth.get_sample_count(), 1);
1079        assert_eq!(req_metrics.input_nodes.get_sample_sum(), 1.);
1080        assert_eq!(req_metrics.output_nodes.get_sample_sum(), 1.);
1081        assert_eq!(req_metrics.query_depth.get_sample_sum(), 1.);
1082
1083        schema
1084            .execute("{ chainIdentifier protocolConfig { configs { value key }} }")
1085            .await
1086            .into_result()
1087            .expect("Should complete successfully");
1088
1089        assert_eq!(req_metrics.input_nodes.get_sample_count(), 2);
1090        assert_eq!(req_metrics.output_nodes.get_sample_count(), 2);
1091        assert_eq!(req_metrics.query_depth.get_sample_count(), 2);
1092        assert_eq!(req_metrics.input_nodes.get_sample_sum(), 2. + 4.);
1093        assert_eq!(req_metrics.output_nodes.get_sample_sum(), 2. + 4.);
1094        assert_eq!(req_metrics.query_depth.get_sample_sum(), 1. + 3.);
1095    }
1096
1097    pub async fn test_health_check_impl() {
1098        let server_builder = prep_schema(None, None);
1099        let url = format!(
1100            "http://{}:{}/health",
1101            server_builder.state.connection.host, server_builder.state.connection.port
1102        );
1103        server_builder.build_schema();
1104
1105        let resp = reqwest::get(&url).await.unwrap();
1106        assert_eq!(resp.status(), reqwest::StatusCode::OK);
1107
1108        let url_with_param = format!("{}?max_checkpoint_lag_ms=1", url);
1109        let resp = reqwest::get(&url_with_param).await.unwrap();
1110        assert_eq!(resp.status(), reqwest::StatusCode::GATEWAY_TIMEOUT);
1111    }
1112}