1use std::{collections::HashSet, time::Duration};
11
12use anyhow::anyhow;
13use clap::Args;
14use diesel::{
15 PgConnection, QueryableByName,
16 connection::BoxableConnection,
17 query_dsl::RunQueryDsl,
18 r2d2::{ConnectionManager, Pool, PooledConnection, R2D2Connection},
19};
20use strum::IntoEnumIterator;
21use tracing::info;
22
23use crate::{errors::IndexerError, pruning::pruner::PrunableTable};
24
25pub type ConnectionPool = Pool<ConnectionManager<PgConnection>>;
26pub type PoolConnection = PooledConnection<ConnectionManager<PgConnection>>;
27
28#[derive(Args, Debug, Clone)]
29pub struct ConnectionPoolConfig {
30 #[arg(long, default_value_t = 100)]
31 #[arg(env = "DB_POOL_SIZE")]
32 pub pool_size: u32,
33 #[arg(long, value_parser = parse_duration, default_value = "30")]
34 #[arg(env = "DB_CONNECTION_TIMEOUT")]
35 pub connection_timeout: Duration,
36 #[arg(long, value_parser = parse_duration, default_value = "3600")]
37 #[arg(env = "DB_STATEMENT_TIMEOUT")]
38 pub statement_timeout: Duration,
39}
40
41fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
42 let seconds = arg.parse()?;
43 Ok(std::time::Duration::from_secs(seconds))
44}
45
46impl ConnectionPoolConfig {
47 pub const DEFAULT_POOL_SIZE: u32 = 100;
48 pub const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
49 pub const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
50
51 fn connection_config(&self) -> ConnectionConfig {
52 ConnectionConfig {
53 statement_timeout: self.statement_timeout,
54 read_only: false,
55 }
56 }
57
58 pub fn set_pool_size(&mut self, size: u32) {
59 self.pool_size = size;
60 }
61
62 pub fn set_connection_timeout(&mut self, timeout: Duration) {
63 self.connection_timeout = timeout;
64 }
65
66 pub fn set_statement_timeout(&mut self, timeout: Duration) {
67 self.statement_timeout = timeout;
68 }
69}
70
71impl Default for ConnectionPoolConfig {
72 fn default() -> Self {
73 Self {
74 pool_size: Self::DEFAULT_POOL_SIZE,
75 connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
76 statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Copy)]
82pub struct ConnectionConfig {
83 pub statement_timeout: Duration,
84 pub read_only: bool,
85}
86
87impl<T: R2D2Connection + 'static> diesel::r2d2::CustomizeConnection<T, diesel::r2d2::Error>
88 for ConnectionConfig
89{
90 fn on_acquire(&self, _conn: &mut T) -> std::result::Result<(), diesel::r2d2::Error> {
91 _conn
92 .as_any_mut()
93 .downcast_mut::<diesel::PgConnection>()
94 .map_or_else(
95 || {
96 Err(diesel::r2d2::Error::QueryError(
97 diesel::result::Error::DeserializationError(
98 "failed to downcast connection to PgConnection"
99 .to_string()
100 .into(),
101 ),
102 ))
103 },
104 |pg_conn| {
105 diesel::sql_query(format!(
106 "SET statement_timeout = {}",
107 self.statement_timeout.as_millis(),
108 ))
109 .execute(pg_conn)
110 .map_err(diesel::r2d2::Error::QueryError)?;
111
112 if self.read_only {
113 diesel::sql_query("SET default_transaction_read_only = 't'")
114 .execute(pg_conn)
115 .map_err(diesel::r2d2::Error::QueryError)?;
116 }
117 Ok(())
118 },
119 )?;
120 Ok(())
121 }
122}
123
124pub fn new_connection_pool(
125 db_url: &str,
126 config: &ConnectionPoolConfig,
127) -> Result<ConnectionPool, IndexerError> {
128 let manager = ConnectionManager::<PgConnection>::new(db_url);
129
130 Pool::builder()
131 .max_size(config.pool_size)
132 .connection_timeout(config.connection_timeout)
133 .connection_customizer(Box::new(config.connection_config()))
134 .build(manager)
135 .map_err(|e| {
136 IndexerError::PgConnectionPoolInit(format!(
137 "failed to initialize connection pool for {db_url} with error: {e:?}"
138 ))
139 })
140}
141
142pub fn get_pool_connection(pool: &ConnectionPool) -> Result<PoolConnection, IndexerError> {
143 pool.get().map_err(|e| {
144 IndexerError::PgPoolConnection(format!(
145 "failed to get connection from PG connection pool with error: {e:?}"
146 ))
147 })
148}
149
150pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
151 {
152 conn.as_any_mut()
153 .downcast_mut::<PoolConnection>()
154 .map_or_else(
155 || Err(anyhow!("failed to downcast connection to PgConnection")),
156 |pg_conn| {
157 setup_postgres::reset_database(pg_conn)?;
158 Ok(())
159 },
160 )?;
161 }
162 Ok(())
163}
164
165pub async fn check_prunable_tables_valid(conn: &mut PoolConnection) -> Result<(), IndexerError> {
167 info!("Starting compatibility check");
168
169 use diesel::RunQueryDsl;
170
171 let select_parent_tables = r#"
172 SELECT c.relname AS table_name
173 FROM pg_class c
174 JOIN pg_namespace n ON n.oid = c.relnamespace
175 LEFT JOIN pg_partitioned_table pt ON pt.partrelid = c.oid
176 WHERE c.relkind IN ('r', 'p') -- 'r' for regular tables, 'p' for partitioned tables
177 AND n.nspname = 'public'
178 AND (
179 pt.partrelid IS NOT NULL -- This is a partitioned (parent) table
180 OR NOT EXISTS ( -- This is not a partition (child table)
181 SELECT 1
182 FROM pg_inherits i
183 WHERE i.inhrelid = c.oid
184 )
185 );
186 "#;
187
188 #[derive(QueryableByName)]
189 struct TableName {
190 #[diesel(sql_type = diesel::sql_types::Text)]
191 table_name: String,
192 }
193
194 let result: Vec<TableName> = diesel::sql_query(select_parent_tables)
195 .load(conn)
196 .map_err(|e| IndexerError::DbMigration(format!("failed to fetch tables: {e}")))?;
197
198 let parent_tables_from_db: HashSet<_> = result.into_iter().map(|t| t.table_name).collect();
199
200 for key in PrunableTable::iter() {
201 if !parent_tables_from_db.contains(key.as_ref()) {
202 return Err(IndexerError::Generic(format!(
203 "invalid retention policy override provided for table {key}: does not exist in the database",
204 )));
205 }
206 }
207
208 info!("Compatibility check passed");
209 Ok(())
210}
211
212pub mod setup_postgres {
213 use anyhow::anyhow;
214 use diesel::{
215 RunQueryDsl,
216 migration::{Migration, MigrationConnection, MigrationSource, MigrationVersion},
217 pg::Pg,
218 prelude::*,
219 };
220 use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
221 use tracing::info;
222
223 use crate::{IndexerError, db::PoolConnection};
224
225 table! {
226 __diesel_schema_migrations (version) {
227 version -> VarChar,
228 run_on -> Timestamp,
229 }
230 }
231
232 const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
233
234 pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
235 info!("Resetting PG database ...");
236
237 let drop_all_tables = "
238 DO $$ DECLARE
239 r RECORD;
240 BEGIN
241 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
242 LOOP
243 EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
244 END LOOP;
245 END $$;";
246 diesel::sql_query(drop_all_tables).execute(conn)?;
247 info!("Dropped all tables.");
248
249 let drop_all_procedures = "
250 DO $$ DECLARE
251 r RECORD;
252 BEGIN
253 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
254 FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
255 WHERE ns.nspname = 'public' AND prokind = 'p')
256 LOOP
257 EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
258 END LOOP;
259 END $$;";
260 diesel::sql_query(drop_all_procedures).execute(conn)?;
261 info!("Dropped all procedures.");
262
263 let drop_all_functions = "
264 DO $$ DECLARE
265 r RECORD;
266 BEGIN
267 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
268 FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
269 WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
270 LOOP
271 EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
272 END LOOP;
273 END $$;";
274 diesel::sql_query(drop_all_functions).execute(conn)?;
275 info!("Dropped all functions.");
276
277 conn.setup()?;
278 info!("Created __diesel_schema_migrations table.");
279
280 run_migrations(conn)?;
281 info!("Reset database complete.");
282 Ok(())
283 }
284
285 pub fn run_migrations(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
287 let pending_migrations = conn
288 .pending_migrations(MIGRATIONS)
289 .map_err(|e| anyhow!("failed to identify pending migrations {e}"))?;
290 for migration in pending_migrations {
291 info!("Applying migration {}", migration.name());
292 conn.run_migration(&migration)
293 .map_err(|e| anyhow!("failed to run migration {e}"))?;
294 }
295 Ok(())
296 }
297
298 pub fn check_db_migration_consistency(conn: &mut PoolConnection) -> Result<(), IndexerError> {
319 info!("Starting compatibility check");
320 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
321 IndexerError::DbMigration(format!(
322 "failed to fetch local migrations from schema: {err}"
323 ))
324 })?;
325
326 let local_migrations = migrations
327 .iter()
328 .map(|m| m.name().version())
329 .collect::<Vec<_>>();
330
331 check_db_migration_consistency_impl(conn, local_migrations)?;
332 info!("Compatibility check passed");
333 Ok(())
334 }
335
336 fn check_db_migration_consistency_impl(
337 conn: &mut PoolConnection,
338 local_migrations: Vec<MigrationVersion>,
339 ) -> Result<(), IndexerError> {
340 let applied_migrations: Vec<MigrationVersion> = __diesel_schema_migrations::table
345 .select(__diesel_schema_migrations::version)
346 .order(__diesel_schema_migrations::version.asc())
347 .load(conn)?;
348
349 if local_migrations.len() > applied_migrations.len() {
351 return Err(IndexerError::DbMigration(format!(
352 "the number of local migrations is greater than the number of applied migrations. Local migrations: {local_migrations:?}, Applied migrations: {applied_migrations:?}",
353 )));
354 }
355 for (local_migration, applied_migration) in local_migrations.iter().zip(&applied_migrations)
356 {
357 if local_migration != applied_migration {
358 return Err(IndexerError::DbMigration(format!(
359 "the next applied migration `{applied_migration:?}` diverges from the local migration `{local_migration:?}`",
360 )));
361 }
362 }
363 Ok(())
364 }
365
366 #[cfg(feature = "pg_integration")]
367 #[cfg(test)]
368 mod tests {
369 use diesel::{
370 migration::{Migration, MigrationSource},
371 pg::Pg,
372 };
373 use diesel_migrations::MigrationHarness;
374
375 use crate::{
376 db::setup_postgres::{self, MIGRATIONS},
377 test_utils::{TestDatabase, db_url},
378 };
379
380 #[test]
383 fn db_migration_consistency_smoke_test() {
384 let mut database = TestDatabase::new(db_url("db_migration_consistency_smoke_test"));
385 database.recreate();
386 database.reset_db();
387 {
388 let pool = database.to_connection_pool();
389 let mut conn = pool.get().unwrap();
390 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
391 }
392 database.drop_if_exists();
393 }
394
395 #[test]
396 fn db_migration_consistency_non_prefix_test() {
397 let mut database =
398 TestDatabase::new(db_url("db_migration_consistency_non_prefix_test"));
399 database.recreate();
400 database.reset_db();
401 {
402 let pool = database.to_connection_pool();
403 let mut conn = pool.get().unwrap();
404 conn.revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
405 .unwrap();
406 assert!(setup_postgres::check_db_migration_consistency(&mut conn).is_err());
409
410 conn.run_pending_migrations(MIGRATIONS).unwrap();
411 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
413 }
414 database.drop_if_exists();
415 }
416
417 #[test]
418 fn db_migration_consistency_prefix_test() {
419 let mut database = TestDatabase::new(db_url("db_migration_consistency_prefix_test"));
420 database.recreate();
421 database.reset_db();
422 {
423 let pool = database.to_connection_pool();
424 let mut conn = pool.get().unwrap();
425
426 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
427 let mut local_migrations: Vec<_> =
428 migrations.iter().map(|m| m.name().version()).collect();
429 local_migrations.pop();
430 setup_postgres::check_db_migration_consistency_impl(&mut conn, local_migrations)
433 .unwrap();
434 }
435 database.drop_if_exists();
436 }
437 }
438}