1use std::{
5 any::Any,
6 convert::Infallible,
7 net::{SocketAddr, TcpStream},
8 sync::Arc,
9 time::{Duration, Instant},
10};
11
12use async_graphql::{
13 Context, EmptySubscription, ResultExt, Schema, SchemaBuilder,
14 extensions::{ApolloTracing, ExtensionFactory, Tracing},
15};
16use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
17use axum::{
18 Extension, Router,
19 body::Body,
20 extract::{ConnectInfo, FromRef, Query as AxumQuery, State},
21 http::{HeaderMap, StatusCode},
22 middleware::{self},
23 response::IntoResponse,
24 routing::{MethodRouter, Route, get, post},
25};
26use chrono::Utc;
27use http::{HeaderValue, Method, Request};
28use iota_graphql_rpc_headers::LIMITS_HEADER;
29use iota_indexer::{
30 db::{get_pool_connection, setup_postgres::check_db_migration_consistency},
31 metrics::IndexerMetrics,
32 store::PgIndexerStore,
33};
34use iota_metrics::spawn_monitored_task;
35use iota_network_stack::callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler};
36use iota_package_resolver::{PackageStoreWithLruCache, Resolver};
37use iota_sdk::{IotaClient, IotaClientBuilder};
38use tokio::{join, net::TcpListener, sync::OnceCell};
39use tokio_util::sync::CancellationToken;
40use tower::{Layer, Service};
41use tower_http::cors::{AllowOrigin, CorsLayer};
42use tracing::{info, warn};
43use uuid::Uuid;
44
45use crate::{
46 config::{
47 ConnectionConfig, MAX_CONCURRENT_REQUESTS, RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD,
48 ServerConfig, ServiceConfig, Version,
49 },
50 context_data::db_data_provider::PgManager,
51 data::{
52 DataLoader, Db,
53 package_resolver::{DbPackageStore, PackageResolver},
54 },
55 error::Error,
56 extensions::{
57 directive_checker::DirectiveChecker,
58 feature_gate::FeatureGate,
59 logger::Logger,
60 query_limits_checker::{QueryLimitsChecker, ShowUsage},
61 timeout::Timeout,
62 },
63 metrics::Metrics,
64 mutation::Mutation,
65 server::{
66 exchange_rates_task::TriggerExchangeRatesTask,
67 system_package_task::SystemPackageTask,
68 version::{check_version_middleware, set_version_middleware},
69 watermark_task::{Watermark, WatermarkLock, WatermarkTask},
70 },
71 types::{
72 chain_identifier::ChainIdentifierCache,
73 datatype::IMoveDatatype,
74 move_object::IMoveObject,
75 object::IObject,
76 owner::IOwner,
77 query::{IotaGraphQLSchema, Query},
78 },
79};
80
81const DEFAULT_MAX_CHECKPOINT_LAG: Duration = Duration::from_secs(300);
84
85pub(crate) struct Server {
86 router: Router,
87 address: SocketAddr,
88 watermark_task: WatermarkTask,
89 system_package_task: SystemPackageTask,
90 trigger_exchange_rates_task: TriggerExchangeRatesTask,
91 state: AppState,
92 db_reader: Db,
93}
94
95impl Server {
96 pub async fn run(mut self) -> Result<(), Error> {
100 get_or_init_server_start_time().await;
101
102 {
103 info!("Starting compatibility check");
105 let mut connection = get_pool_connection(&self.db_reader.inner.get_pool())?;
106 check_db_migration_consistency(&mut connection)?;
107 info!("Compatibility check passed");
108 }
109
110 let watermark_task = {
114 info!("Starting watermark update task");
115 spawn_monitored_task!(async move {
116 self.watermark_task.run().await;
117 })
118 };
119
120 let system_package_task = {
123 info!("Starting system package task");
124 spawn_monitored_task!(async move {
125 self.system_package_task.run().await;
126 })
127 };
128
129 let trigger_exchange_rates_task = {
130 info!("Starting trigger exchange rates task");
131 spawn_monitored_task!(async move {
132 self.trigger_exchange_rates_task.run().await;
133 })
134 };
135
136 let server_task = {
137 info!("Starting graphql service");
138 let cancellation_token = self.state.cancellation_token.clone();
139 let address = self.address;
140 let router = self.router;
141 spawn_monitored_task!(async move {
142 axum::serve(
143 TcpListener::bind(address)
144 .await
145 .map_err(|e| Error::Internal(format!("listener bind failed: {e}")))?,
146 router.into_make_service_with_connect_info::<SocketAddr>(),
147 )
148 .with_graceful_shutdown(async move {
149 cancellation_token.cancelled().await;
150 info!("Shutdown signal received, terminating graphql service");
151 })
152 .await
153 .map_err(|e| Error::Internal(format!("Server run failed: {e}")))
154 })
155 };
156
157 let _ = join!(
161 watermark_task,
162 system_package_task,
163 trigger_exchange_rates_task,
164 server_task
165 );
166
167 Ok(())
168 }
169}
170
171pub(crate) struct ServerBuilder {
172 state: AppState,
173 schema: SchemaBuilder<Query, Mutation, EmptySubscription>,
174 router: Option<Router>,
175 db_reader: Option<Db>,
176 resolver: Option<PackageResolver>,
177}
178
179#[derive(Clone)]
180pub(crate) struct AppState {
181 connection: ConnectionConfig,
182 service: ServiceConfig,
183 metrics: Metrics,
184 cancellation_token: CancellationToken,
185 pub version: Version,
186}
187
188impl AppState {
189 pub(crate) fn new(
190 connection: ConnectionConfig,
191 service: ServiceConfig,
192 metrics: Metrics,
193 cancellation_token: CancellationToken,
194 version: Version,
195 ) -> Self {
196 Self {
197 connection,
198 service,
199 metrics,
200 cancellation_token,
201 version,
202 }
203 }
204}
205
206impl FromRef<AppState> for ConnectionConfig {
207 fn from_ref(app_state: &AppState) -> ConnectionConfig {
208 app_state.connection.clone()
209 }
210}
211
212impl FromRef<AppState> for Metrics {
213 fn from_ref(app_state: &AppState) -> Metrics {
214 app_state.metrics.clone()
215 }
216}
217
218impl ServerBuilder {
219 pub fn new(state: AppState) -> Self {
220 Self {
221 state,
222 schema: schema_builder(),
223 router: None,
224 db_reader: None,
225 resolver: None,
226 }
227 }
228
229 pub fn address(&self) -> String {
230 format!(
231 "{}:{}",
232 self.state.connection.host, self.state.connection.port
233 )
234 }
235
236 pub fn context_data(mut self, context_data: impl Any + Send + Sync) -> Self {
237 self.schema = self.schema.data(context_data);
238 self
239 }
240
241 pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
242 self.schema = self.schema.extension(extension);
243 self
244 }
245
246 fn build_schema(self) -> Schema<Query, Mutation, EmptySubscription> {
247 self.schema.finish()
248 }
249
250 fn build_components(
253 self,
254 ) -> (
255 String,
256 Schema<Query, Mutation, EmptySubscription>,
257 Db,
258 PackageResolver,
259 Router,
260 ) {
261 let address = self.address();
262 let ServerBuilder {
263 state: _,
264 schema,
265 db_reader,
266 resolver,
267 router,
268 } = self;
269 (
270 address,
271 schema.finish(),
272 db_reader.expect("db reader not initialized"),
273 resolver.expect("package resolver not initialized"),
274 router.expect("router not initialized"),
275 )
276 }
277
278 fn init_router(&mut self) {
279 if self.router.is_none() {
280 let router: Router = Router::new()
281 .route("/", post(graphql_handler))
282 .route("/{version}", post(graphql_handler))
283 .route("/graphql", post(graphql_handler))
284 .route("/graphql/{version}", post(graphql_handler))
285 .route("/health", get(health_check))
286 .route("/graphql/health", get(health_check))
287 .route("/graphql/{version}/health", get(health_check))
288 .with_state(self.state.clone())
289 .route_layer(CallbackLayer::new(MetricsMakeCallbackHandler {
290 metrics: self.state.metrics.clone(),
291 }));
292 self.router = Some(router);
293 }
294 }
295
296 pub fn route(mut self, path: &str, method_handler: MethodRouter) -> Self {
297 self.init_router();
298 self.router = self.router.map(|router| router.route(path, method_handler));
299 self
300 }
301
302 pub fn layer<L>(mut self, layer: L) -> Self
303 where
304 L: Layer<Route> + Clone + Send + Sync + 'static,
305 L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
306 <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
307 <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
308 <L::Service as Service<Request<Body>>>::Future: Send + 'static,
309 {
310 self.init_router();
311 self.router = self.router.map(|router| router.layer(layer));
312 self
313 }
314
315 fn cors() -> Result<CorsLayer, Error> {
316 let acl = match std::env::var("ACCESS_CONTROL_ALLOW_ORIGIN") {
317 Ok(value) => {
318 let allow_hosts = value
319 .split(',')
320 .map(HeaderValue::from_str)
321 .collect::<Result<Vec<_>, _>>()
322 .map_err(|_| {
323 Error::Internal(
324 "Cannot resolve access control origin env variable".to_string(),
325 )
326 })?;
327 AllowOrigin::list(allow_hosts)
328 }
329 _ => AllowOrigin::any(),
330 };
331 info!("Access control allow origin set to: {acl:?}");
332
333 let cors = CorsLayer::new()
334 .allow_methods([Method::POST])
336 .allow_origin(acl)
338 .allow_headers([hyper::header::CONTENT_TYPE, LIMITS_HEADER.clone()]);
339 Ok(cors)
340 }
341
342 pub fn build(self) -> Result<Server, Error> {
344 let state = self.state.clone();
345 let (address, schema, db_reader, resolver, router) = self.build_components();
346
347 let watermark_task = WatermarkTask::new(
349 db_reader.clone(),
350 state.metrics.clone(),
351 std::time::Duration::from_millis(state.service.background_tasks.watermark_update_ms),
352 state.cancellation_token.clone(),
353 );
354
355 let system_package_task = SystemPackageTask::new(
356 resolver,
357 watermark_task.epoch_receiver(),
358 state.cancellation_token.clone(),
359 );
360
361 let trigger_exchange_rates_task = TriggerExchangeRatesTask::new(
362 db_reader.clone(),
363 watermark_task.epoch_receiver(),
364 state.cancellation_token.clone(),
365 );
366
367 let router = router
368 .route_layer(middleware::from_fn_with_state(
369 state.version,
370 set_version_middleware,
371 ))
372 .route_layer(middleware::from_fn_with_state(
373 state.version,
374 check_version_middleware,
375 ))
376 .layer(axum::extract::Extension(schema))
377 .layer(axum::extract::Extension(watermark_task.lock()))
378 .layer(Self::cors()?);
379
380 Ok(Server {
381 router,
382 address: address
383 .parse()
384 .map_err(|_| Error::Internal(format!("Failed to parse address {address}")))?,
385 watermark_task,
386 system_package_task,
387 trigger_exchange_rates_task,
388 state,
389 db_reader,
390 })
391 }
392
393 pub async fn from_config(
396 config: &ServerConfig,
397 version: &Version,
398 cancellation_token: CancellationToken,
399 ) -> Result<Self, Error> {
400 let prom_addr: SocketAddr = format!(
402 "{}:{}",
403 config.connection.prom_url, config.connection.prom_port
404 )
405 .parse()
406 .map_err(|_| {
407 Error::Internal(format!(
408 "Failed to parse url {}, port {} into socket address",
409 config.connection.prom_url, config.connection.prom_port
410 ))
411 })?;
412
413 let registry_service = iota_metrics::start_prometheus_server(prom_addr);
414 info!("Starting Prometheus HTTP endpoint at {}", prom_addr);
415 let registry = registry_service.default_registry();
416 registry
417 .register(iota_metrics::uptime_metric(
418 "graphql",
419 version.full,
420 "unknown",
421 ))
422 .unwrap();
423
424 let metrics = Metrics::new(®istry);
426 let indexer_metrics = IndexerMetrics::new(®istry);
427 let state = AppState::new(
428 config.connection.clone(),
429 config.service.clone(),
430 metrics.clone(),
431 cancellation_token,
432 *version,
433 );
434 let mut builder = ServerBuilder::new(state);
435
436 let iota_names_config = config.service.iota_names.clone();
437 let zklogin_config = config.service.zklogin.clone();
438 let reader = PgManager::reader_with_config(
439 config.connection.db_url.clone(),
440 config.connection.db_pool_size,
441 config.service.limits.request_timeout_ms.into(),
445 )
446 .map_err(|e| Error::Internal(format!("Failed to create pg connection pool: {e}")))?;
447
448 let db = Db::new(
450 reader.clone(),
451 config.service.limits.clone(),
452 metrics.clone(),
453 );
454 let loader = DataLoader::new(db.clone());
455 let pg_conn_pool = PgManager::new(reader.clone());
456 let package_store = DbPackageStore::new(loader.clone());
457 let resolver = Arc::new(Resolver::new_with_limits(
458 PackageStoreWithLruCache::new(package_store),
459 config.service.limits.package_resolver_limits(),
460 ));
461
462 builder.db_reader = Some(db.clone());
463 builder.resolver = Some(resolver.clone());
464
465 let (iota_sdk_client, optimistic_tx_executor) = if let Some(url) =
468 &config.tx_exec_full_node.node_rpc_url
469 {
470 let iota_sdk_client = IotaClientBuilder::default()
471 .request_timeout(RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD)
472 .max_concurrent_requests(MAX_CONCURRENT_REQUESTS)
473 .build(url)
474 .await
475 .map_err(|e| {
476 Error::Internal(format!(
477 "Failed to connect to fullnode {e}. Is the node server running?"
478 ))
479 })?;
480 let indexer_store = PgIndexerStore::new(reader.get_pool(), indexer_metrics.clone());
481 let optimistic_tx_executor =
482 iota_indexer::optimistic_indexing::OptimisticTransactionExecutor::new(
483 url,
484 reader.clone(),
485 indexer_store,
486 indexer_metrics.clone(),
487 );
488 (Some(iota_sdk_client), Some(optimistic_tx_executor))
489 } else {
490 warn!(
491 "No fullnode url found in config. `dryRunTransactionBlock` and `executeTransactionBlock` will not work"
492 );
493 (None, None)
494 };
495
496 builder = builder
497 .context_data(config.service.clone())
498 .context_data(loader)
499 .context_data(db)
500 .context_data(pg_conn_pool)
501 .context_data(resolver)
502 .context_data(iota_sdk_client)
503 .context_data(optimistic_tx_executor)
504 .context_data(iota_names_config)
505 .context_data(zklogin_config)
506 .context_data(metrics.clone())
507 .context_data(config.clone())
508 .context_data(ChainIdentifierCache::default());
509
510 if config.internal_features.feature_gate {
511 builder = builder.extension(FeatureGate);
512 }
513
514 if config.internal_features.logger {
515 builder = builder.extension(Logger::default());
516 }
517
518 if config.internal_features.query_limits_checker {
519 builder = builder.extension(QueryLimitsChecker);
520 }
521
522 if config.internal_features.directive_checker {
523 builder = builder.extension(DirectiveChecker);
524 }
525
526 if config.internal_features.query_timeout {
527 builder = builder.extension(Timeout);
528 }
529
530 if config.internal_features.tracing {
531 builder = builder.extension(Tracing);
532 }
533
534 if config.internal_features.apollo_tracing {
535 builder = builder.extension(ApolloTracing);
536 }
537
538 Ok(builder)
542 }
543}
544
545fn schema_builder() -> SchemaBuilder<Query, Mutation, EmptySubscription> {
546 async_graphql::Schema::build(Query, Mutation, EmptySubscription)
547 .register_output_type::<IMoveObject>()
548 .register_output_type::<IObject>()
549 .register_output_type::<IOwner>()
550 .register_output_type::<IMoveDatatype>()
551}
552
553pub(crate) fn get_fullnode_client<'ctx>(
554 ctx: &'ctx Context<'_>,
555) -> async_graphql::Result<&'ctx IotaClient> {
556 let iota_sdk_client: &Option<IotaClient> = ctx
557 .data()
558 .map_err(|_| Error::Internal("Unable to fetch IOTA SDK client".to_string()))
559 .extend()?;
560 iota_sdk_client
561 .as_ref()
562 .ok_or_else(|| Error::Internal("IOTA SDK client not initialized".to_string()))
563 .extend()
564}
565
566pub fn export_schema() -> String {
568 schema_builder().finish().sdl()
569}
570
571async fn graphql_handler(
575 ConnectInfo(addr): ConnectInfo<SocketAddr>,
576 schema: Extension<IotaGraphQLSchema>,
577 Extension(watermark_lock): Extension<WatermarkLock>,
578 headers: HeaderMap,
579 req: GraphQLRequest,
580) -> (axum::http::Extensions, GraphQLResponse) {
581 let mut req = req.into_inner();
582 req.data.insert(Uuid::new_v4());
583 if headers.contains_key(ShowUsage::name()) {
584 req.data.insert(ShowUsage)
585 }
586 req.data.insert(addr);
590
591 req.data.insert(Watermark::new(watermark_lock).await);
592
593 let result = schema.execute(req).await;
594
595 let mut extensions = axum::http::Extensions::new();
598 if result.is_err() {
599 extensions.insert(GraphqlErrors(std::sync::Arc::new(result.errors.clone())));
600 };
601 (extensions, result.into())
602}
603
604#[derive(Clone)]
605struct MetricsMakeCallbackHandler {
606 metrics: Metrics,
607}
608
609impl MakeCallbackHandler for MetricsMakeCallbackHandler {
610 type Handler = MetricsCallbackHandler;
611
612 fn make_handler(&self, _request: &http::request::Parts) -> Self::Handler {
613 let start = Instant::now();
614 let metrics = self.metrics.clone();
615
616 metrics.request_metrics.inflight_requests.inc();
617 metrics.inc_num_queries();
618
619 MetricsCallbackHandler { metrics, start }
620 }
621}
622
623struct MetricsCallbackHandler {
624 metrics: Metrics,
625 start: Instant,
626}
627
628impl ResponseHandler for MetricsCallbackHandler {
629 fn on_response(&mut self, response: &http::response::Parts) {
630 if let Some(errors) = response.extensions.get::<GraphqlErrors>() {
631 self.metrics.inc_errors(&errors.0);
632 }
633 }
634
635 fn on_error<E>(&mut self, _error: &E) {
636 }
641}
642
643impl Drop for MetricsCallbackHandler {
644 fn drop(&mut self) {
645 self.metrics.query_latency(self.start.elapsed());
646 self.metrics.request_metrics.inflight_requests.dec();
647 }
648}
649
650#[derive(Debug, Clone)]
651struct GraphqlErrors(std::sync::Arc<Vec<async_graphql::ServerError>>);
652
653async fn db_health_check(State(connection): State<ConnectionConfig>) -> StatusCode {
655 let Ok(url) = reqwest::Url::parse(connection.db_url.as_str()) else {
656 return StatusCode::INTERNAL_SERVER_ERROR;
657 };
658
659 let Some(host) = url.host_str() else {
660 return StatusCode::INTERNAL_SERVER_ERROR;
661 };
662
663 let tcp_url = if let Some(port) = url.port() {
664 format!("{host}:{port}")
665 } else {
666 host.to_string()
667 };
668
669 if TcpStream::connect(tcp_url).is_err() {
670 StatusCode::INTERNAL_SERVER_ERROR
671 } else {
672 StatusCode::OK
673 }
674}
675
676#[derive(serde::Deserialize)]
677struct HealthParam {
678 max_checkpoint_lag_ms: Option<u64>,
679}
680
681async fn health_check(
687 State(connection): State<ConnectionConfig>,
688 Extension(watermark_lock): Extension<WatermarkLock>,
689 AxumQuery(query_params): AxumQuery<HealthParam>,
690) -> StatusCode {
691 let db_health_check = db_health_check(axum::extract::State(connection)).await;
692 if db_health_check != StatusCode::OK {
693 return db_health_check;
694 }
695
696 let max_checkpoint_lag_ms = query_params
697 .max_checkpoint_lag_ms
698 .map(Duration::from_millis)
699 .unwrap_or_else(|| DEFAULT_MAX_CHECKPOINT_LAG);
700
701 let checkpoint_timestamp =
702 Duration::from_millis(watermark_lock.read().await.checkpoint_timestamp_ms);
703
704 let now_millis = Utc::now().timestamp_millis();
705
706 let now: Duration = match u64::try_from(now_millis) {
708 Ok(val) => Duration::from_millis(val),
709 Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
710 };
711
712 if (now - checkpoint_timestamp) > max_checkpoint_lag_ms {
713 return StatusCode::GATEWAY_TIMEOUT;
714 }
715
716 db_health_check
717}
718
719async fn get_or_init_server_start_time() -> &'static Instant {
721 static ONCE: OnceCell<Instant> = OnceCell::const_new();
722 ONCE.get_or_init(|| async move { Instant::now() }).await
723}
724
725pub mod tests {
726 use std::{sync::Arc, time::Duration};
727
728 use async_graphql::{
729 Response,
730 extensions::{Extension, ExtensionContext, NextExecute},
731 };
732 use iota_indexer::optimistic_indexing::OptimisticTransactionExecutor;
733 use iota_sdk::wallet_context::WalletContext;
734 use iota_types::transaction::TransactionData;
735 use uuid::Uuid;
736
737 use super::*;
738 use crate::{
739 config::{ConnectionConfig, Limits, ServiceConfig, Version},
740 context_data::db_data_provider::PgManager,
741 extensions::{query_limits_checker::QueryLimitsChecker, timeout::Timeout},
742 };
743
744 fn prep_schema(
748 connection_config: Option<ConnectionConfig>,
749 service_config: Option<ServiceConfig>,
750 ) -> ServerBuilder {
751 let connection_config = connection_config.unwrap_or_default();
752 let service_config = service_config.unwrap_or_default();
753
754 let reader = PgManager::reader_with_config(
755 connection_config.db_url.clone(),
756 connection_config.db_pool_size,
757 service_config.limits.request_timeout_ms.into(),
758 )
759 .expect("failed to create pg connection pool");
760
761 let version = Version::for_testing();
762 let metrics = metrics();
763 let db = Db::new(
764 reader.clone(),
765 service_config.limits.clone(),
766 metrics.clone(),
767 );
768 let loader = DataLoader::new(db.clone());
769 let pg_conn_pool = PgManager::new(reader);
770 let cancellation_token = CancellationToken::new();
771 let watermark = Watermark {
772 checkpoint: 1,
773 checkpoint_timestamp_ms: 1,
774 epoch: 0,
775 };
776 let state = AppState::new(
777 connection_config.clone(),
778 service_config.clone(),
779 metrics.clone(),
780 cancellation_token.clone(),
781 version,
782 );
783 ServerBuilder::new(state)
784 .context_data(db)
785 .context_data(loader)
786 .context_data(pg_conn_pool)
787 .context_data(service_config)
788 .context_data(query_id())
789 .context_data(ip_address())
790 .context_data(watermark)
791 .context_data(ChainIdentifierCache::default())
792 .context_data(metrics)
793 }
794
795 fn metrics() -> Metrics {
796 let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
797 let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
798 Metrics::new(®istry)
799 }
800
801 fn ip_address() -> SocketAddr {
802 let binding_address: SocketAddr = "0.0.0.0:51515".parse().unwrap();
803 binding_address
804 }
805
806 fn query_id() -> Uuid {
807 Uuid::new_v4()
808 }
809
810 pub async fn test_timeout_impl(
811 wallet: &WalletContext,
812 optimistic_tx_executor: &OptimisticTransactionExecutor,
813 ) {
814 struct TimedExecuteExt {
815 pub min_req_delay: Duration,
816 }
817
818 impl ExtensionFactory for TimedExecuteExt {
819 fn create(&self) -> Arc<dyn Extension> {
820 Arc::new(TimedExecuteExt {
821 min_req_delay: self.min_req_delay,
822 })
823 }
824 }
825
826 #[async_trait::async_trait]
827 impl Extension for TimedExecuteExt {
828 async fn execute(
829 &self,
830 ctx: &ExtensionContext<'_>,
831 operation_name: Option<&str>,
832 next: NextExecute<'_>,
833 ) -> Response {
834 tokio::time::sleep(self.min_req_delay).await;
835 next.run(ctx, operation_name).await
836 }
837 }
838
839 async fn test_timeout(
840 delay: Duration,
841 timeout: Duration,
842 query: &str,
843 optimistic_tx_executor: &OptimisticTransactionExecutor,
844 ) -> Response {
845 let mut cfg = ServiceConfig::default();
846 cfg.limits.request_timeout_ms = timeout.as_millis() as u32;
847 cfg.limits.mutation_timeout_ms = timeout.as_millis() as u32;
848
849 let schema = prep_schema(None, Some(cfg))
850 .context_data(Some(optimistic_tx_executor.clone()))
851 .extension(Timeout)
852 .extension(TimedExecuteExt {
853 min_req_delay: delay,
854 })
855 .build_schema();
856
857 schema.execute(query).await
858 }
859
860 let query = r#"{ checkpoint(id: {sequenceNumber: 0 }) { digest }}"#;
861 let timeout = Duration::from_millis(1000);
862 let delay = Duration::from_millis(100);
863
864 test_timeout(delay, timeout, query, optimistic_tx_executor)
865 .await
866 .into_result()
867 .expect("should complete successfully");
868
869 let errs: Vec<_> = test_timeout(delay, delay, query, optimistic_tx_executor)
871 .await
872 .into_result()
873 .unwrap_err()
874 .into_iter()
875 .map(|e| e.message)
876 .collect();
877 let exp = format!("Query request timed out. Limit: {}s", delay.as_secs_f32());
878 assert_eq!(errs, vec![exp]);
879
880 let addresses = wallet.get_addresses();
884 let gas = wallet
885 .get_one_gas_object_owned_by_address(addresses[0])
886 .await
887 .unwrap();
888 let tx_data = TransactionData::new_transfer_iota(
889 addresses[1],
890 addresses[0],
891 Some(1000),
892 gas.unwrap(),
893 1_000_000,
894 wallet.get_reference_gas_price().await.unwrap(),
895 );
896
897 let tx = wallet.sign_transaction(&tx_data);
898 let (tx_bytes, signatures) = tx.to_tx_bytes_and_signatures();
899
900 let signature_base64 = &signatures[0];
901 let query = format!(
902 r#"
903 mutation {{
904 executeTransactionBlock(txBytes: "{}", signatures: "{}") {{
905 effects {{
906 status
907 }}
908 }}
909 }}"#,
910 tx_bytes.encoded(),
911 signature_base64.encoded()
912 );
913 let errs: Vec<_> = test_timeout(delay, delay, &query, optimistic_tx_executor)
914 .await
915 .into_result()
916 .unwrap_err()
917 .into_iter()
918 .map(|e| e.message)
919 .collect();
920 let exp = format!(
921 "Mutation request timed out. Limit: {}s",
922 delay.as_secs_f32()
923 );
924 assert_eq!(errs, vec![exp]);
925 }
926
927 pub async fn test_query_depth_limit_impl() {
928 async fn exec_query_depth_limit(depth: u32, query: &str) -> Response {
929 let service_config = ServiceConfig {
930 limits: Limits {
931 max_query_depth: depth,
932 ..Default::default()
933 },
934 ..Default::default()
935 };
936
937 let schema = prep_schema(None, Some(service_config))
938 .extension(QueryLimitsChecker)
939 .build_schema();
940 schema.execute(query).await
941 }
942
943 exec_query_depth_limit(1, "{ chainIdentifier }")
944 .await
945 .into_result()
946 .expect("should complete successfully");
947
948 exec_query_depth_limit(
949 5,
950 "{ chainIdentifier protocolConfig { configs { value key }} }",
951 )
952 .await
953 .into_result()
954 .expect("should complete successfully");
955
956 let errs: Vec<_> = exec_query_depth_limit(0, "{ chainIdentifier }")
958 .await
959 .into_result()
960 .unwrap_err()
961 .into_iter()
962 .map(|e| e.message)
963 .collect();
964
965 assert_eq!(errs, vec!["Query nesting is over 0".to_string()]);
966 let errs: Vec<_> = exec_query_depth_limit(
967 2,
968 "{ chainIdentifier protocolConfig { configs { value key }} }",
969 )
970 .await
971 .into_result()
972 .unwrap_err()
973 .into_iter()
974 .map(|e| e.message)
975 .collect();
976 assert_eq!(errs, vec!["Query nesting is over 2".to_string()]);
977 }
978
979 pub async fn test_query_node_limit_impl() {
980 async fn exec_query_node_limit(nodes: u32, query: &str) -> Response {
981 let service_config = ServiceConfig {
982 limits: Limits {
983 max_query_nodes: nodes,
984 ..Default::default()
985 },
986 ..Default::default()
987 };
988
989 let schema = prep_schema(None, Some(service_config))
990 .extension(QueryLimitsChecker)
991 .build_schema();
992 schema.execute(query).await
993 }
994
995 exec_query_node_limit(1, "{ chainIdentifier }")
996 .await
997 .into_result()
998 .expect("should complete successfully");
999
1000 exec_query_node_limit(
1001 5,
1002 "{ chainIdentifier protocolConfig { configs { value key }} }",
1003 )
1004 .await
1005 .into_result()
1006 .expect("should complete successfully");
1007
1008 let err: Vec<_> = exec_query_node_limit(0, "{ chainIdentifier }")
1010 .await
1011 .into_result()
1012 .unwrap_err()
1013 .into_iter()
1014 .map(|e| e.message)
1015 .collect();
1016 assert_eq!(err, vec!["Query has over 0 nodes".to_string()]);
1017
1018 let err: Vec<_> = exec_query_node_limit(
1019 4,
1020 "{ chainIdentifier protocolConfig { configs { value key }} }",
1021 )
1022 .await
1023 .into_result()
1024 .unwrap_err()
1025 .into_iter()
1026 .map(|e| e.message)
1027 .collect();
1028 assert_eq!(err, vec!["Query has over 4 nodes".to_string()]);
1029 }
1030
1031 pub async fn test_query_default_page_limit_impl(connection_config: ConnectionConfig) {
1032 let service_config = ServiceConfig {
1033 limits: Limits {
1034 default_page_size: 1,
1035 ..Default::default()
1036 },
1037 ..Default::default()
1038 };
1039 let schema = prep_schema(Some(connection_config), Some(service_config)).build_schema();
1040
1041 let resp = schema
1042 .execute("{ checkpoints { nodes { sequenceNumber } } }")
1043 .await;
1044 let data = resp.data.clone().into_json().unwrap();
1045 let checkpoints = data
1046 .get("checkpoints")
1047 .unwrap()
1048 .get("nodes")
1049 .unwrap()
1050 .as_array()
1051 .unwrap();
1052 assert_eq!(
1053 checkpoints.len(),
1054 1,
1055 "Checkpoints should have exactly one element"
1056 );
1057
1058 let resp = schema
1059 .execute("{ checkpoints(first: 2) { nodes { sequenceNumber } } }")
1060 .await;
1061 let data = resp.data.clone().into_json().unwrap();
1062 let checkpoints = data
1063 .get("checkpoints")
1064 .unwrap()
1065 .get("nodes")
1066 .unwrap()
1067 .as_array()
1068 .unwrap();
1069 assert_eq!(
1070 checkpoints.len(),
1071 2,
1072 "Checkpoints should return two elements"
1073 );
1074 }
1075
1076 pub async fn test_query_max_page_limit_impl() {
1077 let schema = prep_schema(None, None).build_schema();
1078
1079 schema
1080 .execute("{ objects(first: 1) { nodes { version } } }")
1081 .await
1082 .into_result()
1083 .expect("should complete successfully");
1084
1085 let err: Vec<_> = schema
1087 .execute("{ objects(first: 51) { nodes { version } } }")
1088 .await
1089 .into_result()
1090 .unwrap_err()
1091 .into_iter()
1092 .map(|e| e.message)
1093 .collect();
1094 assert_eq!(
1095 err,
1096 vec!["Connection's page size of 51 exceeds max of 50".to_string()]
1097 );
1098 }
1099
1100 pub async fn test_query_complexity_metrics_impl() {
1101 let server_builder = prep_schema(None, None);
1102 let metrics = server_builder.state.metrics.clone();
1103 let schema = server_builder
1104 .extension(QueryLimitsChecker) .build_schema();
1106
1107 schema
1108 .execute("{ chainIdentifier }")
1109 .await
1110 .into_result()
1111 .expect("should complete successfully");
1112
1113 let req_metrics = metrics.request_metrics;
1114 assert_eq!(req_metrics.input_nodes.get_sample_count(), 1);
1115 assert_eq!(req_metrics.output_nodes.get_sample_count(), 1);
1116 assert_eq!(req_metrics.query_depth.get_sample_count(), 1);
1117 assert_eq!(req_metrics.input_nodes.get_sample_sum(), 1.);
1118 assert_eq!(req_metrics.output_nodes.get_sample_sum(), 1.);
1119 assert_eq!(req_metrics.query_depth.get_sample_sum(), 1.);
1120
1121 schema
1122 .execute("{ chainIdentifier protocolConfig { configs { value key }} }")
1123 .await
1124 .into_result()
1125 .expect("should complete successfully");
1126
1127 assert_eq!(req_metrics.input_nodes.get_sample_count(), 2);
1128 assert_eq!(req_metrics.output_nodes.get_sample_count(), 2);
1129 assert_eq!(req_metrics.query_depth.get_sample_count(), 2);
1130 assert_eq!(req_metrics.input_nodes.get_sample_sum(), 2. + 4.);
1131 assert_eq!(req_metrics.output_nodes.get_sample_sum(), 2. + 4.);
1132 assert_eq!(req_metrics.query_depth.get_sample_sum(), 1. + 3.);
1133 }
1134
1135 pub async fn test_health_check_impl() {
1136 let server_builder = prep_schema(None, None);
1137 let url = format!(
1138 "http://{}:{}/health",
1139 server_builder.state.connection.host, server_builder.state.connection.port
1140 );
1141 server_builder.build_schema();
1142
1143 let resp = reqwest::get(&url).await.unwrap();
1144 assert_eq!(resp.status(), reqwest::StatusCode::OK);
1145
1146 let url_with_param = format!("{url}?max_checkpoint_lag_ms=1");
1147 let resp = reqwest::get(&url_with_param).await.unwrap();
1148 assert_eq!(resp.status(), reqwest::StatusCode::GATEWAY_TIMEOUT);
1149 }
1150}