diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 6fd584f25f..1adbd18dfe 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -344,17 +344,13 @@ WHERE rngtypid = $1 } // language=SQL - let (oid,): (Oid,) = query_as( - " -SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 - ", - ) - .bind(name) - .fetch_optional(&mut *self) - .await? - .ok_or_else(|| Error::TypeNotFound { - type_name: String::from(name), - })?; + let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid") + .bind(name) + .fetch_optional(&mut *self) + .await? + .ok_or_else(|| Error::TypeNotFound { + type_name: String::from(name), + })?; self.cache_type_oid.insert(name.to_string().into(), oid); Ok(oid) diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index a021e4875e..e1a44e80e9 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -176,6 +176,8 @@ impl Connection for PgConnection { fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { + self.cache_type_oid.clear(); + let mut cleared = 0_usize; self.wait_until_ready().await?; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 67be84fc2a..68650a10c8 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1388,6 +1388,69 @@ VALUES Ok(()) } +#[sqlx_macros::test] +async fn custom_type_resolution_respects_search_path() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute( + r#" +DROP TYPE IF EXISTS some_enum_type; +DROP SCHEMA IF EXISTS another CASCADE; + +CREATE SCHEMA another; +CREATE TYPE some_enum_type AS ENUM ('a', 'b', 'c'); +CREATE TYPE another.some_enum_type AS ENUM ('d', 'e', 'f'); + "#, + ) + .await?; + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct SomeEnumType(String); + + impl sqlx::Type for SomeEnumType { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("some_enum_type") + } + + fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == Self::type_info() + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for SomeEnumType { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + Ok(Self(>::decode(value)?)) + } + } + + impl<'q> sqlx::Encode<'q, Postgres> for SomeEnumType { + fn encode_by_ref( + &self, + buf: &mut sqlx::postgres::PgArgumentBuffer, + ) -> sqlx::encode::IsNull { + >::encode_by_ref(&self.0, buf) + } + } + + let mut conn = new::().await?; + + sqlx::query("set search_path = 'another'") + .execute(&mut conn) + .await?; + + let result = sqlx::query("SELECT 1 WHERE $1::some_enum_type = 'd'::some_enum_type;") + .bind(SomeEnumType("d".into())) + .fetch_all(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.len(), 1); + + Ok(()) +} + #[sqlx_macros::test] async fn test_pg_server_num() -> anyhow::Result<()> { let conn = new::().await?;