Skip to content

Commit

Permalink
Made SPI query arguments type safe (#1858)
Browse files Browse the repository at this point in the history
  • Loading branch information
YohDeadfall authored Oct 28, 2024
1 parent 039c24f commit ae0335b
Show file tree
Hide file tree
Showing 30 changed files with 267 additions and 294 deletions.
2 changes: 1 addition & 1 deletion pgrx-examples/bgworker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub extern "C" fn background_worker_main(arg: pg_sys::Datum) {
let tuple_table = client.select(
"SELECT 'Hi', id, ''||a FROM (SELECT id, 42 from generate_series(1,10) id) a ",
None,
None,
&[],
)?;
for tuple in tuple_table {
let a = tuple.get_datum_by_ordinal(1)?.value::<String>()?;
Expand Down
2 changes: 1 addition & 1 deletion pgrx-examples/custom_sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mod tests {
let buf = Spi::connect(|client| {
Ok::<_, spi::Error>(
client
.select("SELECT * FROM extension_sql", None, None)?
.select("SELECT * FROM extension_sql", None, &[])?
.flat_map(|tup| {
tup.get_datum_by_ordinal(1)
.ok()
Expand Down
4 changes: 2 additions & 2 deletions pgrx-examples/schemas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ mod tests {
fn test_my_some_schema_type() -> Result<(), spi::Error> {
Spi::connect(|mut c| {
// "MySomeSchemaType" is in 'some_schema', so it needs to be discoverable
c.update("SET search_path TO some_schema,public", None, None)?;
c.update("SET search_path TO some_schema,public", None, &[])?;
assert_eq!(
String::from("test"),
c.select("SELECT '\"test\"'::MySomeSchemaType", None, None)?
c.select("SELECT '\"test\"'::MySomeSchemaType", None, &[])?
.first()
.get_one::<MySomeSchemaType>()
.expect("get_one::<MySomeSchemaType>() failed")
Expand Down
19 changes: 6 additions & 13 deletions pgrx-examples/spi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn spi_return_query() -> Result<

Spi::connect(|client| {
client
.select(query, None, None)?
.select(query, None, &[])?
.map(|row| Ok((row["oid"].value()?, row[2].value()?)))
.collect::<Result<Vec<_>, _>>()
})
Expand All @@ -62,21 +62,14 @@ fn spi_query_random_id() -> Result<Option<i64>, pgrx::spi::Error> {

#[pg_extern]
fn spi_query_title(title: &str) -> Result<Option<i64>, pgrx::spi::Error> {
Spi::get_one_with_args(
"SELECT id FROM spi.spi_example WHERE title = $1;",
vec![(PgBuiltInOids::TEXTOID.oid(), title.into_datum())],
)
Spi::get_one_with_args("SELECT id FROM spi.spi_example WHERE title = $1;", &[title.into()])
}

#[pg_extern]
fn spi_query_by_id(id: i64) -> Result<Option<String>, spi::Error> {
let (returned_id, title) = Spi::connect(|client| {
let tuptable = client
.select(
"SELECT id, title FROM spi.spi_example WHERE id = $1",
None,
Some(vec![(PgBuiltInOids::INT8OID.oid(), id.into_datum())]),
)?
.select("SELECT id, title FROM spi.spi_example WHERE id = $1", None, &[id.into()])?
.first();

tuptable.get_two::<i64, String>()
Expand All @@ -90,7 +83,7 @@ fn spi_query_by_id(id: i64) -> Result<Option<String>, spi::Error> {
fn spi_insert_title(title: &str) -> Result<Option<i64>, spi::Error> {
Spi::get_one_with_args(
"INSERT INTO spi.spi_example(title) VALUES ($1) RETURNING id",
vec![(PgBuiltInOids::TEXTOID.oid(), title.into_datum())],
&[title.into()],
)
}

Expand All @@ -100,7 +93,7 @@ fn spi_insert_title2(
) -> TableIterator<(name!(id, Option<i64>), name!(title, Option<String>))> {
let tuple = Spi::get_two_with_args(
"INSERT INTO spi.spi_example(title) VALUES ($1) RETURNING id, title",
vec![(PgBuiltInOids::TEXTOID.oid(), title.into_datum())],
&[title.into()],
)
.unwrap();

Expand All @@ -110,7 +103,7 @@ fn spi_insert_title2(
#[pg_extern]
fn issue1209_fixed() -> Result<Option<String>, Box<dyn std::error::Error>> {
let res = Spi::connect(|c| {
let mut cursor = c.open_cursor("SELECT 'hello' FROM generate_series(1, 10000)", None);
let mut cursor = c.try_open_cursor("SELECT 'hello' FROM generate_series(1, 10000)", &[])?;
let table = cursor.fetch(10000)?;
table.into_iter().map(|row| row.get::<&str>(1)).collect::<Result<Vec<_>, _>>()
})?;
Expand Down
6 changes: 3 additions & 3 deletions pgrx-examples/spi_srf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn calculate_human_years() -> Result<

Spi::connect(|client| {
let mut results = Vec::new();
let tup_table = client.select(query, None, None)?;
let tup_table = client.select(query, None, &[])?;

for row in tup_table {
let dog_name = row["dog_name"].value::<String>();
Expand Down Expand Up @@ -89,10 +89,10 @@ fn filter_by_breed(
*/

let query = "SELECT * FROM spi_srf.dog_daycare WHERE dog_breed = $1;";
let args = vec![(PgBuiltInOids::TEXTOID.oid(), breed.into_datum())];
let args = vec![breed.into()];

Spi::connect(|client| {
let tup_table = client.select(query, None, Some(args))?;
let tup_table = client.select(query, None, &args)?;

let filtered = tup_table
.map(|row| {
Expand Down
24 changes: 6 additions & 18 deletions pgrx-tests/src/tests/aggregate_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,9 @@ mod tests {
fn aggregate_first_json() -> Result<(), pgrx::spi::Error> {
let retval = Spi::get_one_with_args::<pgrx::Json>(
"SELECT FirstJson(value) FROM UNNEST(ARRAY [$1, $2]) as value;",
vec![
(
PgBuiltInOids::JSONOID.oid(),
pgrx::Json(serde_json::json!({ "foo": "one" })).into_datum(),
),
(
PgBuiltInOids::JSONOID.oid(),
pgrx::Json(serde_json::json!({ "foo": "two" })).into_datum(),
),
&[
pgrx::Json(serde_json::json!({ "foo": "one" })).into(),
pgrx::Json(serde_json::json!({ "foo": "two" })).into(),
],
)?
.map(|json| json.0);
Expand All @@ -285,15 +279,9 @@ mod tests {
fn aggregate_first_jsonb() -> Result<(), pgrx::spi::Error> {
let retval = Spi::get_one_with_args::<pgrx::JsonB>(
"SELECT FirstJsonB(value) FROM UNNEST(ARRAY [$1, $2]) as value;",
vec![
(
PgBuiltInOids::JSONBOID.oid(),
pgrx::JsonB(serde_json::json!({ "foo": "one" })).into_datum(),
),
(
PgBuiltInOids::JSONBOID.oid(),
pgrx::JsonB(serde_json::json!({ "foo": "two" })).into_datum(),
),
&[
pgrx::JsonB(serde_json::json!({ "foo": "one" })).into(),
pgrx::JsonB(serde_json::json!({ "foo": "two" })).into(),
],
)?
.map(|json| json.0);
Expand Down
9 changes: 4 additions & 5 deletions pgrx-tests/src/tests/anyelement_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ mod tests {
#[allow(unused_imports)]
use crate as pgrx_tests;

use pgrx::{prelude::*, AnyElement};
use pgrx::{datum::DatumWithOid, prelude::*, AnyElement};

#[pg_test]
fn test_anyelement_arg() -> Result<(), pgrx::spi::Error> {
let element = Spi::get_one_with_args::<AnyElement>(
"SELECT anyelement_arg($1);",
vec![(PgBuiltInOids::ANYELEMENTOID.oid(), 123.into_datum())],
)?
let element = Spi::get_one_with_args::<AnyElement>("SELECT anyelement_arg($1);", unsafe {
&[DatumWithOid::new(123, AnyElement::type_oid())]
})?
.map(|e| e.datum());

assert_eq!(element, 123.into_datum());
Expand Down
8 changes: 3 additions & 5 deletions pgrx-tests/src/tests/anynumeric_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ mod tests {

#[pg_test]
fn test_anynumeric_arg() -> Result<(), pgrx::spi::Error> {
let numeric = Spi::get_one_with_args::<AnyNumeric>(
"SELECT anynumeric_arg($1);",
vec![(PgBuiltInOids::INT4OID.oid(), 123.into_datum())],
)?
.map(|n| n.normalize().to_string());
let numeric =
Spi::get_one_with_args::<AnyNumeric>("SELECT anynumeric_arg($1);", &[123.into()])?
.map(|n| n.normalize().to_string());

assert_eq!(numeric, Some("123".to_string()));

Expand Down
7 changes: 2 additions & 5 deletions pgrx-tests/src/tests/array_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ mod tests {

use super::ArrayTestEnum;
use pgrx::prelude::*;
use pgrx::{IntoDatum, Json};
use pgrx::Json;
use serde_json::json;

#[pg_test]
Expand Down Expand Up @@ -300,10 +300,7 @@ mod tests {
.select(
"SELECT serde_serialize_array_i32($1)",
None,
Some(vec![(
PgBuiltInOids::INT4ARRAYOID.oid(),
owned_vec.as_slice().into_datum(),
)]),
&[owned_vec.as_slice().into()],
)?
.first()
.get_one::<Json>()
Expand Down
19 changes: 4 additions & 15 deletions pgrx-tests/src/tests/bgworker_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) {
Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);")?;
Spi::connect(|mut client| {
client
.update(
"INSERT INTO tests.bgworker_test VALUES ($1);",
None,
Some(vec![(PgOid::BuiltIn(PgBuiltInOids::INT4OID), arg.into_datum())]),
)
.update("INSERT INTO tests.bgworker_test VALUES ($1);", None, &[arg.into()])
.map(|_| ())
})
})
Expand Down Expand Up @@ -61,10 +57,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) {
let val = if arg > 0 {
BackgroundWorker::transaction(|| {
Spi::run("CREATE TABLE tests.bgworker_test_return (v INTEGER);")?;
Spi::get_one_with_args::<i32>(
"SELECT $1",
vec![(PgOid::BuiltIn(PgBuiltInOids::INT4OID), arg.into_datum())],
)
Spi::get_one_with_args::<i32>("SELECT $1", &[arg.into()])
})
.expect("bgworker transaction failed")
.unwrap()
Expand All @@ -74,12 +67,8 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) {
while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {}
BackgroundWorker::transaction(|| {
Spi::connect(|mut c| {
c.update(
"INSERT INTO tests.bgworker_test_return VALUES ($1)",
None,
Some(vec![(PgOid::BuiltIn(PgBuiltInOids::INT4OID), val.into_datum())]),
)
.map(|_| ())
c.update("INSERT INTO tests.bgworker_test_return VALUES ($1)", None, &[val.into()])
.map(|_| ())
})
})
.expect("bgworker transaction failed");
Expand Down
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/borrow_datum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ macro_rules! clonetrip_test {
let expected: $rtype = $expected;
let result: $rtype = Spi::get_one_with_args(
&format!("SELECT {}($1)", stringify!(tests.$fname)),
vec![(PgOid::from(<$rtype>::type_oid()), expected.into_datum())],
&[expected.into()],
)?
.unwrap();

Expand All @@ -44,7 +44,7 @@ macro_rules! clonetrip_test {
let value: $own_ty = $value;
let result: $own_ty = Spi::get_one_with_args(
&format!("SELECT {}($1)", stringify!(tests.$fname)),
vec![(PgOid::from(<$ref_ty>::type_oid()), value.into_datum())],
&[value.into()],
)?
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/guc_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ mod tests {
Spi::run("SET test.no_reset_all TO false;").expect("SPI failed");
assert_eq!(GUC_NO_RESET_ALL.get(), false);
Spi::connect(|mut client| {
let r = client.update("SHOW ALL", None, None).expect("SPI failed");
let r = client.update("SHOW ALL", None, &[]).expect("SPI failed");

let mut no_reset_guc_in_show_all = false;
for row in r {
Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/heap_tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ mod tests {
fn test_tuple_desc_clone() -> Result<(), spi::Error> {
let result = Spi::connect(|client| {
let query = "select * from generate_lots_of_dogs()";
client.select(query, None, None).map(|table| table.len())
client.select(query, None, &[]).map(|table| table.len())
})?;
assert_eq!(result, 10_000);
Ok(())
Expand Down
10 changes: 2 additions & 8 deletions pgrx-tests/src/tests/json_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ mod tests {
fn test_json_arg() -> Result<(), pgrx::spi::Error> {
let json = Spi::get_one_with_args::<Json>(
"SELECT json_arg($1);",
vec![(
PgBuiltInOids::JSONOID.oid(),
Json(serde_json::json!({ "foo": "bar" })).into_datum(),
)],
&[Json(serde_json::json!({ "foo": "bar" })).into()],
)?
.expect("json was null");

Expand All @@ -95,10 +92,7 @@ mod tests {
fn test_jsonb_arg() -> Result<(), pgrx::spi::Error> {
let json = Spi::get_one_with_args::<JsonB>(
"SELECT jsonb_arg($1);",
vec![(
PgBuiltInOids::JSONBOID.oid(),
JsonB(serde_json::json!({ "foo": "bar" })).into_datum(),
)],
&[JsonB(serde_json::json!({ "foo": "bar" })).into()],
)?
.expect("json was null");

Expand Down
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/pg_cast_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ mod tests {
#[pg_test]
fn test_pg_cast_assignment_type_cast() {
let _ = Spi::connect(|mut client| {
client.update("CREATE TABLE test_table(value int4);", None, None)?;
client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, None)?;
client.update("CREATE TABLE test_table(value int4);", None, &[])?;
client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, &[])?;

Ok::<_, spi::Error>(())
});
Expand Down
3 changes: 1 addition & 2 deletions pgrx-tests/src/tests/proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ pub fn [<$datetime_ty:lower _spi_roundtrip>] () {
proptest
.run(&strat, |datetime| {
let query = concat!("SELECT ", stringify!($nop_fn), "($1)");
let builtin_oid = PgOid::BuiltIn(pg_sys::BuiltinOid::from_u32(<$datetime_ty as IntoDatum>::type_oid().as_u32()).unwrap());
let args = vec![(builtin_oid, datetime.into_datum())];
let args = &[datetime.into()];
let spi_ret: $datetime_ty = Spi::get_one_with_args(query, args).unwrap().unwrap();
// 5. A condition on which the test is accepted or rejected:
// this is easily done via `prop_assert!` and its friends,
Expand Down
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/roundtrip_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ mod tests {
let expected: $rtype = Clone::clone(&value);
let result: $rtype = Spi::get_one_with_args(
&format!("SELECT {}($1)", stringify!(tests.$fname)),
vec![(PgOid::from(<$rtype>::type_oid()), value.into_datum())],
&[value.into()],
)?
.unwrap();

Expand All @@ -102,7 +102,7 @@ mod tests {
let output: $otype = $output;
let result: $otype = Spi::get_one_with_args(
&format!("SELECT {}($1)", stringify!(tests.$fname)),
vec![(PgOid::from(<$itype>::type_oid()), input.into_datum())],
&[input.into()],
)?
.unwrap();

Expand Down
Loading

0 comments on commit ae0335b

Please sign in to comment.