1use std::time::Duration;
6
7use anyhow::anyhow;
8use clap::Args;
9use diesel::{
10 PgConnection,
11 connection::BoxableConnection,
12 query_dsl::RunQueryDsl,
13 r2d2::{ConnectionManager, Pool, PooledConnection, R2D2Connection},
14};
15
16use crate::errors::IndexerError;
17
18pub type ConnectionPool = Pool<ConnectionManager<PgConnection>>;
19pub type PoolConnection = PooledConnection<ConnectionManager<PgConnection>>;
20
21#[derive(Args, Debug, Clone)]
22pub struct ConnectionPoolConfig {
23 #[arg(long, default_value_t = 100)]
24 #[arg(env = "DB_POOL_SIZE")]
25 pub pool_size: u32,
26 #[arg(long, value_parser = parse_duration, default_value = "30")]
27 #[arg(env = "DB_CONNECTION_TIMEOUT")]
28 pub connection_timeout: Duration,
29 #[arg(long, value_parser = parse_duration, default_value = "3600")]
30 #[arg(env = "DB_STATEMENT_TIMEOUT")]
31 pub statement_timeout: Duration,
32}
33
34fn parse_duration(arg: &str) -> Result<std::time::Duration, std::num::ParseIntError> {
35 let seconds = arg.parse()?;
36 Ok(std::time::Duration::from_secs(seconds))
37}
38
39impl ConnectionPoolConfig {
40 pub const DEFAULT_POOL_SIZE: u32 = 100;
41 pub const DEFAULT_CONNECTION_TIMEOUT: u64 = 30;
42 pub const DEFAULT_STATEMENT_TIMEOUT: u64 = 3600;
43
44 fn connection_config(&self) -> ConnectionConfig {
45 ConnectionConfig {
46 statement_timeout: self.statement_timeout,
47 read_only: false,
48 }
49 }
50
51 pub fn set_pool_size(&mut self, size: u32) {
52 self.pool_size = size;
53 }
54
55 pub fn set_connection_timeout(&mut self, timeout: Duration) {
56 self.connection_timeout = timeout;
57 }
58
59 pub fn set_statement_timeout(&mut self, timeout: Duration) {
60 self.statement_timeout = timeout;
61 }
62}
63
64impl Default for ConnectionPoolConfig {
65 fn default() -> Self {
66 Self {
67 pool_size: Self::DEFAULT_POOL_SIZE,
68 connection_timeout: Duration::from_secs(Self::DEFAULT_CONNECTION_TIMEOUT),
69 statement_timeout: Duration::from_secs(Self::DEFAULT_STATEMENT_TIMEOUT),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
75pub struct ConnectionConfig {
76 pub statement_timeout: Duration,
77 pub read_only: bool,
78}
79
80impl<T: R2D2Connection + 'static> diesel::r2d2::CustomizeConnection<T, diesel::r2d2::Error>
81 for ConnectionConfig
82{
83 fn on_acquire(&self, _conn: &mut T) -> std::result::Result<(), diesel::r2d2::Error> {
84 _conn
85 .as_any_mut()
86 .downcast_mut::<diesel::PgConnection>()
87 .map_or_else(
88 || {
89 Err(diesel::r2d2::Error::QueryError(
90 diesel::result::Error::DeserializationError(
91 "Failed to downcast connection to PgConnection"
92 .to_string()
93 .into(),
94 ),
95 ))
96 },
97 |pg_conn| {
98 diesel::sql_query(format!(
99 "SET statement_timeout = {}",
100 self.statement_timeout.as_millis(),
101 ))
102 .execute(pg_conn)
103 .map_err(diesel::r2d2::Error::QueryError)?;
104
105 if self.read_only {
106 diesel::sql_query("SET default_transaction_read_only = 't'")
107 .execute(pg_conn)
108 .map_err(diesel::r2d2::Error::QueryError)?;
109 }
110 Ok(())
111 },
112 )?;
113 Ok(())
114 }
115}
116
117pub fn new_connection_pool(
118 db_url: &str,
119 config: &ConnectionPoolConfig,
120) -> Result<ConnectionPool, IndexerError> {
121 let manager = ConnectionManager::<PgConnection>::new(db_url);
122
123 Pool::builder()
124 .max_size(config.pool_size)
125 .connection_timeout(config.connection_timeout)
126 .connection_customizer(Box::new(config.connection_config()))
127 .build(manager)
128 .map_err(|e| {
129 IndexerError::PgConnectionPoolInit(format!(
130 "Failed to initialize connection pool for {db_url} with error: {e:?}"
131 ))
132 })
133}
134
135pub fn get_pool_connection(pool: &ConnectionPool) -> Result<PoolConnection, IndexerError> {
136 pool.get().map_err(|e| {
137 IndexerError::PgPoolConnection(format!(
138 "Failed to get connection from PG connection pool with error: {:?}",
139 e
140 ))
141 })
142}
143
144pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
145 {
146 conn.as_any_mut()
147 .downcast_mut::<PoolConnection>()
148 .map_or_else(
149 || Err(anyhow!("Failed to downcast connection to PgConnection")),
150 |pg_conn| {
151 setup_postgres::reset_database(pg_conn)?;
152 Ok(())
153 },
154 )?;
155 }
156 Ok(())
157}
158
159pub mod setup_postgres {
160 use anyhow::anyhow;
161 use diesel::{
162 RunQueryDsl,
163 migration::{Migration, MigrationConnection, MigrationSource, MigrationVersion},
164 pg::Pg,
165 prelude::*,
166 };
167 use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
168 use tracing::info;
169
170 use crate::{IndexerError, db::PoolConnection};
171
172 table! {
173 __diesel_schema_migrations (version) {
174 version -> VarChar,
175 run_on -> Timestamp,
176 }
177 }
178
179 const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/pg");
180
181 pub fn reset_database(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
182 info!("Resetting PG database ...");
183
184 let drop_all_tables = "
185 DO $$ DECLARE
186 r RECORD;
187 BEGIN
188 FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
189 LOOP
190 EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
191 END LOOP;
192 END $$;";
193 diesel::sql_query(drop_all_tables).execute(conn)?;
194 info!("Dropped all tables.");
195
196 let drop_all_procedures = "
197 DO $$ DECLARE
198 r RECORD;
199 BEGIN
200 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
201 FROM pg_proc INNER JOIN pg_namespace ns ON (pg_proc.pronamespace = ns.oid)
202 WHERE ns.nspname = 'public' AND prokind = 'p')
203 LOOP
204 EXECUTE 'DROP PROCEDURE IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
205 END LOOP;
206 END $$;";
207 diesel::sql_query(drop_all_procedures).execute(conn)?;
208 info!("Dropped all procedures.");
209
210 let drop_all_functions = "
211 DO $$ DECLARE
212 r RECORD;
213 BEGIN
214 FOR r IN (SELECT proname, oidvectortypes(proargtypes) as argtypes
215 FROM pg_proc INNER JOIN pg_namespace ON (pg_proc.pronamespace = pg_namespace.oid)
216 WHERE pg_namespace.nspname = 'public' AND prokind = 'f')
217 LOOP
218 EXECUTE 'DROP FUNCTION IF EXISTS ' || quote_ident(r.proname) || '(' || r.argtypes || ') CASCADE';
219 END LOOP;
220 END $$;";
221 diesel::sql_query(drop_all_functions).execute(conn)?;
222 info!("Dropped all functions.");
223
224 conn.setup()?;
225 info!("Created __diesel_schema_migrations table.");
226
227 run_migrations(conn)?;
228 info!("Reset database complete.");
229 Ok(())
230 }
231
232 pub fn run_migrations(conn: &mut PoolConnection) -> Result<(), anyhow::Error> {
234 conn.run_pending_migrations(MIGRATIONS)
235 .map_err(|e| anyhow!("Failed to run migrations {e}"))?;
236 Ok(())
237 }
238
239 pub fn check_db_migration_consistency(conn: &mut PoolConnection) -> Result<(), IndexerError> {
260 info!("Starting compatibility check");
261 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().map_err(|err| {
262 IndexerError::DbMigration(format!(
263 "Failed to fetch local migrations from schema: {err}"
264 ))
265 })?;
266
267 let local_migrations = migrations
268 .iter()
269 .map(|m| m.name().version())
270 .collect::<Vec<_>>();
271
272 check_db_migration_consistency_impl(conn, local_migrations)?;
273 info!("Compatibility check passed");
274 Ok(())
275 }
276
277 fn check_db_migration_consistency_impl(
278 conn: &mut PoolConnection,
279 local_migrations: Vec<MigrationVersion>,
280 ) -> Result<(), IndexerError> {
281 let applied_migrations: Vec<MigrationVersion> = __diesel_schema_migrations::table
286 .select(__diesel_schema_migrations::version)
287 .order(__diesel_schema_migrations::version.asc())
288 .load(conn)?;
289
290 if local_migrations.len() > applied_migrations.len() {
292 return Err(IndexerError::DbMigration(format!(
293 "The number of local migrations is greater than the number of applied migrations. Local migrations: {local_migrations:?}, Applied migrations: {applied_migrations:?}",
294 )));
295 }
296 for (local_migration, applied_migration) in local_migrations.iter().zip(&applied_migrations)
297 {
298 if local_migration != applied_migration {
299 return Err(IndexerError::DbMigration(format!(
300 "The next applied migration `{applied_migration:?}` diverges from the local migration `{local_migration:?}`",
301 )));
302 }
303 }
304 Ok(())
305 }
306
307 #[cfg(feature = "pg_integration")]
308 #[cfg(test)]
309 mod tests {
310 use diesel::{
311 migration::{Migration, MigrationSource},
312 pg::Pg,
313 };
314 use diesel_migrations::MigrationHarness;
315
316 use crate::{
317 db::setup_postgres::{self, MIGRATIONS},
318 test_utils::TestDatabase,
319 };
320
321 fn database_url(db_name: &str) -> String {
322 format!("postgres://postgres:postgrespw@localhost:5432/{db_name}")
323 }
324
325 #[test]
328 fn db_migration_consistency_smoke_test() {
329 let mut database =
330 TestDatabase::new(database_url("db_migration_consistency_smoke_test"));
331 database.recreate();
332 database.reset_db();
333 {
334 let pool = database.to_connection_pool();
335 let mut conn = pool.get().unwrap();
336 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
337 }
338 database.drop_if_exists();
339 }
340
341 #[test]
342 fn db_migration_consistency_non_prefix_test() {
343 let mut database =
344 TestDatabase::new(database_url("db_migration_consistency_non_prefix_test"));
345 database.recreate();
346 database.reset_db();
347 {
348 let pool = database.to_connection_pool();
349 let mut conn = pool.get().unwrap();
350 conn.revert_migration(MIGRATIONS.migrations().unwrap().last().unwrap())
351 .unwrap();
352 assert!(setup_postgres::check_db_migration_consistency(&mut conn).is_err());
355
356 conn.run_pending_migrations(MIGRATIONS).unwrap();
357 setup_postgres::check_db_migration_consistency(&mut conn).unwrap();
359 }
360 database.drop_if_exists();
361 }
362
363 #[test]
364 fn db_migration_consistency_prefix_test() {
365 let mut database =
366 TestDatabase::new(database_url("db_migration_consistency_prefix_test"));
367 database.recreate();
368 database.reset_db();
369 {
370 let pool = database.to_connection_pool();
371 let mut conn = pool.get().unwrap();
372
373 let migrations: Vec<Box<dyn Migration<Pg>>> = MIGRATIONS.migrations().unwrap();
374 let mut local_migrations: Vec<_> =
375 migrations.iter().map(|m| m.name().version()).collect();
376 local_migrations.pop();
377 setup_postgres::check_db_migration_consistency_impl(&mut conn, local_migrations)
380 .unwrap();
381 }
382 database.drop_if_exists();
383 }
384 }
385}