1use std::{
6 any::Any,
7 convert::Infallible,
8 net::{SocketAddr, TcpStream},
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use async_graphql::{
14 EmptySubscription, Schema, SchemaBuilder,
15 extensions::{ApolloTracing, ExtensionFactory, Tracing},
16};
17use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
18use axum::{
19 Extension, Router,
20 body::Body,
21 extract::{ConnectInfo, FromRef, Query as AxumQuery, State},
22 http::{HeaderMap, StatusCode},
23 middleware::{self},
24 response::IntoResponse,
25 routing::{MethodRouter, Route, get, post},
26};
27use chrono::Utc;
28use http::{HeaderValue, Method, Request};
29use iota_graphql_rpc_headers::LIMITS_HEADER;
30use iota_indexer::db::{get_pool_connection, setup_postgres::check_db_migration_consistency};
31use iota_metrics::spawn_monitored_task;
32use iota_network_stack::callback::{CallbackLayer, MakeCallbackHandler, ResponseHandler};
33use iota_package_resolver::{PackageStoreWithLruCache, Resolver};
34use iota_sdk::IotaClientBuilder;
35use tokio::{join, net::TcpListener, sync::OnceCell};
36use tokio_util::sync::CancellationToken;
37use tower::{Layer, Service};
38use tower_http::cors::{AllowOrigin, CorsLayer};
39use tracing::{info, warn};
40use uuid::Uuid;
41
42use crate::{
43 config::{
44 ConnectionConfig, MAX_CONCURRENT_REQUESTS, RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD,
45 ServerConfig, ServiceConfig, Version,
46 },
47 context_data::db_data_provider::PgManager,
48 data::{
49 DataLoader, Db,
50 package_resolver::{DbPackageStore, PackageResolver},
51 },
52 error::Error,
53 extensions::{
54 directive_checker::DirectiveChecker,
55 feature_gate::FeatureGate,
56 logger::Logger,
57 query_limits_checker::{QueryLimitsChecker, ShowUsage},
58 timeout::Timeout,
59 },
60 metrics::Metrics,
61 mutation::Mutation,
62 server::{
63 exchange_rates_task::TriggerExchangeRatesTask,
64 system_package_task::SystemPackageTask,
65 version::{check_version_middleware, set_version_middleware},
66 watermark_task::{Watermark, WatermarkLock, WatermarkTask},
67 },
68 types::{
69 datatype::IMoveDatatype,
70 move_object::IMoveObject,
71 object::IObject,
72 owner::IOwner,
73 query::{IotaGraphQLSchema, Query},
74 },
75};
76
77const DEFAULT_MAX_CHECKPOINT_LAG: Duration = Duration::from_secs(300);
80
81pub(crate) struct Server {
82 router: Router,
83 address: SocketAddr,
84 watermark_task: WatermarkTask,
85 system_package_task: SystemPackageTask,
86 trigger_exchange_rates_task: TriggerExchangeRatesTask,
87 state: AppState,
88 db_reader: Db,
89}
90
91impl Server {
92 pub async fn run(mut self) -> Result<(), Error> {
96 get_or_init_server_start_time().await;
97
98 {
99 info!("Starting compatibility check");
101 let mut connection = get_pool_connection(&self.db_reader.inner.get_pool())?;
102 check_db_migration_consistency(&mut connection)?;
103 info!("Compatibility check passed");
104 }
105
106 let watermark_task = {
110 info!("Starting watermark update task");
111 spawn_monitored_task!(async move {
112 self.watermark_task.run().await;
113 })
114 };
115
116 let system_package_task = {
119 info!("Starting system package task");
120 spawn_monitored_task!(async move {
121 self.system_package_task.run().await;
122 })
123 };
124
125 let trigger_exchange_rates_task = {
126 info!("Starting trigger exchange rates task");
127 spawn_monitored_task!(async move {
128 self.trigger_exchange_rates_task.run().await;
129 })
130 };
131
132 let server_task = {
133 info!("Starting graphql service");
134 let cancellation_token = self.state.cancellation_token.clone();
135 let address = self.address;
136 let router = self.router;
137 spawn_monitored_task!(async move {
138 axum::serve(
139 TcpListener::bind(address)
140 .await
141 .map_err(|e| Error::Internal(format!("listener bind failed: {e}")))?,
142 router.into_make_service_with_connect_info::<SocketAddr>(),
143 )
144 .with_graceful_shutdown(async move {
145 cancellation_token.cancelled().await;
146 info!("Shutdown signal received, terminating graphql service");
147 })
148 .await
149 .map_err(|e| Error::Internal(format!("Server run failed: {e}")))
150 })
151 };
152
153 let _ = join!(
157 watermark_task,
158 system_package_task,
159 trigger_exchange_rates_task,
160 server_task
161 );
162
163 Ok(())
164 }
165}
166
167pub(crate) struct ServerBuilder {
168 state: AppState,
169 schema: SchemaBuilder<Query, Mutation, EmptySubscription>,
170 router: Option<Router>,
171 db_reader: Option<Db>,
172 resolver: Option<PackageResolver>,
173}
174
175#[derive(Clone)]
176pub(crate) struct AppState {
177 connection: ConnectionConfig,
178 service: ServiceConfig,
179 metrics: Metrics,
180 cancellation_token: CancellationToken,
181 pub version: Version,
182}
183
184impl AppState {
185 pub(crate) fn new(
186 connection: ConnectionConfig,
187 service: ServiceConfig,
188 metrics: Metrics,
189 cancellation_token: CancellationToken,
190 version: Version,
191 ) -> Self {
192 Self {
193 connection,
194 service,
195 metrics,
196 cancellation_token,
197 version,
198 }
199 }
200}
201
202impl FromRef<AppState> for ConnectionConfig {
203 fn from_ref(app_state: &AppState) -> ConnectionConfig {
204 app_state.connection.clone()
205 }
206}
207
208impl FromRef<AppState> for Metrics {
209 fn from_ref(app_state: &AppState) -> Metrics {
210 app_state.metrics.clone()
211 }
212}
213
214impl ServerBuilder {
215 pub fn new(state: AppState) -> Self {
216 Self {
217 state,
218 schema: schema_builder(),
219 router: None,
220 db_reader: None,
221 resolver: None,
222 }
223 }
224
225 pub fn address(&self) -> String {
226 format!(
227 "{}:{}",
228 self.state.connection.host, self.state.connection.port
229 )
230 }
231
232 pub fn context_data(mut self, context_data: impl Any + Send + Sync) -> Self {
233 self.schema = self.schema.data(context_data);
234 self
235 }
236
237 pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
238 self.schema = self.schema.extension(extension);
239 self
240 }
241
242 fn build_schema(self) -> Schema<Query, Mutation, EmptySubscription> {
243 self.schema.finish()
244 }
245
246 fn build_components(
249 self,
250 ) -> (
251 String,
252 Schema<Query, Mutation, EmptySubscription>,
253 Db,
254 PackageResolver,
255 Router,
256 ) {
257 let address = self.address();
258 let ServerBuilder {
259 state: _,
260 schema,
261 db_reader,
262 resolver,
263 router,
264 } = self;
265 (
266 address,
267 schema.finish(),
268 db_reader.expect("DB reader not initialized"),
269 resolver.expect("Package resolver not initialized"),
270 router.expect("Router not initialized"),
271 )
272 }
273
274 fn init_router(&mut self) {
275 if self.router.is_none() {
276 let router: Router = Router::new()
277 .route("/", post(graphql_handler))
278 .route("/{version}", post(graphql_handler))
279 .route("/graphql", post(graphql_handler))
280 .route("/graphql/{version}", post(graphql_handler))
281 .route("/health", get(health_check))
282 .route("/graphql/health", get(health_check))
283 .route("/graphql/{version}/health", get(health_check))
284 .with_state(self.state.clone())
285 .route_layer(CallbackLayer::new(MetricsMakeCallbackHandler {
286 metrics: self.state.metrics.clone(),
287 }));
288 self.router = Some(router);
289 }
290 }
291
292 pub fn route(mut self, path: &str, method_handler: MethodRouter) -> Self {
293 self.init_router();
294 self.router = self.router.map(|router| router.route(path, method_handler));
295 self
296 }
297
298 pub fn layer<L>(mut self, layer: L) -> Self
299 where
300 L: Layer<Route> + Clone + Send + Sync + 'static,
301 L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
302 <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
303 <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
304 <L::Service as Service<Request<Body>>>::Future: Send + 'static,
305 {
306 self.init_router();
307 self.router = self.router.map(|router| router.layer(layer));
308 self
309 }
310
311 fn cors() -> Result<CorsLayer, Error> {
312 let acl = match std::env::var("ACCESS_CONTROL_ALLOW_ORIGIN") {
313 Ok(value) => {
314 let allow_hosts = value
315 .split(',')
316 .map(HeaderValue::from_str)
317 .collect::<Result<Vec<_>, _>>()
318 .map_err(|_| {
319 Error::Internal(
320 "Cannot resolve access control origin env variable".to_string(),
321 )
322 })?;
323 AllowOrigin::list(allow_hosts)
324 }
325 _ => AllowOrigin::any(),
326 };
327 info!("Access control allow origin set to: {acl:?}");
328
329 let cors = CorsLayer::new()
330 .allow_methods([Method::POST])
332 .allow_origin(acl)
334 .allow_headers([hyper::header::CONTENT_TYPE, LIMITS_HEADER.clone()]);
335 Ok(cors)
336 }
337
338 pub fn build(self) -> Result<Server, Error> {
340 let state = self.state.clone();
341 let (address, schema, db_reader, resolver, router) = self.build_components();
342
343 let watermark_task = WatermarkTask::new(
345 db_reader.clone(),
346 state.metrics.clone(),
347 std::time::Duration::from_millis(state.service.background_tasks.watermark_update_ms),
348 state.cancellation_token.clone(),
349 );
350
351 let system_package_task = SystemPackageTask::new(
352 resolver,
353 watermark_task.epoch_receiver(),
354 state.cancellation_token.clone(),
355 );
356
357 let trigger_exchange_rates_task = TriggerExchangeRatesTask::new(
358 db_reader.clone(),
359 watermark_task.epoch_receiver(),
360 state.cancellation_token.clone(),
361 );
362
363 let router = router
364 .route_layer(middleware::from_fn_with_state(
365 state.version,
366 set_version_middleware,
367 ))
368 .route_layer(middleware::from_fn_with_state(
369 state.version,
370 check_version_middleware,
371 ))
372 .layer(axum::extract::Extension(schema))
373 .layer(axum::extract::Extension(watermark_task.lock()))
374 .layer(Self::cors()?);
375
376 Ok(Server {
377 router,
378 address: address
379 .parse()
380 .map_err(|_| Error::Internal(format!("Failed to parse address {address}")))?,
381 watermark_task,
382 system_package_task,
383 trigger_exchange_rates_task,
384 state,
385 db_reader,
386 })
387 }
388
389 pub async fn from_config(
392 config: &ServerConfig,
393 version: &Version,
394 cancellation_token: CancellationToken,
395 ) -> Result<Self, Error> {
396 let prom_addr: SocketAddr = format!(
398 "{}:{}",
399 config.connection.prom_url, config.connection.prom_port
400 )
401 .parse()
402 .map_err(|_| {
403 Error::Internal(format!(
404 "Failed to parse url {}, port {} into socket address",
405 config.connection.prom_url, config.connection.prom_port
406 ))
407 })?;
408
409 let registry_service = iota_metrics::start_prometheus_server(prom_addr);
410 info!("Starting Prometheus HTTP endpoint at {}", prom_addr);
411 let registry = registry_service.default_registry();
412 registry
413 .register(iota_metrics::uptime_metric(
414 "graphql",
415 version.full,
416 "unknown",
417 ))
418 .unwrap();
419
420 let metrics = Metrics::new(®istry);
422 let state = AppState::new(
423 config.connection.clone(),
424 config.service.clone(),
425 metrics.clone(),
426 cancellation_token,
427 *version,
428 );
429 let mut builder = ServerBuilder::new(state);
430
431 let iota_names_config = config.service.iota_names.clone();
432 let zklogin_config = config.service.zklogin.clone();
433 let reader = PgManager::reader_with_config(
434 config.connection.db_url.clone(),
435 config.connection.db_pool_size,
436 config.service.limits.request_timeout_ms.into(),
440 )
441 .map_err(|e| Error::Internal(format!("Failed to create pg connection pool: {e}")))?;
442
443 let db = Db::new(
445 reader.clone(),
446 config.service.limits.clone(),
447 metrics.clone(),
448 );
449 let loader = DataLoader::new(db.clone());
450 let pg_conn_pool = PgManager::new(reader.clone());
451 let package_store = DbPackageStore::new(loader.clone());
452 let resolver = Arc::new(Resolver::new_with_limits(
453 PackageStoreWithLruCache::new(package_store),
454 config.service.limits.package_resolver_limits(),
455 ));
456
457 builder.db_reader = Some(db.clone());
458 builder.resolver = Some(resolver.clone());
459
460 let iota_sdk_client = if let Some(url) = &config.tx_exec_full_node.node_rpc_url {
463 Some(
464 IotaClientBuilder::default()
465 .request_timeout(RPC_TIMEOUT_ERR_SLEEP_RETRY_PERIOD)
466 .max_concurrent_requests(MAX_CONCURRENT_REQUESTS)
467 .build(url)
468 .await
469 .map_err(|e| {
470 Error::Internal(format!(
471 "Failed to connect to fullnode {e}. Is the node server running?"
472 ))
473 })?,
474 )
475 } else {
476 warn!(
477 "No fullnode url found in config. `dryRunTransactionBlock` and `executeTransactionBlock` will not work"
478 );
479 None
480 };
481
482 builder = builder
483 .context_data(config.service.clone())
484 .context_data(loader)
485 .context_data(db)
486 .context_data(pg_conn_pool)
487 .context_data(resolver)
488 .context_data(iota_sdk_client)
489 .context_data(iota_names_config)
490 .context_data(zklogin_config)
491 .context_data(metrics.clone())
492 .context_data(config.clone());
493
494 if config.internal_features.feature_gate {
495 builder = builder.extension(FeatureGate);
496 }
497
498 if config.internal_features.logger {
499 builder = builder.extension(Logger::default());
500 }
501
502 if config.internal_features.query_limits_checker {
503 builder = builder.extension(QueryLimitsChecker);
504 }
505
506 if config.internal_features.directive_checker {
507 builder = builder.extension(DirectiveChecker);
508 }
509
510 if config.internal_features.query_timeout {
511 builder = builder.extension(Timeout);
512 }
513
514 if config.internal_features.tracing {
515 builder = builder.extension(Tracing);
516 }
517
518 if config.internal_features.apollo_tracing {
519 builder = builder.extension(ApolloTracing);
520 }
521
522 Ok(builder)
526 }
527}
528
529fn schema_builder() -> SchemaBuilder<Query, Mutation, EmptySubscription> {
530 async_graphql::Schema::build(Query, Mutation, EmptySubscription)
531 .register_output_type::<IMoveObject>()
532 .register_output_type::<IObject>()
533 .register_output_type::<IOwner>()
534 .register_output_type::<IMoveDatatype>()
535}
536
537pub fn export_schema() -> String {
539 schema_builder().finish().sdl()
540}
541
542async fn graphql_handler(
546 ConnectInfo(addr): ConnectInfo<SocketAddr>,
547 schema: Extension<IotaGraphQLSchema>,
548 Extension(watermark_lock): Extension<WatermarkLock>,
549 headers: HeaderMap,
550 req: GraphQLRequest,
551) -> (axum::http::Extensions, GraphQLResponse) {
552 let mut req = req.into_inner();
553 req.data.insert(Uuid::new_v4());
554 if headers.contains_key(ShowUsage::name()) {
555 req.data.insert(ShowUsage)
556 }
557 req.data.insert(addr);
561
562 req.data.insert(Watermark::new(watermark_lock).await);
563
564 let result = schema.execute(req).await;
565
566 let mut extensions = axum::http::Extensions::new();
569 if result.is_err() {
570 extensions.insert(GraphqlErrors(std::sync::Arc::new(result.errors.clone())));
571 };
572 (extensions, result.into())
573}
574
575#[derive(Clone)]
576struct MetricsMakeCallbackHandler {
577 metrics: Metrics,
578}
579
580impl MakeCallbackHandler for MetricsMakeCallbackHandler {
581 type Handler = MetricsCallbackHandler;
582
583 fn make_handler(&self, _request: &http::request::Parts) -> Self::Handler {
584 let start = Instant::now();
585 let metrics = self.metrics.clone();
586
587 metrics.request_metrics.inflight_requests.inc();
588 metrics.inc_num_queries();
589
590 MetricsCallbackHandler { metrics, start }
591 }
592}
593
594struct MetricsCallbackHandler {
595 metrics: Metrics,
596 start: Instant,
597}
598
599impl ResponseHandler for MetricsCallbackHandler {
600 fn on_response(self, response: &http::response::Parts) {
601 if let Some(errors) = response.extensions.get::<GraphqlErrors>() {
602 self.metrics.inc_errors(&errors.0);
603 }
604 }
605
606 fn on_error<E>(self, _error: &E) {
607 }
612}
613
614impl Drop for MetricsCallbackHandler {
615 fn drop(&mut self) {
616 self.metrics.query_latency(self.start.elapsed());
617 self.metrics.request_metrics.inflight_requests.dec();
618 }
619}
620
621#[derive(Debug, Clone)]
622struct GraphqlErrors(std::sync::Arc<Vec<async_graphql::ServerError>>);
623
624async fn db_health_check(State(connection): State<ConnectionConfig>) -> StatusCode {
626 let Ok(url) = reqwest::Url::parse(connection.db_url.as_str()) else {
627 return StatusCode::INTERNAL_SERVER_ERROR;
628 };
629
630 let Some(host) = url.host_str() else {
631 return StatusCode::INTERNAL_SERVER_ERROR;
632 };
633
634 let tcp_url = if let Some(port) = url.port() {
635 format!("{host}:{port}")
636 } else {
637 host.to_string()
638 };
639
640 if TcpStream::connect(tcp_url).is_err() {
641 StatusCode::INTERNAL_SERVER_ERROR
642 } else {
643 StatusCode::OK
644 }
645}
646
647#[derive(serde::Deserialize)]
648struct HealthParam {
649 max_checkpoint_lag_ms: Option<u64>,
650}
651
652async fn health_check(
658 State(connection): State<ConnectionConfig>,
659 Extension(watermark_lock): Extension<WatermarkLock>,
660 AxumQuery(query_params): AxumQuery<HealthParam>,
661) -> StatusCode {
662 let db_health_check = db_health_check(axum::extract::State(connection)).await;
663 if db_health_check != StatusCode::OK {
664 return db_health_check;
665 }
666
667 let max_checkpoint_lag_ms = query_params
668 .max_checkpoint_lag_ms
669 .map(Duration::from_millis)
670 .unwrap_or_else(|| DEFAULT_MAX_CHECKPOINT_LAG);
671
672 let checkpoint_timestamp =
673 Duration::from_millis(watermark_lock.read().await.checkpoint_timestamp_ms);
674
675 let now_millis = Utc::now().timestamp_millis();
676
677 let now: Duration = match u64::try_from(now_millis) {
679 Ok(val) => Duration::from_millis(val),
680 Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
681 };
682
683 if (now - checkpoint_timestamp) > max_checkpoint_lag_ms {
684 return StatusCode::GATEWAY_TIMEOUT;
685 }
686
687 db_health_check
688}
689
690async fn get_or_init_server_start_time() -> &'static Instant {
692 static ONCE: OnceCell<Instant> = OnceCell::const_new();
693 ONCE.get_or_init(|| async move { Instant::now() }).await
694}
695
696pub mod tests {
697 use std::{sync::Arc, time::Duration};
698
699 use async_graphql::{
700 Response,
701 extensions::{Extension, ExtensionContext, NextExecute},
702 };
703 use iota_sdk::{IotaClient, wallet_context::WalletContext};
704 use iota_types::transaction::TransactionData;
705 use uuid::Uuid;
706
707 use super::*;
708 use crate::{
709 config::{ConnectionConfig, Limits, ServiceConfig, Version},
710 context_data::db_data_provider::PgManager,
711 extensions::{query_limits_checker::QueryLimitsChecker, timeout::Timeout},
712 };
713
714 fn prep_schema(
718 connection_config: Option<ConnectionConfig>,
719 service_config: Option<ServiceConfig>,
720 ) -> ServerBuilder {
721 let connection_config = connection_config.unwrap_or_default();
722 let service_config = service_config.unwrap_or_default();
723
724 let reader = PgManager::reader_with_config(
725 connection_config.db_url.clone(),
726 connection_config.db_pool_size,
727 service_config.limits.request_timeout_ms.into(),
728 )
729 .expect("Failed to create pg connection pool");
730
731 let version = Version::for_testing();
732 let metrics = metrics();
733 let db = Db::new(
734 reader.clone(),
735 service_config.limits.clone(),
736 metrics.clone(),
737 );
738 let pg_conn_pool = PgManager::new(reader);
739 let cancellation_token = CancellationToken::new();
740 let watermark = Watermark {
741 checkpoint: 1,
742 checkpoint_timestamp_ms: 1,
743 epoch: 0,
744 };
745 let state = AppState::new(
746 connection_config.clone(),
747 service_config.clone(),
748 metrics.clone(),
749 cancellation_token.clone(),
750 version,
751 );
752 ServerBuilder::new(state)
753 .context_data(db)
754 .context_data(pg_conn_pool)
755 .context_data(service_config)
756 .context_data(query_id())
757 .context_data(ip_address())
758 .context_data(watermark)
759 .context_data(metrics)
760 }
761
762 fn metrics() -> Metrics {
763 let binding_address: SocketAddr = "0.0.0.0:9185".parse().unwrap();
764 let registry = iota_metrics::start_prometheus_server(binding_address).default_registry();
765 Metrics::new(®istry)
766 }
767
768 fn ip_address() -> SocketAddr {
769 let binding_address: SocketAddr = "0.0.0.0:51515".parse().unwrap();
770 binding_address
771 }
772
773 fn query_id() -> Uuid {
774 Uuid::new_v4()
775 }
776
777 pub async fn test_timeout_impl(wallet: &WalletContext) {
778 struct TimedExecuteExt {
779 pub min_req_delay: Duration,
780 }
781
782 impl ExtensionFactory for TimedExecuteExt {
783 fn create(&self) -> Arc<dyn Extension> {
784 Arc::new(TimedExecuteExt {
785 min_req_delay: self.min_req_delay,
786 })
787 }
788 }
789
790 #[async_trait::async_trait]
791 impl Extension for TimedExecuteExt {
792 async fn execute(
793 &self,
794 ctx: &ExtensionContext<'_>,
795 operation_name: Option<&str>,
796 next: NextExecute<'_>,
797 ) -> Response {
798 tokio::time::sleep(self.min_req_delay).await;
799 next.run(ctx, operation_name).await
800 }
801 }
802
803 async fn test_timeout(
804 delay: Duration,
805 timeout: Duration,
806 query: &str,
807 iota_client: &IotaClient,
808 ) -> Response {
809 let mut cfg = ServiceConfig::default();
810 cfg.limits.request_timeout_ms = timeout.as_millis() as u32;
811 cfg.limits.mutation_timeout_ms = timeout.as_millis() as u32;
812
813 let schema = prep_schema(None, Some(cfg))
814 .context_data(Some(iota_client.clone()))
815 .extension(Timeout)
816 .extension(TimedExecuteExt {
817 min_req_delay: delay,
818 })
819 .build_schema();
820
821 schema.execute(query).await
822 }
823
824 let query = "{ chainIdentifier }";
825 let timeout = Duration::from_millis(1000);
826 let delay = Duration::from_millis(100);
827 let iota_client = wallet.get_client().await.unwrap();
828
829 test_timeout(delay, timeout, query, &iota_client)
830 .await
831 .into_result()
832 .expect("Should complete successfully");
833
834 let errs: Vec<_> = test_timeout(delay, delay, query, &iota_client)
836 .await
837 .into_result()
838 .unwrap_err()
839 .into_iter()
840 .map(|e| e.message)
841 .collect();
842 let exp = format!("Query request timed out. Limit: {}s", delay.as_secs_f32());
843 assert_eq!(errs, vec![exp]);
844
845 let addresses = wallet.get_addresses();
849 let gas = wallet
850 .get_one_gas_object_owned_by_address(addresses[0])
851 .await
852 .unwrap();
853 let tx_data = TransactionData::new_transfer_iota(
854 addresses[1],
855 addresses[0],
856 Some(1000),
857 gas.unwrap(),
858 1_000_000,
859 wallet.get_reference_gas_price().await.unwrap(),
860 );
861
862 let tx = wallet.sign_transaction(&tx_data);
863 let (tx_bytes, signatures) = tx.to_tx_bytes_and_signatures();
864
865 let signature_base64 = &signatures[0];
866 let query = format!(
867 r#"
868 mutation {{
869 executeTransactionBlock(txBytes: "{}", signatures: "{}") {{
870 effects {{
871 status
872 }}
873 }}
874 }}"#,
875 tx_bytes.encoded(),
876 signature_base64.encoded()
877 );
878 let errs: Vec<_> = test_timeout(delay, delay, &query, &iota_client)
879 .await
880 .into_result()
881 .unwrap_err()
882 .into_iter()
883 .map(|e| e.message)
884 .collect();
885 let exp = format!(
886 "Mutation request timed out. Limit: {}s",
887 delay.as_secs_f32()
888 );
889 assert_eq!(errs, vec![exp]);
890 }
891
892 pub async fn test_query_depth_limit_impl() {
893 async fn exec_query_depth_limit(depth: u32, query: &str) -> Response {
894 let service_config = ServiceConfig {
895 limits: Limits {
896 max_query_depth: depth,
897 ..Default::default()
898 },
899 ..Default::default()
900 };
901
902 let schema = prep_schema(None, Some(service_config))
903 .extension(QueryLimitsChecker)
904 .build_schema();
905 schema.execute(query).await
906 }
907
908 exec_query_depth_limit(1, "{ chainIdentifier }")
909 .await
910 .into_result()
911 .expect("Should complete successfully");
912
913 exec_query_depth_limit(
914 5,
915 "{ chainIdentifier protocolConfig { configs { value key }} }",
916 )
917 .await
918 .into_result()
919 .expect("Should complete successfully");
920
921 let errs: Vec<_> = exec_query_depth_limit(0, "{ chainIdentifier }")
923 .await
924 .into_result()
925 .unwrap_err()
926 .into_iter()
927 .map(|e| e.message)
928 .collect();
929
930 assert_eq!(errs, vec!["Query nesting is over 0".to_string()]);
931 let errs: Vec<_> = exec_query_depth_limit(
932 2,
933 "{ chainIdentifier protocolConfig { configs { value key }} }",
934 )
935 .await
936 .into_result()
937 .unwrap_err()
938 .into_iter()
939 .map(|e| e.message)
940 .collect();
941 assert_eq!(errs, vec!["Query nesting is over 2".to_string()]);
942 }
943
944 pub async fn test_query_node_limit_impl() {
945 async fn exec_query_node_limit(nodes: u32, query: &str) -> Response {
946 let service_config = ServiceConfig {
947 limits: Limits {
948 max_query_nodes: nodes,
949 ..Default::default()
950 },
951 ..Default::default()
952 };
953
954 let schema = prep_schema(None, Some(service_config))
955 .extension(QueryLimitsChecker)
956 .build_schema();
957 schema.execute(query).await
958 }
959
960 exec_query_node_limit(1, "{ chainIdentifier }")
961 .await
962 .into_result()
963 .expect("Should complete successfully");
964
965 exec_query_node_limit(
966 5,
967 "{ chainIdentifier protocolConfig { configs { value key }} }",
968 )
969 .await
970 .into_result()
971 .expect("Should complete successfully");
972
973 let err: Vec<_> = exec_query_node_limit(0, "{ chainIdentifier }")
975 .await
976 .into_result()
977 .unwrap_err()
978 .into_iter()
979 .map(|e| e.message)
980 .collect();
981 assert_eq!(err, vec!["Query has over 0 nodes".to_string()]);
982
983 let err: Vec<_> = exec_query_node_limit(
984 4,
985 "{ chainIdentifier protocolConfig { configs { value key }} }",
986 )
987 .await
988 .into_result()
989 .unwrap_err()
990 .into_iter()
991 .map(|e| e.message)
992 .collect();
993 assert_eq!(err, vec!["Query has over 4 nodes".to_string()]);
994 }
995
996 pub async fn test_query_default_page_limit_impl(connection_config: ConnectionConfig) {
997 let service_config = ServiceConfig {
998 limits: Limits {
999 default_page_size: 1,
1000 ..Default::default()
1001 },
1002 ..Default::default()
1003 };
1004 let schema = prep_schema(Some(connection_config), Some(service_config)).build_schema();
1005
1006 let resp = schema
1007 .execute("{ checkpoints { nodes { sequenceNumber } } }")
1008 .await;
1009 let data = resp.data.clone().into_json().unwrap();
1010 let checkpoints = data
1011 .get("checkpoints")
1012 .unwrap()
1013 .get("nodes")
1014 .unwrap()
1015 .as_array()
1016 .unwrap();
1017 assert_eq!(
1018 checkpoints.len(),
1019 1,
1020 "Checkpoints should have exactly one element"
1021 );
1022
1023 let resp = schema
1024 .execute("{ checkpoints(first: 2) { nodes { sequenceNumber } } }")
1025 .await;
1026 let data = resp.data.clone().into_json().unwrap();
1027 let checkpoints = data
1028 .get("checkpoints")
1029 .unwrap()
1030 .get("nodes")
1031 .unwrap()
1032 .as_array()
1033 .unwrap();
1034 assert_eq!(
1035 checkpoints.len(),
1036 2,
1037 "Checkpoints should return two elements"
1038 );
1039 }
1040
1041 pub async fn test_query_max_page_limit_impl() {
1042 let schema = prep_schema(None, None).build_schema();
1043
1044 schema
1045 .execute("{ objects(first: 1) { nodes { version } } }")
1046 .await
1047 .into_result()
1048 .expect("Should complete successfully");
1049
1050 let err: Vec<_> = schema
1052 .execute("{ objects(first: 51) { nodes { version } } }")
1053 .await
1054 .into_result()
1055 .unwrap_err()
1056 .into_iter()
1057 .map(|e| e.message)
1058 .collect();
1059 assert_eq!(
1060 err,
1061 vec!["Connection's page size of 51 exceeds max of 50".to_string()]
1062 );
1063 }
1064
1065 pub async fn test_query_complexity_metrics_impl() {
1066 let server_builder = prep_schema(None, None);
1067 let metrics = server_builder.state.metrics.clone();
1068 let schema = server_builder
1069 .extension(QueryLimitsChecker) .build_schema();
1071
1072 schema
1073 .execute("{ chainIdentifier }")
1074 .await
1075 .into_result()
1076 .expect("Should complete successfully");
1077
1078 let req_metrics = metrics.request_metrics;
1079 assert_eq!(req_metrics.input_nodes.get_sample_count(), 1);
1080 assert_eq!(req_metrics.output_nodes.get_sample_count(), 1);
1081 assert_eq!(req_metrics.query_depth.get_sample_count(), 1);
1082 assert_eq!(req_metrics.input_nodes.get_sample_sum(), 1.);
1083 assert_eq!(req_metrics.output_nodes.get_sample_sum(), 1.);
1084 assert_eq!(req_metrics.query_depth.get_sample_sum(), 1.);
1085
1086 schema
1087 .execute("{ chainIdentifier protocolConfig { configs { value key }} }")
1088 .await
1089 .into_result()
1090 .expect("Should complete successfully");
1091
1092 assert_eq!(req_metrics.input_nodes.get_sample_count(), 2);
1093 assert_eq!(req_metrics.output_nodes.get_sample_count(), 2);
1094 assert_eq!(req_metrics.query_depth.get_sample_count(), 2);
1095 assert_eq!(req_metrics.input_nodes.get_sample_sum(), 2. + 4.);
1096 assert_eq!(req_metrics.output_nodes.get_sample_sum(), 2. + 4.);
1097 assert_eq!(req_metrics.query_depth.get_sample_sum(), 1. + 3.);
1098 }
1099
1100 pub async fn test_health_check_impl() {
1101 let server_builder = prep_schema(None, None);
1102 let url = format!(
1103 "http://{}:{}/health",
1104 server_builder.state.connection.host, server_builder.state.connection.port
1105 );
1106 server_builder.build_schema();
1107
1108 let resp = reqwest::get(&url).await.unwrap();
1109 assert_eq!(resp.status(), reqwest::StatusCode::OK);
1110
1111 let url_with_param = format!("{url}?max_checkpoint_lag_ms=1");
1112 let resp = reqwest::get(&url_with_param).await.unwrap();
1113 assert_eq!(resp.status(), reqwest::StatusCode::GATEWAY_TIMEOUT);
1114 }
1115}