iota_indexer/
db.rs

1// Copyright (c) Mysten Labs, Inc.
2// Modifications Copyright (c) 2024 IOTA Stiftung
3// SPDX-License-Identifier: Apache-2.0
4
5use std::time::Duration;
6
7use anyhow::anyhow;
8use clap::Args;
9use diesel::{
10    PgConnection,
11    connection::BoxableConnection,
12    query_dsl::RunQueryDsl,
13    r2d2::{ConnectionManager, Pool, PooledConnection, R2D2Connection},
14};
15
16use crate::errors::IndexerError;
17
18pub type ConnectionPool = Pool<ConnectionManager<PgConnection>>;
19pub type PoolConnection = PooledConnection<ConnectionManager<PgConnection>>;
20
21#[derive(Args, Debug, Clone)]
22pub struct ConnectionPoolConfig {
23    #[arg(long, default_value_t = 100)]
24    #[arg(env = "DB_POOL_SIZE")]
25    pub pool_size: u32,
26    #[arg(long, value_parser = parse_duration, default_value = "30")]
27    #[arg(env = "DB_CONNECTION_TIMEOUT")]
28    pub connection_timeout: Duration,
29    #[arg(long, value_parser = parse_duration, default_value = "3600")]
30    #[arg(env = "DB_STATEMENT_TIMEOUT")]
31    pub statement_timeout: Duration,
32}
33
34fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
35    let seconds = arg.parse()?;
36    Ok(std::time::Duration::from_secs(seconds))
37}
38
39impl ConnectionPoolConfig {
40    pub const DEFAULT_POOL_SIZE: u32 = 100;
41    pub const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
42    pub const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
43
44    fn connection_config(&self) -> ConnectionConfig {
45        ConnectionConfig {
46            statement_timeout: self.statement_timeout,
47            read_only: false,
48        }
49    }
50
51    pub fn set_pool_size(&mut self, size: u32) {
52        self.pool_size = size;
53    }
54
55    pub fn set_connection_timeout(&mut self, timeout: Duration) {
56        self.connection_timeout = timeout;
57    }
58
59    pub fn set_statement_timeout(&mut self, timeout: Duration) {
60        self.statement_timeout = timeout;
61    }
62}
63
64impl Default for ConnectionPoolConfig {
65    fn default() -> Self {
66        Self {
67            pool_size: Self::DEFAULT_POOL_SIZE,
68            connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
69            statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
70        }
71    }
72}
73
74#[derive(Debug, Clone, Copy)]
75pub struct ConnectionConfig {
76    pub statement_timeout: Duration,
77    pub read_only: bool,
78}
79
80impl<T: R2D2Connection + 'static> diesel::r2d2::CustomizeConnection<T, diesel::r2d2::Error>
81    for ConnectionConfig
82{
83    fn on_acquire(&self, _conn: &mut T) -> std::result::Result<(), diesel::r2d2::Error> {
84        _conn
85            .as_any_mut()
86            .downcast_mut::<diesel::PgConnection>()
87            .map_or_else(
88                || {
89                    Err(diesel::r2d2::Error::QueryError(
90                        diesel::result::Error::DeserializationError(
91                            "Failed to downcast connection to PgConnection"
92                                .to_string()
93                                .into(),
94                        ),
95                    ))
96                },
97                |pg_conn| {
98                    diesel::sql_query(format!(
99                        "SET statement_timeout = {}",
100                        self.statement_timeout.as_millis(),
101                    ))
102                    .execute(pg_conn)
103                    .map_err(diesel::r2d2::Error::QueryError)?;
104
105                    if self.read_only {
106                        diesel::sql_query("SET default_transaction_read_only = 't'")
107                            .execute(pg_conn)
108                            .map_err(diesel::r2d2::Error::QueryError)?;
109                    }
110                    Ok(())
111                },
112            )?;
113        Ok(())
114    }
115}
116
117pub fn new_connection_pool(
118    db_url: &str,
119    config: &ConnectionPoolConfig,
120) -> Result<ConnectionPool, IndexerError> {
121    let manager = ConnectionManager::<PgConnection>::new(db_url);
122
123    Pool::builder()
124        .max_size(config.pool_size)
125        .connection_timeout(config.connection_timeout)
126        .connection_customizer(Box::new(config.connection_config()))
127        .build(manager)
128        .map_err(|e| {
129            IndexerError::PgConnectionPoolInit(format!(
130                "Failed to initialize connection pool for {db_url} with error: {e:?}"
131            ))
132        })
133}
134
135pub fn get_pool_connection(pool: &ConnectionPool) -> Result<PoolConnection, IndexerError> {
136    pool.get().map_err(|e| {
137        IndexerError::PgPoolConnection(format!(
138            "Failed to get connection from PG connection pool with error: {:?}",
139            e
140        ))
141    })
142}
143
144pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
145    {
146        conn.as_any_mut()
147            .downcast_mut::<PoolConnection>()
148            .map_or_else(
149                || Err(anyhow!("Failed to downcast connection to PgConnection")),
150                |pg_conn| {
151                    setup_postgres::reset_database(pg_conn)?;
152                    Ok(())
153                },
154            )?;
155    }
156    Ok(())
157}
158
159pub mod setup_postgres {
160    use anyhow::anyhow;
161    use diesel::{
162        RunQueryDsl,
163        migration::{Migration, MigrationConnection, MigrationSource, MigrationVersion},
164        pg::Pg,
165        prelude::*,
166    };
167    use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
168    use tracing::info;
169
170    use crate::{IndexerError, db::PoolConnection};
171
172    table! {
173        __diesel_schema_migrations (version) {
174            version -> VarChar,
175            run_on -> Timestamp,
176        }
177    }
178
179    const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
180
181    pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
182        info!("Resetting PG database ...");
183
184        let drop_all_tables = "
185        DO $$ DECLARE
186            r RECORD;
187        BEGIN
188        FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
189            LOOP
190                EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
191            END LOOP;
192        END $$;";
193        diesel::sql_query(drop_all_tables).execute(conn)?;
194        info!("Dropped all tables.");
195
196        let drop_all_procedures = "
197        DO $$ DECLARE
198            r RECORD;
199        BEGIN
200            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
201                      FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
202                      WHERE ns.nspname = 'public' AND prokind = 'p')
203            LOOP
204                EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
205            END LOOP;
206        END $$;";
207        diesel::sql_query(drop_all_procedures).execute(conn)?;
208        info!("Dropped all procedures.");
209
210        let drop_all_functions = "
211        DO $$ DECLARE
212            r RECORD;
213        BEGIN
214            FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
215                      FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
216                      WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
217            LOOP
218                EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
219            END LOOP;
220        END $$;";
221        diesel::sql_query(drop_all_functions).execute(conn)?;
222        info!("Dropped all functions.");
223
224        conn.setup()?;
225        info!("Created __diesel_schema_migrations table.");
226
227        run_migrations(conn)?;
228        info!("Reset database complete.");
229        Ok(())
230    }
231
232    /// Execute all unapplied migrations.
233    pub fn run_migrations(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
234        conn.run_pending_migrations(MIGRATIONS)
235            .map_err(|e| anyhow!("Failed to run migrations {e}"))?;
236        Ok(())
237    }
238
239    /// Checks that the local migration scripts are a prefix of the records in
240    /// the database. This allows to run migration scripts against a DB at
241    /// any time, without worrying about existing readers failing over.
242    ///
243    /// # Deployment Requirement
244    /// Whenever deploying a new version of either the reader or writer,
245    /// migration scripts **must** be run first. This ensures that there are
246    /// never more local migration scripts than those recorded in the database.
247    ///
248    /// # Backward Compatibility
249    /// All new migrations must be **backward compatible** with the previous
250    /// data model. Do **not** remove or rename columns, tables, or change types
251    /// in a way that would break older versions of the reader or writer.
252    ///
253    /// Only after all services are running the new code and no old versions
254    /// are in use, can you safely remove deprecated fields or make breaking
255    /// changes.
256    ///
257    /// This approach supports rolling upgrades and prevents unnecessary
258    /// failures during deployment.
259    pub fn check_db_migration_consistency(conn: &mut PoolConnection) -> Result<(), IndexerError> {
260        info!("Starting compatibility check");
261        let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
262            IndexerError::DbMigration(format!(
263                "Failed to fetch local migrations from schema: {err}"
264            ))
265        })?;
266
267        let local_migrations = migrations
268            .iter()
269            .map(|m| m.name().version())
270            .collect::<Vec<_>>();
271
272        check_db_migration_consistency_impl(conn, local_migrations)?;
273        info!("Compatibility check passed");
274        Ok(())
275    }
276
277    fn check_db_migration_consistency_impl(
278        conn: &mut PoolConnection,
279        local_migrations: Vec<MigrationVersion>,
280    ) -> Result<(), IndexerError> {
281        // Unfortunately we cannot call applied_migrations() directly on the connection,
282        // since it implicitly creates the __diesel_schema_migrations table if it
283        // doesn't exist, which is a write operation that we don't want to do in
284        // this function.
285        let applied_migrations: Vec<MigrationVersion> = __diesel_schema_migrations::table
286            .select(__diesel_schema_migrations::version)
287            .order(__diesel_schema_migrations::version.asc())
288            .load(conn)?;
289
290        // We check that the local migrations is a prefix of the applied migrations.
291        if local_migrations.len() > applied_migrations.len() {
292            return Err(IndexerError::DbMigration(format!(
293                "The number of local migrations is greater than the number of applied migrations. Local migrations: {local_migrations:?}, Applied migrations: {applied_migrations:?}",
294            )));
295        }
296        for (local_migration, applied_migration) in local_migrations.iter().zip(&applied_migrations)
297        {
298            if local_migration != applied_migration {
299                return Err(IndexerError::DbMigration(format!(
300                    "The next applied migration `{applied_migration:?}` diverges from the local migration `{local_migration:?}`",
301                )));
302            }
303        }
304        Ok(())
305    }
306
307    #[cfg(feature = "pg_integration")]
308    #[cfg(test)]
309    mod tests {
310        use diesel::{
311            migration::{Migration, MigrationSource},
312            pg::Pg,
313        };
314        use diesel_migrations::MigrationHarness;
315
316        use crate::{
317            db::setup_postgres::{self, MIGRATIONS},
318            test_utils::TestDatabase,
319        };
320
321        fn database_url(db_name: &str) -> String {
322            format!("postgres://postgres:postgrespw@localhost:5432/{db_name}")
323        }
324
325        // Check that the migration records in the database created from the local
326        // schema pass the consistency check.
327        #[test]
328        fn db_migration_consistency_smoke_test() {
329            let mut database =
330                TestDatabase::new(database_url("db_migration_consistency_smoke_test"));
331            database.recreate();
332            database.reset_db();
333            {
334                let pool = database.to_connection_pool();
335                let mut conn = pool.get().unwrap();
336                setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
337            }
338            database.drop_if_exists();
339        }
340
341        #[test]
342        fn db_migration_consistency_non_prefix_test() {
343            let mut database =
344                TestDatabase::new(database_url("db_migration_consistency_non_prefix_test"));
345            database.recreate();
346            database.reset_db();
347            {
348                let pool = database.to_connection_pool();
349                let mut conn = pool.get().unwrap();
350                conn.revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
351                    .unwrap();
352                // Local migrations is one record more than the applied migrations.
353                // This will fail the consistency check since it's not a prefix.
354                assert!(setup_postgres::check_db_migration_consistency(&mut conn).is_err());
355
356                conn.run_pending_migrations(MIGRATIONS).unwrap();
357                // After running pending migrations they should be consistent.
358                setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
359            }
360            database.drop_if_exists();
361        }
362
363        #[test]
364        fn db_migration_consistency_prefix_test() {
365            let mut database =
366                TestDatabase::new(database_url("db_migration_consistency_prefix_test"));
367            database.recreate();
368            database.reset_db();
369            {
370                let pool = database.to_connection_pool();
371                let mut conn = pool.get().unwrap();
372
373                let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
374                let mut local_migrations: Vec<_> =
375                    migrations.iter().map(|m| m.name().version()).collect();
376                local_migrations.pop();
377                // Local migrations is one record less than the applied migrations.
378                // This should pass the consistency check since it's still a prefix.
379                setup_postgres::check_db_migration_consistency_impl(&mut conn, local_migrations)
380                    .unwrap();
381            }
382            database.drop_if_exists();
383        }
384    }
385}