iota_graphql_rpc/data/
pg.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{fmt, time::Instant};
6
7use async_trait::async_trait;
8use diesel::{
9    QueryResult, RunQueryDsl,
10    pg::Pg,
11    query_builder::{Query, QueryFragment, QueryId},
12    query_dsl::LoadQuery,
13};
14use iota_indexer::{
15    read::IndexerReader, run_query_async, run_query_repeatable_async, spawn_read_only_blocking,
16};
17use tracing::error;
18
19use crate::{config::Limits, data::QueryExecutor, error::Error, metrics::Metrics};
20
21#[derive(Clone)]
22pub(crate) struct PgExecutor {
23    pub inner: IndexerReader,
24    pub limits: Limits,
25    pub metrics: Metrics,
26}
27
28pub(crate) struct PgConnection<'c> {
29    max_cost: u32,
30    conn: &'c mut diesel::PgConnection,
31}
32
33pub(crate) struct ByteaLiteral<'a>(pub &'a [u8]);
34
35impl PgExecutor {
36    pub(crate) fn new(inner: IndexerReader, limits: Limits, metrics: Metrics) -> Self {
37        Self {
38            inner,
39            limits,
40            metrics,
41        }
42    }
43}
44
45#[async_trait]
46impl QueryExecutor for PgExecutor {
47    type Connection = diesel::PgConnection;
48    type Backend = Pg;
49    type DbConnection<'c> = PgConnection<'c>;
50
51    async fn execute<T, U, E>(&self, txn: T) -> Result<U, Error>
52    where
53        T: FnOnce(&mut Self::DbConnection<'_>) -> Result<U, E>,
54        E: From<diesel::result::Error> + std::error::Error,
55        T: Send + 'static,
56        U: Send + 'static,
57        E: Send + 'static,
58    {
59        let max_cost = self.limits.max_db_query_cost;
60        let instant = Instant::now();
61        let pool = self.inner.get_pool();
62        #[allow(unexpected_cfgs)]
63        let result = run_query_async!(&pool, move |conn| txn(&mut PgConnection { max_cost, conn }));
64        self.metrics
65            .observe_db_data(instant.elapsed(), result.is_ok());
66        if let Err(e) = &result {
67            error!("db query error: {e:?}");
68        }
69        result.map_err(|e| Error::Internal(e.to_string()))
70    }
71
72    async fn execute_repeatable<T, U, E>(&self, txn: T) -> Result<U, Error>
73    where
74        T: FnOnce(&mut Self::DbConnection<'_>) -> Result<U, E>,
75        E: From<diesel::result::Error> + std::error::Error,
76        T: Send + 'static,
77        U: Send + 'static,
78        E: Send + 'static,
79    {
80        let max_cost = self.limits.max_db_query_cost;
81        let instant = Instant::now();
82        let pool = self.inner.get_pool();
83        #[allow(unexpected_cfgs)]
84        let result = run_query_repeatable_async!(&pool, move |conn| txn(&mut PgConnection {
85            max_cost,
86            conn
87        }));
88        self.metrics
89            .observe_db_data(instant.elapsed(), result.is_ok());
90        if let Err(e) = &result {
91            error!("db query error: {e:?}");
92        }
93        result.map_err(|e| Error::Internal(e.to_string()))
94    }
95}
96
97impl super::DbConnection for PgConnection<'_> {
98    type Connection = diesel::PgConnection;
99    type Backend = Pg;
100
101    fn result<Q, U>(&mut self, query: impl Fn() -> Q) -> QueryResult<U>
102    where
103        Q: diesel::query_builder::Query,
104        Q: LoadQuery<'static, Self::Connection, U>,
105        Q: QueryId + QueryFragment<Self::Backend>,
106    {
107        query_cost::log(self.conn, self.max_cost, query());
108        query().get_result(self.conn)
109    }
110
111    fn results<Q, U>(&mut self, query: impl Fn() -> Q) -> QueryResult<Vec<U>>
112    where
113        Q: diesel::query_builder::Query,
114        Q: LoadQuery<'static, Self::Connection, U>,
115        Q: QueryId + QueryFragment<Self::Backend>,
116    {
117        query_cost::log(self.conn, self.max_cost, query());
118        query().get_results(self.conn)
119    }
120}
121
122impl fmt::Display for ByteaLiteral<'_> {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        write!(f, "'\\x{}'::bytea", hex::encode(self.0))
125    }
126}
127
128pub(crate) fn bytea_literal(slice: &[u8]) -> ByteaLiteral<'_> {
129    ByteaLiteral(slice)
130}
131
132/// Support for calculating estimated query cost using EXPLAIN and then logging
133/// it.
134mod query_cost {
135    use diesel::{PgConnection, QueryResult, query_builder::AstPass, sql_types::Text};
136    use serde_json::Value;
137    use tap::{TapFallible, TapOptional};
138    use tracing::{debug, info, warn};
139
140    use super::*;
141
142    #[derive(Debug, Clone, Copy, QueryId)]
143    struct Explained<Q> {
144        query: Q,
145    }
146
147    impl<Q: Query> Query for Explained<Q> {
148        type SqlType = Text;
149    }
150
151    impl<Q> RunQueryDsl<PgConnection> for Explained<Q> {}
152
153    impl<Q: QueryFragment<Pg>> QueryFragment<Pg> for Explained<Q> {
154        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
155            out.push_sql("EXPLAIN (FORMAT JSON) ");
156            self.query.walk_ast(out.reborrow())?;
157            Ok(())
158        }
159    }
160
161    /// Run `EXPLAIN` on the `query`, and log the estimated cost.
162    pub(crate) fn log<Q>(conn: &mut PgConnection, max_db_query_cost: u32, query: Q)
163    where
164        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<PgConnection>,
165    {
166        debug!("Estimating: {}", diesel::debug_query(&query).to_string());
167
168        let Some(cost) = explain(conn, query) else {
169            warn!("Failed to extract cost from EXPLAIN.");
170            return;
171        };
172
173        if cost > max_db_query_cost as f64 {
174            warn!(cost, max_db_query_cost, exceeds = true, "Estimated cost");
175        } else {
176            info!(cost, max_db_query_cost, exceeds = false, "Estimated cost");
177        }
178    }
179
180    pub(crate) fn explain<Q>(conn: &mut PgConnection, query: Q) -> Option<f64>
181    where
182        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<PgConnection>,
183    {
184        let result: String = Explained { query }
185            .get_result(conn)
186            .tap_err(|e| warn!("Failed to run EXPLAIN: {e}"))
187            .ok()?;
188
189        let parsed = serde_json::from_str(&result)
190            .tap_err(|e| warn!("Failed to parse EXPLAIN result: {e}"))
191            .ok()?;
192
193        extract_cost(&parsed).tap_none(|| warn!("Failed to extract cost from EXPLAIN"))
194    }
195
196    fn extract_cost(parsed: &Value) -> Option<f64> {
197        parsed.get(0)?.get("Plan")?.get("Total Cost")?.as_f64()
198    }
199}
200
201#[cfg(all(test, feature = "pg_integration"))]
202mod tests {
203    use diesel::QueryDsl;
204    use iota_framework::BuiltInFramework;
205    use iota_indexer::{
206        db::{ConnectionPoolConfig, get_pool_connection, new_connection_pool, reset_database},
207        models::objects::StoredObject,
208        schema::objects,
209        types::IndexedObject,
210    };
211
212    use super::*;
213    use crate::config::ConnectionConfig;
214
215    #[test]
216    fn test_query_cost() {
217        let connection_config = ConnectionConfig::default();
218        let connection_pool_config = ConnectionPoolConfig {
219            pool_size: connection_config.db_pool_size,
220            ..Default::default()
221        };
222        let pool = new_connection_pool(&connection_config.db_url, &connection_pool_config).unwrap();
223        let mut conn = get_pool_connection(&pool).unwrap();
224        reset_database(&mut conn).unwrap();
225
226        let objects: Vec<StoredObject> = BuiltInFramework::iter_system_packages()
227            .map(|pkg| IndexedObject::from_object(1, pkg.genesis_object(), None).into())
228            .collect();
229
230        let expect = objects.len();
231        let actual = diesel::insert_into(objects::dsl::objects)
232            .values(objects)
233            .execute(&mut conn)
234            .unwrap();
235
236        assert_eq!(expect, actual, "Failed to write objects");
237
238        use objects::dsl;
239        let query_one = dsl::objects.select(dsl::objects.star()).limit(1);
240        let query_all = dsl::objects.select(dsl::objects.star());
241
242        // Test estimating query costs
243        let cost_one = query_cost::explain(&mut conn, query_one).unwrap();
244        let cost_all = query_cost::explain(&mut conn, query_all).unwrap();
245
246        assert!(
247            cost_one < cost_all,
248            "cost_one = {cost_one} >= {cost_all} = cost_all"
249        );
250    }
251}