iota_graphql_rpc/server/
builder.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4use std::{
5    any::Any,
6    convert::Infallible,
7    net::{SocketAddr, TcpStream},
8    sync::Arc,
9    time::{Duration, Instant},
10};
11
12use async_graphql::{
13    Context, EmptySubscription, ResultExt, Schema, SchemaBuilder,
14    extensions::{ApolloTracing, ExtensionFactory, Tracing},
15};
16use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
17use axum::{
18    Extension, Router,
19    body::Body,
20    extract::{ConnectInfo, FromRef, Query as AxumQuery, State},
21    http::{HeaderMap, StatusCode},
22    middleware::{self},
23    response::IntoResponse,
24    routing::{MethodRouter, Route, get, post},
25};
26use chrono::Utc;
27use http::{HeaderValue, Method, Request};
28use iota_graphql_rpc_headers::LIMITS_HEADER;
29use iota_indexer::{
30    db::{get_pool_connection, setup_postgres::check_db_migration_consistency},
31    metrics::IndexerMetrics,
32    store::PgIndexerStore,
33};
34use iota_metrics::spawn_monitored_task;
35use iota_network_stack::callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler};
36use iota_package_resolver::{PackageStoreWithLruCache, Resolver};
37use iota_sdk::{IotaClient, IotaClientBuilder};
38use tokio::{join, net::TcpListener, sync::OnceCell};
39use tokio_util::sync::CancellationToken;
40use tower::{Layer, Service};
41use tower_http::cors::{AllowOrigin, CorsLayer};
42use tracing::{info, warn};
43use uuid::Uuid;
44
45use crate::{
46    config::{
47        ConnectionConfig, MAX_CONCURRENT_REQUESTS, RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD,
48        ServerConfig, ServiceConfig, Version,
49    },
50    context_data::db_data_provider::PgManager,
51    data::{
52        DataLoader, Db,
53        package_resolver::{DbPackageStore, PackageResolver},
54    },
55    error::Error,
56    extensions::{
57        directive_checker::DirectiveChecker,
58        feature_gate::FeatureGate,
59        logger::Logger,
60        query_limits_checker::{QueryLimitsChecker, ShowUsage},
61        timeout::Timeout,
62    },
63    metrics::Metrics,
64    mutation::Mutation,
65    server::{
66        exchange_rates_task::TriggerExchangeRatesTask,
67        system_package_task::SystemPackageTask,
68        version::{check_version_middleware, set_version_middleware},
69        watermark_task::{Watermark, WatermarkLock, WatermarkTask},
70    },
71    types::{
72        chain_identifier::ChainIdentifierCache,
73        datatype::IMoveDatatype,
74        move_object::IMoveObject,
75        object::IObject,
76        owner::IOwner,
77        query::{IotaGraphQLSchema, Query},
78    },
79};
80
81/// The default allowed maximum lag between the current timestamp and the
82/// checkpoint timestamp.
83const DEFAULT_MAX_CHECKPOINT_LAG: Duration = Duration::from_secs(300);
84
85pub(crate) struct Server {
86    router: Router,
87    address: SocketAddr,
88    watermark_task: WatermarkTask,
89    system_package_task: SystemPackageTask,
90    trigger_exchange_rates_task: TriggerExchangeRatesTask,
91    state: AppState,
92    db_reader: Db,
93}
94
95impl Server {
96    /// Start the GraphQL service and any background tasks it is dependent on.
97    /// When a cancellation signal is received, the method waits for all
98    /// tasks to complete before returning.
99    pub async fn run(mut self) -> Result<(), Error> {
100        get_or_init_server_start_time().await;
101
102        {
103            // Compatibility check
104            info!("Starting compatibility check");
105            let mut connection = get_pool_connection(&self.db_reader.inner.get_pool())?;
106            check_db_migration_consistency(&mut connection)?;
107            info!("Compatibility check passed");
108        }
109
110        // A handle that spawns a background task to periodically update the
111        // `Watermark`, which consists of the checkpoint upper bound and current
112        // epoch.
113        let watermark_task = {
114            info!("Starting watermark update task");
115            spawn_monitored_task!(async move {
116                self.watermark_task.run().await;
117            })
118        };
119
120        // A handle that spawns a background task to evict system packages on epoch
121        // changes.
122        let system_package_task = {
123            info!("Starting system package task");
124            spawn_monitored_task!(async move {
125                self.system_package_task.run().await;
126            })
127        };
128
129        let trigger_exchange_rates_task = {
130            info!("Starting trigger exchange rates task");
131            spawn_monitored_task!(async move {
132                self.trigger_exchange_rates_task.run().await;
133            })
134        };
135
136        let server_task = {
137            info!("Starting graphql service");
138            let cancellation_token = self.state.cancellation_token.clone();
139            let address = self.address;
140            let router = self.router;
141            spawn_monitored_task!(async move {
142                axum::serve(
143                    TcpListener::bind(address)
144                        .await
145                        .map_err(|e| Error::Internal(format!("listener bind failed: {e}")))?,
146                    router.into_make_service_with_connect_info::<SocketAddr>(),
147                )
148                .with_graceful_shutdown(async move {
149                    cancellation_token.cancelled().await;
150                    info!("Shutdown signal received, terminating graphql service");
151                })
152                .await
153                .map_err(|e| Error::Internal(format!("Server run failed: {e}")))
154            })
155        };
156
157        // Wait for all tasks to complete. This ensures that the service doesn't fully
158        // shut down until all tasks and the server have completed their
159        // shutdown processes.
160        let _ = join!(
161            watermark_task,
162            system_package_task,
163            trigger_exchange_rates_task,
164            server_task
165        );
166
167        Ok(())
168    }
169}
170
171pub(crate) struct ServerBuilder {
172    state: AppState,
173    schema: SchemaBuilder<Query, Mutation, EmptySubscription>,
174    router: Option<Router>,
175    db_reader: Option<Db>,
176    resolver: Option<PackageResolver>,
177}
178
179#[derive(Clone)]
180pub(crate) struct AppState {
181    connection: ConnectionConfig,
182    service: ServiceConfig,
183    metrics: Metrics,
184    cancellation_token: CancellationToken,
185    pub version: Version,
186}
187
188impl AppState {
189    pub(crate) fn new(
190        connection: ConnectionConfig,
191        service: ServiceConfig,
192        metrics: Metrics,
193        cancellation_token: CancellationToken,
194        version: Version,
195    ) -> Self {
196        Self {
197            connection,
198            service,
199            metrics,
200            cancellation_token,
201            version,
202        }
203    }
204}
205
206impl FromRef<AppState> for ConnectionConfig {
207    fn from_ref(app_state: &AppState) -> ConnectionConfig {
208        app_state.connection.clone()
209    }
210}
211
212impl FromRef<AppState> for Metrics {
213    fn from_ref(app_state: &AppState) -> Metrics {
214        app_state.metrics.clone()
215    }
216}
217
218impl ServerBuilder {
219    pub fn new(state: AppState) -> Self {
220        Self {
221            state,
222            schema: schema_builder(),
223            router: None,
224            db_reader: None,
225            resolver: None,
226        }
227    }
228
229    pub fn address(&self) -> String {
230        format!(
231            "{}:{}",
232            self.state.connection.host, self.state.connection.port
233        )
234    }
235
236    pub fn context_data(mut self, context_data: impl Any + Send + Sync) -> Self {
237        self.schema = self.schema.data(context_data);
238        self
239    }
240
241    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
242        self.schema = self.schema.extension(extension);
243        self
244    }
245
246    fn build_schema(self) -> Schema<Query, Mutation, EmptySubscription> {
247        self.schema.finish()
248    }
249
250    /// Prepares the components of the server to be run. Finalizes the graphql
251    /// schema, and expects the `Db` and `Router` to have been initialized.
252    fn build_components(
253        self,
254    ) -> (
255        String,
256        Schema<Query, Mutation, EmptySubscription>,
257        Db,
258        PackageResolver,
259        Router,
260    ) {
261        let address = self.address();
262        let ServerBuilder {
263            state: _,
264            schema,
265            db_reader,
266            resolver,
267            router,
268        } = self;
269        (
270            address,
271            schema.finish(),
272            db_reader.expect("db reader not initialized"),
273            resolver.expect("package resolver not initialized"),
274            router.expect("router not initialized"),
275        )
276    }
277
278    fn init_router(&mut self) {
279        if self.router.is_none() {
280            let router: Router = Router::new()
281                .route("/", post(graphql_handler))
282                .route("/{version}", post(graphql_handler))
283                .route("/graphql", post(graphql_handler))
284                .route("/graphql/{version}", post(graphql_handler))
285                .route("/health", get(health_check))
286                .route("/graphql/health", get(health_check))
287                .route("/graphql/{version}/health", get(health_check))
288                .with_state(self.state.clone())
289                .route_layer(CallbackLayer::new(MetricsMakeCallbackHandler {
290                    metrics: self.state.metrics.clone(),
291                }));
292            self.router = Some(router);
293        }
294    }
295
296    pub fn route(mut self, path: &str, method_handler: MethodRouter) -> Self {
297        self.init_router();
298        self.router = self.router.map(|router| router.route(path, method_handler));
299        self
300    }
301
302    pub fn layer<L>(mut self, layer: L) -> Self
303    where
304        L: Layer<Route> + Clone + Send + Sync + 'static,
305        L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
306        <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
307        <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
308        <L::Service as Service<Request<Body>>>::Future: Send + 'static,
309    {
310        self.init_router();
311        self.router = self.router.map(|router| router.layer(layer));
312        self
313    }
314
315    fn cors() -> Result<CorsLayer, Error> {
316        let acl = match std::env::var("ACCESS_CONTROL_ALLOW_ORIGIN") {
317            Ok(value) => {
318                let allow_hosts = value
319                    .split(',')
320                    .map(HeaderValue::from_str)
321                    .collect::<Result<Vec<_>, _>>()
322                    .map_err(|_| {
323                        Error::Internal(
324                            "Cannot resolve access control origin env variable".to_string(),
325                        )
326                    })?;
327                AllowOrigin::list(allow_hosts)
328            }
329            _ => AllowOrigin::any(),
330        };
331        info!("Access control allow origin set to: {acl:?}");
332
333        let cors = CorsLayer::new()
334            // Allow `POST` when accessing the resource
335            .allow_methods([Method::POST])
336            // Allow requests from any origin
337            .allow_origin(acl)
338            .allow_headers([hyper::header::CONTENT_TYPE, LIMITS_HEADER.clone()]);
339        Ok(cors)
340    }
341
342    /// Consumes the `ServerBuilder` to create a `Server` that can be run.
343    pub fn build(self) -> Result<Server, Error> {
344        let state = self.state.clone();
345        let (address, schema, db_reader, resolver, router) = self.build_components();
346
347        // Initialize the watermark background task struct.
348        let watermark_task = WatermarkTask::new(
349            db_reader.clone(),
350            state.metrics.clone(),
351            std::time::Duration::from_millis(state.service.background_tasks.watermark_update_ms),
352            state.cancellation_token.clone(),
353        );
354
355        let system_package_task = SystemPackageTask::new(
356            resolver,
357            watermark_task.epoch_receiver(),
358            state.cancellation_token.clone(),
359        );
360
361        let trigger_exchange_rates_task = TriggerExchangeRatesTask::new(
362            db_reader.clone(),
363            watermark_task.epoch_receiver(),
364            state.cancellation_token.clone(),
365        );
366
367        let router = router
368            .route_layer(middleware::from_fn_with_state(
369                state.version,
370                set_version_middleware,
371            ))
372            .route_layer(middleware::from_fn_with_state(
373                state.version,
374                check_version_middleware,
375            ))
376            .layer(axum::extract::Extension(schema))
377            .layer(axum::extract::Extension(watermark_task.lock()))
378            .layer(Self::cors()?);
379
380        Ok(Server {
381            router,
382            address: address
383                .parse()
384                .map_err(|_| Error::Internal(format!("Failed to parse address {address}")))?,
385            watermark_task,
386            system_package_task,
387            trigger_exchange_rates_task,
388            state,
389            db_reader,
390        })
391    }
392
393    /// Instantiate a `ServerBuilder` from a `ServerConfig`, typically called
394    /// when building the graphql service for production usage.
395    pub async fn from_config(
396        config: &ServerConfig,
397        version: &Version,
398        cancellation_token: CancellationToken,
399    ) -> Result<Self, Error> {
400        // PROMETHEUS
401        let prom_addr: SocketAddr = format!(
402            "{}:{}",
403            config.connection.prom_url, config.connection.prom_port
404        )
405        .parse()
406        .map_err(|_| {
407            Error::Internal(format!(
408                "Failed to parse url {}, port {} into socket address",
409                config.connection.prom_url, config.connection.prom_port
410            ))
411        })?;
412
413        let registry_service = iota_metrics::start_prometheus_server(prom_addr);
414        info!("Starting Prometheus HTTP endpoint at {}", prom_addr);
415        let registry = registry_service.default_registry();
416        registry
417            .register(iota_metrics::uptime_metric(
418                "graphql",
419                version.full,
420                "unknown",
421            ))
422            .unwrap();
423
424        // METRICS
425        let metrics = Metrics::new(&registry);
426        let indexer_metrics = IndexerMetrics::new(&registry);
427        let state = AppState::new(
428            config.connection.clone(),
429            config.service.clone(),
430            metrics.clone(),
431            cancellation_token,
432            *version,
433        );
434        let mut builder = ServerBuilder::new(state);
435
436        let iota_names_config = config.service.iota_names.clone();
437        let zklogin_config = config.service.zklogin.clone();
438        let reader = PgManager::reader_with_config(
439            config.connection.db_url.clone(),
440            config.connection.db_pool_size,
441            // Bound each statement in a request with the overall request timeout, to bound DB
442            // utilisation (in the worst case we will use 2x the request timeout time in DB wall
443            // time).
444            config.service.limits.request_timeout_ms.into(),
445        )
446        .map_err(|e| Error::Internal(format!("Failed to create pg connection pool: {e}")))?;
447
448        // DB
449        let db = Db::new(
450            reader.clone(),
451            config.service.limits.clone(),
452            metrics.clone(),
453        );
454        let loader = DataLoader::new(db.clone());
455        let pg_conn_pool = PgManager::new(reader.clone());
456        let package_store = DbPackageStore::new(loader.clone());
457        let resolver = Arc::new(Resolver::new_with_limits(
458            PackageStoreWithLruCache::new(package_store),
459            config.service.limits.package_resolver_limits(),
460        ));
461
462        builder.db_reader = Some(db.clone());
463        builder.resolver = Some(resolver.clone());
464
465        // SDK for talking to fullnode. Used for executing transactions only
466        // TODO: fail fast if no url, once we enable mutations fully
467        let (iota_sdk_client, optimistic_tx_executor) = if let Some(url) =
468            &config.tx_exec_full_node.node_rpc_url
469        {
470            let iota_sdk_client = IotaClientBuilder::default()
471                .request_timeout(RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD)
472                .max_concurrent_requests(MAX_CONCURRENT_REQUESTS)
473                .build(url)
474                .await
475                .map_err(|e| {
476                    Error::Internal(format!(
477                        "Failed to connect to fullnode {e}. Is the node server running?"
478                    ))
479                })?;
480            let indexer_store = PgIndexerStore::new(reader.get_pool(), indexer_metrics.clone());
481            let optimistic_tx_executor =
482                iota_indexer::optimistic_indexing::OptimisticTransactionExecutor::new(
483                    url,
484                    reader.clone(),
485                    indexer_store,
486                    indexer_metrics.clone(),
487                );
488            (Some(iota_sdk_client), Some(optimistic_tx_executor))
489        } else {
490            warn!(
491                "No fullnode url found in config. `dryRunTransactionBlock` and `executeTransactionBlock` will not work"
492            );
493            (None, None)
494        };
495
496        builder = builder
497            .context_data(config.service.clone())
498            .context_data(loader)
499            .context_data(db)
500            .context_data(pg_conn_pool)
501            .context_data(resolver)
502            .context_data(iota_sdk_client)
503            .context_data(optimistic_tx_executor)
504            .context_data(iota_names_config)
505            .context_data(zklogin_config)
506            .context_data(metrics.clone())
507            .context_data(config.clone())
508            .context_data(ChainIdentifierCache::default());
509
510        if config.internal_features.feature_gate {
511            builder = builder.extension(FeatureGate);
512        }
513
514        if config.internal_features.logger {
515            builder = builder.extension(Logger::default());
516        }
517
518        if config.internal_features.query_limits_checker {
519            builder = builder.extension(QueryLimitsChecker);
520        }
521
522        if config.internal_features.directive_checker {
523            builder = builder.extension(DirectiveChecker);
524        }
525
526        if config.internal_features.query_timeout {
527            builder = builder.extension(Timeout);
528        }
529
530        if config.internal_features.tracing {
531            builder = builder.extension(Tracing);
532        }
533
534        if config.internal_features.apollo_tracing {
535            builder = builder.extension(ApolloTracing);
536        }
537
538        // TODO: uncomment once impl
539        // if config.internal_features.open_telemetry { }
540
541        Ok(builder)
542    }
543}
544
545fn schema_builder() -> SchemaBuilder<Query, Mutation, EmptySubscription> {
546    async_graphql::Schema::build(Query, Mutation, EmptySubscription)
547        .register_output_type::<IMoveObject>()
548        .register_output_type::<IObject>()
549        .register_output_type::<IOwner>()
550        .register_output_type::<IMoveDatatype>()
551}
552
553pub(crate) fn get_fullnode_client<'ctx>(
554    ctx: &'ctx Context<'_>,
555) -> async_graphql::Result<&'ctx IotaClient> {
556    let iota_sdk_client: &Option<IotaClient> = ctx
557        .data()
558        .map_err(|_| Error::Internal("Unable to fetch IOTA SDK client".to_string()))
559        .extend()?;
560    iota_sdk_client
561        .as_ref()
562        .ok_or_else(|| Error::Internal("IOTA SDK client not initialized".to_string()))
563        .extend()
564}
565
566/// Return the string representation of the schema used by this server.
567pub fn export_schema() -> String {
568    schema_builder().finish().sdl()
569}
570
571/// Entry point for graphql requests. Each request is stamped with a unique ID,
572/// a `ShowUsage` flag if set in the request headers, and the watermark as set
573/// by the background task.
574async fn graphql_handler(
575    ConnectInfo(addr): ConnectInfo<SocketAddr>,
576    schema: Extension<IotaGraphQLSchema>,
577    Extension(watermark_lock): Extension<WatermarkLock>,
578    headers: HeaderMap,
579    req: GraphQLRequest,
580) -> (axum::http::Extensions, GraphQLResponse) {
581    let mut req = req.into_inner();
582    req.data.insert(Uuid::new_v4());
583    if headers.contains_key(ShowUsage::name()) {
584        req.data.insert(ShowUsage)
585    }
586    // Capture the IP address of the client
587    // Note: if a load balancer is used it must be configured to forward the client
588    // IP address
589    req.data.insert(addr);
590
591    req.data.insert(Watermark::new(watermark_lock).await);
592
593    let result = schema.execute(req).await;
594
595    // If there are errors, insert them as an extension so that the Metrics callback
596    // handler can pull it out later.
597    let mut extensions = axum::http::Extensions::new();
598    if result.is_err() {
599        extensions.insert(GraphqlErrors(std::sync::Arc::new(result.errors.clone())));
600    };
601    (extensions, result.into())
602}
603
604#[derive(Clone)]
605struct MetricsMakeCallbackHandler {
606    metrics: Metrics,
607}
608
609impl MakeCallbackHandler for MetricsMakeCallbackHandler {
610    type Handler = MetricsCallbackHandler;
611
612    fn make_handler(&self, _request: &http::request::Parts) -> Self::Handler {
613        let start = Instant::now();
614        let metrics = self.metrics.clone();
615
616        metrics.request_metrics.inflight_requests.inc();
617        metrics.inc_num_queries();
618
619        MetricsCallbackHandler { metrics, start }
620    }
621}
622
623struct MetricsCallbackHandler {
624    metrics: Metrics,
625    start: Instant,
626}
627
628impl ResponseHandler for MetricsCallbackHandler {
629    fn on_response(&mut self, response: &http::response::Parts) {
630        if let Some(errors) = response.extensions.get::<GraphqlErrors>() {
631            self.metrics.inc_errors(&errors.0);
632        }
633    }
634
635    fn on_error<E>(&mut self, _error: &E) {
636        // Do nothing if the whole service errored
637        //
638        // in Axum this isn't possible since all services are required to have
639        // an error type of Infallible
640    }
641}
642
643impl Drop for MetricsCallbackHandler {
644    fn drop(&mut self) {
645        self.metrics.query_latency(self.start.elapsed());
646        self.metrics.request_metrics.inflight_requests.dec();
647    }
648}
649
650#[derive(Debug, Clone)]
651struct GraphqlErrors(std::sync::Arc<Vec<async_graphql::ServerError>>);
652
653/// Connect via a TCPStream to the DB to check if it is alive
654async fn db_health_check(State(connection): State<ConnectionConfig>) -> StatusCode {
655    let Ok(url) = reqwest::Url::parse(connection.db_url.as_str()) else {
656        return StatusCode::INTERNAL_SERVER_ERROR;
657    };
658
659    let Some(host) = url.host_str() else {
660        return StatusCode::INTERNAL_SERVER_ERROR;
661    };
662
663    let tcp_url = if let Some(port) = url.port() {
664        format!("{host}:{port}")
665    } else {
666        host.to_string()
667    };
668
669    if TcpStream::connect(tcp_url).is_err() {
670        StatusCode::INTERNAL_SERVER_ERROR
671    } else {
672        StatusCode::OK
673    }
674}
675
676#[derive(serde::Deserialize)]
677struct HealthParam {
678    max_checkpoint_lag_ms: Option<u64>,
679}
680
681/// Endpoint for querying the health of the service.
682/// It returns 500 for any internal error, including not connecting to the DB,
683/// and 504 if the checkpoint timestamp is too far behind the current timestamp
684/// as per the max checkpoint timestamp lag query parameter, or the default
685/// value if not provided.
686async fn health_check(
687    State(connection): State<ConnectionConfig>,
688    Extension(watermark_lock): Extension<WatermarkLock>,
689    AxumQuery(query_params): AxumQuery<HealthParam>,
690) -> StatusCode {
691    let db_health_check = db_health_check(axum::extract::State(connection)).await;
692    if db_health_check != StatusCode::OK {
693        return db_health_check;
694    }
695
696    let max_checkpoint_lag_ms = query_params
697        .max_checkpoint_lag_ms
698        .map(Duration::from_millis)
699        .unwrap_or_else(|| DEFAULT_MAX_CHECKPOINT_LAG);
700
701    let checkpoint_timestamp =
702        Duration::from_millis(watermark_lock.read().await.checkpoint_timestamp_ms);
703
704    let now_millis = Utc::now().timestamp_millis();
705
706    // Check for negative timestamp or conversion failure
707    let now: Duration = match u64::try_from(now_millis) {
708        Ok(val) => Duration::from_millis(val),
709        Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
710    };
711
712    if (now - checkpoint_timestamp) > max_checkpoint_lag_ms {
713        return StatusCode::GATEWAY_TIMEOUT;
714    }
715
716    db_health_check
717}
718
719// One server per proc, so this is okay
720async fn get_or_init_server_start_time() -> &'static Instant {
721    static ONCE: OnceCell<Instant> = OnceCell::const_new();
722    ONCE.get_or_init(|| async move { Instant::now() }).await
723}
724
725pub mod tests {
726    use std::{sync::Arc, time::Duration};
727
728    use async_graphql::{
729        Response,
730        extensions::{Extension, ExtensionContext, NextExecute},
731    };
732    use iota_indexer::optimistic_indexing::OptimisticTransactionExecutor;
733    use iota_sdk::wallet_context::WalletContext;
734    use iota_types::transaction::TransactionData;
735    use uuid::Uuid;
736
737    use super::*;
738    use crate::{
739        config::{ConnectionConfig, Limits, ServiceConfig, Version},
740        context_data::db_data_provider::PgManager,
741        extensions::{query_limits_checker::QueryLimitsChecker, timeout::Timeout},
742    };
743
744    /// Prepares a schema for tests dealing with extensions. Returns a
745    /// `ServerBuilder` that can be further extended with `context_data` and
746    /// `extension` for testing.
747    fn prep_schema(
748        connection_config: Option<ConnectionConfig>,
749        service_config: Option<ServiceConfig>,
750    ) -> ServerBuilder {
751        let connection_config = connection_config.unwrap_or_default();
752        let service_config = service_config.unwrap_or_default();
753
754        let reader = PgManager::reader_with_config(
755            connection_config.db_url.clone(),
756            connection_config.db_pool_size,
757            service_config.limits.request_timeout_ms.into(),
758        )
759        .expect("failed to create pg connection pool");
760
761        let version = Version::for_testing();
762        let metrics = metrics();
763        let db = Db::new(
764            reader.clone(),
765            service_config.limits.clone(),
766            metrics.clone(),
767        );
768        let loader = DataLoader::new(db.clone());
769        let pg_conn_pool = PgManager::new(reader);
770        let cancellation_token = CancellationToken::new();
771        let watermark = Watermark {
772            checkpoint: 1,
773            checkpoint_timestamp_ms: 1,
774            epoch: 0,
775        };
776        let state = AppState::new(
777            connection_config.clone(),
778            service_config.clone(),
779            metrics.clone(),
780            cancellation_token.clone(),
781            version,
782        );
783        ServerBuilder::new(state)
784            .context_data(db)
785            .context_data(loader)
786            .context_data(pg_conn_pool)
787            .context_data(service_config)
788            .context_data(query_id())
789            .context_data(ip_address())
790            .context_data(watermark)
791            .context_data(ChainIdentifierCache::default())
792            .context_data(metrics)
793    }
794
795    fn metrics() -> Metrics {
796        let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
797        let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
798        Metrics::new(&registry)
799    }
800
801    fn ip_address() -> SocketAddr {
802        let binding_address: SocketAddr = "0.0.0.0:51515".parse().unwrap();
803        binding_address
804    }
805
806    fn query_id() -> Uuid {
807        Uuid::new_v4()
808    }
809
810    pub async fn test_timeout_impl(
811        wallet: &WalletContext,
812        optimistic_tx_executor: &OptimisticTransactionExecutor,
813    ) {
814        struct TimedExecuteExt {
815            pub min_req_delay: Duration,
816        }
817
818        impl ExtensionFactory for TimedExecuteExt {
819            fn create(&self) -> Arc<dyn Extension> {
820                Arc::new(TimedExecuteExt {
821                    min_req_delay: self.min_req_delay,
822                })
823            }
824        }
825
826        #[async_trait::async_trait]
827        impl Extension for TimedExecuteExt {
828            async fn execute(
829                &self,
830                ctx: &ExtensionContext<'_>,
831                operation_name: Option<&str>,
832                next: NextExecute<'_>,
833            ) -> Response {
834                tokio::time::sleep(self.min_req_delay).await;
835                next.run(ctx, operation_name).await
836            }
837        }
838
839        async fn test_timeout(
840            delay: Duration,
841            timeout: Duration,
842            query: &str,
843            optimistic_tx_executor: &OptimisticTransactionExecutor,
844        ) -> Response {
845            let mut cfg = ServiceConfig::default();
846            cfg.limits.request_timeout_ms = timeout.as_millis() as u32;
847            cfg.limits.mutation_timeout_ms = timeout.as_millis() as u32;
848
849            let schema = prep_schema(None, Some(cfg))
850                .context_data(Some(optimistic_tx_executor.clone()))
851                .extension(Timeout)
852                .extension(TimedExecuteExt {
853                    min_req_delay: delay,
854                })
855                .build_schema();
856
857            schema.execute(query).await
858        }
859
860        let query = r#"{ checkpoint(id: {sequenceNumber: 0 }) { digest }}"#;
861        let timeout = Duration::from_millis(1000);
862        let delay = Duration::from_millis(100);
863
864        test_timeout(delay, timeout, query, optimistic_tx_executor)
865            .await
866            .into_result()
867            .expect("should complete successfully");
868
869        // Should timeout
870        let errs: Vec<_> = test_timeout(delay, delay, query, optimistic_tx_executor)
871            .await
872            .into_result()
873            .unwrap_err()
874            .into_iter()
875            .map(|e| e.message)
876            .collect();
877        let exp = format!("Query request timed out. Limit: {}s", delay.as_secs_f32());
878        assert_eq!(errs, vec![exp]);
879
880        // Should timeout for mutation
881        // Create a transaction and sign it, and use the tx_bytes + signatures for the
882        // GraphQL executeTransactionBlock mutation call.
883        let addresses = wallet.get_addresses();
884        let gas = wallet
885            .get_one_gas_object_owned_by_address(addresses[0])
886            .await
887            .unwrap();
888        let tx_data = TransactionData::new_transfer_iota(
889            addresses[1],
890            addresses[0],
891            Some(1000),
892            gas.unwrap(),
893            1_000_000,
894            wallet.get_reference_gas_price().await.unwrap(),
895        );
896
897        let tx = wallet.sign_transaction(&tx_data);
898        let (tx_bytes, signatures) = tx.to_tx_bytes_and_signatures();
899
900        let signature_base64 = &signatures[0];
901        let query = format!(
902            r#"
903            mutation {{
904              executeTransactionBlock(txBytes: "{}", signatures: "{}") {{
905                effects {{
906                  status
907                }}
908              }}
909            }}"#,
910            tx_bytes.encoded(),
911            signature_base64.encoded()
912        );
913        let errs: Vec<_> = test_timeout(delay, delay, &query, optimistic_tx_executor)
914            .await
915            .into_result()
916            .unwrap_err()
917            .into_iter()
918            .map(|e| e.message)
919            .collect();
920        let exp = format!(
921            "Mutation request timed out. Limit: {}s",
922            delay.as_secs_f32()
923        );
924        assert_eq!(errs, vec![exp]);
925    }
926
927    pub async fn test_query_depth_limit_impl() {
928        async fn exec_query_depth_limit(depth: u32, query: &str) -> Response {
929            let service_config = ServiceConfig {
930                limits: Limits {
931                    max_query_depth: depth,
932                    ..Default::default()
933                },
934                ..Default::default()
935            };
936
937            let schema = prep_schema(None, Some(service_config))
938                .extension(QueryLimitsChecker)
939                .build_schema();
940            schema.execute(query).await
941        }
942
943        exec_query_depth_limit(1, "{ chainIdentifier }")
944            .await
945            .into_result()
946            .expect("should complete successfully");
947
948        exec_query_depth_limit(
949            5,
950            "{ chainIdentifier protocolConfig { configs { value key }} }",
951        )
952        .await
953        .into_result()
954        .expect("should complete successfully");
955
956        // Should fail
957        let errs: Vec<_> = exec_query_depth_limit(0, "{ chainIdentifier }")
958            .await
959            .into_result()
960            .unwrap_err()
961            .into_iter()
962            .map(|e| e.message)
963            .collect();
964
965        assert_eq!(errs, vec!["Query nesting is over 0".to_string()]);
966        let errs: Vec<_> = exec_query_depth_limit(
967            2,
968            "{ chainIdentifier protocolConfig { configs { value key }} }",
969        )
970        .await
971        .into_result()
972        .unwrap_err()
973        .into_iter()
974        .map(|e| e.message)
975        .collect();
976        assert_eq!(errs, vec!["Query nesting is over 2".to_string()]);
977    }
978
979    pub async fn test_query_node_limit_impl() {
980        async fn exec_query_node_limit(nodes: u32, query: &str) -> Response {
981            let service_config = ServiceConfig {
982                limits: Limits {
983                    max_query_nodes: nodes,
984                    ..Default::default()
985                },
986                ..Default::default()
987            };
988
989            let schema = prep_schema(None, Some(service_config))
990                .extension(QueryLimitsChecker)
991                .build_schema();
992            schema.execute(query).await
993        }
994
995        exec_query_node_limit(1, "{ chainIdentifier }")
996            .await
997            .into_result()
998            .expect("should complete successfully");
999
1000        exec_query_node_limit(
1001            5,
1002            "{ chainIdentifier protocolConfig { configs { value key }} }",
1003        )
1004        .await
1005        .into_result()
1006        .expect("should complete successfully");
1007
1008        // Should fail
1009        let err: Vec<_> = exec_query_node_limit(0, "{ chainIdentifier }")
1010            .await
1011            .into_result()
1012            .unwrap_err()
1013            .into_iter()
1014            .map(|e| e.message)
1015            .collect();
1016        assert_eq!(err, vec!["Query has over 0 nodes".to_string()]);
1017
1018        let err: Vec<_> = exec_query_node_limit(
1019            4,
1020            "{ chainIdentifier protocolConfig { configs { value key }} }",
1021        )
1022        .await
1023        .into_result()
1024        .unwrap_err()
1025        .into_iter()
1026        .map(|e| e.message)
1027        .collect();
1028        assert_eq!(err, vec!["Query has over 4 nodes".to_string()]);
1029    }
1030
1031    pub async fn test_query_default_page_limit_impl(connection_config: ConnectionConfig) {
1032        let service_config = ServiceConfig {
1033            limits: Limits {
1034                default_page_size: 1,
1035                ..Default::default()
1036            },
1037            ..Default::default()
1038        };
1039        let schema = prep_schema(Some(connection_config), Some(service_config)).build_schema();
1040
1041        let resp = schema
1042            .execute("{ checkpoints { nodes { sequenceNumber } } }")
1043            .await;
1044        let data = resp.data.clone().into_json().unwrap();
1045        let checkpoints = data
1046            .get("checkpoints")
1047            .unwrap()
1048            .get("nodes")
1049            .unwrap()
1050            .as_array()
1051            .unwrap();
1052        assert_eq!(
1053            checkpoints.len(),
1054            1,
1055            "Checkpoints should have exactly one element"
1056        );
1057
1058        let resp = schema
1059            .execute("{ checkpoints(first: 2) { nodes { sequenceNumber } } }")
1060            .await;
1061        let data = resp.data.clone().into_json().unwrap();
1062        let checkpoints = data
1063            .get("checkpoints")
1064            .unwrap()
1065            .get("nodes")
1066            .unwrap()
1067            .as_array()
1068            .unwrap();
1069        assert_eq!(
1070            checkpoints.len(),
1071            2,
1072            "Checkpoints should return two elements"
1073        );
1074    }
1075
1076    pub async fn test_query_max_page_limit_impl() {
1077        let schema = prep_schema(None, None).build_schema();
1078
1079        schema
1080            .execute("{ objects(first: 1) { nodes { version } } }")
1081            .await
1082            .into_result()
1083            .expect("should complete successfully");
1084
1085        // Should fail
1086        let err: Vec<_> = schema
1087            .execute("{ objects(first: 51) { nodes { version } } }")
1088            .await
1089            .into_result()
1090            .unwrap_err()
1091            .into_iter()
1092            .map(|e| e.message)
1093            .collect();
1094        assert_eq!(
1095            err,
1096            vec!["Connection's page size of 51 exceeds max of 50".to_string()]
1097        );
1098    }
1099
1100    pub async fn test_query_complexity_metrics_impl() {
1101        let server_builder = prep_schema(None, None);
1102        let metrics = server_builder.state.metrics.clone();
1103        let schema = server_builder
1104            .extension(QueryLimitsChecker) // QueryLimitsChecker is where we actually set the metrics
1105            .build_schema();
1106
1107        schema
1108            .execute("{ chainIdentifier }")
1109            .await
1110            .into_result()
1111            .expect("should complete successfully");
1112
1113        let req_metrics = metrics.request_metrics;
1114        assert_eq!(req_metrics.input_nodes.get_sample_count(), 1);
1115        assert_eq!(req_metrics.output_nodes.get_sample_count(), 1);
1116        assert_eq!(req_metrics.query_depth.get_sample_count(), 1);
1117        assert_eq!(req_metrics.input_nodes.get_sample_sum(), 1.);
1118        assert_eq!(req_metrics.output_nodes.get_sample_sum(), 1.);
1119        assert_eq!(req_metrics.query_depth.get_sample_sum(), 1.);
1120
1121        schema
1122            .execute("{ chainIdentifier protocolConfig { configs { value key }} }")
1123            .await
1124            .into_result()
1125            .expect("should complete successfully");
1126
1127        assert_eq!(req_metrics.input_nodes.get_sample_count(), 2);
1128        assert_eq!(req_metrics.output_nodes.get_sample_count(), 2);
1129        assert_eq!(req_metrics.query_depth.get_sample_count(), 2);
1130        assert_eq!(req_metrics.input_nodes.get_sample_sum(), 2. + 4.);
1131        assert_eq!(req_metrics.output_nodes.get_sample_sum(), 2. + 4.);
1132        assert_eq!(req_metrics.query_depth.get_sample_sum(), 1. + 3.);
1133    }
1134
1135    pub async fn test_health_check_impl() {
1136        let server_builder = prep_schema(None, None);
1137        let url = format!(
1138            "http://{}:{}/health",
1139            server_builder.state.connection.host, server_builder.state.connection.port
1140        );
1141        server_builder.build_schema();
1142
1143        let resp = reqwest::get(&url).await.unwrap();
1144        assert_eq!(resp.status(), reqwest::StatusCode::OK);
1145
1146        let url_with_param = format!("{url}?max_checkpoint_lag_ms=1");
1147        let resp = reqwest::get(&url_with_param).await.unwrap();
1148        assert_eq!(resp.status(), reqwest::StatusCode::GATEWAY_TIMEOUT);
1149    }
1150}