iota_graphql_rpc/extensions/
timeout.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    net::SocketAddr,
7    sync::{
8        Arc, Mutex,
9        atomic::{AtomicBool, Ordering},
10    },
11    time::Duration,
12};
13
14use async_graphql::{
15    Response, ServerError, ServerResult,
16    extensions::{Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery},
17    parser::types::{ExecutableDocument, OperationType},
18};
19use async_graphql_value::Variables;
20use tokio::time::timeout;
21use tracing::error;
22use uuid::Uuid;
23
24use crate::{config::ServiceConfig, error::code};
25
26/// Extension factory for creating new `Timeout` instances, per query.
27pub(crate) struct Timeout;
28
29#[derive(Debug, Default)]
30struct TimeoutExt {
31    pub query: Mutex<Option<String>>,
32    pub is_mutation: AtomicBool,
33}
34
35impl ExtensionFactory for Timeout {
36    fn create(&self) -> Arc<dyn Extension> {
37        Arc::new(TimeoutExt {
38            query: Mutex::new(None),
39            is_mutation: AtomicBool::new(false),
40        })
41    }
42}
43
44#[async_trait::async_trait]
45impl Extension for TimeoutExt {
46    async fn parse_query(
47        &self,
48        ctx: &ExtensionContext<'_>,
49        query: &str,
50        variables: &Variables,
51        next: NextParseQuery<'_>,
52    ) -> ServerResult<ExecutableDocument> {
53        let document = next.run(ctx, query, variables).await?;
54        *self.query.lock().unwrap() = Some(ctx.stringify_execute_doc(&document, variables));
55
56        let is_mutation = document
57            .operations
58            .iter()
59            .any(|(_, operation)| operation.node.ty == OperationType::Mutation);
60        self.is_mutation.store(is_mutation, Ordering::Relaxed);
61
62        Ok(document)
63    }
64
65    async fn execute(
66        &self,
67        ctx: &ExtensionContext<'_>,
68        operation_name: Option<&str>,
69        next: NextExecute<'_>,
70    ) -> Response {
71        let cfg: &ServiceConfig = ctx
72            .data()
73            .expect("No service config provided in schema data");
74
75        // increase the timeout if the request is a mutation
76        let is_mutation = self.is_mutation.load(Ordering::Relaxed);
77        let request_timeout = if is_mutation {
78            Duration::from_millis(cfg.limits.mutation_timeout_ms.into())
79        } else {
80            Duration::from_millis(cfg.limits.request_timeout_ms.into())
81        };
82
83        timeout(request_timeout, next.run(ctx, operation_name))
84            .await
85            .unwrap_or_else(|_| {
86                let query_id: &Uuid = ctx.data_unchecked();
87                let session_id: &SocketAddr = ctx.data_unchecked();
88                let error_code = code::REQUEST_TIMEOUT;
89                let guard = self.query.lock().unwrap();
90                let query = match guard.as_ref() {
91                    Some(s) => s.as_str(),
92                    None => "",
93                };
94
95                error!(
96                    %query_id,
97                    %session_id,
98                    %error_code,
99                    %query
100                );
101                let error_msg = if is_mutation {
102                    format!(
103                        "Mutation request timed out. Limit: {}s",
104                        request_timeout.as_secs_f32()
105                    )
106                } else {
107                    format!(
108                        "Query request timed out. Limit: {}s",
109                        request_timeout.as_secs_f32()
110                    )
111                };
112                Response::from_errors(vec![ServerError::new(error_msg, None)])
113            })
114    }
115}