1use std::{collections::HashSet, time::Duration};
11
12use anyhow::anyhow;
13use clap::Args;
14use diesel::{
15 PgConnection, QueryableByName,
16 connection::BoxableConnection,
17 query_dsl::RunQueryDsl,
18 r2d2::{ConnectionManager, Pool, PooledConnection, R2D2Connection},
19};
20use strum::IntoEnumIterator;
21use tracing::info;
22
23use crate::{errors::IndexerError, pruning::pruner::PrunableTable};
24
25pub type ConnectionPool = Pool<ConnectionManager<PgConnection>>;
26pub type PoolConnection = PooledConnection<ConnectionManager<PgConnection>>;
27
28#[derive(Args, Debug, Clone)]
29pub struct ConnectionPoolConfig {
30 #[arg(long, default_value_t = 100)]
31 #[arg(env = "DB_POOL_SIZE")]
32 pub pool_size: u32,
33 #[arg(long, value_parser = parse_duration, default_value = "30")]
34 #[arg(env = "DB_CONNECTION_TIMEOUT")]
35 pub connection_timeout: Duration,
36 #[arg(long, value_parser = parse_duration, default_value = "3600")]
37 #[arg(env = "DB_STATEMENT_TIMEOUT")]
38 pub statement_timeout: Duration,
39}
40
41fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
42 let seconds = arg.parse()?;
43 Ok(std::time::Duration::from_secs(seconds))
44}
45
46impl ConnectionPoolConfig {
47 pub const DEFAULT_POOL_SIZE: u32 = 100;
48 pub const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
49 pub const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
50
51 fn connection_config(&self) -> ConnectionConfig {
52 ConnectionConfig {
53 statement_timeout: self.statement_timeout,
54 read_only: false,
55 }
56 }
57
58 pub fn set_pool_size(&mut self, size: u32) {
59 self.pool_size = size;
60 }
61
62 pub fn set_connection_timeout(&mut self, timeout: Duration) {
63 self.connection_timeout = timeout;
64 }
65
66 pub fn set_statement_timeout(&mut self, timeout: Duration) {
67 self.statement_timeout = timeout;
68 }
69}
70
71impl Default for ConnectionPoolConfig {
72 fn default() -> Self {
73 Self {
74 pool_size: Self::DEFAULT_POOL_SIZE,
75 connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
76 statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Copy)]
82pub struct ConnectionConfig {
83 pub statement_timeout: Duration,
84 pub read_only: bool,
85}
86
87impl<T: R2D2Connection + 'static> diesel::r2d2::CustomizeConnection<T, diesel::r2d2::Error>
88 for ConnectionConfig
89{
90 fn on_acquire(&self, _conn: &mut T) -> std::result::Result<(), diesel::r2d2::Error> {
91 _conn
92 .as_any_mut()
93 .downcast_mut::<diesel::PgConnection>()
94 .map_or_else(
95 || {
96 Err(diesel::r2d2::Error::QueryError(
97 diesel::result::Error::DeserializationError(
98 "failed to downcast connection to PgConnection"
99 .to_string()
100 .into(),
101 ),
102 ))
103 },
104 |pg_conn| {
105 diesel::sql_query(format!(
106 "SET statement_timeout = {}",
107 self.statement_timeout.as_millis(),
108 ))
109 .execute(pg_conn)
110 .map_err(diesel::r2d2::Error::QueryError)?;
111
112 if self.read_only {
113 diesel::sql_query("SET default_transaction_read_only = 't'")
114 .execute(pg_conn)
115 .map_err(diesel::r2d2::Error::QueryError)?;
116 }
117 Ok(())
118 },
119 )?;
120 Ok(())
121 }
122}
123
124pub fn new_connection_pool(
125 db_url: &str,
126 config: &ConnectionPoolConfig,
127) -> Result<ConnectionPool, IndexerError> {
128 let manager = ConnectionManager::<PgConnection>::new(db_url);
129
130 Pool::builder()
131 .max_size(config.pool_size)
132 .connection_timeout(config.connection_timeout)
133 .connection_customizer(Box::new(config.connection_config()))
134 .build(manager)
135 .map_err(|e| {
136 IndexerError::PgConnectionPoolInit(format!(
137 "failed to initialize connection pool for {db_url} with error: {e:?}"
138 ))
139 })
140}
141
142pub fn get_pool_connection(pool: &ConnectionPool) -> Result<PoolConnection, IndexerError> {
143 pool.get().map_err(|e| {
144 IndexerError::PgPoolConnection(format!(
145 "failed to get connection from PG connection pool with error: {e:?}"
146 ))
147 })
148}
149
150pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
151 {
152 conn.as_any_mut()
153 .downcast_mut::<PoolConnection>()
154 .map_or_else(
155 || Err(anyhow!("failed to downcast connection to PgConnection")),
156 |pg_conn| {
157 setup_postgres::reset_database(pg_conn)?;
158 Ok(())
159 },
160 )?;
161 }
162 Ok(())
163}
164
165pub async fn check_prunable_tables_valid(conn: &mut PoolConnection) -> Result<(), IndexerError> {
167 info!("Starting compatibility check");
168
169 use diesel::RunQueryDsl;
170
171 let select_parent_tables = r#"
172 SELECT c.relname AS table_name
173 FROM pg_class c
174 JOIN pg_namespace n ON n.oid = c.relnamespace
175 LEFT JOIN pg_partitioned_table pt ON pt.partrelid = c.oid
176 WHERE c.relkind IN ('r', 'p') -- 'r' for regular tables, 'p' for partitioned tables
177 AND n.nspname = 'public'
178 AND (
179 pt.partrelid IS NOT NULL -- This is a partitioned (parent) table
180 OR NOT EXISTS ( -- This is not a partition (child table)
181 SELECT 1
182 FROM pg_inherits i
183 WHERE i.inhrelid = c.oid
184 )
185 );
186 "#;
187
188 #[derive(QueryableByName)]
189 struct TableName {
190 #[diesel(sql_type = diesel::sql_types::Text)]
191 table_name: String,
192 }
193
194 let result: Vec<TableName> = diesel::sql_query(select_parent_tables)
195 .load(conn)
196 .map_err(|e| IndexerError::DbMigration(format!("failed to fetch tables: {e}")))?;
197
198 let parent_tables_from_db: HashSet<_> = result.into_iter().map(|t| t.table_name).collect();
199
200 for key in PrunableTable::iter() {
201 if !parent_tables_from_db.contains(key.as_ref()) {
202 return Err(IndexerError::Generic(format!(
203 "invalid retention policy override provided for table {key}: does not exist in the database",
204 )));
205 }
206 }
207
208 info!("Compatibility check passed");
209 Ok(())
210}
211
212pub mod setup_postgres {
213 use anyhow::anyhow;
214 use diesel::{
215 RunQueryDsl,
216 migration::{Migration, MigrationConnection, MigrationSource, MigrationVersion},
217 pg::Pg,
218 prelude::*,
219 };
220 use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
221 use tracing::info;
222
223 use crate::{IndexerError, db::PoolConnection};
224
225 table! {
226 __diesel_schema_migrations (version) {
227 version -> VarChar,
228 run_on -> Timestamp,
229 }
230 }
231
232 const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
233
234 pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
235 info!("Resetting PG database ...");
236
237 let drop_all_tables = "
238 DO $$ DECLARE
239 r RECORD;
240 BEGIN
241 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
242 LOOP
243 EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
244 END LOOP;
245 END $$;";
246 diesel::sql_query(drop_all_tables).execute(conn)?;
247 info!("Dropped all tables.");
248
249 let drop_all_procedures = "
250 DO $$ DECLARE
251 r RECORD;
252 BEGIN
253 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
254 FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
255 WHERE ns.nspname = 'public' AND prokind = 'p')
256 LOOP
257 EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
258 END LOOP;
259 END $$;";
260 diesel::sql_query(drop_all_procedures).execute(conn)?;
261 info!("Dropped all procedures.");
262
263 let drop_all_functions = "
264 DO $$ DECLARE
265 r RECORD;
266 BEGIN
267 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
268 FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
269 WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
270 LOOP
271 EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
272 END LOOP;
273 END $$;";
274 diesel::sql_query(drop_all_functions).execute(conn)?;
275 info!("Dropped all functions.");
276
277 conn.setup()?;
278 info!("Created __diesel_schema_migrations table.");
279
280 run_migrations(conn)?;
281 info!("Reset database complete.");
282 Ok(())
283 }
284
285 pub fn run_migrations(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
287 conn.run_pending_migrations(MIGRATIONS)
288 .map_err(|e| anyhow!("failed to run migrations {e}"))?;
289 Ok(())
290 }
291
292 pub fn check_db_migration_consistency(conn: &mut PoolConnection) -> Result<(), IndexerError> {
313 info!("Starting compatibility check");
314 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
315 IndexerError::DbMigration(format!(
316 "failed to fetch local migrations from schema: {err}"
317 ))
318 })?;
319
320 let local_migrations = migrations
321 .iter()
322 .map(|m| m.name().version())
323 .collect::<Vec<_>>();
324
325 check_db_migration_consistency_impl(conn, local_migrations)?;
326 info!("Compatibility check passed");
327 Ok(())
328 }
329
330 fn check_db_migration_consistency_impl(
331 conn: &mut PoolConnection,
332 local_migrations: Vec<MigrationVersion>,
333 ) -> Result<(), IndexerError> {
334 let applied_migrations: Vec<MigrationVersion> = __diesel_schema_migrations::table
339 .select(__diesel_schema_migrations::version)
340 .order(__diesel_schema_migrations::version.asc())
341 .load(conn)?;
342
343 if local_migrations.len() > applied_migrations.len() {
345 return Err(IndexerError::DbMigration(format!(
346 "the number of local migrations is greater than the number of applied migrations. Local migrations: {local_migrations:?}, Applied migrations: {applied_migrations:?}",
347 )));
348 }
349 for (local_migration, applied_migration) in local_migrations.iter().zip(&applied_migrations)
350 {
351 if local_migration != applied_migration {
352 return Err(IndexerError::DbMigration(format!(
353 "the next applied migration `{applied_migration:?}` diverges from the local migration `{local_migration:?}`",
354 )));
355 }
356 }
357 Ok(())
358 }
359
360 #[cfg(feature = "pg_integration")]
361 #[cfg(test)]
362 mod tests {
363 use diesel::{
364 migration::{Migration, MigrationSource},
365 pg::Pg,
366 };
367 use diesel_migrations::MigrationHarness;
368
369 use crate::{
370 db::setup_postgres::{self, MIGRATIONS},
371 test_utils::{TestDatabase, db_url},
372 };
373
374 #[test]
377 fn db_migration_consistency_smoke_test() {
378 let mut database = TestDatabase::new(db_url("db_migration_consistency_smoke_test"));
379 database.recreate();
380 database.reset_db();
381 {
382 let pool = database.to_connection_pool();
383 let mut conn = pool.get().unwrap();
384 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
385 }
386 database.drop_if_exists();
387 }
388
389 #[test]
390 fn db_migration_consistency_non_prefix_test() {
391 let mut database =
392 TestDatabase::new(db_url("db_migration_consistency_non_prefix_test"));
393 database.recreate();
394 database.reset_db();
395 {
396 let pool = database.to_connection_pool();
397 let mut conn = pool.get().unwrap();
398 conn.revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
399 .unwrap();
400 assert!(setup_postgres::check_db_migration_consistency(&mut conn).is_err());
403
404 conn.run_pending_migrations(MIGRATIONS).unwrap();
405 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
407 }
408 database.drop_if_exists();
409 }
410
411 #[test]
412 fn db_migration_consistency_prefix_test() {
413 let mut database = TestDatabase::new(db_url("db_migration_consistency_prefix_test"));
414 database.recreate();
415 database.reset_db();
416 {
417 let pool = database.to_connection_pool();
418 let mut conn = pool.get().unwrap();
419
420 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
421 let mut local_migrations: Vec<_> =
422 migrations.iter().map(|m| m.name().version()).collect();
423 local_migrations.pop();
424 setup_postgres::check_db_migration_consistency_impl(&mut conn, local_migrations)
427 .unwrap();
428 }
429 database.drop_if_exists();
430 }
431 }
432}