iota_indexer/
db.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5//! Types and logic to setup and maintain the database.
6//!
7//! Creating connections, applying or validating migrations are examples of
8//! operations included in this scope.
9
10use std::{collections::HashSet, time::Duration};
11
12use anyhow::anyhow;
13use clap::Args;
14use diesel::{
15    PgConnection, QueryableByName,
16    connection::BoxableConnection,
17    query_dsl::RunQueryDsl,
18    r2d2::{ConnectionManager, Pool, PooledConnection, R2D2Connection},
19};
20use strum::IntoEnumIterator;
21use tracing::info;
22
23use crate::{errors::IndexerError, pruning::pruner::PrunableTable};
24
25pub type ConnectionPool = Pool<ConnectionManager<PgConnection>>;
26pub type PoolConnection = PooledConnection<ConnectionManager<PgConnection>>;
27
28#[derive(Args, Debug, Clone)]
29pub struct ConnectionPoolConfig {
30    #[arg(long, default_value_t = 100)]
31    #[arg(env = "DB_POOL_SIZE")]
32    pub pool_size: u32,
33    #[arg(long, value_parser = parse_duration, default_value = "30")]
34    #[arg(env = "DB_CONNECTION_TIMEOUT")]
35    pub connection_timeout: Duration,
36    #[arg(long, value_parser = parse_duration, default_value = "3600")]
37    #[arg(env = "DB_STATEMENT_TIMEOUT")]
38    pub statement_timeout: Duration,
39}
40
41fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
42    let seconds = arg.parse()?;
43    Ok(std::time::Duration::from_secs(seconds))
44}
45
46impl ConnectionPoolConfig {
47    pub const DEFAULT_POOL_SIZE: u32 = 100;
48    pub const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
49    pub const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
50
51    fn connection_config(&self) -> ConnectionConfig {
52        ConnectionConfig {
53            statement_timeout: self.statement_timeout,
54            read_only: false,
55        }
56    }
57
58    pub fn set_pool_size(&mut self, size: u32) {
59        self.pool_size = size;
60    }
61
62    pub fn set_connection_timeout(&mut self, timeout: Duration) {
63        self.connection_timeout = timeout;
64    }
65
66    pub fn set_statement_timeout(&mut self, timeout: Duration) {
67        self.statement_timeout = timeout;
68    }
69}
70
71impl Default for ConnectionPoolConfig {
72    fn default() -> Self {
73        Self {
74            pool_size: Self::DEFAULT_POOL_SIZE,
75            connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
76            statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
77        }
78    }
79}
80
81#[derive(Debug, Clone, Copy)]
82pub struct ConnectionConfig {
83    pub statement_timeout: Duration,
84    pub read_only: bool,
85}
86
87impl<T: R2D2Connection + 'static> diesel::r2d2::CustomizeConnection<T, diesel::r2d2::Error>
88    for ConnectionConfig
89{
90    fn on_acquire(&self, _conn: &mut T) -> std::result::Result<(), diesel::r2d2::Error> {
91        _conn
92            .as_any_mut()
93            .downcast_mut::<diesel::PgConnection>()
94            .map_or_else(
95                || {
96                    Err(diesel::r2d2::Error::QueryError(
97                        diesel::result::Error::DeserializationError(
98                            "failed to downcast connection to PgConnection"
99                                .to_string()
100                                .into(),
101                        ),
102                    ))
103                },
104                |pg_conn| {
105                    diesel::sql_query(format!(
106                        "SET statement_timeout = {}",
107                        self.statement_timeout.as_millis(),
108                    ))
109                    .execute(pg_conn)
110                    .map_err(diesel::r2d2::Error::QueryError)?;
111
112                    if self.read_only {
113                        diesel::sql_query("SET default_transaction_read_only = 't'")
114                            .execute(pg_conn)
115                            .map_err(diesel::r2d2::Error::QueryError)?;
116                    }
117                    Ok(())
118                },
119            )?;
120        Ok(())
121    }
122}
123
124pub fn new_connection_pool(
125    db_url: &str,
126    config: &ConnectionPoolConfig,
127) -> Result<ConnectionPool, IndexerError> {
128    let manager = ConnectionManager::<PgConnection>::new(db_url);
129
130    Pool::builder()
131        .max_size(config.pool_size)
132        .connection_timeout(config.connection_timeout)
133        .connection_customizer(Box::new(config.connection_config()))
134        .build(manager)
135        .map_err(|e| {
136            IndexerError::PgConnectionPoolInit(format!(
137                "failed to initialize connection pool for {db_url} with error: {e:?}"
138            ))
139        })
140}
141
142pub fn get_pool_connection(pool: &ConnectionPool) -> Result<PoolConnection, IndexerError> {
143    pool.get().map_err(|e| {
144        IndexerError::PgPoolConnection(format!(
145            "failed to get connection from PG connection pool with error: {e:?}"
146        ))
147    })
148}
149
150pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
151    {
152        conn.as_any_mut()
153            .downcast_mut::<PoolConnection>()
154            .map_or_else(
155                || Err(anyhow!("failed to downcast connection to PgConnection")),
156                |pg_conn| {
157                    setup_postgres::reset_database(pg_conn)?;
158                    Ok(())
159                },
160            )?;
161    }
162    Ok(())
163}
164
165/// Check that prunable tables exist in the database.
166pub async fn check_prunable_tables_valid(conn: &mut PoolConnection) -> Result<(), IndexerError> {
167    info!("Starting compatibility check");
168
169    use diesel::RunQueryDsl;
170
171    let select_parent_tables = r#"
172    SELECT c.relname AS table_name
173    FROM pg_class c
174    JOIN pg_namespace n ON n.oid = c.relnamespace
175    LEFT JOIN pg_partitioned_table pt ON pt.partrelid = c.oid
176    WHERE c.relkind IN ('r', 'p')  -- 'r' for regular tables, 'p' for partitioned tables
177        AND n.nspname = 'public'
178        AND (
179            pt.partrelid IS NOT NULL  -- This is a partitioned (parent) table
180            OR NOT EXISTS (  -- This is not a partition (child table)
181                SELECT 1
182                FROM pg_inherits i
183                WHERE i.inhrelid = c.oid
184            )
185        );
186    "#;
187
188    #[derive(QueryableByName)]
189    struct TableName {
190        #[diesel(sql_type = diesel::sql_types::Text)]
191        table_name: String,
192    }
193
194    let result: Vec<TableName> = diesel::sql_query(select_parent_tables)
195        .load(conn)
196        .map_err(|e| IndexerError::DbMigration(format!("failed to fetch tables: {e}")))?;
197
198    let parent_tables_from_db: HashSet<_> = result.into_iter().map(|t| t.table_name).collect();
199
200    for key in PrunableTable::iter() {
201        if !parent_tables_from_db.contains(key.as_ref()) {
202            return Err(IndexerError::Generic(format!(
203                "invalid retention policy override provided for table {key}: does not exist in the database",
204            )));
205        }
206    }
207
208    info!("Compatibility check passed");
209    Ok(())
210}
211
212pub mod setup_postgres {
213    use anyhow::anyhow;
214    use diesel::{
215        RunQueryDsl,
216        migration::{Migration, MigrationConnection, MigrationSource, MigrationVersion},
217        pg::Pg,
218        prelude::*,
219    };
220    use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
221    use tracing::info;
222
223    use crate::{IndexerError, db::PoolConnection};
224
225    table! {
226        __diesel_schema_migrations (version) {
227            version -> VarChar,
228            run_on -> Timestamp,
229        }
230    }
231
232    const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
233
234    pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
235        info!("Resetting PG database ...");
236
237        let drop_all_tables = "
238        DO $$ DECLARE
239            r RECORD;
240        BEGIN
241        FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
242            LOOP
243                EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
244            END LOOP;
245        END $$;";
246        diesel::sql_query(drop_all_tables).execute(conn)?;
247        info!("Dropped all tables.");
248
249        let drop_all_procedures = "
250        DO $$ DECLARE
251            r RECORD;
252        BEGIN
253            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
254                      FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
255                      WHERE ns.nspname = 'public' AND prokind = 'p')
256            LOOP
257                EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
258            END LOOP;
259        END $$;";
260        diesel::sql_query(drop_all_procedures).execute(conn)?;
261        info!("Dropped all procedures.");
262
263        let drop_all_functions = "
264        DO $$ DECLARE
265            r RECORD;
266        BEGIN
267            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
268                      FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
269                      WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
270            LOOP
271                EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
272            END LOOP;
273        END $$;";
274        diesel::sql_query(drop_all_functions).execute(conn)?;
275        info!("Dropped all functions.");
276
277        conn.setup()?;
278        info!("Created __diesel_schema_migrations table.");
279
280        run_migrations(conn)?;
281        info!("Reset database complete.");
282        Ok(())
283    }
284
285    /// Execute all unapplied migrations.
286    pub fn run_migrations(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
287        conn.run_pending_migrations(MIGRATIONS)
288            .map_err(|e| anyhow!("failed to run migrations {e}"))?;
289        Ok(())
290    }
291
292    /// Checks that the local migration scripts are a prefix of the records in
293    /// the database. This allows to run migration scripts against a DB at
294    /// any time, without worrying about existing readers failing over.
295    ///
296    /// # Deployment Requirement
297    /// Whenever deploying a new version of either the reader or writer,
298    /// migration scripts **must** be run first. This ensures that there are
299    /// never more local migration scripts than those recorded in the database.
300    ///
301    /// # Backward Compatibility
302    /// All new migrations must be **backward compatible** with the previous
303    /// data model. Do **not** remove or rename columns, tables, or change types
304    /// in a way that would break older versions of the reader or writer.
305    ///
306    /// Only after all services are running the new code and no old versions
307    /// are in use, can you safely remove deprecated fields or make breaking
308    /// changes.
309    ///
310    /// This approach supports rolling upgrades and prevents unnecessary
311    /// failures during deployment.
312    pub fn check_db_migration_consistency(conn: &mut PoolConnection) -> Result<(), IndexerError> {
313        info!("Starting compatibility check");
314        let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
315            IndexerError::DbMigration(format!(
316                "failed to fetch local migrations from schema: {err}"
317            ))
318        })?;
319
320        let local_migrations = migrations
321            .iter()
322            .map(|m| m.name().version())
323            .collect::<Vec<_>>();
324
325        check_db_migration_consistency_impl(conn, local_migrations)?;
326        info!("Compatibility check passed");
327        Ok(())
328    }
329
330    fn check_db_migration_consistency_impl(
331        conn: &mut PoolConnection,
332        local_migrations: Vec<MigrationVersion>,
333    ) -> Result<(), IndexerError> {
334        // Unfortunately we cannot call applied_migrations() directly on the connection,
335        // since it implicitly creates the __diesel_schema_migrations table if it
336        // doesn't exist, which is a write operation that we don't want to do in
337        // this function.
338        let applied_migrations: Vec<MigrationVersion> = __diesel_schema_migrations::table
339            .select(__diesel_schema_migrations::version)
340            .order(__diesel_schema_migrations::version.asc())
341            .load(conn)?;
342
343        // We check that the local migrations is a prefix of the applied migrations.
344        if local_migrations.len() > applied_migrations.len() {
345            return Err(IndexerError::DbMigration(format!(
346                "the number of local migrations is greater than the number of applied migrations. Local migrations: {local_migrations:?}, Applied migrations: {applied_migrations:?}",
347            )));
348        }
349        for (local_migration, applied_migration) in local_migrations.iter().zip(&applied_migrations)
350        {
351            if local_migration != applied_migration {
352                return Err(IndexerError::DbMigration(format!(
353                    "the next applied migration `{applied_migration:?}` diverges from the local migration `{local_migration:?}`",
354                )));
355            }
356        }
357        Ok(())
358    }
359
360    #[cfg(feature = "pg_integration")]
361    #[cfg(test)]
362    mod tests {
363        use diesel::{
364            migration::{Migration, MigrationSource},
365            pg::Pg,
366        };
367        use diesel_migrations::MigrationHarness;
368
369        use crate::{
370            db::setup_postgres::{self, MIGRATIONS},
371            test_utils::{TestDatabase, db_url},
372        };
373
374        // Check that the migration records in the database created from the local
375        // schema pass the consistency check.
376        #[test]
377        fn db_migration_consistency_smoke_test() {
378            let mut database = TestDatabase::new(db_url("db_migration_consistency_smoke_test"));
379            database.recreate();
380            database.reset_db();
381            {
382                let pool = database.to_connection_pool();
383                let mut conn = pool.get().unwrap();
384                setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
385            }
386            database.drop_if_exists();
387        }
388
389        #[test]
390        fn db_migration_consistency_non_prefix_test() {
391            let mut database =
392                TestDatabase::new(db_url("db_migration_consistency_non_prefix_test"));
393            database.recreate();
394            database.reset_db();
395            {
396                let pool = database.to_connection_pool();
397                let mut conn = pool.get().unwrap();
398                conn.revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
399                    .unwrap();
400                // Local migrations is one record more than the applied migrations.
401                // This will fail the consistency check since it's not a prefix.
402                assert!(setup_postgres::check_db_migration_consistency(&mut conn).is_err());
403
404                conn.run_pending_migrations(MIGRATIONS).unwrap();
405                // After running pending migrations they should be consistent.
406                setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
407            }
408            database.drop_if_exists();
409        }
410
411        #[test]
412        fn db_migration_consistency_prefix_test() {
413            let mut database = TestDatabase::new(db_url("db_migration_consistency_prefix_test"));
414            database.recreate();
415            database.reset_db();
416            {
417                let pool = database.to_connection_pool();
418                let mut conn = pool.get().unwrap();
419
420                let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
421                let mut local_migrations: Vec<_> =
422                    migrations.iter().map(|m| m.name().version()).collect();
423                local_migrations.pop();
424                // Local migrations is one record less than the applied migrations.
425                // This should pass the consistency check since it's still a prefix.
426                setup_postgres::check_db_migration_consistency_impl(&mut conn, local_migrations)
427                    .unwrap();
428            }
429            database.drop_if_exists();
430        }
431    }
432}