Skip to main content

iota_graphql_rpc/data/
pg.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{fmt, time::Instant};
6
7use async_trait::async_trait;
8use diesel::{
9    QueryResult, RunQueryDsl,
10    pg::Pg,
11    query_builder::{Query, QueryFragment, QueryId},
12    query_dsl::LoadQuery,
13};
14use iota_indexer::{
15    read::IndexerReader, run_query_async, run_query_repeatable_async, spawn_read_only_blocking,
16};
17use tracing::error;
18
19use crate::{config::Limits, data::QueryExecutor, error::Error, metrics::Metrics};
20
21#[derive(Clone)]
22pub(crate) struct PgExecutor {
23    pub inner: IndexerReader,
24    pub limits: Limits,
25    pub metrics: Metrics,
26    /// Maximum size of the `availableRange` window (`MAX_AVAILABLE_RANGE`).
27    /// Caps how many checkpoints below `latest` a consistent view can be
28    /// requested for.
29    pub max_available_range: u64,
30}
31
32pub(crate) struct PgConnection<'c> {
33    max_cost: u32,
34    conn: &'c mut diesel::PgConnection,
35}
36
37pub(crate) struct ByteaLiteral<'a>(pub &'a [u8]);
38
39impl PgExecutor {
40    pub(crate) fn new(
41        inner: IndexerReader,
42        limits: Limits,
43        metrics: Metrics,
44        max_available_range: u64,
45    ) -> Self {
46        Self {
47            inner,
48            limits,
49            metrics,
50            max_available_range,
51        }
52    }
53}
54
55#[async_trait]
56impl QueryExecutor for PgExecutor {
57    type Connection = diesel::PgConnection;
58    type Backend = Pg;
59    type DbConnection<'c> = PgConnection<'c>;
60
61    async fn execute<T, U, E>(&self, txn: T) -> Result<U, Error>
62    where
63        T: FnOnce(&mut Self::DbConnection<'_>) -> Result<U, E>,
64        E: From<diesel::result::Error> + std::error::Error,
65        T: Send + 'static,
66        U: Send + 'static,
67        E: Send + 'static,
68    {
69        let max_cost = self.limits.max_db_query_cost;
70        let instant = Instant::now();
71        let pool = self.inner.get_pool();
72        #[allow(unexpected_cfgs)]
73        let result = run_query_async!(&pool, move |conn| txn(&mut PgConnection { max_cost, conn }));
74        self.metrics
75            .observe_db_data(instant.elapsed(), result.is_ok());
76        if let Err(e) = &result {
77            error!("db query error: {e:?}");
78        }
79        result.map_err(|e| Error::Internal(e.to_string()))
80    }
81
82    async fn execute_repeatable<T, U, E>(&self, txn: T) -> Result<U, Error>
83    where
84        T: FnOnce(&mut Self::DbConnection<'_>) -> Result<U, E>,
85        E: From<diesel::result::Error> + std::error::Error,
86        T: Send + 'static,
87        U: Send + 'static,
88        E: Send + 'static,
89    {
90        let max_cost = self.limits.max_db_query_cost;
91        let instant = Instant::now();
92        let pool = self.inner.get_pool();
93        #[allow(unexpected_cfgs)]
94        let result = run_query_repeatable_async!(&pool, move |conn| txn(&mut PgConnection {
95            max_cost,
96            conn
97        }));
98        self.metrics
99            .observe_db_data(instant.elapsed(), result.is_ok());
100        if let Err(e) = &result {
101            error!("db query error: {e:?}");
102        }
103        result.map_err(|e| Error::Internal(e.to_string()))
104    }
105}
106
107impl super::DbConnection for PgConnection<'_> {
108    type Connection = diesel::PgConnection;
109    type Backend = Pg;
110
111    fn result<Q, U>(&mut self, query: impl Fn() -> Q) -> QueryResult<U>
112    where
113        Q: diesel::query_builder::Query,
114        Q: LoadQuery<'static, Self::Connection, U>,
115        Q: QueryId + QueryFragment<Self::Backend>,
116    {
117        query_cost::log(self.conn, self.max_cost, query());
118        query().get_result(self.conn)
119    }
120
121    fn results<Q, U>(&mut self, query: impl Fn() -> Q) -> QueryResult<Vec<U>>
122    where
123        Q: diesel::query_builder::Query,
124        Q: LoadQuery<'static, Self::Connection, U>,
125        Q: QueryId + QueryFragment<Self::Backend>,
126    {
127        query_cost::log(self.conn, self.max_cost, query());
128        query().get_results(self.conn)
129    }
130}
131
132impl fmt::Display for ByteaLiteral<'_> {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        write!(f, "'\\x{}'::bytea", hex::encode(self.0))
135    }
136}
137
138pub(crate) fn bytea_literal(slice: &[u8]) -> ByteaLiteral<'_> {
139    ByteaLiteral(slice)
140}
141
142/// Support for calculating estimated query cost using EXPLAIN and then logging
143/// it.
144mod query_cost {
145    use diesel::{PgConnection, QueryResult, query_builder::AstPass, sql_types::Text};
146    use serde_json::Value;
147    use tap::{TapFallible, TapOptional};
148    use tracing::{debug, info, warn};
149
150    use super::*;
151
152    #[derive(Debug, Clone, Copy, QueryId)]
153    struct Explained<Q> {
154        query: Q,
155    }
156
157    impl<Q: Query> Query for Explained<Q> {
158        type SqlType = Text;
159    }
160
161    impl<Q> RunQueryDsl<PgConnection> for Explained<Q> {}
162
163    impl<Q: QueryFragment<Pg>> QueryFragment<Pg> for Explained<Q> {
164        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
165            out.push_sql("EXPLAIN (FORMAT JSON) ");
166            self.query.walk_ast(out.reborrow())?;
167            Ok(())
168        }
169    }
170
171    /// Run `EXPLAIN` on the `query`, and log the estimated cost.
172    pub(crate) fn log<Q>(conn: &mut PgConnection, max_db_query_cost: u32, query: Q)
173    where
174        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<PgConnection>,
175    {
176        debug!("Estimating: {}", diesel::debug_query(&query).to_string());
177
178        let Some(cost) = explain(conn, query) else {
179            warn!("Failed to extract cost from EXPLAIN.");
180            return;
181        };
182
183        if cost > max_db_query_cost as f64 {
184            warn!(cost, max_db_query_cost, exceeds = true, "Estimated cost");
185        } else {
186            info!(cost, max_db_query_cost, exceeds = false, "Estimated cost");
187        }
188    }
189
190    pub(crate) fn explain<Q>(conn: &mut PgConnection, query: Q) -> Option<f64>
191    where
192        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<PgConnection>,
193    {
194        let result: String = Explained { query }
195            .get_result(conn)
196            .tap_err(|e| warn!("Failed to run EXPLAIN: {e}"))
197            .ok()?;
198
199        let parsed = serde_json::from_str(&result)
200            .tap_err(|e| warn!("Failed to parse EXPLAIN result: {e}"))
201            .ok()?;
202
203        extract_cost(&parsed).tap_none(|| warn!("Failed to extract cost from EXPLAIN"))
204    }
205
206    fn extract_cost(parsed: &Value) -> Option<f64> {
207        parsed.get(0)?.get("Plan")?.get("Total Cost")?.as_f64()
208    }
209}
210
211#[cfg(all(test, feature = "pg_integration"))]
212mod tests {
213    use diesel::QueryDsl;
214    use iota_framework::BuiltInFramework;
215    use iota_indexer::{
216        db::{ConnectionPoolConfig, get_pool_connection, new_connection_pool, reset_database},
217        models::objects::StoredObject,
218        schema::objects,
219        types::IndexedObject,
220    };
221
222    use super::*;
223    use crate::config::ConnectionConfig;
224
225    #[test]
226    fn test_query_cost() {
227        let connection_config = ConnectionConfig::default();
228        let connection_pool_config = ConnectionPoolConfig {
229            pool_size: connection_config.db_pool_size,
230            ..Default::default()
231        };
232        let pool = new_connection_pool(&connection_config.db_url, &connection_pool_config).unwrap();
233        let mut conn = get_pool_connection(&pool).unwrap();
234        reset_database(&mut conn).unwrap();
235
236        let objects: Vec<StoredObject> = BuiltInFramework::iter_system_packages()
237            .map(|pkg| IndexedObject::from_object(Some(1), pkg.genesis_object(), None).into())
238            .collect();
239
240        let expect = objects.len();
241        let actual = diesel::insert_into(objects::dsl::objects)
242            .values(objects)
243            .execute(&mut conn)
244            .unwrap();
245
246        assert_eq!(expect, actual, "Failed to write objects");
247
248        use objects::dsl;
249        let query_one = dsl::objects.select(dsl::objects.star()).limit(1);
250        let query_all = dsl::objects.select(dsl::objects.star());
251
252        // Test estimating query costs
253        let cost_one = query_cost::explain(&mut conn, query_one).unwrap();
254        let cost_all = query_cost::explain(&mut conn, query_all).unwrap();
255
256        assert!(
257            cost_one < cost_all,
258            "cost_one = {cost_one} >= {cost_all} = cost_all"
259        );
260    }
261}