diff --git a/sqlx-core/src/any/database.rs b/sqlx-core/src/any/database.rs index 9c3f15bb1f..0b45116a13 100644 --- a/sqlx-core/src/any/database.rs +++ b/sqlx-core/src/any/database.rs @@ -33,6 +33,8 @@ impl Database for Any { const NAME: &'static str = "Any"; const URL_SCHEMES: &'static [&'static str] = &[]; + + const TYPE_IMPORT_PATH: &'static str = "sqlx::any::database::Any"; } // This _may_ be true, depending on the selected database diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index d17621c719..0f130546f0 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -108,6 +108,10 @@ pub trait Database: 'static + Sized + Send + Debug { /// The schemes for database URLs that should match this driver. const URL_SCHEMES: &'static [&'static str]; + + // This can be removed once https://github.com/rust-lang/rust/issues/63084 is resolved and type_name is available in const fns. + /// The path to the database-specific type system. + const TYPE_IMPORT_PATH: &'static str; } /// A [`Database`] that maintains a client-side cache of prepared statements. diff --git a/sqlx-macros-core/src/query/input.rs b/sqlx-macros-core/src/query/input.rs index 63e35ec77d..c664d8190a 100644 --- a/sqlx-macros-core/src/query/input.rs +++ b/sqlx-macros-core/src/query/input.rs @@ -7,6 +7,7 @@ use syn::{Expr, LitBool, LitStr, Token}; use syn::{ExprArray, Type}; /// Macro input shared by `query!()` and `query_file!()` +#[derive(Clone)] pub struct QueryMacroInput { pub(super) sql: String, @@ -19,6 +20,9 @@ pub struct QueryMacroInput { pub(super) checked: bool, pub(super) file_path: Option, + + // TODO: This should be some type and not a string + pub(super) driver: Option, } enum QuerySrc { @@ -26,6 +30,7 @@ enum QuerySrc { File(String), } +#[derive(Clone)] pub enum RecordType { Given(Type), Scalar, @@ -38,6 +43,7 @@ impl Parse for QueryMacroInput { let mut args: Option> = None; let mut record_type = RecordType::Generated; let mut checked = true; + let mut driver = None; let mut expect_comma = false; @@ -82,6 +88,9 @@ impl Parse for QueryMacroInput { } else if key == "checked" { let lit_bool = input.parse::()?; checked = lit_bool.value; + } else if key == "driver" { + // TODO: This should be some actual type and not a string + driver = Some(input.parse::()?); } else { let message = format!("unexpected input key: {key}"); return Err(syn::Error::new_spanned(key, message)); @@ -104,6 +113,7 @@ impl Parse for QueryMacroInput { arg_exprs, checked, file_path, + driver, }) } } diff --git a/sqlx-macros-core/src/query/mod.rs b/sqlx-macros-core/src/query/mod.rs index 1536eebaa1..8fe0a05865 100644 --- a/sqlx-macros-core/src/query/mod.rs +++ b/sqlx-macros-core/src/query/mod.rs @@ -27,6 +27,7 @@ pub struct QueryDriver { db_name: &'static str, url_schemes: &'static [&'static str], expand: fn(QueryMacroInput, QueryDataSource) -> crate::Result, + db_type_name: &'static str, } impl QueryDriver { @@ -38,41 +39,94 @@ impl QueryDriver { db_name: DB::NAME, url_schemes: DB::URL_SCHEMES, expand: expand_with::, + db_type_name: DB::TYPE_IMPORT_PATH, } } } + +#[derive(Clone)] +pub struct QueryDataSourceUrl<'a> { + database_url: &'a str, + database_url_parsed: Url, +} + +impl<'a> From<&'a String> for QueryDataSourceUrl<'a> { + fn from(database_url: &'a String) -> Self { + let database_url_parsed = Url::parse(database_url).expect("invalid URL"); + + QueryDataSourceUrl { + database_url, + database_url_parsed, + } + } +} + +#[derive(Clone)] pub enum QueryDataSource<'a> { Live { - database_url: &'a str, - database_url_parsed: Url, + database_urls: Vec>, }, Cached(DynQueryData), } impl<'a> QueryDataSource<'a> { - pub fn live(database_url: &'a str) -> crate::Result { + pub fn live(database_urls: Vec>) -> crate::Result { Ok(QueryDataSource::Live { - database_url, - database_url_parsed: database_url.parse()?, + database_urls, }) } pub fn matches_driver(&self, driver: &QueryDriver) -> bool { match self { Self::Live { - database_url_parsed, + database_urls, .. - } => driver.url_schemes.contains(&database_url_parsed.scheme()), + } => driver.url_schemes.iter().any(|scheme| { + database_urls.iter().any(|url| url.database_url_parsed.scheme() == *scheme) + }), Self::Cached(dyn_data) => dyn_data.db_name == driver.db_name, } } + + pub fn get_url_for_schemes(&self, schemes: &[&str]) -> Option<&QueryDataSourceUrl> { + match self { + Self::Live { + database_urls, + .. + } => { + for scheme in schemes { + if let Some(url) = database_urls.iter().find(|url| url.database_url_parsed.scheme() == *scheme) { + return Some(url); + } + } + None + } + Self::Cached(_) => { + None + } + } + } + + pub fn supported_schemes(&self) -> Vec<&str> { + match self { + Self::Live { + database_urls, + .. + } => { + let mut schemes = vec![]; + schemes.extend(database_urls.iter().map(|url| url.database_url_parsed.scheme())); + schemes + } + Self::Cached(..) => vec![], + } + } } struct Metadata { #[allow(unused)] manifest_dir: PathBuf, offline: bool, - database_url: Option, + database_urls: Vec, workspace_root: Arc>>, } @@ -139,12 +193,10 @@ static METADATA: Lazy = Lazy::new(|| { .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); - let database_url = env("DATABASE_URL").ok(); - Metadata { manifest_dir, offline, - database_url, + database_urls: env_db_urls(), workspace_root: Arc::new(Mutex::new(None)), } }); @@ -156,9 +208,11 @@ pub fn expand_input<'a>( let data_source = match &*METADATA { Metadata { offline: false, - database_url: Some(db_url), + database_urls: db_urls, .. - } => QueryDataSource::live(db_url)?, + } => { + QueryDataSource::live(db_urls.iter().map(QueryDataSourceUrl::from).collect())? + }, Metadata { offline, .. } => { // Try load the cached query metadata file. @@ -189,6 +243,54 @@ pub fn expand_input<'a>( } }; + let mut working_drivers = vec![]; + + // If the driver was explicitly set, use it directly. + if let Some(input_driver) = input.driver.clone() { + for driver in drivers { + if (driver.expand)(input.clone(), data_source.clone()).is_ok() { + working_drivers.push(driver); + } + } + + return match working_drivers.len() { + 0 => { + Err(format!( + "no database driver found matching for query; the corresponding Cargo feature may need to be enabled" + ).into()) + } + 1 => { + let driver = working_drivers.pop().unwrap(); + (driver.expand)(input, data_source) + } + _ => { + let expansions = working_drivers.iter().map(|driver| { + let driver_name = driver.db_type_name; + let driver_type: Type = syn::parse_str(driver_name).unwrap(); + let expanded = (driver.expand)(input.clone(), data_source.clone()).unwrap(); + quote! { + impl ProvideQuery<#driver_type> for #driver_type { + fn provide_query<'a>() -> Query<'a, #driver_type, <#driver_type as sqlx::Database>::Arguments<'a>> { + #expanded + } + } + } + }); + Ok(quote! { + { + use sqlx::query::Query; + trait ProvideQuery { + fn provide_query<'a>() -> Query<'a, DB, DB::Arguments<'a>>; + } + #(#expansions)* + #input_driver::provide_query() + } + }) + } + } + } + + // If no driver was set, try to find a matching driver for the data source. for driver in drivers { if data_source.matches_driver(driver) { return (driver.expand)(input, data_source); @@ -196,12 +298,9 @@ pub fn expand_input<'a>( } match data_source { - QueryDataSource::Live { - database_url_parsed, - .. - } => Err(format!( + QueryDataSource::Live{..} => Err(format!( "no database driver found matching URL scheme {:?}; the corresponding Cargo feature may need to be enabled", - database_url_parsed.scheme() + data_source.supported_schemes() ).into()), QueryDataSource::Cached(data) => { Err(format!( @@ -221,8 +320,9 @@ where { let (query_data, offline): (QueryData, bool) = match data_source { QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, true), - QueryDataSource::Live { database_url, .. } => { - let describe = DB::describe_blocking(&input.sql, database_url)?; + QueryDataSource::Live { .. } => { + let data_source_url = data_source.get_url_for_schemes(DB::URL_SCHEMES).unwrap(); + let describe = DB::describe_blocking(&input.sql, data_source_url.database_url)?; (QueryData::from_describe(&input.sql, describe), false) } }; @@ -386,3 +486,7 @@ fn env(name: &str) -> Result { std::env::var(name) } } + +fn env_db_urls() -> Vec { + std::env::vars().filter(|(k, _)| k.starts_with("DATABASE_URL")).map(|(_, v)| v).collect() +} diff --git a/sqlx-mysql/src/database.rs b/sqlx-mysql/src/database.rs index d03a567284..97a93f43a2 100644 --- a/sqlx-mysql/src/database.rs +++ b/sqlx-mysql/src/database.rs @@ -33,6 +33,8 @@ impl Database for MySql { const NAME: &'static str = "MySQL"; const URL_SCHEMES: &'static [&'static str] = &["mysql", "mariadb"]; + + const TYPE_IMPORT_PATH: &'static str = "sqlx::mysql::MySql"; } impl HasStatementCache for MySql {} diff --git a/sqlx-postgres/src/database.rs b/sqlx-postgres/src/database.rs index 876e295899..7956acbabf 100644 --- a/sqlx-postgres/src/database.rs +++ b/sqlx-postgres/src/database.rs @@ -35,6 +35,8 @@ impl Database for Postgres { const NAME: &'static str = "PostgreSQL"; const URL_SCHEMES: &'static [&'static str] = &["postgres", "postgresql"]; + + const TYPE_IMPORT_PATH: &'static str = "sqlx::postgres::Postgres"; } impl HasStatementCache for Postgres {} diff --git a/sqlx-sqlite/src/database.rs b/sqlx-sqlite/src/database.rs index c89c7b8322..7d70975eaa 100644 --- a/sqlx-sqlite/src/database.rs +++ b/sqlx-sqlite/src/database.rs @@ -34,6 +34,8 @@ impl Database for Sqlite { const NAME: &'static str = "SQLite"; const URL_SCHEMES: &'static [&'static str] = &["sqlite"]; + + const TYPE_IMPORT_PATH: &'static str = "sqlx::sqlite::Sqlite"; } impl HasStatementCache for Sqlite {} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 2523c259d7..24e44d5d11 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -316,6 +316,9 @@ macro_rules! query ( ($query:expr) => ({ $crate::sqlx_macros::expand_query!(source = $query) }); + ($driver:ty, $query:expr) => ({ + $crate::sqlx_macros::expand_query!(source = $query, driver = $driver) + }); // RFC: this semantically should be `$($args:expr),*` (with `$(,)?` to allow trailing comma) // but that doesn't work in 1.45 because `expr` fragments get wrapped in a way that changes // their hygiene, which is fixed in 1.46 so this is technically just a temp. workaround. @@ -326,6 +329,9 @@ macro_rules! query ( // not like it makes them magically understandable at-a-glance. ($query:expr, $($args:tt)*) => ({ $crate::sqlx_macros::expand_query!(source = $query, args = [$($args)*]) + }); + ($driver:ty, $query:expr, $($args:tt)*) => ({ + $crate::sqlx_macros::expand_query!(source = $query, args = [$($args)*], driver = $driver) }) );