From 73253e99fdacec672f88cb5ae6be869428640181 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Tue, 25 Oct 2022 21:14:49 +0800 Subject: [PATCH] Fix parsing Postgres user-defined types --- src/postgres/def/types.rs | 126 ++++++++++++++++++---------------- src/postgres/parser/column.rs | 21 ++++++ src/postgres/query/column.rs | 13 ++-- src/postgres/writer/column.rs | 6 +- 4 files changed, 98 insertions(+), 68 deletions(-) diff --git a/src/postgres/def/types.rs b/src/postgres/def/types.rs index fcd957cd..af9ec07c 100644 --- a/src/postgres/def/types.rs +++ b/src/postgres/def/types.rs @@ -111,7 +111,7 @@ pub enum Type { JsonBinary, /// Variable-length multidimensional array - Array(SeaRc), + Array(ArrayDef), // TODO: // /// The structure of a row or record; a list of field names and types @@ -149,66 +149,60 @@ impl Type { // TODO: Support more types #[allow(clippy::should_implement_trait)] pub fn from_str(name: &str) -> Type { - fn parse_type(name: &str) -> Type { - match name.to_lowercase().as_str() { - "smallint" | "int2" => Type::SmallInt, - "integer" | "int" | "int4" => Type::Integer, - "bigint" | "int8" => Type::BigInt, - "decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()), - "numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()), - "real" | "float4" => Type::Real, - "double precision" | "double" | "float8" => Type::DoublePrecision, - "smallserial" | "serial2" => Type::SmallSerial, - "serial" | "serial4" => Type::Serial, - "bigserial" | "serial8" => Type::BigSerial, - "money" => Type::Money, - "character varying" | "varchar" => Type::Varchar(StringAttr::default()), - "character" | "char" => Type::Char(StringAttr::default()), - "text" => Type::Text, - "bytea" => Type::Bytea, - "timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()), - "timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()), - "date" => Type::Date, - "time" | "time without time zone" => Type::Time(TimeAttr::default()), - "time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()), - "interval" => Type::Interval(IntervalAttr::default()), - "boolean" => Type::Boolean, - "point" => Type::Point, - "line" => Type::Line, - "lseg" => Type::Lseg, - "box" => Type::Box, - "path" => Type::Path, - "polygon" => Type::Polygon, - "circle" => Type::Circle, - "cidr" => Type::Cidr, - "inet" => Type::Inet, - "macaddr" => Type::MacAddr, - "macaddr8" => Type::MacAddr8, - "bit" => Type::Bit(BitAttr::default()), - "tsvector" => Type::TsVector, - "tsquery" => Type::TsQuery, - "uuid" => Type::Uuid, - "xml" => Type::Xml, - "json" => Type::Json, - "jsonb" => Type::JsonBinary, - // "" => Type::Composite, - "int4range" => Type::Int4Range, - "int8range" => Type::Int8Range, - "numrange" => Type::NumRange, - "tsrange" => Type::TsRange, - "tstzrange" => Type::TsTzRange, - "daterange" => Type::DateRange, - // "" => Type::Domain, - "pg_lsn" => Type::PgLsn, - "user-defined" => Type::Enum(EnumDef::default()), - _ if name.ends_with("[]") => { - let col_type = parse_type(&name.replacen("[]", "", 1)); - Type::Array(SeaRc::new(col_type)) - } - _ => Type::Unknown(name.to_owned()), - } + match name.to_lowercase().as_str() { + "smallint" | "int2" => Type::SmallInt, + "integer" | "int" | "int4" => Type::Integer, + "bigint" | "int8" => Type::BigInt, + "decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()), + "numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()), + "real" | "float4" => Type::Real, + "double precision" | "double" | "float8" => Type::DoublePrecision, + "smallserial" | "serial2" => Type::SmallSerial, + "serial" | "serial4" => Type::Serial, + "bigserial" | "serial8" => Type::BigSerial, + "money" => Type::Money, + "character varying" | "varchar" => Type::Varchar(StringAttr::default()), + "character" | "char" => Type::Char(StringAttr::default()), + "text" => Type::Text, + "bytea" => Type::Bytea, + "timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()), + "timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()), + "date" => Type::Date, + "time" | "time without time zone" => Type::Time(TimeAttr::default()), + "time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()), + "interval" => Type::Interval(IntervalAttr::default()), + "boolean" => Type::Boolean, + "point" => Type::Point, + "line" => Type::Line, + "lseg" => Type::Lseg, + "box" => Type::Box, + "path" => Type::Path, + "polygon" => Type::Polygon, + "circle" => Type::Circle, + "cidr" => Type::Cidr, + "inet" => Type::Inet, + "macaddr" => Type::MacAddr, + "macaddr8" => Type::MacAddr8, + "bit" => Type::Bit(BitAttr::default()), + "tsvector" => Type::TsVector, + "tsquery" => Type::TsQuery, + "uuid" => Type::Uuid, + "xml" => Type::Xml, + "json" => Type::Json, + "jsonb" => Type::JsonBinary, + // "" => Type::Composite, + "int4range" => Type::Int4Range, + "int8range" => Type::Int8Range, + "numrange" => Type::NumRange, + "tsrange" => Type::TsRange, + "tstzrange" => Type::TsTzRange, + "daterange" => Type::DateRange, + // "" => Type::Domain, + "pg_lsn" => Type::PgLsn, + "user-defined" => Type::Enum(EnumDef::default()), + "array" => Type::Array(ArrayDef::default()), + _ => Type::Unknown(name.to_owned()), } - parse_type(name) } } @@ -259,6 +253,14 @@ pub struct EnumDef { pub typename: String, } +/// Defines an enum for the PostgreSQL module +#[derive(Clone, Debug, PartialEq, Default)] +#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] +pub struct ArrayDef { + /// Array type + pub col_type: Option>, +} + impl Type { pub fn has_numeric_attr(&self) -> bool { matches!(self, Type::Numeric(_) | Type::Decimal(_)) @@ -289,4 +291,8 @@ impl Type { pub fn has_enum_attr(&self) -> bool { matches!(self, Type::Enum(_)) } + + pub fn has_array_attr(&self) -> bool { + matches!(self, Type::Array(_)) + } } diff --git a/src/postgres/parser/column.rs b/src/postgres/parser/column.rs index dbf961ea..59b23219 100644 --- a/src/postgres/parser/column.rs +++ b/src/postgres/parser/column.rs @@ -1,4 +1,5 @@ use crate::postgres::{def::*, parser::yes_or_no_to_bool, query::ColumnQueryResult}; +use sea_query::SeaRc; use std::{collections::HashMap, convert::TryFrom}; impl ColumnQueryResult { @@ -44,6 +45,9 @@ pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType { if ctype.has_enum_attr() { ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype); } + if ctype.has_array_attr() { + ctype = parse_array_attributes(result.udt_name_regtype.as_ref(), ctype); + } ctype } @@ -183,6 +187,23 @@ pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) - ctype } +pub fn parse_array_attributes( + udt_name_regtype: Option<&String>, + mut ctype: ColumnType, +) -> ColumnType { + match ctype { + Type::Array(ref mut def) => { + def.col_type = match udt_name_regtype { + None => panic!("parse_array_attributes(_) received an empty udt_name_regtype"), + Some(typename) => Some(SeaRc::new(Type::from_str(&typename.replacen("[]", "", 1)))), + }; + } + _ => panic!("parse_array_attributes(_) received a type that does not have EnumDef"), + }; + + ctype +} + impl ColumnInfo { pub fn parse_enum_variants(mut self, enums: &HashMap>) -> Self { if let Type::Enum(ref mut enum_def) = self.col_type { diff --git a/src/postgres/query/column.rs b/src/postgres/query/column.rs index dfb95f8a..aa001aef 100644 --- a/src/postgres/query/column.rs +++ b/src/postgres/query/column.rs @@ -66,6 +66,7 @@ pub struct ColumnQueryResult { pub interval_precision: Option, pub udt_name: Option, + pub udt_name_regtype: Option, } impl SchemaQueryBuilder { @@ -75,12 +76,9 @@ impl SchemaQueryBuilder { table: SeaRc, ) -> SelectStatement { Query::select() - .column(ColumnsField::ColumnName) - .expr( - Expr::expr(Expr::cust("udt_name::regtype").cast_as(Alias::new("text"))) - .binary(BinOper::As, Expr::col(ColumnsField::DataType)), - ) .columns([ + ColumnsField::ColumnName, + ColumnsField::DataType, ColumnsField::ColumnDefault, ColumnsField::GenerationExpression, ColumnsField::IsNullable, @@ -95,6 +93,10 @@ impl SchemaQueryBuilder { ColumnsField::IntervalPrecision, ColumnsField::UdtName, ]) + .expr( + Expr::expr(Expr::cust("udt_name::regtype").cast_as(Alias::new("text"))) + .binary(BinOper::As, Expr::col(Alias::new("udt_name_regtype"))), + ) .from((InformationSchema::Schema, InformationSchema::Columns)) .and_where(Expr::col(ColumnsField::TableSchema).eq(schema.to_string())) .and_where(Expr::col(ColumnsField::TableName).eq(table.to_string())) @@ -122,6 +124,7 @@ impl From<&PgRow> for ColumnQueryResult { interval_type: row.get(12), interval_precision: row.get(13), udt_name: row.get(14), + udt_name_regtype: row.get(15), } } } diff --git a/src/postgres/writer/column.rs b/src/postgres/writer/column.rs index 41d60b65..71f9e25a 100644 --- a/src/postgres/writer/column.rs +++ b/src/postgres/writer/column.rs @@ -146,9 +146,9 @@ impl ColumnInfo { .collect(); ColumnType::Enum { name, variants } } - Type::Array(col_type) => { - ColumnType::Array(SeaRc::new(Box::new(write_type(col_type)))) - } + Type::Array(array_def) => ColumnType::Array(SeaRc::new(Box::new(write_type( + array_def.col_type.as_ref().expect("Array type not defined"), + )))), } } write_type(&self.col_type)