Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make sqlx work with multiple different databases in the same crate #3397

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
752 changes: 752 additions & 0 deletions sqlx-macros-core/src/database_macro/expand.rs

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions sqlx-macros-core/src/database_macro/input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use syn::{
parse::{Parse, ParseStream},
Ident, LitStr,
};

/// Macro input `query!()` and `query_file!()`
pub struct DatabaseMacroInput {
pub(super) env: String,
}

impl Parse for DatabaseMacroInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut env = None;

let mut expect_comma = false;
while !input.is_empty() {
if expect_comma {
let _ = input.parse::<syn::token::Comma>()?;
}

let key: Ident = input.parse()?;

let _ = input.parse::<syn::token::Eq>()?;

if key == "env" {
env = Some(input.parse::<LitStr>()?.value());
} else {
let message = format!("unexpected input key: {key}");
return Err(syn::Error::new_spanned(key, message));
}

expect_comma = true;
}

let env = env.ok_or_else(|| input.error("expected `env` key"))?;

Ok(DatabaseMacroInput { env })
}
}
5 changes: 5 additions & 0 deletions sqlx-macros-core/src/database_macro/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod expand;
mod input;

pub use expand::expand_database_macros;
pub use input::DatabaseMacroInput;
2 changes: 2 additions & 0 deletions sqlx-macros-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub type Result<T> = std::result::Result<T, Error>;
mod common;
mod database;

#[cfg(feature = "macros")]
pub mod database_macro;
#[cfg(feature = "derive")]
pub mod derives;
#[cfg(feature = "macros")]
Expand Down
7 changes: 7 additions & 0 deletions sqlx-macros-core/src/query/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub struct QueryMacroInput {
pub(super) checked: bool,

pub(super) file_path: Option<String>,

pub(super) db_url_env: Option<String>
}

enum QuerySrc {
Expand All @@ -38,6 +40,7 @@ impl Parse for QueryMacroInput {
let mut args: Option<Vec<Expr>> = None;
let mut record_type = RecordType::Generated;
let mut checked = true;
let mut db_url_env = None;

let mut expect_comma = false;

Expand Down Expand Up @@ -82,6 +85,9 @@ impl Parse for QueryMacroInput {
} else if key == "checked" {
let lit_bool = input.parse::<LitBool>()?;
checked = lit_bool.value;
} else if key == "db_url_env" {
let lit_str = input.parse::<LitStr>()?;
db_url_env = Some(lit_str.value());
} else {
let message = format!("unexpected input key: {key}");
return Err(syn::Error::new_spanned(key, message));
Expand All @@ -104,6 +110,7 @@ impl Parse for QueryMacroInput {
arg_exprs,
checked,
file_path,
db_url_env
})
}
}
Expand Down
122 changes: 85 additions & 37 deletions sqlx-macros-core/src/query/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::{fs, io};
Expand Down Expand Up @@ -72,8 +73,9 @@ struct Metadata {
#[allow(unused)]
manifest_dir: PathBuf,
offline: bool,
database_url: Option<String>,
default_database_url: Option<String>,
workspace_root: Arc<Mutex<Option<PathBuf>>>,
env_cache: HashMap<String, String>,
}

impl Metadata {
Expand Down Expand Up @@ -139,54 +141,80 @@ static METADATA: Lazy<Metadata> = Lazy::new(|| {
.map(|s| s.eq_ignore_ascii_case("true") || s == "1")
.unwrap_or(false);

let database_url = env("DATABASE_URL").ok();
let env_cache = HashMap::from_iter(dotenvy::vars());

let default_database_url = env("DATABASE_URL").ok();

Metadata {
manifest_dir,
offline,
database_url,
default_database_url,
workspace_root: Arc::new(Mutex::new(None)),
env_cache,
}
});

pub fn expand_input<'a>(
input: QueryMacroInput,
drivers: impl IntoIterator<Item = &'a QueryDriver>,
) -> crate::Result<TokenStream> {
let data_source = match &*METADATA {
Metadata {
offline: false,
database_url: Some(db_url),
..
} => QueryDataSource::live(db_url)?,

Metadata { offline, .. } => {
// Try load the cached query metadata file.
let filename = format!("query-{}.json", hash_string(&input.sql));

// Check SQLX_OFFLINE_DIR, then local .sqlx, then workspace .sqlx.
let dirs = [
|| env("SQLX_OFFLINE_DIR").ok().map(PathBuf::from),
|| Some(METADATA.manifest_dir.join(".sqlx")),
|| Some(METADATA.workspace_root().join(".sqlx")),
];
let Some(data_file_path) = dirs
.iter()
.filter_map(|path| path())
.map(|path| path.join(&filename))
.find(|path| path.exists())
else {
return Err(
if *offline {
"`SQLX_OFFLINE=true` but there is no cached data for this query, run `cargo sqlx prepare` to update the query cache or unset `SQLX_OFFLINE`"
// If we don't require the query to be offline, check if we have a valid online datasource url
let online_data_source: Option<QueryDataSource> = if METADATA.offline == false {
if let Some(ref custom_env) = input.db_url_env {
// Get the custom db url environment
METADATA
.env_cache
.get(custom_env)
.map(|custom_db_url| QueryDataSource::live(custom_db_url))
.transpose()?
} else if let Some(default_database_url) = &METADATA.default_database_url {
// Get the default db url env
Some(QueryDataSource::live(default_database_url)?)
} else {
None
}
} else {
None
};

let data_source = if let Some(data_source) = online_data_source {
data_source
} else {
// If we don't have a live source, try load the cached query metadata file.
let filename = format!("query-{}.json", hash_string(&input.sql));

// Check SQLX_OFFLINE_DIR, then local .sqlx, then workspace .sqlx.
let dirs = [
|| env("SQLX_OFFLINE_DIR").ok().map(PathBuf::from),
|| Some(METADATA.manifest_dir.join(".sqlx")),
|| Some(METADATA.workspace_root().join(".sqlx")),
];
let Some(data_file_path) = dirs
.iter()
.filter_map(|path| path())
.map(|path| {
if let Some(ref custom_env) = input.db_url_env {
path.join(custom_env).join(&filename)
} else {
path.join(&filename)
}
})
.find(|path| path.exists())
else {
return Err(
if METADATA.offline {
"`SQLX_OFFLINE=true` but there is no cached data for this query, run `cargo sqlx prepare` to update the query cache or unset `SQLX_OFFLINE`".to_string()
} else {
if let Some(custom_env) = input.db_url_env {
format!("set custom env `{:?}` to use query macros online, or run `cargo sqlx prepare` to update the query cache", custom_env)
} else {
"set `DATABASE_URL` to use query macros online, or run `cargo sqlx prepare` to update the query cache"
}.into()
);
};
"set `DATABASE_URL` to use query macros online, or run `cargo sqlx prepare` to update the query cache".to_string()
}
}.into()
);
};

QueryDataSource::Cached(DynQueryData::from_data_file(&data_file_path, &input.sql)?)
}
QueryDataSource::Cached(DynQueryData::from_data_file(&data_file_path, &input.sql)?)
};

for driver in drivers {
Expand Down Expand Up @@ -364,8 +392,28 @@ where
.into());
}

// .sqlx exists and is a directory, store data.
data.save_in(path)?;
if let Some(custom_db_env) = input.db_url_env {
let full_path: PathBuf = path.join(custom_db_env);

match fs::create_dir(&full_path) {
Ok(_) => {}
Err(err) => {
match err.kind() {
std::io::ErrorKind::AlreadyExists => {}
_ => return Err(format!(
"Failed to create offline cache path {full_path:?}: {err}"
)
.into()),
}
}
};

// created subfolder if not exists, store data.
data.save_in(full_path)?;
} else {
// .sqlx exists and is a directory, store data.
data.save_in(path)?;
}
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions sqlx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ pub fn expand_query(input: TokenStream) -> TokenStream {
}
}

#[cfg(feature = "macros")]
#[proc_macro]
pub fn database_macros(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as database_macro::DatabaseMacroInput);
database_macro::expand_database_macros(input).into()
}

#[cfg(feature = "derive")]
#[proc_macro_derive(Encode, attributes(sqlx))]
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
Expand Down