1use std::{
6 any::Any,
7 convert::Infallible,
8 net::{SocketAddr, TcpStream},
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use async_graphql::{
14 EmptySubscription, Schema, SchemaBuilder,
15 extensions::{ApolloTracing, ExtensionFactory, Tracing},
16};
17use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
18use axum::{
19 Extension, Router,
20 body::Body,
21 extract::{ConnectInfo, FromRef, Query as AxumQuery, State},
22 http::{HeaderMap, StatusCode},
23 middleware::{self},
24 response::IntoResponse,
25 routing::{MethodRouter, Route, get, post},
26};
27use chrono::Utc;
28use http::{HeaderValue, Method, Request};
29use iota_graphql_rpc_headers::LIMITS_HEADER;
30use iota_metrics::spawn_monitored_task;
31use iota_network_stack::callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler};
32use iota_package_resolver::{PackageStoreWithLruCache, Resolver};
33use iota_sdk::IotaClientBuilder;
34use tokio::{join, net::TcpListener, sync::OnceCell};
35use tokio_util::sync::CancellationToken;
36use tower::{Layer, Service};
37use tower_http::cors::{AllowOrigin, CorsLayer};
38use tracing::{info, warn};
39use uuid::Uuid;
40
41use crate::{
42 config::{
43 ConnectionConfig, MAX_CONCURRENT_REQUESTS, RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD,
44 ServerConfig, ServiceConfig, Version,
45 },
46 context_data::db_data_provider::PgManager,
47 data::{
48 DataLoader, Db,
49 package_resolver::{DbPackageStore, PackageResolver},
50 },
51 error::Error,
52 extensions::{
53 directive_checker::DirectiveChecker,
54 feature_gate::FeatureGate,
55 logger::Logger,
56 query_limits_checker::{QueryLimitsChecker, ShowUsage},
57 timeout::Timeout,
58 },
59 metrics::Metrics,
60 mutation::Mutation,
61 server::{
62 compatibility_check::check_all_tables,
63 exchange_rates_task::TriggerExchangeRatesTask,
64 system_package_task::SystemPackageTask,
65 version::{check_version_middleware, set_version_middleware},
66 watermark_task::{Watermark, WatermarkLock, WatermarkTask},
67 },
68 types::{
69 datatype::IMoveDatatype,
70 move_object::IMoveObject,
71 object::IObject,
72 owner::IOwner,
73 query::{IotaGraphQLSchema, Query},
74 },
75};
76
77const DEFAULT_MAX_CHECKPOINT_LAG: Duration = Duration::from_secs(300);
80
81pub(crate) struct Server {
82 router: Router,
83 address: SocketAddr,
84 watermark_task: WatermarkTask,
85 system_package_task: SystemPackageTask,
86 trigger_exchange_rates_task: TriggerExchangeRatesTask,
87 state: AppState,
88 db_reader: Db,
89}
90
91impl Server {
92 pub async fn run(mut self) -> Result<(), Error> {
96 get_or_init_server_start_time().await;
97
98 info!("Starting compatibility check");
100 check_all_tables(&self.db_reader).await?;
101 info!("Compatibility check passed");
102
103 let watermark_task = {
107 info!("Starting watermark update task");
108 spawn_monitored_task!(async move {
109 self.watermark_task.run().await;
110 })
111 };
112
113 let system_package_task = {
116 info!("Starting system package task");
117 spawn_monitored_task!(async move {
118 self.system_package_task.run().await;
119 })
120 };
121
122 let trigger_exchange_rates_task = {
123 info!("Starting trigger exchange rates task");
124 spawn_monitored_task!(async move {
125 self.trigger_exchange_rates_task.run().await;
126 })
127 };
128
129 let server_task = {
130 info!("Starting graphql service");
131 let cancellation_token = self.state.cancellation_token.clone();
132 let address = self.address;
133 let router = self.router;
134 spawn_monitored_task!(async move {
135 axum::serve(
136 TcpListener::bind(address)
137 .await
138 .map_err(|e| Error::Internal(format!("listener bind failed: {}", e)))?,
139 router.into_make_service_with_connect_info::<SocketAddr>(),
140 )
141 .with_graceful_shutdown(async move {
142 cancellation_token.cancelled().await;
143 info!("Shutdown signal received, terminating graphql service");
144 })
145 .await
146 .map_err(|e| Error::Internal(format!("Server run failed: {}", e)))
147 })
148 };
149
150 let _ = join!(
154 watermark_task,
155 system_package_task,
156 trigger_exchange_rates_task,
157 server_task
158 );
159
160 Ok(())
161 }
162}
163
164pub(crate) struct ServerBuilder {
165 state: AppState,
166 schema: SchemaBuilder<Query, Mutation, EmptySubscription>,
167 router: Option<Router>,
168 db_reader: Option<Db>,
169 resolver: Option<PackageResolver>,
170}
171
172#[derive(Clone)]
173pub(crate) struct AppState {
174 connection: ConnectionConfig,
175 service: ServiceConfig,
176 metrics: Metrics,
177 cancellation_token: CancellationToken,
178 pub version: Version,
179}
180
181impl AppState {
182 pub(crate) fn new(
183 connection: ConnectionConfig,
184 service: ServiceConfig,
185 metrics: Metrics,
186 cancellation_token: CancellationToken,
187 version: Version,
188 ) -> Self {
189 Self {
190 connection,
191 service,
192 metrics,
193 cancellation_token,
194 version,
195 }
196 }
197}
198
199impl FromRef<AppState> for ConnectionConfig {
200 fn from_ref(app_state: &AppState) -> ConnectionConfig {
201 app_state.connection.clone()
202 }
203}
204
205impl FromRef<AppState> for Metrics {
206 fn from_ref(app_state: &AppState) -> Metrics {
207 app_state.metrics.clone()
208 }
209}
210
211impl ServerBuilder {
212 pub fn new(state: AppState) -> Self {
213 Self {
214 state,
215 schema: schema_builder(),
216 router: None,
217 db_reader: None,
218 resolver: None,
219 }
220 }
221
222 pub fn address(&self) -> String {
223 format!(
224 "{}:{}",
225 self.state.connection.host, self.state.connection.port
226 )
227 }
228
229 pub fn context_data(mut self, context_data: impl Any + Send + Sync) -> Self {
230 self.schema = self.schema.data(context_data);
231 self
232 }
233
234 pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
235 self.schema = self.schema.extension(extension);
236 self
237 }
238
239 fn build_schema(self) -> Schema<Query, Mutation, EmptySubscription> {
240 self.schema.finish()
241 }
242
243 fn build_components(
246 self,
247 ) -> (
248 String,
249 Schema<Query, Mutation, EmptySubscription>,
250 Db,
251 PackageResolver,
252 Router,
253 ) {
254 let address = self.address();
255 let ServerBuilder {
256 state: _,
257 schema,
258 db_reader,
259 resolver,
260 router,
261 } = self;
262 (
263 address,
264 schema.finish(),
265 db_reader.expect("DB reader not initialized"),
266 resolver.expect("Package resolver not initialized"),
267 router.expect("Router not initialized"),
268 )
269 }
270
271 fn init_router(&mut self) {
272 if self.router.is_none() {
273 let router: Router = Router::new()
274 .route("/", post(graphql_handler))
275 .route("/:version", post(graphql_handler))
276 .route("/graphql", post(graphql_handler))
277 .route("/graphql/:version", post(graphql_handler))
278 .route("/health", get(health_check))
279 .route("/graphql/health", get(health_check))
280 .route("/graphql/:version/health", get(health_check))
281 .with_state(self.state.clone())
282 .route_layer(CallbackLayer::new(MetricsMakeCallbackHandler {
283 metrics: self.state.metrics.clone(),
284 }));
285 self.router = Some(router);
286 }
287 }
288
289 pub fn route(mut self, path: &str, method_handler: MethodRouter) -> Self {
290 self.init_router();
291 self.router = self.router.map(|router| router.route(path, method_handler));
292 self
293 }
294
295 pub fn layer<L>(mut self, layer: L) -> Self
296 where
297 L: Layer<Route> + Clone + Send + 'static,
298 L::Service: Service<Request<Body>> + Clone + Send + 'static,
299 <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
300 <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
301 <L::Service as Service<Request<Body>>>::Future: Send + 'static,
302 {
303 self.init_router();
304 self.router = self.router.map(|router| router.layer(layer));
305 self
306 }
307
308 fn cors() -> Result<CorsLayer, Error> {
309 let acl = match std::env::var("ACCESS_CONTROL_ALLOW_ORIGIN") {
310 Ok(value) => {
311 let allow_hosts = value
312 .split(',')
313 .map(HeaderValue::from_str)
314 .collect::<Result<Vec<_>, _>>()
315 .map_err(|_| {
316 Error::Internal(
317 "Cannot resolve access control origin env variable".to_string(),
318 )
319 })?;
320 AllowOrigin::list(allow_hosts)
321 }
322 _ => AllowOrigin::any(),
323 };
324 info!("Access control allow origin set to: {acl:?}");
325
326 let cors = CorsLayer::new()
327 .allow_methods([Method::POST])
329 .allow_origin(acl)
331 .allow_headers([hyper::header::CONTENT_TYPE, LIMITS_HEADER.clone()]);
332 Ok(cors)
333 }
334
335 pub fn build(self) -> Result<Server, Error> {
337 let state = self.state.clone();
338 let (address, schema, db_reader, resolver, router) = self.build_components();
339
340 let watermark_task = WatermarkTask::new(
342 db_reader.clone(),
343 state.metrics.clone(),
344 std::time::Duration::from_millis(state.service.background_tasks.watermark_update_ms),
345 state.cancellation_token.clone(),
346 );
347
348 let system_package_task = SystemPackageTask::new(
349 resolver,
350 watermark_task.epoch_receiver(),
351 state.cancellation_token.clone(),
352 );
353
354 let trigger_exchange_rates_task = TriggerExchangeRatesTask::new(
355 db_reader.clone(),
356 watermark_task.epoch_receiver(),
357 state.cancellation_token.clone(),
358 );
359
360 let router = router
361 .route_layer(middleware::from_fn_with_state(
362 state.version,
363 set_version_middleware,
364 ))
365 .route_layer(middleware::from_fn_with_state(
366 state.version,
367 check_version_middleware,
368 ))
369 .layer(axum::extract::Extension(schema))
370 .layer(axum::extract::Extension(watermark_task.lock()))
371 .layer(Self::cors()?);
372
373 Ok(Server {
374 router,
375 address: address
376 .parse()
377 .map_err(|_| Error::Internal(format!("Failed to parse address {}", address)))?,
378 watermark_task,
379 system_package_task,
380 trigger_exchange_rates_task,
381 state,
382 db_reader,
383 })
384 }
385
386 pub async fn from_config(
389 config: &ServerConfig,
390 version: &Version,
391 cancellation_token: CancellationToken,
392 ) -> Result<Self, Error> {
393 let prom_addr: SocketAddr = format!(
395 "{}:{}",
396 config.connection.prom_url, config.connection.prom_port
397 )
398 .parse()
399 .map_err(|_| {
400 Error::Internal(format!(
401 "Failed to parse url {}, port {} into socket address",
402 config.connection.prom_url, config.connection.prom_port
403 ))
404 })?;
405
406 let registry_service = iota_metrics::start_prometheus_server(prom_addr);
407 info!("Starting Prometheus HTTP endpoint at {}", prom_addr);
408 let registry = registry_service.default_registry();
409 registry
410 .register(iota_metrics::uptime_metric(
411 "graphql",
412 version.full,
413 "unknown",
414 ))
415 .unwrap();
416
417 let metrics = Metrics::new(®istry);
419 let state = AppState::new(
420 config.connection.clone(),
421 config.service.clone(),
422 metrics.clone(),
423 cancellation_token,
424 *version,
425 );
426 let mut builder = ServerBuilder::new(state);
427
428 let iota_names_config = config.service.iota_names.clone();
429 let zklogin_config = config.service.zklogin.clone();
430 let reader = PgManager::reader_with_config(
431 config.connection.db_url.clone(),
432 config.connection.db_pool_size,
433 config.service.limits.request_timeout_ms.into(),
437 )
438 .map_err(|e| Error::Internal(format!("Failed to create pg connection pool: {}", e)))?;
439
440 let db = Db::new(
442 reader.clone(),
443 config.service.limits.clone(),
444 metrics.clone(),
445 );
446 let loader = DataLoader::new(db.clone());
447 let pg_conn_pool = PgManager::new(reader.clone());
448 let package_store = DbPackageStore::new(loader.clone());
449 let resolver = Arc::new(Resolver::new_with_limits(
450 PackageStoreWithLruCache::new(package_store),
451 config.service.limits.package_resolver_limits(),
452 ));
453
454 builder.db_reader = Some(db.clone());
455 builder.resolver = Some(resolver.clone());
456
457 let iota_sdk_client = if let Some(url) = &config.tx_exec_full_node.node_rpc_url {
460 Some(
461 IotaClientBuilder::default()
462 .request_timeout(RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD)
463 .max_concurrent_requests(MAX_CONCURRENT_REQUESTS)
464 .build(url)
465 .await
466 .map_err(|e| {
467 Error::Internal(format!(
468 "Failed to connect to fullnode {e}. Is the node server running?"
469 ))
470 })?,
471 )
472 } else {
473 warn!(
474 "No fullnode url found in config. `dryRunTransactionBlock` and `executeTransactionBlock` will not work"
475 );
476 None
477 };
478
479 builder = builder
480 .context_data(config.service.clone())
481 .context_data(loader)
482 .context_data(db)
483 .context_data(pg_conn_pool)
484 .context_data(resolver)
485 .context_data(iota_sdk_client)
486 .context_data(iota_names_config)
487 .context_data(zklogin_config)
488 .context_data(metrics.clone())
489 .context_data(config.clone());
490
491 if config.internal_features.feature_gate {
492 builder = builder.extension(FeatureGate);
493 }
494
495 if config.internal_features.logger {
496 builder = builder.extension(Logger::default());
497 }
498
499 if config.internal_features.query_limits_checker {
500 builder = builder.extension(QueryLimitsChecker);
501 }
502
503 if config.internal_features.directive_checker {
504 builder = builder.extension(DirectiveChecker);
505 }
506
507 if config.internal_features.query_timeout {
508 builder = builder.extension(Timeout);
509 }
510
511 if config.internal_features.tracing {
512 builder = builder.extension(Tracing);
513 }
514
515 if config.internal_features.apollo_tracing {
516 builder = builder.extension(ApolloTracing);
517 }
518
519 Ok(builder)
523 }
524}
525
526fn schema_builder() -> SchemaBuilder<Query, Mutation, EmptySubscription> {
527 async_graphql::Schema::build(Query, Mutation, EmptySubscription)
528 .register_output_type::<IMoveObject>()
529 .register_output_type::<IObject>()
530 .register_output_type::<IOwner>()
531 .register_output_type::<IMoveDatatype>()
532}
533
534pub fn export_schema() -> String {
536 schema_builder().finish().sdl()
537}
538
539async fn graphql_handler(
543 ConnectInfo(addr): ConnectInfo<SocketAddr>,
544 schema: Extension<IotaGraphQLSchema>,
545 Extension(watermark_lock): Extension<WatermarkLock>,
546 headers: HeaderMap,
547 req: GraphQLRequest,
548) -> (axum::http::Extensions, GraphQLResponse) {
549 let mut req = req.into_inner();
550 req.data.insert(Uuid::new_v4());
551 if headers.contains_key(ShowUsage::name()) {
552 req.data.insert(ShowUsage)
553 }
554 req.data.insert(addr);
558
559 req.data.insert(Watermark::new(watermark_lock).await);
560
561 let result = schema.execute(req).await;
562
563 let mut extensions = axum::http::Extensions::new();
566 if result.is_err() {
567 extensions.insert(GraphqlErrors(std::sync::Arc::new(result.errors.clone())));
568 };
569 (extensions, result.into())
570}
571
572#[derive(Clone)]
573struct MetricsMakeCallbackHandler {
574 metrics: Metrics,
575}
576
577impl MakeCallbackHandler for MetricsMakeCallbackHandler {
578 type Handler = MetricsCallbackHandler;
579
580 fn make_handler(&self, _request: &http::request::Parts) -> Self::Handler {
581 let start = Instant::now();
582 let metrics = self.metrics.clone();
583
584 metrics.request_metrics.inflight_requests.inc();
585 metrics.inc_num_queries();
586
587 MetricsCallbackHandler { metrics, start }
588 }
589}
590
591struct MetricsCallbackHandler {
592 metrics: Metrics,
593 start: Instant,
594}
595
596impl ResponseHandler for MetricsCallbackHandler {
597 fn on_response(self, response: &http::response::Parts) {
598 if let Some(errors) = response.extensions.get::<GraphqlErrors>() {
599 self.metrics.inc_errors(&errors.0);
600 }
601 }
602
603 fn on_error<E>(self, _error: &E) {
604 }
609}
610
611impl Drop for MetricsCallbackHandler {
612 fn drop(&mut self) {
613 self.metrics.query_latency(self.start.elapsed());
614 self.metrics.request_metrics.inflight_requests.dec();
615 }
616}
617
618#[derive(Debug, Clone)]
619struct GraphqlErrors(std::sync::Arc<Vec<async_graphql::ServerError>>);
620
621async fn db_health_check(State(connection): State<ConnectionConfig>) -> StatusCode {
623 let Ok(url) = reqwest::Url::parse(connection.db_url.as_str()) else {
624 return StatusCode::INTERNAL_SERVER_ERROR;
625 };
626
627 let Some(host) = url.host_str() else {
628 return StatusCode::INTERNAL_SERVER_ERROR;
629 };
630
631 let tcp_url = if let Some(port) = url.port() {
632 format!("{host}:{port}")
633 } else {
634 host.to_string()
635 };
636
637 if TcpStream::connect(tcp_url).is_err() {
638 StatusCode::INTERNAL_SERVER_ERROR
639 } else {
640 StatusCode::OK
641 }
642}
643
644#[derive(serde::Deserialize)]
645struct HealthParam {
646 max_checkpoint_lag_ms: Option<u64>,
647}
648
649async fn health_check(
655 State(connection): State<ConnectionConfig>,
656 Extension(watermark_lock): Extension<WatermarkLock>,
657 AxumQuery(query_params): AxumQuery<HealthParam>,
658) -> StatusCode {
659 let db_health_check = db_health_check(axum::extract::State(connection)).await;
660 if db_health_check != StatusCode::OK {
661 return db_health_check;
662 }
663
664 let max_checkpoint_lag_ms = query_params
665 .max_checkpoint_lag_ms
666 .map(Duration::from_millis)
667 .unwrap_or_else(|| DEFAULT_MAX_CHECKPOINT_LAG);
668
669 let checkpoint_timestamp =
670 Duration::from_millis(watermark_lock.read().await.checkpoint_timestamp_ms);
671
672 let now_millis = Utc::now().timestamp_millis();
673
674 let now: Duration = match u64::try_from(now_millis) {
676 Ok(val) => Duration::from_millis(val),
677 Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
678 };
679
680 if (now - checkpoint_timestamp) > max_checkpoint_lag_ms {
681 return StatusCode::GATEWAY_TIMEOUT;
682 }
683
684 db_health_check
685}
686
687async fn get_or_init_server_start_time() -> &'static Instant {
689 static ONCE: OnceCell<Instant> = OnceCell::const_new();
690 ONCE.get_or_init(|| async move { Instant::now() }).await
691}
692
693pub mod tests {
694 use std::{sync::Arc, time::Duration};
695
696 use async_graphql::{
697 Response,
698 extensions::{Extension, ExtensionContext, NextExecute},
699 };
700 use iota_sdk::{IotaClient, wallet_context::WalletContext};
701 use iota_types::transaction::TransactionData;
702 use uuid::Uuid;
703
704 use super::*;
705 use crate::{
706 config::{ConnectionConfig, Limits, ServiceConfig, Version},
707 context_data::db_data_provider::PgManager,
708 extensions::{query_limits_checker::QueryLimitsChecker, timeout::Timeout},
709 };
710
711 fn prep_schema(
715 connection_config: Option<ConnectionConfig>,
716 service_config: Option<ServiceConfig>,
717 ) -> ServerBuilder {
718 let connection_config = connection_config.unwrap_or_default();
719 let service_config = service_config.unwrap_or_default();
720
721 let reader = PgManager::reader_with_config(
722 connection_config.db_url.clone(),
723 connection_config.db_pool_size,
724 service_config.limits.request_timeout_ms.into(),
725 )
726 .expect("Failed to create pg connection pool");
727
728 let version = Version::for_testing();
729 let metrics = metrics();
730 let db = Db::new(
731 reader.clone(),
732 service_config.limits.clone(),
733 metrics.clone(),
734 );
735 let pg_conn_pool = PgManager::new(reader);
736 let cancellation_token = CancellationToken::new();
737 let watermark = Watermark {
738 checkpoint: 1,
739 checkpoint_timestamp_ms: 1,
740 epoch: 0,
741 };
742 let state = AppState::new(
743 connection_config.clone(),
744 service_config.clone(),
745 metrics.clone(),
746 cancellation_token.clone(),
747 version,
748 );
749 ServerBuilder::new(state)
750 .context_data(db)
751 .context_data(pg_conn_pool)
752 .context_data(service_config)
753 .context_data(query_id())
754 .context_data(ip_address())
755 .context_data(watermark)
756 .context_data(metrics)
757 }
758
759 fn metrics() -> Metrics {
760 let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
761 let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
762 Metrics::new(®istry)
763 }
764
765 fn ip_address() -> SocketAddr {
766 let binding_address: SocketAddr = "0.0.0.0:51515".parse().unwrap();
767 binding_address
768 }
769
770 fn query_id() -> Uuid {
771 Uuid::new_v4()
772 }
773
774 pub async fn test_timeout_impl(wallet: &WalletContext) {
775 struct TimedExecuteExt {
776 pub min_req_delay: Duration,
777 }
778
779 impl ExtensionFactory for TimedExecuteExt {
780 fn create(&self) -> Arc<dyn Extension> {
781 Arc::new(TimedExecuteExt {
782 min_req_delay: self.min_req_delay,
783 })
784 }
785 }
786
787 #[async_trait::async_trait]
788 impl Extension for TimedExecuteExt {
789 async fn execute(
790 &self,
791 ctx: &ExtensionContext<'_>,
792 operation_name: Option<&str>,
793 next: NextExecute<'_>,
794 ) -> Response {
795 tokio::time::sleep(self.min_req_delay).await;
796 next.run(ctx, operation_name).await
797 }
798 }
799
800 async fn test_timeout(
801 delay: Duration,
802 timeout: Duration,
803 query: &str,
804 iota_client: &IotaClient,
805 ) -> Response {
806 let mut cfg = ServiceConfig::default();
807 cfg.limits.request_timeout_ms = timeout.as_millis() as u32;
808 cfg.limits.mutation_timeout_ms = timeout.as_millis() as u32;
809
810 let schema = prep_schema(None, Some(cfg))
811 .context_data(Some(iota_client.clone()))
812 .extension(Timeout)
813 .extension(TimedExecuteExt {
814 min_req_delay: delay,
815 })
816 .build_schema();
817
818 schema.execute(query).await
819 }
820
821 let query = "{ chainIdentifier }";
822 let timeout = Duration::from_millis(1000);
823 let delay = Duration::from_millis(100);
824 let iota_client = wallet.get_client().await.unwrap();
825
826 test_timeout(delay, timeout, query, &iota_client)
827 .await
828 .into_result()
829 .expect("Should complete successfully");
830
831 let errs: Vec<_> = test_timeout(delay, delay, query, &iota_client)
833 .await
834 .into_result()
835 .unwrap_err()
836 .into_iter()
837 .map(|e| e.message)
838 .collect();
839 let exp = format!("Query request timed out. Limit: {}s", delay.as_secs_f32());
840 assert_eq!(errs, vec![exp]);
841
842 let addresses = wallet.get_addresses();
846 let gas = wallet
847 .get_one_gas_object_owned_by_address(addresses[0])
848 .await
849 .unwrap();
850 let tx_data = TransactionData::new_transfer_iota(
851 addresses[1],
852 addresses[0],
853 Some(1000),
854 gas.unwrap(),
855 1_000_000,
856 wallet.get_reference_gas_price().await.unwrap(),
857 );
858
859 let tx = wallet.sign_transaction(&tx_data);
860 let (tx_bytes, signatures) = tx.to_tx_bytes_and_signatures();
861
862 let signature_base64 = &signatures[0];
863 let query = format!(
864 r#"
865 mutation {{
866 executeTransactionBlock(txBytes: "{}", signatures: "{}") {{
867 effects {{
868 status
869 }}
870 }}
871 }}"#,
872 tx_bytes.encoded(),
873 signature_base64.encoded()
874 );
875 let errs: Vec<_> = test_timeout(delay, delay, &query, &iota_client)
876 .await
877 .into_result()
878 .unwrap_err()
879 .into_iter()
880 .map(|e| e.message)
881 .collect();
882 let exp = format!(
883 "Mutation request timed out. Limit: {}s",
884 delay.as_secs_f32()
885 );
886 assert_eq!(errs, vec![exp]);
887 }
888
889 pub async fn test_query_depth_limit_impl() {
890 async fn exec_query_depth_limit(depth: u32, query: &str) -> Response {
891 let service_config = ServiceConfig {
892 limits: Limits {
893 max_query_depth: depth,
894 ..Default::default()
895 },
896 ..Default::default()
897 };
898
899 let schema = prep_schema(None, Some(service_config))
900 .extension(QueryLimitsChecker)
901 .build_schema();
902 schema.execute(query).await
903 }
904
905 exec_query_depth_limit(1, "{ chainIdentifier }")
906 .await
907 .into_result()
908 .expect("Should complete successfully");
909
910 exec_query_depth_limit(
911 5,
912 "{ chainIdentifier protocolConfig { configs { value key }} }",
913 )
914 .await
915 .into_result()
916 .expect("Should complete successfully");
917
918 let errs: Vec<_> = exec_query_depth_limit(0, "{ chainIdentifier }")
920 .await
921 .into_result()
922 .unwrap_err()
923 .into_iter()
924 .map(|e| e.message)
925 .collect();
926
927 assert_eq!(errs, vec!["Query nesting is over 0".to_string()]);
928 let errs: Vec<_> = exec_query_depth_limit(
929 2,
930 "{ chainIdentifier protocolConfig { configs { value key }} }",
931 )
932 .await
933 .into_result()
934 .unwrap_err()
935 .into_iter()
936 .map(|e| e.message)
937 .collect();
938 assert_eq!(errs, vec!["Query nesting is over 2".to_string()]);
939 }
940
941 pub async fn test_query_node_limit_impl() {
942 async fn exec_query_node_limit(nodes: u32, query: &str) -> Response {
943 let service_config = ServiceConfig {
944 limits: Limits {
945 max_query_nodes: nodes,
946 ..Default::default()
947 },
948 ..Default::default()
949 };
950
951 let schema = prep_schema(None, Some(service_config))
952 .extension(QueryLimitsChecker)
953 .build_schema();
954 schema.execute(query).await
955 }
956
957 exec_query_node_limit(1, "{ chainIdentifier }")
958 .await
959 .into_result()
960 .expect("Should complete successfully");
961
962 exec_query_node_limit(
963 5,
964 "{ chainIdentifier protocolConfig { configs { value key }} }",
965 )
966 .await
967 .into_result()
968 .expect("Should complete successfully");
969
970 let err: Vec<_> = exec_query_node_limit(0, "{ chainIdentifier }")
972 .await
973 .into_result()
974 .unwrap_err()
975 .into_iter()
976 .map(|e| e.message)
977 .collect();
978 assert_eq!(err, vec!["Query has over 0 nodes".to_string()]);
979
980 let err: Vec<_> = exec_query_node_limit(
981 4,
982 "{ chainIdentifier protocolConfig { configs { value key }} }",
983 )
984 .await
985 .into_result()
986 .unwrap_err()
987 .into_iter()
988 .map(|e| e.message)
989 .collect();
990 assert_eq!(err, vec!["Query has over 4 nodes".to_string()]);
991 }
992
993 pub async fn test_query_default_page_limit_impl(connection_config: ConnectionConfig) {
994 let service_config = ServiceConfig {
995 limits: Limits {
996 default_page_size: 1,
997 ..Default::default()
998 },
999 ..Default::default()
1000 };
1001 let schema = prep_schema(Some(connection_config), Some(service_config)).build_schema();
1002
1003 let resp = schema
1004 .execute("{ checkpoints { nodes { sequenceNumber } } }")
1005 .await;
1006 let data = resp.data.clone().into_json().unwrap();
1007 let checkpoints = data
1008 .get("checkpoints")
1009 .unwrap()
1010 .get("nodes")
1011 .unwrap()
1012 .as_array()
1013 .unwrap();
1014 assert_eq!(
1015 checkpoints.len(),
1016 1,
1017 "Checkpoints should have exactly one element"
1018 );
1019
1020 let resp = schema
1021 .execute("{ checkpoints(first: 2) { nodes { sequenceNumber } } }")
1022 .await;
1023 let data = resp.data.clone().into_json().unwrap();
1024 let checkpoints = data
1025 .get("checkpoints")
1026 .unwrap()
1027 .get("nodes")
1028 .unwrap()
1029 .as_array()
1030 .unwrap();
1031 assert_eq!(
1032 checkpoints.len(),
1033 2,
1034 "Checkpoints should return two elements"
1035 );
1036 }
1037
1038 pub async fn test_query_max_page_limit_impl() {
1039 let schema = prep_schema(None, None).build_schema();
1040
1041 schema
1042 .execute("{ objects(first: 1) { nodes { version } } }")
1043 .await
1044 .into_result()
1045 .expect("Should complete successfully");
1046
1047 let err: Vec<_> = schema
1049 .execute("{ objects(first: 51) { nodes { version } } }")
1050 .await
1051 .into_result()
1052 .unwrap_err()
1053 .into_iter()
1054 .map(|e| e.message)
1055 .collect();
1056 assert_eq!(
1057 err,
1058 vec!["Connection's page size of 51 exceeds max of 50".to_string()]
1059 );
1060 }
1061
1062 pub async fn test_query_complexity_metrics_impl() {
1063 let server_builder = prep_schema(None, None);
1064 let metrics = server_builder.state.metrics.clone();
1065 let schema = server_builder
1066 .extension(QueryLimitsChecker) .build_schema();
1068
1069 schema
1070 .execute("{ chainIdentifier }")
1071 .await
1072 .into_result()
1073 .expect("Should complete successfully");
1074
1075 let req_metrics = metrics.request_metrics;
1076 assert_eq!(req_metrics.input_nodes.get_sample_count(), 1);
1077 assert_eq!(req_metrics.output_nodes.get_sample_count(), 1);
1078 assert_eq!(req_metrics.query_depth.get_sample_count(), 1);
1079 assert_eq!(req_metrics.input_nodes.get_sample_sum(), 1.);
1080 assert_eq!(req_metrics.output_nodes.get_sample_sum(), 1.);
1081 assert_eq!(req_metrics.query_depth.get_sample_sum(), 1.);
1082
1083 schema
1084 .execute("{ chainIdentifier protocolConfig { configs { value key }} }")
1085 .await
1086 .into_result()
1087 .expect("Should complete successfully");
1088
1089 assert_eq!(req_metrics.input_nodes.get_sample_count(), 2);
1090 assert_eq!(req_metrics.output_nodes.get_sample_count(), 2);
1091 assert_eq!(req_metrics.query_depth.get_sample_count(), 2);
1092 assert_eq!(req_metrics.input_nodes.get_sample_sum(), 2. + 4.);
1093 assert_eq!(req_metrics.output_nodes.get_sample_sum(), 2. + 4.);
1094 assert_eq!(req_metrics.query_depth.get_sample_sum(), 1. + 3.);
1095 }
1096
1097 pub async fn test_health_check_impl() {
1098 let server_builder = prep_schema(None, None);
1099 let url = format!(
1100 "http://{}:{}/health",
1101 server_builder.state.connection.host, server_builder.state.connection.port
1102 );
1103 server_builder.build_schema();
1104
1105 let resp = reqwest::get(&url).await.unwrap();
1106 assert_eq!(resp.status(), reqwest::StatusCode::OK);
1107
1108 let url_with_param = format!("{}?max_checkpoint_lag_ms=1", url);
1109 let resp = reqwest::get(&url_with_param).await.unwrap();
1110 assert_eq!(resp.status(), reqwest::StatusCode::GATEWAY_TIMEOUT);
1111 }
1112}