diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index 6d7a3f7dc0..3b8d936fb3 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -16,6 +16,7 @@ use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column /// /// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypType { Base, Composite, @@ -45,6 +46,7 @@ impl TryFrom for TypType { /// Describes the type of the `pg_type.typcategory` column /// /// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypCategory { Array, Boolean, @@ -198,7 +200,9 @@ impl PgConnection { (Ok(TypType::Base), Ok(TypCategory::Array)) => { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?), + kind: PgTypeKind::Array( + self.maybe_fetch_type_info_by_oid(element, true).await?, + ), name: name.into(), oid, })))) diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index 37c018f798..97d5efa0cc 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -1,5 +1,6 @@ #![allow(dead_code)] +use std::borrow::Cow; use std::fmt::{self, Display, Formatter}; use std::ops::Deref; use std::sync::Arc; @@ -750,6 +751,125 @@ impl PgType { } } } + + /// If `self` is an array type, return the type info for its element. + /// + /// This method should only be called on resolved types: calling it on + /// a type that is merely declared (DeclareWithOid/Name) is a bug. + pub(crate) fn try_array_element(&self) -> Option> { + // We explicitly match on all the `None` cases to ensure an exhaustive match. + match self { + PgType::Bool => None, + PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))), + PgType::Bytea => None, + PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))), + PgType::Char => None, + PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))), + PgType::Name => None, + PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))), + PgType::Int8 => None, + PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))), + PgType::Int2 => None, + PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))), + PgType::Int4 => None, + PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))), + PgType::Text => None, + PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))), + PgType::Oid => None, + PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))), + PgType::Json => None, + PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))), + PgType::Point => None, + PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))), + PgType::Lseg => None, + PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))), + PgType::Path => None, + PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))), + PgType::Box => None, + PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))), + PgType::Polygon => None, + PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))), + PgType::Line => None, + PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))), + PgType::Cidr => None, + PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))), + PgType::Float4 => None, + PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))), + PgType::Float8 => None, + PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))), + PgType::Circle => None, + PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))), + PgType::Macaddr8 => None, + PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))), + PgType::Money => None, + PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))), + PgType::Macaddr => None, + PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))), + PgType::Inet => None, + PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))), + PgType::Bpchar => None, + PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))), + PgType::Varchar => None, + PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))), + PgType::Date => None, + PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))), + PgType::Time => None, + PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))), + PgType::Timestamp => None, + PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))), + PgType::Timestamptz => None, + PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))), + PgType::Interval => None, + PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))), + PgType::Timetz => None, + PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))), + PgType::Bit => None, + PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))), + PgType::Varbit => None, + PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))), + PgType::Numeric => None, + PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))), + PgType::Record => None, + PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))), + PgType::Uuid => None, + PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))), + PgType::Jsonb => None, + PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))), + PgType::Int4Range => None, + PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))), + PgType::NumRange => None, + PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))), + PgType::TsRange => None, + PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))), + PgType::TstzRange => None, + PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))), + PgType::DateRange => None, + PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))), + PgType::Int8Range => None, + PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))), + PgType::Jsonpath => None, + PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))), + // There is no `UnknownArray` + PgType::Unknown => None, + // There is no `VoidArray` + PgType::Void => None, + PgType::Custom(ty) => match &ty.kind { + PgTypeKind::Simple => None, + PgTypeKind::Pseudo => None, + PgTypeKind::Domain(_) => None, + PgTypeKind::Composite(_) => None, + PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)), + PgTypeKind::Enum(_) => None, + PgTypeKind::Range(_) => None, + }, + PgType::DeclareWithOid(oid) => { + unreachable!("(bug) use of unresolved type declaration [oid={}]", oid); + } + PgType::DeclareWithName(name) => { + unreachable!("(bug) use of unresolved type declaration [name={}]", name); + } + } + } } impl TypeInfo for PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index 55ebb7e520..68a0ac385e 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -1,4 +1,5 @@ use bytes::Buf; +use std::borrow::Cow; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -103,7 +104,6 @@ where T: for<'a> Decode<'a, Postgres> + Type, { fn decode(value: PgValueRef<'r>) -> Result { - let element_type_info; let format = value.format(); match format { @@ -131,7 +131,8 @@ where // the OID of the element let element_type_oid = buf.get_u32(); - element_type_info = PgTypeInfo::try_from_oid(element_type_oid) + let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid) + .or_else(|| value.type_info.try_array_element().map(Cow::into_owned)) .unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid))); // length of the array axis @@ -159,7 +160,7 @@ where PgValueFormat::Text => { // no type is provided from the database for the element - element_type_info = T::type_info(); + let element_type_info = T::type_info(); let s = value.as_str()?; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index c244a93dc1..ac755018bc 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1094,6 +1094,93 @@ CREATE TABLE heating_bills ( Ok(()) } +#[sqlx_macros::test] +async fn it_resolves_custom_type_in_array() -> anyhow::Result<()> { + // Only supported in Postgres 11+ + let mut conn = new::().await?; + if matches!(conn.server_version_num(), Some(version) if version < 110000) { + return Ok(()); + } + + // language=PostgreSQL + conn.execute( + r#" +DROP TABLE IF EXISTS pets; +DROP TYPE IF EXISTS pet_name_and_race; + +CREATE TYPE pet_name_and_race AS ( + name TEXT, + race TEXT +); +CREATE TABLE pets ( + owner TEXT NOT NULL, + name TEXT NOT NULL, + race TEXT NOT NULL, + PRIMARY KEY (owner, name) +); +INSERT INTO pets(owner, name, race) +VALUES + ('Alice', 'Foo', 'cat'); +INSERT INTO pets(owner, name, race) +VALUES + ('Alice', 'Bar', 'dog'); + "#, + ) + .await?; + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct PetNameAndRace { + name: String, + race: String, + } + + impl sqlx::Type for PetNameAndRace { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("pet_name_and_race") + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRace { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + let name = decoder.try_decode::()?; + let race = decoder.try_decode::()?; + Ok(Self { name, race }) + } + } + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct PetNameAndRaceArray(Vec); + + impl sqlx::Type for PetNameAndRaceArray { + fn type_info() -> sqlx::postgres::PgTypeInfo { + // Array type name is the name of the element type prefixed with `_` + sqlx::postgres::PgTypeInfo::with_name("_pet_name_and_race") + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRaceArray { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + Ok(Self(Vec::::decode(value)?)) + } + } + + let mut conn = new::().await?; + + let row = sqlx::query("select owner, array_agg(row(name, race)::pet_name_and_race) as pets from pets group by owner") + .fetch_one(&mut conn) + .await?; + + let pets: PetNameAndRaceArray = row.get("pets"); + + assert_eq!(pets.0.len(), 2); + Ok(()) +} + #[sqlx_macros::test] async fn test_pg_server_num() -> anyhow::Result<()> { use sqlx::postgres::PgConnectionInfo;