1use std::{collections::HashSet, time::Duration};
6
7use anyhow::anyhow;
8use clap::Args;
9use diesel::{
10 PgConnection, QueryableByName,
11 connection::BoxableConnection,
12 query_dsl::RunQueryDsl,
13 r2d2::{ConnectionManager, Pool, PooledConnection, R2D2Connection},
14};
15use strum::IntoEnumIterator;
16use tracing::info;
17
18use crate::{errors::IndexerError, handlers::pruner::PrunableTable};
19
20pub type ConnectionPool = Pool<ConnectionManager<PgConnection>>;
21pub type PoolConnection = PooledConnection<ConnectionManager<PgConnection>>;
22
23#[derive(Args, Debug, Clone)]
24pub struct ConnectionPoolConfig {
25 #[arg(long, default_value_t = 100)]
26 #[arg(env = "DB_POOL_SIZE")]
27 pub pool_size: u32,
28 #[arg(long, value_parser = parse_duration, default_value = "30")]
29 #[arg(env = "DB_CONNECTION_TIMEOUT")]
30 pub connection_timeout: Duration,
31 #[arg(long, value_parser = parse_duration, default_value = "3600")]
32 #[arg(env = "DB_STATEMENT_TIMEOUT")]
33 pub statement_timeout: Duration,
34}
35
36fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
37 let seconds = arg.parse()?;
38 Ok(std::time::Duration::from_secs(seconds))
39}
40
41impl ConnectionPoolConfig {
42 pub const DEFAULT_POOL_SIZE: u32 = 100;
43 pub const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
44 pub const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
45
46 fn connection_config(&self) -> ConnectionConfig {
47 ConnectionConfig {
48 statement_timeout: self.statement_timeout,
49 read_only: false,
50 }
51 }
52
53 pub fn set_pool_size(&mut self, size: u32) {
54 self.pool_size = size;
55 }
56
57 pub fn set_connection_timeout(&mut self, timeout: Duration) {
58 self.connection_timeout = timeout;
59 }
60
61 pub fn set_statement_timeout(&mut self, timeout: Duration) {
62 self.statement_timeout = timeout;
63 }
64}
65
66impl Default for ConnectionPoolConfig {
67 fn default() -> Self {
68 Self {
69 pool_size: Self::DEFAULT_POOL_SIZE,
70 connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
71 statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy)]
77pub struct ConnectionConfig {
78 pub statement_timeout: Duration,
79 pub read_only: bool,
80}
81
82impl<T: R2D2Connection + 'static> diesel::r2d2::CustomizeConnection<T, diesel::r2d2::Error>
83 for ConnectionConfig
84{
85 fn on_acquire(&self, _conn: &mut T) -> std::result::Result<(), diesel::r2d2::Error> {
86 _conn
87 .as_any_mut()
88 .downcast_mut::<diesel::PgConnection>()
89 .map_or_else(
90 || {
91 Err(diesel::r2d2::Error::QueryError(
92 diesel::result::Error::DeserializationError(
93 "Failed to downcast connection to PgConnection"
94 .to_string()
95 .into(),
96 ),
97 ))
98 },
99 |pg_conn| {
100 diesel::sql_query(format!(
101 "SET statement_timeout = {}",
102 self.statement_timeout.as_millis(),
103 ))
104 .execute(pg_conn)
105 .map_err(diesel::r2d2::Error::QueryError)?;
106
107 if self.read_only {
108 diesel::sql_query("SET default_transaction_read_only = 't'")
109 .execute(pg_conn)
110 .map_err(diesel::r2d2::Error::QueryError)?;
111 }
112 Ok(())
113 },
114 )?;
115 Ok(())
116 }
117}
118
119pub fn new_connection_pool(
120 db_url: &str,
121 config: &ConnectionPoolConfig,
122) -> Result<ConnectionPool, IndexerError> {
123 let manager = ConnectionManager::<PgConnection>::new(db_url);
124
125 Pool::builder()
126 .max_size(config.pool_size)
127 .connection_timeout(config.connection_timeout)
128 .connection_customizer(Box::new(config.connection_config()))
129 .build(manager)
130 .map_err(|e| {
131 IndexerError::PgConnectionPoolInit(format!(
132 "Failed to initialize connection pool for {db_url} with error: {e:?}"
133 ))
134 })
135}
136
137pub fn get_pool_connection(pool: &ConnectionPool) -> Result<PoolConnection, IndexerError> {
138 pool.get().map_err(|e| {
139 IndexerError::PgPoolConnection(format!(
140 "Failed to get connection from PG connection pool with error: {e:?}"
141 ))
142 })
143}
144
145pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
146 {
147 conn.as_any_mut()
148 .downcast_mut::<PoolConnection>()
149 .map_or_else(
150 || Err(anyhow!("Failed to downcast connection to PgConnection")),
151 |pg_conn| {
152 setup_postgres::reset_database(pg_conn)?;
153 Ok(())
154 },
155 )?;
156 }
157 Ok(())
158}
159
160pub async fn check_prunable_tables_valid(conn: &mut PoolConnection) -> Result<(), IndexerError> {
162 info!("Starting compatibility check");
163
164 use diesel::RunQueryDsl;
165
166 let select_parent_tables = r#"
167 SELECT c.relname AS table_name
168 FROM pg_class c
169 JOIN pg_namespace n ON n.oid = c.relnamespace
170 LEFT JOIN pg_partitioned_table pt ON pt.partrelid = c.oid
171 WHERE c.relkind IN ('r', 'p') -- 'r' for regular tables, 'p' for partitioned tables
172 AND n.nspname = 'public'
173 AND (
174 pt.partrelid IS NOT NULL -- This is a partitioned (parent) table
175 OR NOT EXISTS ( -- This is not a partition (child table)
176 SELECT 1
177 FROM pg_inherits i
178 WHERE i.inhrelid = c.oid
179 )
180 );
181 "#;
182
183 #[derive(QueryableByName)]
184 struct TableName {
185 #[diesel(sql_type = diesel::sql_types::Text)]
186 table_name: String,
187 }
188
189 let result: Vec<TableName> = diesel::sql_query(select_parent_tables)
190 .load(conn)
191 .map_err(|e| IndexerError::DbMigration(format!("Failed to fetch tables: {e}")))?;
192
193 let parent_tables_from_db: HashSet<_> = result.into_iter().map(|t| t.table_name).collect();
194
195 for key in PrunableTable::iter() {
196 if !parent_tables_from_db.contains(key.as_ref()) {
197 return Err(IndexerError::Generic(format!(
198 "Invalid retention policy override provided for table {key}: does not exist in the database",
199 )));
200 }
201 }
202
203 info!("Compatibility check passed");
204 Ok(())
205}
206
207pub mod setup_postgres {
208 use anyhow::anyhow;
209 use diesel::{
210 RunQueryDsl,
211 migration::{Migration, MigrationConnection, MigrationSource, MigrationVersion},
212 pg::Pg,
213 prelude::*,
214 };
215 use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
216 use tracing::info;
217
218 use crate::{IndexerError, db::PoolConnection};
219
220 table! {
221 __diesel_schema_migrations (version) {
222 version -> VarChar,
223 run_on -> Timestamp,
224 }
225 }
226
227 const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
228
229 pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
230 info!("Resetting PG database ...");
231
232 let drop_all_tables = "
233 DO $$ DECLARE
234 r RECORD;
235 BEGIN
236 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
237 LOOP
238 EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
239 END LOOP;
240 END $$;";
241 diesel::sql_query(drop_all_tables).execute(conn)?;
242 info!("Dropped all tables.");
243
244 let drop_all_procedures = "
245 DO $$ DECLARE
246 r RECORD;
247 BEGIN
248 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
249 FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
250 WHERE ns.nspname = 'public' AND prokind = 'p')
251 LOOP
252 EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
253 END LOOP;
254 END $$;";
255 diesel::sql_query(drop_all_procedures).execute(conn)?;
256 info!("Dropped all procedures.");
257
258 let drop_all_functions = "
259 DO $$ DECLARE
260 r RECORD;
261 BEGIN
262 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
263 FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
264 WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
265 LOOP
266 EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
267 END LOOP;
268 END $$;";
269 diesel::sql_query(drop_all_functions).execute(conn)?;
270 info!("Dropped all functions.");
271
272 conn.setup()?;
273 info!("Created __diesel_schema_migrations table.");
274
275 run_migrations(conn)?;
276 info!("Reset database complete.");
277 Ok(())
278 }
279
280 pub fn run_migrations(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
282 conn.run_pending_migrations(MIGRATIONS)
283 .map_err(|e| anyhow!("Failed to run migrations {e}"))?;
284 Ok(())
285 }
286
287 pub fn check_db_migration_consistency(conn: &mut PoolConnection) -> Result<(), IndexerError> {
308 info!("Starting compatibility check");
309 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
310 IndexerError::DbMigration(format!(
311 "Failed to fetch local migrations from schema: {err}"
312 ))
313 })?;
314
315 let local_migrations = migrations
316 .iter()
317 .map(|m| m.name().version())
318 .collect::<Vec<_>>();
319
320 check_db_migration_consistency_impl(conn, local_migrations)?;
321 info!("Compatibility check passed");
322 Ok(())
323 }
324
325 fn check_db_migration_consistency_impl(
326 conn: &mut PoolConnection,
327 local_migrations: Vec<MigrationVersion>,
328 ) -> Result<(), IndexerError> {
329 let applied_migrations: Vec<MigrationVersion> = __diesel_schema_migrations::table
334 .select(__diesel_schema_migrations::version)
335 .order(__diesel_schema_migrations::version.asc())
336 .load(conn)?;
337
338 if local_migrations.len() > applied_migrations.len() {
340 return Err(IndexerError::DbMigration(format!(
341 "The number of local migrations is greater than the number of applied migrations. Local migrations: {local_migrations:?}, Applied migrations: {applied_migrations:?}",
342 )));
343 }
344 for (local_migration, applied_migration) in local_migrations.iter().zip(&applied_migrations)
345 {
346 if local_migration != applied_migration {
347 return Err(IndexerError::DbMigration(format!(
348 "The next applied migration `{applied_migration:?}` diverges from the local migration `{local_migration:?}`",
349 )));
350 }
351 }
362 Ok(())
363 }
364
365 #[cfg(feature = "pg_integration")]
366 #[cfg(test)]
367 mod tests {
368 use diesel::{
369 migration::{Migration, MigrationSource},
370 pg::Pg,
371 };
372 use diesel_migrations::MigrationHarness;
373
374 use crate::{
375 db::setup_postgres::{self, MIGRATIONS},
376 test_utils::{TestDatabase, db_url},
377 };
378
379 #[test]
382 fn db_migration_consistency_smoke_test() {
383 let mut database = TestDatabase::new(db_url("db_migration_consistency_smoke_test"));
384 database.recreate();
385 database.reset_db();
386 {
387 let pool = database.to_connection_pool();
388 let mut conn = pool.get().unwrap();
389 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
390 }
391 database.drop_if_exists();
392 }
393
394 #[test]
395 fn db_migration_consistency_non_prefix_test() {
396 let mut database =
397 TestDatabase::new(db_url("db_migration_consistency_non_prefix_test"));
398 database.recreate();
399 database.reset_db();
400 {
401 let pool = database.to_connection_pool();
402 let mut conn = pool.get().unwrap();
403 conn.revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
404 .unwrap();
405 assert!(setup_postgres::check_db_migration_consistency(&mut conn).is_err());
408
409 conn.run_pending_migrations(MIGRATIONS).unwrap();
410 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
412 }
413 database.drop_if_exists();
414 }
415
416 #[test]
417 fn db_migration_consistency_prefix_test() {
418 let mut database = TestDatabase::new(db_url("db_migration_consistency_prefix_test"));
419 database.recreate();
420 database.reset_db();
421 {
422 let pool = database.to_connection_pool();
423 let mut conn = pool.get().unwrap();
424
425 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
426 let mut local_migrations: Vec<_> =
427 migrations.iter().map(|m| m.name().version()).collect();
428 local_migrations.pop();
429 setup_postgres::check_db_migration_consistency_impl(&mut conn, local_migrations)
432 .unwrap();
433 }
434 database.drop_if_exists();
435 }
436 }
437}