From ddd2b5675d131f25ae3312bca269736ad5082437 Mon Sep 17 00:00:00 2001 From: Charles Samborski Date: Fri, 8 Oct 2021 13:25:15 +0200 Subject: [PATCH] Fix support for Postgres array of custom types This commit fixes the array decoder to support custom types. The core of the issue was that the array decoder did not use the type info retrieved from the database. It means that it only supported native types. This commit fixes the issue by using the element type info fetched from the database. A new internal helper method is added to the `PgType` struct: it returns the type info for the inner array element, if available. Closes #1477 --- sqlx-core/src/postgres/connection/describe.rs | 6 +- sqlx-core/src/postgres/type_info.rs | 120 ++++++++++++++++++ sqlx-core/src/postgres/types/array.rs | 7 +- tests/postgres/postgres.rs | 87 +++++++++++++ 4 files changed, 216 insertions(+), 4 deletions(-) diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index 6d7a3f7dc0..3b8d936fb3 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -16,6 +16,7 @@ use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column /// /// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypType { Base, Composite, @@ -45,6 +46,7 @@ impl TryFrom for TypType { /// Describes the type of the `pg_type.typcategory` column /// /// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypCategory { Array, Boolean, @@ -198,7 +200,9 @@ impl PgConnection { (Ok(TypType::Base), Ok(TypCategory::Array)) => { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?), + kind: PgTypeKind::Array( + self.maybe_fetch_type_info_by_oid(element, true).await?, + ), name: name.into(), oid, })))) diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index 37c018f798..97d5efa0cc 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -1,5 +1,6 @@ #![allow(dead_code)] +use std::borrow::Cow; use std::fmt::{self, Display, Formatter}; use std::ops::Deref; use std::sync::Arc; @@ -750,6 +751,125 @@ impl PgType { } } } + + /// If `self` is an array type, return the type info for its element. + /// + /// This method should only be called on resolved types: calling it on + /// a type that is merely declared (DeclareWithOid/Name) is a bug. + pub(crate) fn try_array_element(&self) -> Option> { + // We explicitly match on all the `None` cases to ensure an exhaustive match. + match self { + PgType::Bool => None, + PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))), + PgType::Bytea => None, + PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))), + PgType::Char => None, + PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))), + PgType::Name => None, + PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))), + PgType::Int8 => None, + PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))), + PgType::Int2 => None, + PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))), + PgType::Int4 => None, + PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))), + PgType::Text => None, + PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))), + PgType::Oid => None, + PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))), + PgType::Json => None, + PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))), + PgType::Point => None, + PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))), + PgType::Lseg => None, + PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))), + PgType::Path => None, + PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))), + PgType::Box => None, + PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))), + PgType::Polygon => None, + PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))), + PgType::Line => None, + PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))), + PgType::Cidr => None, + PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))), + PgType::Float4 => None, + PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))), + PgType::Float8 => None, + PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))), + PgType::Circle => None, + PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))), + PgType::Macaddr8 => None, + PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))), + PgType::Money => None, + PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))), + PgType::Macaddr => None, + PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))), + PgType::Inet => None, + PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))), + PgType::Bpchar => None, + PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))), + PgType::Varchar => None, + PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))), + PgType::Date => None, + PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))), + PgType::Time => None, + PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))), + PgType::Timestamp => None, + PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))), + PgType::Timestamptz => None, + PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))), + PgType::Interval => None, + PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))), + PgType::Timetz => None, + PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))), + PgType::Bit => None, + PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))), + PgType::Varbit => None, + PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))), + PgType::Numeric => None, + PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))), + PgType::Record => None, + PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))), + PgType::Uuid => None, + PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))), + PgType::Jsonb => None, + PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))), + PgType::Int4Range => None, + PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))), + PgType::NumRange => None, + PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))), + PgType::TsRange => None, + PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))), + PgType::TstzRange => None, + PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))), + PgType::DateRange => None, + PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))), + PgType::Int8Range => None, + PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))), + PgType::Jsonpath => None, + PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))), + // There is no `UnknownArray` + PgType::Unknown => None, + // There is no `VoidArray` + PgType::Void => None, + PgType::Custom(ty) => match &ty.kind { + PgTypeKind::Simple => None, + PgTypeKind::Pseudo => None, + PgTypeKind::Domain(_) => None, + PgTypeKind::Composite(_) => None, + PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)), + PgTypeKind::Enum(_) => None, + PgTypeKind::Range(_) => None, + }, + PgType::DeclareWithOid(oid) => { + unreachable!("(bug) use of unresolved type declaration [oid={}]", oid); + } + PgType::DeclareWithName(name) => { + unreachable!("(bug) use of unresolved type declaration [name={}]", name); + } + } + } } impl TypeInfo for PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index cf2baea40a..ef586d6e42 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -1,4 +1,5 @@ use bytes::Buf; +use std::borrow::Cow; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -77,7 +78,6 @@ where T: for<'a> Decode<'a, Postgres> + Type, { fn decode(value: PgValueRef<'r>) -> Result { - let element_type_info; let format = value.format(); match format { @@ -105,7 +105,8 @@ where // the OID of the element let element_type_oid = buf.get_u32(); - element_type_info = PgTypeInfo::try_from_oid(element_type_oid) + let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid) + .or_else(|| value.type_info.try_array_element().map(Cow::into_owned)) .unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid))); // length of the array axis @@ -133,7 +134,7 @@ where PgValueFormat::Text => { // no type is provided from the database for the element - element_type_info = T::type_info(); + let element_type_info = T::type_info(); let s = value.as_str()?; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 51dfbc6d37..c03e580a14 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1094,6 +1094,93 @@ CREATE TABLE heating_bills ( Ok(()) } +#[sqlx_macros::test] +async fn it_resolves_custom_type_in_array() -> anyhow::Result<()> { + // Only supported in Postgres 11+ + let mut conn = new::().await?; + if matches!(conn.server_version_num(), Some(version) if version < 110000) { + return Ok(()); + } + + // language=PostgreSQL + conn.execute( + r#" +DROP TABLE IF EXISTS pets; +DROP TYPE IF EXISTS pet_name_and_race; + +CREATE TYPE pet_name_and_race AS ( + name TEXT, + race TEXT +); +CREATE TABLE pets ( + owner TEXT NOT NULL, + name TEXT NOT NULL, + race TEXT NOT NULL, + PRIMARY KEY (owner, name) +); +INSERT INTO pets(owner, name, race) +VALUES + ('Alice', 'Foo', 'cat'); +INSERT INTO pets(owner, name, race) +VALUES + ('Alice', 'Bar', 'dog'); + "#, + ) + .await?; + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct PetNameAndRace { + name: String, + race: String, + } + + impl sqlx::Type for PetNameAndRace { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("pet_name_and_race") + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRace { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + let name = decoder.try_decode::()?; + let race = decoder.try_decode::()?; + Ok(Self { name, race }) + } + } + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct PetNameAndRaceArray(Vec); + + impl sqlx::Type for PetNameAndRaceArray { + fn type_info() -> sqlx::postgres::PgTypeInfo { + // Array type name is the name of the element type prefixed with `_` + sqlx::postgres::PgTypeInfo::with_name("_pet_name_and_race") + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRaceArray { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + Ok(Self(Vec::::decode(value)?)) + } + } + + let mut conn = new::().await?; + + let row = sqlx::query("select owner, array_agg(row(name, race)::pet_name_and_race) as pets from pets group by owner") + .fetch_one(&mut conn) + .await?; + + let pets: PetNameAndRaceArray = row.get("pets"); + + assert_eq!(pets.0.len(), 2); + Ok(()) +} + #[sqlx_macros::test] async fn test_pg_server_num() -> anyhow::Result<()> { use sqlx::postgres::PgConnectionInfo;