diff --git a/pgrx-macros/src/lib.rs b/pgrx-macros/src/lib.rs index ef8e6ff8e..5da455013 100644 --- a/pgrx-macros/src/lib.rs +++ b/pgrx-macros/src/lib.rs @@ -21,7 +21,7 @@ use operators::{deriving_postgres_eq, deriving_postgres_hash, deriving_postgres_ use pgrx_sql_entity_graph as sql_gen; use sql_gen::{ parse_extern_attributes, CodeEnrichment, ExtensionSql, ExtensionSqlFile, ExternArgs, - PgAggregate, PgExtern, PostgresEnum, Schema, + PgAggregate, PgCast, PgExtern, PostgresEnum, Schema, }; mod operators; @@ -148,6 +148,66 @@ pub fn initialize(_attr: TokenStream, item: TokenStream) -> TokenStream { item } +/** +Declare a function as `#[pg_cast]` to indicate that it represents a Postgres [cast](https://www.postgresql.org/docs/current/sql-createcast.html). + +* `assignment`: Corresponds to [`AS ASSIGNMENT`](https://www.postgresql.org/docs/current/sql-createcast.html). +* `implicit`: Corresponds to [`AS IMPLICIT`](https://www.postgresql.org/docs/current/sql-createcast.html). + +By default if no attribute is specified, the cast function can only be used in an explicit cast. + +Functions MUST accept and return exactly one value whose type MUST be a `pgrx` supported type. `pgrx` supports many PostgreSQL types by default. +New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`]. + +Example usage: +```rust,ignore +use pgrx::*; +#[pg_cast(implicit)] +fn cast_json_to_int(input: Json) -> i32 { todo!() } +*/ +#[proc_macro_attribute] +pub fn pg_cast(attr: TokenStream, item: TokenStream) -> TokenStream { + fn wrapped(attr: TokenStream, item: TokenStream) -> Result { + use syn::parse::Parser; + let mut cast = PgCast::Default; + match syn::punctuated::Punctuated::::parse_terminated.parse(attr) + { + Ok(paths) => { + if paths.len() > 1 { + panic!( + "pg_cast must take either 0 or 1 attribute. Found {}: {}", + paths.len(), + paths.to_token_stream() + ) + } else if paths.len() == 1 { + match paths.first().unwrap().segments.last().unwrap().ident.to_string().as_str() + { + "implicit" => cast = PgCast::Implicit, + "assignment" => cast = PgCast::Assignment, + other => panic!("Unrecognized pg_cast option: {}. ", other), + } + } + } + Err(err) => { + panic!("Failed to parse attribute to pg_cast: {}", err) + } + } + // `pg_cast` does not support other `pg_extern` attributes for now, pass an empty attribute token stream. + let pg_extern = PgExtern::new(TokenStream::new().into(), item.clone().into())?.0; + Ok(CodeEnrichment(pg_extern.as_cast(cast)).to_token_stream().into()) + } + + match wrapped(attr, item) { + Ok(tokens) => tokens, + Err(e) => { + let msg = e.to_string(); + TokenStream::from(quote! { + compile_error!(#msg); + }) + } + } +} + /// Declare a function as `#[pg_operator]` to indicate that it represents a Postgres operator /// `cargo pgrx schema` will automatically generate the underlying SQL #[proc_macro_attribute] diff --git a/pgrx-sql-entity-graph/src/lib.rs b/pgrx-sql-entity-graph/src/lib.rs index 964a17a55..0b774ac86 100644 --- a/pgrx-sql-entity-graph/src/lib.rs +++ b/pgrx-sql-entity-graph/src/lib.rs @@ -27,10 +27,10 @@ pub use extension_sql::{ExtensionSql, ExtensionSqlFile, SqlDeclared}; pub use extern_args::{parse_extern_attributes, ExternArgs}; pub use mapping::RustSqlMapping; pub use pg_extern::entity::{ - PgExternArgumentEntity, PgExternEntity, PgExternReturnEntity, PgExternReturnEntityIteratedItem, - PgOperatorEntity, + PgCastEntity, PgExternArgumentEntity, PgExternEntity, PgExternReturnEntity, + PgExternReturnEntityIteratedItem, PgOperatorEntity, }; -pub use pg_extern::{NameMacro, PgExtern, PgExternArgument, PgOperator}; +pub use pg_extern::{NameMacro, PgCast, PgExtern, PgExternArgument, PgOperator}; pub use pg_trigger::attribute::PgTriggerAttribute; pub use pg_trigger::entity::PgTriggerEntity; pub use pg_trigger::PgTrigger; diff --git a/pgrx-sql-entity-graph/src/pg_extern/cast.rs b/pgrx-sql-entity-graph/src/pg_extern/cast.rs new file mode 100644 index 000000000..a8cd57e53 --- /dev/null +++ b/pgrx-sql-entity-graph/src/pg_extern/cast.rs @@ -0,0 +1,46 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +/*! + +`#[pg_cast]` related macro expansion for Rust to SQL translation + +> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal** +to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution. + +*/ +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, ToTokens, TokenStreamExt}; + +/// A parsed `#[pg_cast]` operator. +/// +/// It is created during [`PgExtern`](crate::PgExtern) parsing. +#[derive(Debug, Clone)] +pub enum PgCast { + Default, + Assignment, + Implicit, +} + +impl ToTokens for PgCast { + fn to_tokens(&self, tokens: &mut TokenStream2) { + let quoted = match self { + PgCast::Default => quote! { + ::pgrx::pgrx_sql_entity_graph::PgCastEntity::Default + }, + PgCast::Assignment => quote! { + ::pgrx::pgrx_sql_entity_graph::PgCastEntity::Assignment + }, + PgCast::Implicit => quote! { + ::pgrx::pgrx_sql_entity_graph::PgCastEntity::Implicit + }, + }; + tokens.append_all(quoted); + } +} diff --git a/pgrx-sql-entity-graph/src/pg_extern/entity/cast.rs b/pgrx-sql-entity-graph/src/pg_extern/entity/cast.rs new file mode 100644 index 000000000..cae7ae02e --- /dev/null +++ b/pgrx-sql-entity-graph/src/pg_extern/entity/cast.rs @@ -0,0 +1,25 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +/*! + +`#[pg_extern]` related cast entities for Rust to SQL translation + +> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal** +to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution. + +*/ + +/// The output of a [`PgCast`](crate::PgCast) from `quote::ToTokens::to_tokens`. +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum PgCastEntity { + Default, + Assignment, + Implicit, +} diff --git a/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs b/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs index 28c72f80e..8376827f3 100644 --- a/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs +++ b/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs @@ -16,10 +16,12 @@ to the `pgrx` framework and very subject to change between versions. While you m */ mod argument; +mod cast; mod operator; mod returning; pub use argument::PgExternArgumentEntity; +pub use cast::PgCastEntity; pub use operator::PgOperatorEntity; pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem}; @@ -49,6 +51,7 @@ pub struct PgExternEntity { pub extern_attrs: Vec, pub search_path: Option>, pub operator: Option, + pub cast: Option, pub to_sql_config: ToSqlConfigEntity, } @@ -338,7 +341,7 @@ impl ToSql for PgExternEntity { } }; - let ext_sql = format!( + let mut ext_sql = format!( "\n\ -- {file}:{line}\n\ -- {module_path}::{name}\n\ @@ -346,7 +349,7 @@ impl ToSql for PgExternEntity { {fn_sql}" ); - let rendered = if let Some(op) = &self.operator { + if let Some(op) = &self.operator { let mut optionals = vec![]; if let Some(it) = op.commutator { optionals.push(format!("\tCOMMUTATOR = {}", it)); @@ -452,7 +455,6 @@ impl ToSql for PgExternEntity { .map(|schema| format!("{}.", schema)) .unwrap_or_else(|| context.schema_prefix_for(&self_index)); - eprintln!("schema={schema}"); let operator_sql = format!("\n\n\ -- {file}:{line}\n\ -- {module_path}::{name}\n\ @@ -471,10 +473,124 @@ impl ToSql for PgExternEntity { maybe_comma = if !optionals.is_empty() { "," } else { "" }, optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() }, ); - ext_sql + &operator_sql - } else { - ext_sql + ext_sql += &operator_sql + }; + if let Some(cast) = &self.cast { + let target_arg = &self.metadata.retval; + let target_fn_arg = &self.fn_return; + let target_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match (&context.graph[*neighbor], target_fn_arg) { + (SqlGraphEntity::Type(ty), PgExternReturnEntity::Type { ty: rty }) => { + ty.id_matches(&rty.ty_id) + } + (SqlGraphEntity::Enum(en), PgExternReturnEntity::Type { ty: rty }) => { + en.id_matches(&rty.ty_id) + } + (SqlGraphEntity::BuiltinType(defined), _) => defined == target_arg.type_name, + _ => false, + }) + .ok_or_else(|| { + eyre!("Could not find source type in graph. Got: {:?}", target_arg) + })?; + let target_arg_sql = match target_arg.argument_sql { + Ok(SqlMapping::As(ref sql)) => sql.clone(), + Ok(SqlMapping::Composite { array_brackets }) => { + if array_brackets { + let composite_type = self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?; + format!("{composite_type}[]") + } else { + self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string() + } + } + Ok(SqlMapping::Skip) => { + return Err(eyre!("Found an skipped SQL type in a cast, this is not valid")) + } + Err(err) => return Err(err.into()), + }; + if self.metadata.arguments.len() != 1 { + return Err(eyre!( + "PG cast function ({}) must have exactly one argument, got {}", + self.name, + self.metadata.arguments.len() + )); + } + if self.fn_args.len() != 1 { + return Err(eyre!( + "PG cast function ({}) must have exactly one argument, got {}", + self.name, + self.fn_args.len() + )); + } + let source_arg = self + .metadata + .arguments + .first() + .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?; + let source_fn_arg = self + .fn_args + .first() + .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?; + let source_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match &context.graph[*neighbor] { + SqlGraphEntity::Type(ty) => ty.id_matches(&source_fn_arg.used_ty.ty_id), + SqlGraphEntity::Enum(en) => en.id_matches(&source_fn_arg.used_ty.ty_id), + SqlGraphEntity::BuiltinType(defined) => defined == source_arg.type_name, + _ => false, + }) + .ok_or_else(|| { + eyre!("Could not find source type in graph. Got: {:?}", source_arg) + })?; + let source_arg_sql = match source_arg.argument_sql { + Ok(SqlMapping::As(ref sql)) => sql.clone(), + Ok(SqlMapping::Composite { array_brackets }) => { + if array_brackets { + let composite_type = self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?; + format!("{composite_type}[]") + } else { + self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string() + } + } + Ok(SqlMapping::Skip) => { + return Err(eyre!("Found an skipped SQL type in a cast, this is not valid")) + } + Err(err) => return Err(err.into()), + }; + let optional = match cast { + PgCastEntity::Default => String::from(""), + PgCastEntity::Assignment => String::from(" AS ASSIGNMENT"), + PgCastEntity::Implicit => String::from(" AS IMPLICIT"), + }; + + let cast_sql = format!("\n\n\ + -- {file}:{line}\n\ + -- {module_path}::{name}\n\ + CREATE CAST (\n\ + \t{schema_prefix_source}{source_arg_sql} /* {source_name} */\n\ + \tAS\n\ + \t{schema_prefix_target}{target_arg_sql} /* {target_name} */\n\ + )\n\ + WITH FUNCTION {function_name}{optional};\ + ", + file = self.file, + line = self.line, + name = self.name, + module_path = self.module_path, + schema_prefix_source = context.schema_prefix_for(&source_arg_graph_index), + source_name = source_arg.type_name, + schema_prefix_target = context.schema_prefix_for(&target_arg_graph_index), + target_name = target_arg.type_name, + function_name = self.name, + ); + ext_sql += &cast_sql }; - Ok(rendered) + Ok(ext_sql) } } diff --git a/pgrx-sql-entity-graph/src/pg_extern/mod.rs b/pgrx-sql-entity-graph/src/pg_extern/mod.rs index f0bc02275..d62919e78 100644 --- a/pgrx-sql-entity-graph/src/pg_extern/mod.rs +++ b/pgrx-sql-entity-graph/src/pg_extern/mod.rs @@ -17,12 +17,14 @@ to the `pgrx` framework and very subject to change between versions. While you m */ mod argument; mod attribute; +mod cast; pub mod entity; mod operator; mod returning; mod search_path; pub use argument::PgExternArgument; +pub use cast::PgCast; pub use operator::PgOperator; pub use returning::NameMacro; @@ -74,6 +76,7 @@ pub struct PgExtern { func: syn::ItemFn, to_sql_config: ToSqlConfig, operator: Option, + cast: Option, search_path: Option, inputs: Vec, input_types: Vec, @@ -118,6 +121,7 @@ impl PgExtern { func, to_sql_config, operator, + cast: None, search_path, inputs, input_types, @@ -125,6 +129,13 @@ impl PgExtern { })) } + /// Returns a new instance of this `PgExtern` with `cast` overwritten to `pg_cast`. + pub fn as_cast(&self, pg_cast: PgCast) -> PgExtern { + let mut result = self.clone(); + result.cast = Some(pg_cast); + result + } + fn input_types(func: &syn::ItemFn) -> syn::Result> { func.sig .inputs @@ -282,6 +293,7 @@ impl PgExtern { }; let operator = self.operator.clone().into_iter(); + let cast = self.cast.clone().into_iter(); let to_sql_config = match self.overridden() { None => self.to_sql_config.clone(), Some(content) => ToSqlConfig { content: Some(content), ..self.to_sql_config.clone() }, @@ -316,6 +328,7 @@ impl PgExtern { search_path: None #( .unwrap_or_else(|| Some(vec![#search_path])) )*, #[allow(clippy::or_fun_call)] operator: None #( .unwrap_or_else(|| Some(#operator)) )*, + cast: None #( .unwrap_or_else(|| Some(#cast)) )*, to_sql_config: #to_sql_config, }; ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::Function(submission) diff --git a/pgrx-tests/src/tests/mod.rs b/pgrx-tests/src/tests/mod.rs index 1621568b3..1d372b616 100644 --- a/pgrx-tests/src/tests/mod.rs +++ b/pgrx-tests/src/tests/mod.rs @@ -37,6 +37,7 @@ mod log_tests; mod memcxt_tests; mod name_tests; mod numeric_tests; +mod pg_cast_tests; mod pg_extern_tests; mod pg_guard_tests; mod pg_operator_tests; diff --git a/pgrx-tests/src/tests/pg_cast_tests.rs b/pgrx-tests/src/tests/pg_cast_tests.rs new file mode 100644 index 000000000..12178050c --- /dev/null +++ b/pgrx-tests/src/tests/pg_cast_tests.rs @@ -0,0 +1,63 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +use pgrx::prelude::*; + +#[pg_schema] +mod pg_catalog { + use pgrx::pg_cast; + use serde_json::Value::Number; + + #[pg_cast(implicit)] + fn int4_from_json(value: pgrx::Json) -> i32 { + if let Number(num) = &value.0 { + if num.is_i64() { + return num.as_i64().unwrap() as i32; + } else if num.is_f64() { + return num.as_f64().unwrap() as i32; + } else if num.is_u64() { + return num.as_u64().unwrap() as i32; + } + }; + panic!("Error casting json value {} to an integer", value.0) + } +} + +#[cfg(any(test, feature = "pg_test"))] +#[pg_schema] +mod tests { + #[allow(unused_imports)] + use crate as pgrx_tests; + use pgrx::prelude::*; + + #[pg_test] + fn test_pg_cast_explicit_type_cast() { + assert_eq!( + Spi::get_one::("SELECT CAST('{\"a\": 1}'::json->'a' AS int4);"), + Ok(Some(1)) + ); + assert_eq!(Spi::get_one::("SELECT ('{\"a\": 1}'::json->'a')::int4;"), Ok(Some(1))); + } + + #[pg_test] + fn test_pg_cast_assignment_type_cast() { + let _ = Spi::connect(|mut client| { + client.update("CREATE TABLE test_table(value int4);", None, None)?; + client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, None)?; + + Ok::<_, spi::Error>(()) + }); + assert_eq!(Spi::get_one::("SELECT value FROM test_table"), Ok(Some(1))); + } + + #[pg_test] + fn test_pg_cast_implicit_type_cast() { + assert_eq!(Spi::get_one::("SELECT 1 + ('{\"a\": 1}'::json->'a')"), Ok(Some(2))); + } +} diff --git a/pgrx-tests/tests/ui/invalid_pgcast_options.rs b/pgrx-tests/tests/ui/invalid_pgcast_options.rs new file mode 100644 index 000000000..25d9918cd --- /dev/null +++ b/pgrx-tests/tests/ui/invalid_pgcast_options.rs @@ -0,0 +1,8 @@ +use pgrx::prelude::*; + +#[pg_cast(invalid_opt)] +pub fn cast_function(foo: i32) -> i32 { + foo +} + +fn main() {} \ No newline at end of file diff --git a/pgrx-tests/tests/ui/invalid_pgcast_options.stderr b/pgrx-tests/tests/ui/invalid_pgcast_options.stderr new file mode 100644 index 000000000..f082a3488 --- /dev/null +++ b/pgrx-tests/tests/ui/invalid_pgcast_options.stderr @@ -0,0 +1,7 @@ +error: custom attribute panicked + --> tests/ui/invalid_pgcast_options.rs:3:1 + | +3 | #[pg_cast(invalid_opt)] + | ^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: Unrecognized pg_cast option: invalid_opt.