From dc95b47818ed0dda3c27f14d69acd7bf994a5643 Mon Sep 17 00:00:00 2001 From: Matt Palmer Date: Sat, 4 May 2024 09:08:44 +1000 Subject: [PATCH 1/2] Allow Postgres tests to be run on a different database Not everyone has a "scratch" PostgreSQL running on localhost:5432 for refinery to scribble all over. Now you can specify an arbitrary PostgreSQL server to work on with the `DB_URI` environment variable (which appears to be what `refinery-cli` already uses) to test in. --- refinery/tests/postgres.rs | 108 ++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 63 deletions(-) diff --git a/refinery/tests/postgres.rs b/refinery/tests/postgres.rs index 2ee4e756..d58950b3 100644 --- a/refinery/tests/postgres.rs +++ b/refinery/tests/postgres.rs @@ -5,13 +5,11 @@ mod postgres { use assert_cmd::prelude::*; use predicates::str::contains; use refinery::{ - config::{Config, ConfigDbType}, - embed_migrations, - error::Kind, - Migrate, Migration, Runner, Target, + config::Config, embed_migrations, error::Kind, Migrate, Migration, Runner, Target, }; use refinery_core::postgres::{Client, NoTls}; use std::process::Command; + use std::str::FromStr; use time::OffsetDateTime; const DEFAULT_TABLE_NAME: &str = "refinery_schema_history"; @@ -31,6 +29,10 @@ mod postgres { embed_migrations!("./tests/migrations_missing"); } + fn db_uri() -> String { + std::env::var("DB_URI").unwrap_or("postgres://postgres@localhost:5432/postgres".to_string()) + } + fn get_migrations() -> Vec { embed_migrations!("./tests/migrations"); @@ -65,35 +67,44 @@ mod postgres { } fn clean_database() { - let mut client = - Client::connect("postgres://postgres@localhost:5432/template1", NoTls).unwrap(); + let uri = db_uri(); + let db_name = uri.split('/').last().unwrap(); + + let mut client = Client::connect( + &(uri.strip_suffix(db_name).unwrap().to_string() + "template1"), + NoTls, + ) + .unwrap(); client .execute( - "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname='postgres'", - &[], + "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname=$1", + &[&db_name], ) .unwrap(); - client.execute("DROP DATABASE POSTGRES", &[]).unwrap(); - client.execute("CREATE DATABASE POSTGRES", &[]).unwrap(); + client + .execute(&"DROP DATABASE IF EXISTS $1".replace("$1", db_name), &[]) + .unwrap(); + client + .execute(&"CREATE DATABASE $1".replace("$1", db_name), &[]) + .unwrap(); } fn run_test(test: T) where T: FnOnce() + std::panic::UnwindSafe, { - let result = std::panic::catch_unwind(test); - clean_database(); + let result = std::panic::catch_unwind(test); + assert!(result.is_ok()) } #[test] fn report_contains_applied_migrations() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); let report = embedded::migrations::runner().run(&mut client).unwrap(); @@ -122,8 +133,7 @@ mod postgres { #[test] fn creates_migration_table() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner().run(&mut client).unwrap(); for row in &client .query( @@ -144,8 +154,7 @@ mod postgres { #[test] fn creates_migration_table_grouped_transaction() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner() .set_grouped(true) @@ -171,8 +180,7 @@ mod postgres { #[test] fn applies_migration() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner().run(&mut client).unwrap(); client .execute( @@ -192,8 +200,7 @@ mod postgres { #[test] fn applies_migration_grouped_transaction() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner() .set_grouped(false) @@ -218,8 +225,7 @@ mod postgres { #[test] fn updates_schema_history() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner().run(&mut client).unwrap(); @@ -239,8 +245,7 @@ mod postgres { #[test] fn updates_schema_history_grouped_transaction() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner() .set_grouped(false) @@ -262,8 +267,7 @@ mod postgres { #[test] fn updates_to_last_working_if_not_grouped() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); let result = broken::migrations::runner().run(&mut client); @@ -300,8 +304,7 @@ mod postgres { #[test] fn doesnt_update_to_last_working_if_grouped() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); let result = broken::migrations::runner() .set_grouped(true) @@ -320,8 +323,7 @@ mod postgres { #[test] fn gets_applied_migrations() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner().run(&mut client).unwrap(); @@ -349,8 +351,7 @@ mod postgres { #[test] fn applies_new_migration() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner().run(&mut client).unwrap(); @@ -381,8 +382,7 @@ mod postgres { #[test] fn migrates_to_target_migration() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); let report = embedded::migrations::runner() .set_target(Target::Version(3)) @@ -417,8 +417,7 @@ mod postgres { #[test] fn migrates_to_target_migration_grouped() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); let report = embedded::migrations::runner() .set_target(Target::Version(3)) @@ -454,8 +453,7 @@ mod postgres { #[test] fn aborts_on_missing_migration_on_filesystem() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner().run(&mut client).unwrap(); @@ -488,8 +486,7 @@ mod postgres { #[test] fn aborts_on_divergent_migration() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); embedded::migrations::runner().run(&mut client).unwrap(); @@ -523,8 +520,7 @@ mod postgres { #[test] fn aborts_on_missing_migration_on_database() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); missing::migrations::runner().run(&mut client).unwrap(); @@ -568,11 +564,7 @@ mod postgres { #[test] fn migrates_from_config() { run_test(|| { - let mut config = Config::new(ConfigDbType::Postgres) - .set_db_name("postgres") - .set_db_user("postgres") - .set_db_host("localhost") - .set_db_port("5432"); + let mut config = Config::from_str(&db_uri()).unwrap(); let migrations = get_migrations(); let runner = Runner::new(&migrations) @@ -608,11 +600,7 @@ mod postgres { #[test] fn migrate_from_config_report_contains_migrations() { run_test(|| { - let mut config = Config::new(ConfigDbType::Postgres) - .set_db_name("postgres") - .set_db_user("postgres") - .set_db_host("localhost") - .set_db_port("5432"); + let mut config = Config::from_str(&db_uri()).unwrap(); let migrations = get_migrations(); let runner = Runner::new(&migrations) @@ -648,11 +636,7 @@ mod postgres { #[test] fn migrate_from_config_report_returns_last_applied_migration() { run_test(|| { - let mut config = Config::new(ConfigDbType::Postgres) - .set_db_name("postgres") - .set_db_user("postgres") - .set_db_host("localhost") - .set_db_port("5432"); + let mut config = Config::from_str(&db_uri()).unwrap(); let migrations = get_migrations(); let runner = Runner::new(&migrations) @@ -677,8 +661,7 @@ mod postgres { #[test] fn doesnt_run_migrations_if_fake() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); let report = embedded::migrations::runner() .set_target(Target::Fake) @@ -712,8 +695,7 @@ mod postgres { #[test] fn doesnt_run_migrations_if_fake_version() { run_test(|| { - let mut client = - Client::connect("postgres://postgres@localhost:5432/postgres", NoTls).unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); let report = embedded::migrations::runner() .set_target(Target::FakeVersion(2)) From 2955bdd6f227045f04ee19c6667158f78b120339 Mon Sep 17 00:00:00 2001 From: Matt Palmer Date: Thu, 9 May 2024 16:50:55 +1000 Subject: [PATCH 2/2] Improve DB reset process * Use a more appropriate name for the function that does the work * Clean just the `public` schema, rather than drop/create the whole DB. This means that running the tests no longer requires superuser privs, and that we don't have to temporarily hide out in `template1`. * Drop the `catch_unwind`, because it's needed any more. --- refinery/tests/postgres.rs | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/refinery/tests/postgres.rs b/refinery/tests/postgres.rs index d58950b3..59a263b8 100644 --- a/refinery/tests/postgres.rs +++ b/refinery/tests/postgres.rs @@ -66,39 +66,26 @@ mod postgres { vec![migration1, migration2, migration3, migration4, migration5] } - fn clean_database() { + fn prep_database() { let uri = db_uri(); - let db_name = uri.split('/').last().unwrap(); - let mut client = Client::connect( - &(uri.strip_suffix(db_name).unwrap().to_string() + "template1"), - NoTls, - ) - .unwrap(); + let mut client = Client::connect(&db_uri(), NoTls).unwrap(); client - .execute( - "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname=$1", - &[&db_name], - ) - .unwrap(); - client - .execute(&"DROP DATABASE IF EXISTS $1".replace("$1", db_name), &[]) + .execute("DROP SCHEMA IF EXISTS public CASCADE", &[]) .unwrap(); client - .execute(&"CREATE DATABASE $1".replace("$1", db_name), &[]) + .execute("CREATE SCHEMA IF NOT EXISTS public", &[]) .unwrap(); } fn run_test(test: T) where - T: FnOnce() + std::panic::UnwindSafe, + T: FnOnce(), { - clean_database(); - - let result = std::panic::catch_unwind(test); + prep_database(); - assert!(result.is_ok()) + test(); } #[test]