diff --git a/refinery/tests/mysql.rs b/refinery/tests/mysql.rs index 0bfdf971..76facb03 100644 --- a/refinery/tests/mysql.rs +++ b/refinery/tests/mysql.rs @@ -45,7 +45,7 @@ mod mysql { .unwrap(); let migration3 = Migration::unapplied( - "V3__add_brand_to_cars_table", + "V3__add_brand_to_cars_table.sql", include_str!("./migrations/V3/V3__add_brand_to_cars_table.sql"), ) .unwrap(); @@ -57,7 +57,7 @@ mod mysql { .unwrap(); let migration5 = Migration::unapplied( - "V5__add_year_field_to_cars", + "V5__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -471,7 +471,7 @@ mod mysql { embedded::migrations::runner().run(&mut conn).unwrap(); let migration = Migration::unapplied( - "V4__add_year_field_to_cars", + "V4__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -507,7 +507,7 @@ mod mysql { embedded::migrations::runner().run(&mut conn).unwrap(); let migration = Migration::unapplied( - "V2__add_year_field_to_cars", + "V2__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -544,7 +544,7 @@ mod mysql { missing::migrations::runner().run(&mut conn).unwrap(); let migration1 = Migration::unapplied( - "V1__initial", + "V1__initial.sql", concat!( "CREATE TABLE persons (", "id int,", @@ -556,7 +556,7 @@ mod mysql { .unwrap(); let migration2 = Migration::unapplied( - "V2__add_cars_table", + "V2__add_cars_table.sql", include_str!("./migrations_missing/V2__add_cars_table.sql"), ) .unwrap(); diff --git a/refinery/tests/mysql_async.rs b/refinery/tests/mysql_async.rs index 4b3e9cff..4ec96e50 100644 --- a/refinery/tests/mysql_async.rs +++ b/refinery/tests/mysql_async.rs @@ -29,7 +29,7 @@ mod mysql_async { .unwrap(); let migration3 = Migration::unapplied( - "V3__add_brand_to_cars_table", + "V3__add_brand_to_cars_table.sql", include_str!("./migrations/V3/V3__add_brand_to_cars_table.sql"), ) .unwrap(); @@ -41,7 +41,7 @@ mod mysql_async { .unwrap(); let migration5 = Migration::unapplied( - "V5__add_year_field_to_cars", + "V5__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -481,7 +481,7 @@ mod mysql_async { .unwrap(); let migration = Migration::unapplied( - "V4__add_year_field_to_cars", + "V4__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -526,7 +526,7 @@ mod mysql_async { .unwrap(); let migration = Migration::unapplied( - "V2__add_year_field_to_cars", + "V2__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -567,7 +567,7 @@ mod mysql_async { .unwrap(); let migration1 = Migration::unapplied( - "V1__initial", + "V1__initial.sql", concat!( "CREATE TABLE persons (", "id int,", @@ -579,7 +579,7 @@ mod mysql_async { .unwrap(); let migration2 = Migration::unapplied( - "V2__add_cars_table", + "V2__add_cars_table.sql", include_str!("./migrations_missing/V2__add_cars_table.sql"), ) .unwrap(); diff --git a/refinery/tests/postgres.rs b/refinery/tests/postgres.rs index 39321ed8..c3125f5b 100644 --- a/refinery/tests/postgres.rs +++ b/refinery/tests/postgres.rs @@ -44,7 +44,7 @@ mod postgres { .unwrap(); let migration3 = Migration::unapplied( - "V3__add_brand_to_cars_table", + "V3__add_brand_to_cars_table.sql", include_str!("./migrations/V3/V3__add_brand_to_cars_table.sql"), ) .unwrap(); @@ -56,7 +56,7 @@ mod postgres { .unwrap(); let migration5 = Migration::unapplied( - "V5__add_year_field_to_cars", + "V5__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -460,7 +460,7 @@ mod postgres { embedded::migrations::runner().run(&mut client).unwrap(); let migration = Migration::unapplied( - "V4__add_year_field_to_cars", + "V4__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -494,7 +494,7 @@ mod postgres { embedded::migrations::runner().run(&mut client).unwrap(); let migration = Migration::unapplied( - "V2__add_year_field_to_cars", + "V2__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -529,7 +529,7 @@ mod postgres { missing::migrations::runner().run(&mut client).unwrap(); let migration1 = Migration::unapplied( - "V1__initial", + "V1__initial.sql", concat!( "CREATE TABLE persons (", "id int,", @@ -541,7 +541,7 @@ mod postgres { .unwrap(); let migration2 = Migration::unapplied( - "V2__add_cars_table", + "V2__add_cars_table.sql", include_str!("./migrations_missing/V2__add_cars_table.sql"), ) .unwrap(); diff --git a/refinery/tests/rusqlite.rs b/refinery/tests/rusqlite.rs index 19fed5b8..b60675c4 100644 --- a/refinery/tests/rusqlite.rs +++ b/refinery/tests/rusqlite.rs @@ -60,7 +60,7 @@ mod rusqlite { .unwrap(); let migration3 = Migration::unapplied( - "V3__add_brand_to_cars_table", + "V3__add_brand_to_cars_table.sql", include_str!("./migrations/V3/V3__add_brand_to_cars_table.sql"), ) .unwrap(); @@ -72,7 +72,7 @@ mod rusqlite { .unwrap(); let migration5 = Migration::unapplied( - "V5__add_year_field_to_cars", + "V5__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -408,7 +408,7 @@ mod rusqlite { embedded::migrations::runner().run(&mut conn).unwrap(); let migration = Migration::unapplied( - "V4__add_year_field_to_cars", + "V4__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -439,7 +439,7 @@ mod rusqlite { embedded::migrations::runner().run(&mut conn).unwrap(); let migration = Migration::unapplied( - "V2__add_year_field_to_cars", + "V2__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -471,7 +471,7 @@ mod rusqlite { missing::migrations::runner().run(&mut conn).unwrap(); let migration1 = Migration::unapplied( - "V1__initial", + "V1__initial.sql", concat!( "CREATE TABLE persons (", "id int,", @@ -483,7 +483,7 @@ mod rusqlite { .unwrap(); let migration2 = Migration::unapplied( - "V2__add_cars_table", + "V2__add_cars_table.sql", include_str!("./migrations_missing/V2__add_cars_table.sql"), ) .unwrap(); diff --git a/refinery/tests/tiberius.rs b/refinery/tests/tiberius.rs index f9011404..165a8214 100644 --- a/refinery/tests/tiberius.rs +++ b/refinery/tests/tiberius.rs @@ -32,7 +32,7 @@ mod tiberius { .unwrap(); let migration3 = Migration::unapplied( - "V3__add_brand_to_cars_table", + "V3__add_brand_to_cars_table.sql", include_str!("./migrations/V3/V3__add_brand_to_cars_table.sql"), ) .unwrap(); @@ -44,7 +44,7 @@ mod tiberius { .unwrap(); let migration5 = Migration::unapplied( - "V5__add_year_field_to_cars", + "V5__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -126,7 +126,7 @@ mod tiberius { .unwrap(); let migration = Migration::unapplied( - "V4__add_year_field_to_cars", + "V4__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -177,7 +177,7 @@ mod tiberius { .unwrap(); let migration = Migration::unapplied( - "V2__add_year_field_to_cars", + "V2__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -229,7 +229,7 @@ mod tiberius { .unwrap(); let migration1 = Migration::unapplied( - "V1__initial", + "V1__initial.sql", concat!( "CREATE TABLE persons (", "id int,", @@ -241,7 +241,7 @@ mod tiberius { .unwrap(); let migration2 = Migration::unapplied( - "V2__add_cars_table", + "V2__add_cars_table.sql", include_str!("./migrations_missing/V2__add_cars_table.sql"), ) .unwrap(); diff --git a/refinery/tests/tokio_postgres.rs b/refinery/tests/tokio_postgres.rs index 2f7f7fc6..60d71767 100644 --- a/refinery/tests/tokio_postgres.rs +++ b/refinery/tests/tokio_postgres.rs @@ -29,7 +29,7 @@ mod tokio_postgres { .unwrap(); let migration3 = Migration::unapplied( - "V3__add_brand_to_cars_table", + "V3__add_brand_to_cars_table.sql", include_str!("./migrations/V3/V3__add_brand_to_cars_table.sql"), ) .unwrap(); @@ -41,7 +41,7 @@ mod tokio_postgres { .unwrap(); let migration5 = Migration::unapplied( - "V5__add_year_field_to_cars", + "V5__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -619,7 +619,7 @@ mod tokio_postgres { .unwrap(); let migration = Migration::unapplied( - "V4__add_year_field_to_cars", + "V4__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -664,7 +664,7 @@ mod tokio_postgres { .unwrap(); let migration = Migration::unapplied( - "V2__add_year_field_to_cars", + "V2__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -711,7 +711,7 @@ mod tokio_postgres { .unwrap(); let migration1 = Migration::unapplied( - "V1__initial", + "V1__initial.sql", concat!( "CREATE TABLE persons (", "id int,", @@ -723,7 +723,7 @@ mod tokio_postgres { .unwrap(); let migration2 = Migration::unapplied( - "V2__add_cars_table", + "V2__add_cars_table.sql", include_str!("./migrations_missing/V2__add_cars_table.sql"), ) .unwrap(); diff --git a/refinery_cli/src/migrate.rs b/refinery_cli/src/migrate.rs index 6ece17d6..d3432968 100644 --- a/refinery_cli/src/migrate.rs +++ b/refinery_cli/src/migrate.rs @@ -36,16 +36,20 @@ fn run_migrations( let mut migrations = Vec::new(); for path in migration_files_path { let sql = std::fs::read_to_string(path.as_path()) - .with_context(|| format!("could not read migration file name {}", path.display()))?; + .with_context(|| format!("Could not read contents of file {}", path.display()))?; //safe to call unwrap as find_migration_filenames returns canonical paths let filename = path - .file_stem() + .file_name() .and_then(|file| file.to_os_string().into_string().ok()) .unwrap(); - let migration = Migration::unapplied(&filename, &sql) - .with_context(|| format!("could not read migration file name {}", path.display()))?; + let migration = Migration::unapplied(&filename, &sql).with_context(|| { + format!( + "Could not create new migration from contents of file {}", + path.display() + ) + })?; migrations.push(migration); } let mut config = config(config_location, env_var_opt)?; diff --git a/refinery_core/src/error.rs b/refinery_core/src/error.rs index 3e865c1e..4d5c4211 100644 --- a/refinery_core/src/error.rs +++ b/refinery_core/src/error.rs @@ -43,11 +43,14 @@ impl std::error::Error for Error { #[derive(Debug, TError)] pub enum Kind { /// An Error from an invalid file name migration - #[error("migration name must be in the format V{{number}}__{{name}}")] - InvalidName, + #[error("migration filename must be in the format V{{number}}__{{name}}.rs|sql")] + InvalidFilename, /// An Error from an invalid version on a file name migration #[error("migration version must be a valid integer")] InvalidVersion, + /// An Error from an invalid version type on a file name migration + #[error("migration version type must be either the V or U character.")] + InvalidType, /// An Error from a repeated version, migration version numbers must be unique #[error("migration {0} is repeated, migration versions must be unique")] RepeatedVersion(Migration), diff --git a/refinery_core/src/runner.rs b/refinery_core/src/runner.rs index aebf0833..ff3f0fa3 100644 --- a/refinery_core/src/runner.rs +++ b/refinery_core/src/runner.rs @@ -11,9 +11,10 @@ use crate::traits::DEFAULT_MIGRATION_TABLE_NAME; use crate::{AsyncMigrate, Error, Migrate}; use std::fmt::Formatter; -// regex used to match file names +// Regex matching migration semantics for filenames. pub fn file_match_re() -> Regex { - Regex::new(r"^([U|V])(\d+(?:\.\d+)?)__(\w+)").unwrap() + Regex::new(r"^(?P[^_])(?P[^_]+)__(?P.+)(?P\.(sql|rs))$") + .unwrap() // } lazy_static::lazy_static! { @@ -81,23 +82,25 @@ pub struct Migration { } impl Migration { - /// Create an unapplied migration, name and version are parsed from the input_name, - /// which must be named in the format (U|V){1}__{2}.rs where {1} represents the migration version and {2} the name. - pub fn unapplied(input_name: &str, sql: &str) -> Result { + /// Create an unapplied migration, name, version and prefix are parsed from the input_name. + /// input_name must be named in the format (U|V){1}__{2}.rs where {1} represents the migration version(integer) and {2} the name. + pub fn unapplied(input_file_name: &str, sql: &str) -> Result { let captures = RE - .captures(input_name) - .filter(|caps| caps.len() == 4) - .ok_or_else(|| Error::new(Kind::InvalidName, None))?; - let version: i32 = captures[2] + .captures(input_file_name) + .ok_or_else(|| Error::new(Kind::InvalidFilename, None))?; + let version: i32 = captures + .name("version") + .unwrap() + .as_str() .parse() .map_err(|_| Error::new(Kind::InvalidVersion, None))?; - let name: String = (&captures[3]).into(); - let prefix = match &captures[1] { - "V" => Type::Versioned, - "U" => Type::Unversioned, - _ => unreachable!(), - }; + let name: String = captures.name("name").unwrap().as_str().to_owned(); + let prefix = match captures.name("type").unwrap().as_str() { + "V" => Ok(Type::Versioned), + "U" => Ok(Type::Unversioned), + _ => Err(Error::new(Kind::InvalidType, None)), + }?; // Previously, `std::collections::hash_map::DefaultHasher` was used // to calculate the checksum and the implementation at that time @@ -209,6 +212,88 @@ impl PartialOrd for Migration { } } +#[cfg(test)] +mod tests { + use super::{Error, Kind, Migration}; + + fn is_invalid_version(err: Error) -> bool { + match err.kind() { + Kind::InvalidVersion => true, + _ => false, + } + } + + fn is_invalid_type(err: Error) -> bool { + match err.kind() { + Kind::InvalidType => true, + _ => false, + } + } + + fn is_invalid_filename(err: Error) -> bool { + match err.kind() { + Kind::InvalidFilename => true, + _ => false, + } + } + + #[test] + fn filename_has_bad_extension() { + assert!(is_invalid_filename( + Migration::unapplied("V1__name.txt", "select 1").expect_err("expected error") + )); + } + + #[test] + fn filename_stem_missing_double_underscores() { + assert!(is_invalid_filename( + Migration::unapplied("V1_name.rs", "select 1").expect_err("expected error") + )); + } + + #[test] + fn filename_stem_has_bad_version_number_format() { + assert!(is_invalid_version( + Migration::unapplied("V1.1__name.rs", "select 1").expect_err("expected error") + )); + assert!(is_invalid_version( + Migration::unapplied("V1f__name.rs", "select 1").expect_err("expected error") + )); + assert!(is_invalid_version( + Migration::unapplied("V0,5__name.rs", "select 1").expect_err("expected error") + )); + assert!(is_invalid_version( + Migration::unapplied("Vff__name.rs", "select 1").expect_err("expected error") + )); + } + #[test] + fn filename_stem_has_bad_prefix_format() { + assert!(is_invalid_type( + Migration::unapplied("z1__name.rs", "select 1").expect_err("expected error") + )); + assert!(is_invalid_type( + Migration::unapplied("v1__name.rs", "select 1").expect_err("expected error") + )); + assert!(is_invalid_type( + Migration::unapplied("u1__name.rs", "select 1").expect_err("expected error") + )); + } + + #[test] + fn Filename_has_good_format() { + // accepted prefix variants + assert!(Migration::unapplied("V1__name.sql", "select 1").is_ok()); + assert!(Migration::unapplied("U1__name.rs", "select 1").is_ok()); + // accepted version number format + assert!(Migration::unapplied("V1__name.rs", "select 1").is_ok()); + assert!(Migration::unapplied("V001__name.rs", "select 1").is_ok()); + assert!(Migration::unapplied("V000__name.rs", "select 1").is_ok()); + // accepted migration name + assert!(Migration::unapplied("V000__name-with-dashes.rs", "select 1").is_ok()); + assert!(Migration::unapplied("V000__name with spaces.rs", "select 1").is_ok()); + assert!(Migration::unapplied("V000__name1with2numbers.rs", "select 1").is_ok()); + } +} /// Struct that represents the report of the migration cycle, /// a `Report` instance is returned by the [`Runner::run`] and [`Runner::run_async`] methods /// via [`Result`]``, on case of an [`Error`] during a migration, you can access the `Report` with [`Error.report`] diff --git a/refinery_core/src/traits/mod.rs b/refinery_core/src/traits/mod.rs index b070aeeb..02f484af 100644 --- a/refinery_core/src/traits/mod.rs +++ b/refinery_core/src/traits/mod.rs @@ -124,13 +124,13 @@ mod tests { .unwrap(); let migration3 = Migration::unapplied( - "V3__add_brand_to_cars_table", + "V3__add_brand_to_cars_table.sql", include_str!("../../../refinery/tests/migrations/V3/V3__add_brand_to_cars_table.sql"), ) .unwrap(); let migration4 = Migration::unapplied( - "V4__add_year_field_to_cars", + "V4__add_year_field_to_cars.sql", "ALTER TABLE cars ADD year INTEGER;", ) .unwrap(); @@ -166,7 +166,7 @@ mod tests { migrations[0].clone(), migrations[1].clone(), Migration::unapplied( - "V3__add_brand_to_cars_tableeee", + "V3__add_brand_to_cars_tableeee.sql", include_str!( "../../../refinery/tests/migrations/V3/V3__add_brand_to_cars_table.sql" ), @@ -192,7 +192,7 @@ mod tests { migrations[0].clone(), migrations[1].clone(), Migration::unapplied( - "V3__add_brand_to_cars_tableeee", + "V3__add_brand_to_cars_tableeee.sql", include_str!( "../../../refinery/tests/migrations/V3/V3__add_brand_to_cars_table.sql" ), @@ -264,7 +264,7 @@ mod tests { let mut migrations = get_migrations(); migrations.push( Migration::unapplied( - "U0__merge_out_of_order", + "U0__merge_out_of_order.sql", include_str!( "../../../refinery/tests/migrations_unversioned/U0__merge_out_of_order.sql" ), diff --git a/refinery_core/src/util.rs b/refinery_core/src/util.rs index d46dea27..122d5f36 100644 --- a/refinery_core/src/util.rs +++ b/refinery_core/src/util.rs @@ -4,10 +4,6 @@ use std::ffi::OsStr; use std::path::{Path, PathBuf}; use walkdir::{DirEntry, WalkDir}; -lazy_static::lazy_static! { - static ref RE: regex::Regex = Regex::new(r"^(U|V)(\d+(?:\.\d+)?)__\w+\.(rs|sql)$").unwrap(); -} - /// enum containing the migration types used to search for migrations /// either just .sql files or both .sql and .rs pub enum MigrationType { @@ -21,7 +17,7 @@ impl MigrationType { MigrationType::All => "(rs|sql)", MigrationType::Sql => "sql", }; - let re_str = format!(r"^(U|V)(\d+(?:\.\d+)?)__(\w+)\.{}$", ext); + let re_str = format!(r"^.*\.{}$", ext); Regex::new(re_str.as_str()).unwrap() } } @@ -44,15 +40,12 @@ pub fn find_migration_files( .into_iter() .filter_map(Result::ok) .map(DirEntry::into_path) - // filter by migration file regex + // Filter by migration type encoded in file extension. .filter( move |entry| match entry.file_name().and_then(OsStr::to_str) { Some(file_name) if re.is_match(file_name) => true, Some(file_name) => { - log::warn!( - "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.", - file_name - ); + log::warn!("Filename \"{}\" has not supported extension.", file_name); false } None => false, @@ -70,84 +63,42 @@ mod tests { use tempfile::TempDir; #[test] - fn finds_mod_migrations() { + fn ignores_files_without_supported_file_extension() { let tmp_dir = TempDir::new().unwrap(); let migrations_dir = tmp_dir.path().join("migrations"); fs::create_dir(&migrations_dir).unwrap(); - let sql1 = migrations_dir.join("V1__first.rs"); - fs::File::create(&sql1).unwrap(); - let sql2 = migrations_dir.join("V2__second.rs"); - fs::File::create(&sql2).unwrap(); - - let mut mods: Vec = find_migration_files(migrations_dir, MigrationType::All) - .unwrap() - .collect(); - mods.sort(); - assert_eq!(sql1.canonicalize().unwrap(), mods[0]); - assert_eq!(sql2.canonicalize().unwrap(), mods[1]); - } + let file1 = migrations_dir.join("V1__first.txt"); + fs::File::create(&file1).unwrap(); - #[test] - fn ignores_mod_files_without_migration_regex_match() { - let tmp_dir = TempDir::new().unwrap(); - let migrations_dir = tmp_dir.path().join("migrations"); - fs::create_dir(&migrations_dir).unwrap(); - let sql1 = migrations_dir.join("V1first.rs"); - fs::File::create(&sql1).unwrap(); - let sql2 = migrations_dir.join("V2second.rs"); - fs::File::create(&sql2).unwrap(); + let mut all = find_migration_files(migrations_dir, MigrationType::All).unwrap(); + assert!(all.next().is_none()); - let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap(); - assert!(mods.next().is_none()); + let sql_migrations_dir = tmp_dir.path().join("migrations"); + let mut sqls = find_migration_files(sql_migrations_dir, MigrationType::Sql).unwrap(); + assert!(sqls.next().is_none()); } #[test] - fn finds_sql_migrations() { + fn finds_files_with_supported_file_extension() { let tmp_dir = TempDir::new().unwrap(); let migrations_dir = tmp_dir.path().join("migrations"); fs::create_dir(&migrations_dir).unwrap(); - let sql1 = migrations_dir.join("V1__first.sql"); - fs::File::create(&sql1).unwrap(); - let sql2 = migrations_dir.join("V2__second.sql"); - fs::File::create(&sql2).unwrap(); + let file1 = migrations_dir.join("V1__all_good.rs"); + fs::File::create(&file1).unwrap(); + let file2 = migrations_dir.join("V2_invalid_format_but_good_extension.sql"); + fs::File::create(&file2).unwrap(); - let mut mods: Vec = find_migration_files(migrations_dir, MigrationType::All) + let sqls: Vec = find_migration_files(migrations_dir, MigrationType::Sql) .unwrap() .collect(); - mods.sort(); - assert_eq!(sql1.canonicalize().unwrap(), mods[0]); - assert_eq!(sql2.canonicalize().unwrap(), mods[1]); - } + assert_eq!(file2.canonicalize().unwrap(), sqls[0]); - #[test] - fn finds_unversioned_migrations() { - let tmp_dir = TempDir::new().unwrap(); - let migrations_dir = tmp_dir.path().join("migrations"); - fs::create_dir(&migrations_dir).unwrap(); - let sql1 = migrations_dir.join("U1__first.sql"); - fs::File::create(&sql1).unwrap(); - let sql2 = migrations_dir.join("U2__second.sql"); - fs::File::create(&sql2).unwrap(); - - let mut mods: Vec = find_migration_files(migrations_dir, MigrationType::All) + let all_migrations_dir = tmp_dir.path().join("migrations"); + let mut all: Vec = find_migration_files(all_migrations_dir, MigrationType::All) .unwrap() .collect(); - mods.sort(); - assert_eq!(sql1.canonicalize().unwrap(), mods[0]); - assert_eq!(sql2.canonicalize().unwrap(), mods[1]); - } - - #[test] - fn ignores_sql_files_without_migration_regex_match() { - let tmp_dir = TempDir::new().unwrap(); - let migrations_dir = tmp_dir.path().join("migrations"); - fs::create_dir(&migrations_dir).unwrap(); - let sql1 = migrations_dir.join("V1first.sql"); - fs::File::create(&sql1).unwrap(); - let sql2 = migrations_dir.join("V2second.sql"); - fs::File::create(&sql2).unwrap(); - - let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap(); - assert!(mods.next().is_none()); + all.sort(); + assert_eq!(file1.canonicalize().unwrap(), all[0]); + assert_eq!(file2.canonicalize().unwrap(), all[1]); } } diff --git a/refinery_macros/src/lib.rs b/refinery_macros/src/lib.rs index 891ddd48..7e147485 100644 --- a/refinery_macros/src/lib.rs +++ b/refinery_macros/src/lib.rs @@ -20,10 +20,10 @@ fn migration_fn_quoted(_migrations: Vec) -> TokenStream2 { let result = quote! { use refinery::{Migration, Runner}; pub fn runner() -> Runner { - let quoted_migrations: Vec<(&str, String)> = vec![#(#_migrations),*]; + let quoted_migrations: Vec<(&str, &str, String)> = vec![#(#_migrations),*]; let mut migrations: Vec = Vec::new(); for module in quoted_migrations.into_iter() { - migrations.push(Migration::unapplied(module.0, &module.1).unwrap()); + migrations.push(Migration::unapplied(format!("{}.{}",module.0,module.1).as_str(), &module.2).unwrap()); } Runner::new(&migrations) } @@ -59,27 +59,31 @@ pub fn embed_migrations(input: TokenStream) -> TokenStream { for migration in migration_files { // safe to call unwrap as find_migration_filenames returns canonical paths - let filename = migration + let file_stem = migration .file_stem() - .and_then(|file| file.to_os_string().into_string().ok()) + .and_then(|stem| stem.to_os_string().into_string().ok()) .unwrap(); let path = migration.display().to_string(); - let extension = migration.extension().unwrap(); + let file_extension = migration + .extension() + .and_then(|ext| ext.to_os_string().into_string().ok()) + .unwrap(); - if extension == "sql" { - _migrations.push(quote! {(#filename, include_str!(#path).to_string())}); - } else if extension == "rs" { + if file_extension == "sql" { + _migrations + .push(quote! {(#file_stem, #file_extension, include_str!(#path).to_string())}); + } else if file_extension == "rs" { let rs_content = fs::read_to_string(&path) .unwrap() .parse::() .unwrap(); - let ident = Ident::new(&filename, Span2::call_site()); + let ident = Ident::new(&file_stem, Span2::call_site()); let mig_mod = quote! {pub mod #ident { #rs_content // also include the file as str so we trigger recompilation if it changes const _recompile_if_changed: &str = include_str!(#path); }}; - _migrations.push(quote! {(#filename, #ident::migration())}); + _migrations.push(quote! {(#file_stem, #file_extension, #ident::migration())}); migrations_mods.push(mig_mod); } } @@ -100,14 +104,14 @@ mod tests { #[test] fn test_quote_fn() { - let migs = vec![quote!("V1__first", "valid_sql_file")]; + let migs = vec![quote!("V1__first", "sql", "valid_sql_file")]; let expected = concat! { "use refinery :: { Migration , Runner } ; ", "pub fn runner () -> Runner { ", - "let quoted_migrations : Vec < (& str , String) > = vec ! [\"V1__first\" , \"valid_sql_file\"] ; ", + "let quoted_migrations : Vec < (& str , & str , String) > = vec ! [\"V1__first\" , \"sql\" , \"valid_sql_file\"] ; ", "let mut migrations : Vec < Migration > = Vec :: new () ; ", "for module in quoted_migrations . into_iter () { ", - "migrations . push (Migration :: unapplied (module . 0 , & module . 1) . unwrap ()) ; ", + "migrations . push (Migration :: unapplied (format ! (\"{}.{}\" , module . 0 , module . 1) . as_str () , & module . 2) . unwrap ()) ; ", "} ", "Runner :: new (& migrations) }" };