diff --git a/pgrx-macros/src/lib.rs b/pgrx-macros/src/lib.rs index ef8e6ff8e..6dfa4b315 100644 --- a/pgrx-macros/src/lib.rs +++ b/pgrx-macros/src/lib.rs @@ -148,6 +148,13 @@ pub fn initialize(_attr: TokenStream, item: TokenStream) -> TokenStream { item } +/// Declare a function as `#[pg_cast]` to indicate that it represents a Postgres cast +/// `cargo pgrx schema` will automatically generate the underlying SQL +#[proc_macro_attribute] +pub fn pg_cast(_attr: TokenStream, item: TokenStream) -> TokenStream { + item +} + /// Declare a function as `#[pg_operator]` to indicate that it represents a Postgres operator /// `cargo pgrx schema` will automatically generate the underlying SQL #[proc_macro_attribute] diff --git a/pgrx-sql-entity-graph/src/lib.rs b/pgrx-sql-entity-graph/src/lib.rs index 964a17a55..0b774ac86 100644 --- a/pgrx-sql-entity-graph/src/lib.rs +++ b/pgrx-sql-entity-graph/src/lib.rs @@ -27,10 +27,10 @@ pub use extension_sql::{ExtensionSql, ExtensionSqlFile, SqlDeclared}; pub use extern_args::{parse_extern_attributes, ExternArgs}; pub use mapping::RustSqlMapping; pub use pg_extern::entity::{ - PgExternArgumentEntity, PgExternEntity, PgExternReturnEntity, PgExternReturnEntityIteratedItem, - PgOperatorEntity, + PgCastEntity, PgExternArgumentEntity, PgExternEntity, PgExternReturnEntity, + PgExternReturnEntityIteratedItem, PgOperatorEntity, }; -pub use pg_extern::{NameMacro, PgExtern, PgExternArgument, PgOperator}; +pub use pg_extern::{NameMacro, PgCast, PgExtern, PgExternArgument, PgOperator}; pub use pg_trigger::attribute::PgTriggerAttribute; pub use pg_trigger::entity::PgTriggerEntity; pub use pg_trigger::PgTrigger; diff --git a/pgrx-sql-entity-graph/src/pg_extern/cast.rs b/pgrx-sql-entity-graph/src/pg_extern/cast.rs new file mode 100644 index 000000000..a8cd57e53 --- /dev/null +++ b/pgrx-sql-entity-graph/src/pg_extern/cast.rs @@ -0,0 +1,46 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +/*! + +`#[pg_cast]` related macro expansion for Rust to SQL translation + +> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal** +to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution. + +*/ +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, ToTokens, TokenStreamExt}; + +/// A parsed `#[pg_cast]` operator. +/// +/// It is created during [`PgExtern`](crate::PgExtern) parsing. +#[derive(Debug, Clone)] +pub enum PgCast { + Default, + Assignment, + Implicit, +} + +impl ToTokens for PgCast { + fn to_tokens(&self, tokens: &mut TokenStream2) { + let quoted = match self { + PgCast::Default => quote! { + ::pgrx::pgrx_sql_entity_graph::PgCastEntity::Default + }, + PgCast::Assignment => quote! { + ::pgrx::pgrx_sql_entity_graph::PgCastEntity::Assignment + }, + PgCast::Implicit => quote! { + ::pgrx::pgrx_sql_entity_graph::PgCastEntity::Implicit + }, + }; + tokens.append_all(quoted); + } +} diff --git a/pgrx-sql-entity-graph/src/pg_extern/entity/cast.rs b/pgrx-sql-entity-graph/src/pg_extern/entity/cast.rs new file mode 100644 index 000000000..cae7ae02e --- /dev/null +++ b/pgrx-sql-entity-graph/src/pg_extern/entity/cast.rs @@ -0,0 +1,25 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +/*! + +`#[pg_extern]` related cast entities for Rust to SQL translation + +> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal** +to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution. + +*/ + +/// The output of a [`PgCast`](crate::PgCast) from `quote::ToTokens::to_tokens`. +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum PgCastEntity { + Default, + Assignment, + Implicit, +} diff --git a/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs b/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs index 28c72f80e..168553cb1 100644 --- a/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs +++ b/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs @@ -16,10 +16,12 @@ to the `pgrx` framework and very subject to change between versions. While you m */ mod argument; +mod cast; mod operator; mod returning; pub use argument::PgExternArgumentEntity; +pub use cast::PgCastEntity; pub use operator::PgOperatorEntity; pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem}; @@ -49,6 +51,7 @@ pub struct PgExternEntity { pub extern_attrs: Vec, pub search_path: Option>, pub operator: Option, + pub cast: Option, pub to_sql_config: ToSqlConfigEntity, } @@ -452,7 +455,6 @@ impl ToSql for PgExternEntity { .map(|schema| format!("{}.", schema)) .unwrap_or_else(|| context.schema_prefix_for(&self_index)); - eprintln!("schema={schema}"); let operator_sql = format!("\n\n\ -- {file}:{line}\n\ -- {module_path}::{name}\n\ @@ -472,6 +474,122 @@ impl ToSql for PgExternEntity { optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() }, ); ext_sql + &operator_sql + } else if let Some(cast) = &self.cast { + let target_arg = &self.metadata.retval; + let target_fn_arg = &self.fn_return; + let target_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match (&context.graph[*neighbor], target_fn_arg) { + (SqlGraphEntity::Type(ty), PgExternReturnEntity::Type { ty: rty }) => { + ty.id_matches(&rty.ty_id) + } + (SqlGraphEntity::Enum(en), PgExternReturnEntity::Type { ty: rty }) => { + en.id_matches(&rty.ty_id) + } + (SqlGraphEntity::BuiltinType(defined), _) => defined == target_arg.type_name, + _ => false, + }) + .ok_or_else(|| { + eyre!("Could not find source type in graph. Got: {:?}", target_arg) + })?; + let target_arg_sql = match target_arg.argument_sql { + Ok(SqlMapping::As(ref sql)) => sql.clone(), + Ok(SqlMapping::Composite { array_brackets }) => { + if array_brackets { + let composite_type = self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?; + format!("{composite_type}[]") + } else { + self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string() + } + } + Ok(SqlMapping::Skip) => { + return Err(eyre!("Found an skipped SQL type in a cast, this is not valid")) + } + Err(err) => return Err(err.into()), + }; + if self.metadata.arguments.len() != 1 { + return Err(eyre!( + "PG cast function must have exactly one argument, got {}", + self.metadata.arguments.len() + )); + } + if self.fn_args.len() != 1 { + return Err(eyre!( + "PG cast function must have exactly one argument, got {}", + self.fn_args.len() + )); + } + let source_arg = self + .metadata + .arguments + .first() + .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?; + let source_fn_arg = self + .fn_args + .first() + .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?; + let source_arg_graph_index = context + .graph + .neighbors_undirected(self_index) + .find(|neighbor| match &context.graph[*neighbor] { + SqlGraphEntity::Type(ty) => ty.id_matches(&source_fn_arg.used_ty.ty_id), + SqlGraphEntity::Enum(en) => en.id_matches(&source_fn_arg.used_ty.ty_id), + SqlGraphEntity::BuiltinType(defined) => defined == source_arg.type_name, + _ => false, + }) + .ok_or_else(|| { + eyre!("Could not find source type in graph. Got: {:?}", source_arg) + })?; + let source_arg_sql = match source_arg.argument_sql { + Ok(SqlMapping::As(ref sql)) => sql.clone(), + Ok(SqlMapping::Composite { array_brackets }) => { + if array_brackets { + let composite_type = self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?; + format!("{composite_type}[]") + } else { + self.fn_args[0].used_ty.composite_type + .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string() + } + } + Ok(SqlMapping::Skip) => { + return Err(eyre!("Found an skipped SQL type in a cast, this is not valid")) + } + Err(err) => return Err(err.into()), + }; + let optional = match cast { + PgCastEntity::Default => String::from(""), + PgCastEntity::Assignment => String::from(" AS ASSIGNMENT"), + PgCastEntity::Implicit => String::from(" AS IMPLICIT"), + }; + + let cast_sql = format!("\n\n\ + -- {file}:{line}\n\ + -- {module_path}::{name}\n\ + CREATE CAST (\n\ + \t{schema_prefix_source}{source_arg_sql} /* {source_name} */\n\ + \tAS\n\ + \t{schema_prefix_target}{target_arg_sql} /* {target_name} */\n\ + )\n\ + WITH FUNCTION {function_name}{optional};\ + ", + file = self.file, + line = self.line, + name = self.name, + module_path = self.module_path, + source_arg_sql = source_arg_sql, + schema_prefix_source = context.schema_prefix_for(&source_arg_graph_index), + source_name = source_arg.type_name, + target_arg_sql = target_arg_sql, + schema_prefix_target = context.schema_prefix_for(&target_arg_graph_index), + target_name = target_arg.type_name, + function_name = self.name, + optional = optional, + ); + ext_sql + &cast_sql } else { ext_sql }; diff --git a/pgrx-sql-entity-graph/src/pg_extern/mod.rs b/pgrx-sql-entity-graph/src/pg_extern/mod.rs index f0bc02275..e0b5a0222 100644 --- a/pgrx-sql-entity-graph/src/pg_extern/mod.rs +++ b/pgrx-sql-entity-graph/src/pg_extern/mod.rs @@ -17,14 +17,17 @@ to the `pgrx` framework and very subject to change between versions. While you m */ mod argument; mod attribute; +mod cast; pub mod entity; mod operator; mod returning; mod search_path; pub use argument::PgExternArgument; +pub use cast::PgCast; pub use operator::PgOperator; pub use returning::NameMacro; +use syn::Expr; use crate::ToSqlConfig; use attribute::Attribute; @@ -74,6 +77,7 @@ pub struct PgExtern { func: syn::ItemFn, to_sql_config: ToSqlConfig, operator: Option, + cast: Option, search_path: Option, inputs: Vec, input_types: Vec, @@ -109,6 +113,7 @@ impl PgExtern { crate::ident_is_acceptable_to_postgres(&func.sig.ident)?; } let operator = Self::operator(&func)?; + let cast = Self::cast(&func)?; let search_path = Self::search_path(&func)?; let inputs = Self::inputs(&func)?; let input_types = Self::input_types(&func)?; @@ -118,6 +123,7 @@ impl PgExtern { func, to_sql_config, operator, + cast, search_path, inputs, input_types, @@ -230,6 +236,36 @@ impl PgExtern { Ok(skel) } + fn cast(func: &syn::ItemFn) -> syn::Result> { + let mut skel = Option::::default(); + for attr in &func.attrs { + let last_segment = attr.path.segments.last().unwrap(); + match last_segment.ident.to_string().as_str() { + "pg_cast" => { + let mut cast = PgCast::Default; + if !attr.tokens.is_empty() { + match attr.parse_args::() { + Ok(Expr::Path(p)) => { + match p.path.segments.last().unwrap().ident.to_string().as_str() { + "implicit" => cast = PgCast::Implicit, + "assignment" => cast = PgCast::Assignment, + _ => eprintln!("Unrecognized option: {}. Using default cast options.", p.path.to_token_stream()), + } + } + _ => eprintln!( + "Unable to parse attribute to pg_cast as a Rust Expr: {}. Using default cast options.", + attr.tokens + ), + } + } + skel = Some(cast); + } + _ => (), + } + } + Ok(skel) + } + fn search_path(func: &syn::ItemFn) -> syn::Result> { func.attrs .iter() @@ -282,6 +318,7 @@ impl PgExtern { }; let operator = self.operator.clone().into_iter(); + let cast = self.cast.clone().into_iter(); let to_sql_config = match self.overridden() { None => self.to_sql_config.clone(), Some(content) => ToSqlConfig { content: Some(content), ..self.to_sql_config.clone() }, @@ -316,6 +353,7 @@ impl PgExtern { search_path: None #( .unwrap_or_else(|| Some(vec![#search_path])) )*, #[allow(clippy::or_fun_call)] operator: None #( .unwrap_or_else(|| Some(#operator)) )*, + cast: None #( .unwrap_or_else(|| Some(#cast)) )*, to_sql_config: #to_sql_config, }; ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::Function(submission) diff --git a/pgrx-tests/src/tests/mod.rs b/pgrx-tests/src/tests/mod.rs index 1621568b3..1d372b616 100644 --- a/pgrx-tests/src/tests/mod.rs +++ b/pgrx-tests/src/tests/mod.rs @@ -37,6 +37,7 @@ mod log_tests; mod memcxt_tests; mod name_tests; mod numeric_tests; +mod pg_cast_tests; mod pg_extern_tests; mod pg_guard_tests; mod pg_operator_tests; diff --git a/pgrx-tests/src/tests/pg_cast_tests.rs b/pgrx-tests/src/tests/pg_cast_tests.rs new file mode 100644 index 000000000..29c7d6241 --- /dev/null +++ b/pgrx-tests/src/tests/pg_cast_tests.rs @@ -0,0 +1,64 @@ +//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC. +//LICENSE +//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc. +//LICENSE +//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. +//LICENSE +//LICENSE All rights reserved. +//LICENSE +//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +use pgrx::prelude::*; + +#[pg_schema] +mod pg_catalog { + use pgrx::{pg_cast, pg_extern}; + use serde_json::Value::Number; + + #[pg_extern] + #[pg_cast(implicit)] + fn int4_from_json(value: pgrx::Json) -> i32 { + if let Number(num) = &value.0 { + if num.is_i64() { + return num.as_i64().unwrap() as i32; + } else if num.is_f64() { + return num.as_f64().unwrap() as i32; + } else if num.is_u64() { + return num.as_u64().unwrap() as i32; + } + }; + panic!("Error casting json value {} to an integer", value.0) + } +} + +#[cfg(any(test, feature = "pg_test"))] +#[pg_schema] +mod tests { + #[allow(unused_imports)] + use crate as pgrx_tests; + use pgrx::prelude::*; + + #[pg_test] + fn test_pg_cast_explicit_type_cast() { + assert_eq!( + Spi::get_one::("SELECT CAST('{\"a\": 1}'::json->'a' AS int4);"), + Ok(Some(1)) + ); + assert_eq!(Spi::get_one::("SELECT ('{\"a\": 1}'::json->'a')::int4;"), Ok(Some(1))); + } + + #[pg_test] + fn test_pg_cast_assignment_type_cast() { + let _ = Spi::connect(|mut client| { + client.update("CREATE TABLE test_table(value int4);", None, None)?; + client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, None)?; + + Ok::<_, spi::Error>(()) + }); + assert_eq!(Spi::get_one::("SELECT value FROM test_table"), Ok(Some(1))); + } + + #[pg_test] + fn test_pg_cast_implicit_type_cast() { + assert_eq!(Spi::get_one::("SELECT 1 + ('{\"a\": 1}'::json->'a')"), Ok(Some(2))); + } +}