iota_graphql_rpc/extensions/
timeout.rs1use std::{
6 net::SocketAddr,
7 sync::{
8 Arc, Mutex,
9 atomic::{AtomicBool, Ordering},
10 },
11 time::Duration,
12};
13
14use async_graphql::{
15 Response, ServerError, ServerResult,
16 extensions::{Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery},
17 parser::types::{ExecutableDocument, OperationType},
18};
19use async_graphql_value::Variables;
20use tokio::time::timeout;
21use tracing::error;
22use uuid::Uuid;
23
24use crate::{config::ServiceConfig, error::code};
25
26pub(crate) struct Timeout;
28
29#[derive(Debug, Default)]
30struct TimeoutExt {
31 pub query: Mutex<Option<String>>,
32 pub is_mutation: AtomicBool,
33}
34
35impl ExtensionFactory for Timeout {
36 fn create(&self) -> Arc<dyn Extension> {
37 Arc::new(TimeoutExt {
38 query: Mutex::new(None),
39 is_mutation: AtomicBool::new(false),
40 })
41 }
42}
43
44#[async_trait::async_trait]
45impl Extension for TimeoutExt {
46 async fn parse_query(
47 &self,
48 ctx: &ExtensionContext<'_>,
49 query: &str,
50 variables: &Variables,
51 next: NextParseQuery<'_>,
52 ) -> ServerResult<ExecutableDocument> {
53 let document = next.run(ctx, query, variables).await?;
54 *self.query.lock().unwrap() = Some(ctx.stringify_execute_doc(&document, variables));
55
56 let is_mutation = document
57 .operations
58 .iter()
59 .any(|(_, operation)| operation.node.ty == OperationType::Mutation);
60 self.is_mutation.store(is_mutation, Ordering::Relaxed);
61
62 Ok(document)
63 }
64
65 async fn execute(
66 &self,
67 ctx: &ExtensionContext<'_>,
68 operation_name: Option<&str>,
69 next: NextExecute<'_>,
70 ) -> Response {
71 let cfg: &ServiceConfig = ctx
72 .data()
73 .expect("No service config provided in schema data");
74
75 let is_mutation = self.is_mutation.load(Ordering::Relaxed);
77 let request_timeout = if is_mutation {
78 Duration::from_millis(cfg.limits.mutation_timeout_ms.into())
79 } else {
80 Duration::from_millis(cfg.limits.request_timeout_ms.into())
81 };
82
83 timeout(request_timeout, next.run(ctx, operation_name))
84 .await
85 .unwrap_or_else(|_| {
86 let query_id: &Uuid = ctx.data_unchecked();
87 let session_id: &SocketAddr = ctx.data_unchecked();
88 let error_code = code::REQUEST_TIMEOUT;
89 let guard = self.query.lock().unwrap();
90 let query = match guard.as_ref() {
91 Some(s) => s.as_str(),
92 None => "",
93 };
94
95 error!(
96 %query_id,
97 %session_id,
98 %error_code,
99 %query
100 );
101 let error_msg = if is_mutation {
102 format!(
103 "Mutation request timed out. Limit: {}s",
104 request_timeout.as_secs_f32()
105 )
106 } else {
107 format!(
108 "Query request timed out. Limit: {}s",
109 request_timeout.as_secs_f32()
110 )
111 };
112 Response::from_errors(vec![ServerError::new(error_msg, None)])
113 })
114 }
115}