Skip to content

Commit

Permalink
add helper function for sql bindings, get_mixed_query_params()
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcmicu committed Sep 16, 2024
1 parent 35f2464 commit 41c9ede
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 51 deletions.
46 changes: 46 additions & 0 deletions src/toolkit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ pub enum QueryAsIfKind {
Replace,
}

/// Used to represent a generic query parameter for binding to a SQLx query.
pub enum QueryParam {
Numeric(f64),
Real(f64),
Integer(i32),
String(String),
}

/// Given a string representing the location of a database, return a database connection pool.
pub async fn get_pool_from_connection_string(database: &str) -> Result<AnyPool> {
let connection_options;
Expand Down Expand Up @@ -3743,6 +3751,44 @@ pub fn compile_condition(
}
}

/// Given a list of [SerdeValue]s and the SQL type of the column that they come from, return
/// a SQL string consisting of a comma-separated list of [SQL_PARAM] placeholders to use for the
/// binding, and the list of parameters that will need to be bound to the string before executing.
pub fn get_mixed_query_params(
values: &Vec<SerdeValue>,
sql_type: &str,
) -> (String, Vec<QueryParam>) {
let mut param_values = vec![];
let mut param_placeholders = vec![];

for value in values {
param_placeholders.push(SQL_PARAM);
let param_value = value
.as_str()
.expect(&format!("'{}' is not a string", value));
if sql_type == "numeric" {
let numeric_value: f64 = param_value
.parse()
.expect(&format!("{param_value} is not numeric"));
param_values.push(QueryParam::Numeric(numeric_value));
} else if sql_type == "integer" {
let integer_value: i32 = param_value
.parse()
.expect(&format!("{param_value} is not an integer"));
param_values.push(QueryParam::Integer(integer_value));
} else if sql_type == "real" {
let real_value: f64 = param_value
.parse()
.expect(&format!("{param_value} is not a real"));
param_values.push(QueryParam::Real(real_value));
} else {
param_values.push(QueryParam::String(param_value.to_string()));
}
}

(param_placeholders.join(", "), param_values)
}

/// Given the config map, the name of a datatype, and a database connection pool used to determine
/// the database type, climb the datatype tree (as required), and return the first 'SQL type' found.
/// If there is no SQL type defined for the given datatype, return TEXT.
Expand Down
117 changes: 73 additions & 44 deletions src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use crate::{
ast::Expression,
toolkit::{
cast_sql_param_from_text, get_column_value, get_column_value_as_string,
get_datatype_ancestors, get_sql_type_from_global_config, get_table_options_from_config,
get_value_type, is_sql_type_error, local_sql_syntax, ColumnRule, CompiledCondition,
QueryAsIf, QueryAsIfKind, ValueType,
get_datatype_ancestors, get_mixed_query_params, get_sql_type_from_global_config,
get_table_options_from_config, get_value_type, is_sql_type_error, local_sql_syntax,
ColumnRule, CompiledCondition, QueryAsIf, QueryAsIfKind, QueryParam, ValueType,
},
valve::{
ValveCell, ValveCellMessage, ValveConfig, ValveRow, ValveRuleConfig, ValveTreeConstraint,
Expand Down Expand Up @@ -309,34 +309,42 @@ pub async fn validate_rows_constraints(
let sql_type =
get_sql_type_from_global_config(config, &fkey.ftable, &fkey.fcolumn, pool)
.to_lowercase();

// TODO: Here.
let values_str = received_values
let values = received_values
.get(&*fkey.column)
.unwrap()
.iter()
.map(|value| {
let value = value
.as_str()
.expect(&format!("'{}' is not a string", value));
if vec!["integer", "numeric", "real"].contains(&sql_type.as_str()) {
format!("{}", value)
} else {
format!("'{}'", value.replace("'", "''"))
}
.filter(|value| {
!is_sql_type_error(
&sql_type,
value
.as_str()
.expect(&format!("'{}' is not a string", value)),
)
})
.filter(|value| !is_sql_type_error(&sql_type, value))
.collect::<Vec<_>>()
.join(", ");
.cloned()
.collect::<Vec<_>>();
let (lookup_sql, param_values) = get_mixed_query_params(&values, &sql_type);

// Foreign keys always correspond to columns with unique constraints so we do not
// need to use the keyword 'DISTINCT' when querying the normal version of the table:
let sql = format!(
r#"SELECT "{}" FROM "{}" WHERE "{}" IN ({})"#,
fkey.fcolumn, fkey.ftable, fkey.fcolumn, values_str
let sql = local_sql_syntax(
pool,
&format!(
r#"SELECT "{}" FROM "{}" WHERE "{}" IN ({})"#,
fkey.fcolumn, fkey.ftable, fkey.fcolumn, lookup_sql
),
);
let mut query = sqlx_query(&sql);
for param_value in &param_values {
match param_value {
QueryParam::Integer(p) => query = query.bind(p),
QueryParam::Numeric(p) => query = query.bind(p),
QueryParam::Real(p) => query = query.bind(p),
QueryParam::String(p) => query = query.bind(p),
}
}

let allowed_values = sqlx_query(&sql)
let allowed_values = query
.fetch_all(pool)
.await?
.iter()
Expand All @@ -353,11 +361,23 @@ pub async fn validate_rows_constraints(
// The conflict table has no keys other than on row_number so in principle
// it could have duplicate values of the foreign constraint, therefore we
// add the DISTINCT keyword here:
let sql = format!(
r#"SELECT DISTINCT "{}" FROM "{}_conflict" WHERE "{}" IN ({})"#,
fkey.fcolumn, fkey.ftable, fkey.fcolumn, values_str
let sql = local_sql_syntax(
pool,
&format!(
r#"SELECT DISTINCT "{}" FROM "{}_conflict" WHERE "{}" IN ({})"#,
fkey.fcolumn, fkey.ftable, fkey.fcolumn, lookup_sql
),
);
sqlx_query(&sql)
let mut query = sqlx_query(&sql);
for param_value in &param_values {
match param_value {
QueryParam::Integer(p) => query = query.bind(p),
QueryParam::Numeric(p) => query = query.bind(p),
QueryParam::Real(p) => query = query.bind(p),
QueryParam::String(p) => query = query.bind(p),
}
}
query
.fetch_all(pool)
.await?
.iter()
Expand Down Expand Up @@ -411,33 +431,42 @@ pub async fn validate_rows_constraints(
}
};

// TODO: Here.
let sql_type =
get_sql_type_from_global_config(config, &table, &column, pool).to_lowercase();
let values_str = received_values
let values = received_values
.get(&*column)
.unwrap()
.iter()
.map(|value| {
let value = value
.as_str()
.expect(&format!("'{}' is not a string", value));
if vec!["integer", "numeric", "real"].contains(&sql_type.as_str()) {
format!("{}", value)
} else {
format!("'{}'", value.replace("'", "''"))
}
.filter(|value| {
!is_sql_type_error(
&sql_type,
value
.as_str()
.expect(&format!("'{}' is not a string", value)),
)
})
.filter(|value| !is_sql_type_error(&sql_type, value))
.collect::<Vec<_>>()
.join(", ");
.cloned()
.collect::<Vec<_>>();
let (lookup_sql, param_values) = get_mixed_query_params(&values, &sql_type);

let sql = format!(
r#"SELECT {} "{}" FROM "{}" WHERE "{}" IN ({})"#,
query_modifier, column, query_table, column, values_str
let sql = local_sql_syntax(
pool,
&format!(
r#"SELECT {} "{}" FROM "{}" WHERE "{}" IN ({})"#,
query_modifier, column, query_table, column, lookup_sql
),
);
let mut query = sqlx_query(&sql);
for param_value in &param_values {
match param_value {
QueryParam::Integer(p) => query = query.bind(p),
QueryParam::Numeric(p) => query = query.bind(p),
QueryParam::Real(p) => query = query.bind(p),
QueryParam::String(p) => query = query.bind(p),
}
}

let forbidden_values = sqlx_query(&sql)
let forbidden_values = query
.fetch_all(pool)
.await?
.iter()
Expand Down
30 changes: 23 additions & 7 deletions src/valve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,6 @@ impl Valve {
}
}
} else {
println!("FAVVOOOM");
let sql = format!(
r#"SELECT
ccu.table_name AS foreign_table_name,
Expand Down Expand Up @@ -1961,21 +1960,38 @@ impl Valve {
// Collect the paths and possibly the options of all of the tables that were requested to be
// saved:
let options_enabled = self.column_enabled_in_db("table", "options").await?;
// TODO: Here.

// Build the query to get the path and options info from the table table:
let mut params = vec![];
let sql_param_str = tables
.iter()
.map(|table| {
params.push(table);
SQL_PARAM.to_string()
})
.collect::<Vec<_>>()
.join(", ");
let sql = {
if options_enabled {
format!(
r#"SELECT "table", "path", "options" FROM "table" WHERE "table" IN ('{}')"#,
tables.join("', '")
r#"SELECT "table", "path", "options" FROM "table" WHERE "table" IN ({})"#,
sql_param_str
)
} else {
format!(
r#"SELECT "table", "path" FROM "table" WHERE "table" IN ('{}')"#,
tables.join("', '")
r#"SELECT "table", "path" FROM "table" WHERE "table" IN ({})"#,
sql_param_str
)
}
};
let mut stream = sqlx_query(&sql).fetch(&self.pool);
let sql = local_sql_syntax(&self.pool, &sql);
let mut query = sqlx_query(&sql);
for param in &params {
query = query.bind(param);
}

// Query the db:
let mut stream = query.fetch(&self.pool);
while let Some(row) = stream.try_next().await? {
let table = row
.try_get::<&str, &str>("table")
Expand Down

0 comments on commit 41c9ede

Please sign in to comment.