iota_graphql_rpc/data/
pg.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{fmt, time::Instant};
6
7use async_trait::async_trait;
8use diesel::{
9    QueryResult, RunQueryDsl,
10    pg::Pg,
11    query_builder::{Query, QueryFragment, QueryId},
12    query_dsl::LoadQuery,
13};
14use iota_indexer::{
15    indexer_reader::IndexerReader, run_query_async, run_query_repeatable_async,
16    spawn_read_only_blocking,
17};
18use tracing::error;
19
20use crate::{config::Limits, data::QueryExecutor, error::Error, metrics::Metrics};
21
22#[derive(Clone)]
23pub(crate) struct PgExecutor {
24    pub inner: IndexerReader,
25    pub limits: Limits,
26    pub metrics: Metrics,
27}
28
29pub(crate) struct PgConnection<'c> {
30    max_cost: u32,
31    conn: &'c mut diesel::PgConnection,
32}
33
34pub(crate) struct ByteaLiteral<'a>(pub &'a [u8]);
35
36impl PgExecutor {
37    pub(crate) fn new(inner: IndexerReader, limits: Limits, metrics: Metrics) -> Self {
38        Self {
39            inner,
40            limits,
41            metrics,
42        }
43    }
44}
45
46#[async_trait]
47impl QueryExecutor for PgExecutor {
48    type Connection = diesel::PgConnection;
49    type Backend = Pg;
50    type DbConnection<'c> = PgConnection<'c>;
51
52    async fn execute<T, U, E>(&self, txn: T) -> Result<U, Error>
53    where
54        T: FnOnce(&mut Self::DbConnection<'_>) -> Result<U, E>,
55        E: From<diesel::result::Error> + std::error::Error,
56        T: Send + 'static,
57        U: Send + 'static,
58        E: Send + 'static,
59    {
60        let max_cost = self.limits.max_db_query_cost;
61        let instant = Instant::now();
62        let pool = self.inner.get_pool();
63        #[allow(unexpected_cfgs)]
64        let result = run_query_async!(&pool, move |conn| txn(&mut PgConnection { max_cost, conn }));
65        self.metrics
66            .observe_db_data(instant.elapsed(), result.is_ok());
67        if let Err(e) = &result {
68            error!("DB query error: {e:?}");
69        }
70        result.map_err(|e| Error::Internal(e.to_string()))
71    }
72
73    async fn execute_repeatable<T, U, E>(&self, txn: T) -> Result<U, Error>
74    where
75        T: FnOnce(&mut Self::DbConnection<'_>) -> Result<U, E>,
76        E: From<diesel::result::Error> + std::error::Error,
77        T: Send + 'static,
78        U: Send + 'static,
79        E: Send + 'static,
80    {
81        let max_cost = self.limits.max_db_query_cost;
82        let instant = Instant::now();
83        let pool = self.inner.get_pool();
84        #[allow(unexpected_cfgs)]
85        let result = run_query_repeatable_async!(&pool, move |conn| txn(&mut PgConnection {
86            max_cost,
87            conn
88        }));
89        self.metrics
90            .observe_db_data(instant.elapsed(), result.is_ok());
91        if let Err(e) = &result {
92            error!("DB query error: {e:?}");
93        }
94        result.map_err(|e| Error::Internal(e.to_string()))
95    }
96}
97
98impl super::DbConnection for PgConnection<'_> {
99    type Connection = diesel::PgConnection;
100    type Backend = Pg;
101
102    fn result<Q, U>(&mut self, query: impl Fn() -> Q) -> QueryResult<U>
103    where
104        Q: diesel::query_builder::Query,
105        Q: LoadQuery<'static, Self::Connection, U>,
106        Q: QueryId + QueryFragment<Self::Backend>,
107    {
108        query_cost::log(self.conn, self.max_cost, query());
109        query().get_result(self.conn)
110    }
111
112    fn results<Q, U>(&mut self, query: impl Fn() -> Q) -> QueryResult<Vec<U>>
113    where
114        Q: diesel::query_builder::Query,
115        Q: LoadQuery<'static, Self::Connection, U>,
116        Q: QueryId + QueryFragment<Self::Backend>,
117    {
118        query_cost::log(self.conn, self.max_cost, query());
119        query().get_results(self.conn)
120    }
121}
122
123impl fmt::Display for ByteaLiteral<'_> {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        write!(f, "'\\x{}'::bytea", hex::encode(self.0))
126    }
127}
128
129pub(crate) fn bytea_literal(slice: &[u8]) -> ByteaLiteral<'_> {
130    ByteaLiteral(slice)
131}
132
133/// Support for calculating estimated query cost using EXPLAIN and then logging
134/// it.
135mod query_cost {
136    use diesel::{PgConnection, QueryResult, query_builder::AstPass, sql_types::Text};
137    use serde_json::Value;
138    use tap::{TapFallible, TapOptional};
139    use tracing::{debug, info, warn};
140
141    use super::*;
142
143    #[derive(Debug, Clone, Copy, QueryId)]
144    struct Explained<Q> {
145        query: Q,
146    }
147
148    impl<Q: Query> Query for Explained<Q> {
149        type SqlType = Text;
150    }
151
152    impl<Q> RunQueryDsl<PgConnection> for Explained<Q> {}
153
154    impl<Q: QueryFragment<Pg>> QueryFragment<Pg> for Explained<Q> {
155        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
156            out.push_sql("EXPLAIN (FORMAT JSON) ");
157            self.query.walk_ast(out.reborrow())?;
158            Ok(())
159        }
160    }
161
162    /// Run `EXPLAIN` on the `query`, and log the estimated cost.
163    pub(crate) fn log<Q>(conn: &mut PgConnection, max_db_query_cost: u32, query: Q)
164    where
165        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<PgConnection>,
166    {
167        debug!("Estimating: {}", diesel::debug_query(&query).to_string());
168
169        let Some(cost) = explain(conn, query) else {
170            warn!("Failed to extract cost from EXPLAIN.");
171            return;
172        };
173
174        if cost > max_db_query_cost as f64 {
175            warn!(cost, max_db_query_cost, exceeds = true, "Estimated cost");
176        } else {
177            info!(cost, max_db_query_cost, exceeds = false, "Estimated cost");
178        }
179    }
180
181    pub(crate) fn explain<Q>(conn: &mut PgConnection, query: Q) -> Option<f64>
182    where
183        Q: Query + QueryId + QueryFragment<Pg> + RunQueryDsl<PgConnection>,
184    {
185        let result: String = Explained { query }
186            .get_result(conn)
187            .tap_err(|e| warn!("Failed to run EXPLAIN: {e}"))
188            .ok()?;
189
190        let parsed = serde_json::from_str(&result)
191            .tap_err(|e| warn!("Failed to parse EXPLAIN result: {e}"))
192            .ok()?;
193
194        extract_cost(&parsed).tap_none(|| warn!("Failed to extract cost from EXPLAIN"))
195    }
196
197    fn extract_cost(parsed: &Value) -> Option<f64> {
198        parsed.get(0)?.get("Plan")?.get("Total Cost")?.as_f64()
199    }
200}
201
202#[cfg(all(test, feature = "pg_integration"))]
203mod tests {
204    use diesel::QueryDsl;
205    use iota_framework::BuiltInFramework;
206    use iota_indexer::{
207        db::{get_pool_connection, new_connection_pool, reset_database},
208        models::objects::StoredObject,
209        schema::objects,
210        types::IndexedObject,
211    };
212
213    use super::*;
214    use crate::config::ConnectionConfig;
215
216    #[test]
217    fn test_query_cost() {
218        let connection_config = ConnectionConfig::default();
219        let pool = new_connection_pool(
220            &connection_config.db_url,
221            Some(connection_config.db_pool_size),
222        )
223        .unwrap();
224        let mut conn = get_pool_connection(&pool).unwrap();
225        reset_database(&mut conn).unwrap();
226
227        let objects: Vec<StoredObject> = BuiltInFramework::iter_system_packages()
228            .map(|pkg| IndexedObject::from_object(1, pkg.genesis_object(), None).into())
229            .collect();
230
231        let expect = objects.len();
232        let actual = diesel::insert_into(objects::dsl::objects)
233            .values(objects)
234            .execute(&mut conn)
235            .unwrap();
236
237        assert_eq!(expect, actual, "Failed to write objects");
238
239        use objects::dsl;
240        let query_one = dsl::objects.select(dsl::objects.star()).limit(1);
241        let query_all = dsl::objects.select(dsl::objects.star());
242
243        // Test estimating query costs
244        let cost_one = query_cost::explain(&mut conn, query_one).unwrap();
245        let cost_all = query_cost::explain(&mut conn, query_all).unwrap();
246
247        assert!(
248            cost_one < cost_all,
249            "cost_one = {cost_one} >= {cost_all} = cost_all"
250        );
251    }
252}