Skip to content

Commit

Permalink
Fix parsing Postgres user-defined types (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
billy1624 authored Oct 26, 2022
1 parent 697552a commit b4a6836
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 68 deletions.
126 changes: 66 additions & 60 deletions src/postgres/def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub enum Type {
JsonBinary,

/// Variable-length multidimensional array
Array(SeaRc<Type>),
Array(ArrayDef),

// TODO:
// /// The structure of a row or record; a list of field names and types
Expand Down Expand Up @@ -149,66 +149,60 @@ impl Type {
// TODO: Support more types
#[allow(clippy::should_implement_trait)]
pub fn from_str(name: &str) -> Type {
fn parse_type(name: &str) -> Type {
match name.to_lowercase().as_str() {
"smallint" | "int2" => Type::SmallInt,
"integer" | "int" | "int4" => Type::Integer,
"bigint" | "int8" => Type::BigInt,
"decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()),
"numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()),
"real" | "float4" => Type::Real,
"double precision" | "double" | "float8" => Type::DoublePrecision,
"smallserial" | "serial2" => Type::SmallSerial,
"serial" | "serial4" => Type::Serial,
"bigserial" | "serial8" => Type::BigSerial,
"money" => Type::Money,
"character varying" | "varchar" => Type::Varchar(StringAttr::default()),
"character" | "char" => Type::Char(StringAttr::default()),
"text" => Type::Text,
"bytea" => Type::Bytea,
"timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()),
"timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()),
"date" => Type::Date,
"time" | "time without time zone" => Type::Time(TimeAttr::default()),
"time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()),
"interval" => Type::Interval(IntervalAttr::default()),
"boolean" => Type::Boolean,
"point" => Type::Point,
"line" => Type::Line,
"lseg" => Type::Lseg,
"box" => Type::Box,
"path" => Type::Path,
"polygon" => Type::Polygon,
"circle" => Type::Circle,
"cidr" => Type::Cidr,
"inet" => Type::Inet,
"macaddr" => Type::MacAddr,
"macaddr8" => Type::MacAddr8,
"bit" => Type::Bit(BitAttr::default()),
"tsvector" => Type::TsVector,
"tsquery" => Type::TsQuery,
"uuid" => Type::Uuid,
"xml" => Type::Xml,
"json" => Type::Json,
"jsonb" => Type::JsonBinary,
// "" => Type::Composite,
"int4range" => Type::Int4Range,
"int8range" => Type::Int8Range,
"numrange" => Type::NumRange,
"tsrange" => Type::TsRange,
"tstzrange" => Type::TsTzRange,
"daterange" => Type::DateRange,
// "" => Type::Domain,
"pg_lsn" => Type::PgLsn,
"user-defined" => Type::Enum(EnumDef::default()),
_ if name.ends_with("[]") => {
let col_type = parse_type(&name.replacen("[]", "", 1));
Type::Array(SeaRc::new(col_type))
}
_ => Type::Unknown(name.to_owned()),
}
match name.to_lowercase().as_str() {
"smallint" | "int2" => Type::SmallInt,
"integer" | "int" | "int4" => Type::Integer,
"bigint" | "int8" => Type::BigInt,
"decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()),
"numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()),
"real" | "float4" => Type::Real,
"double precision" | "double" | "float8" => Type::DoublePrecision,
"smallserial" | "serial2" => Type::SmallSerial,
"serial" | "serial4" => Type::Serial,
"bigserial" | "serial8" => Type::BigSerial,
"money" => Type::Money,
"character varying" | "varchar" => Type::Varchar(StringAttr::default()),
"character" | "char" => Type::Char(StringAttr::default()),
"text" => Type::Text,
"bytea" => Type::Bytea,
"timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()),
"timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()),
"date" => Type::Date,
"time" | "time without time zone" => Type::Time(TimeAttr::default()),
"time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()),
"interval" => Type::Interval(IntervalAttr::default()),
"boolean" => Type::Boolean,
"point" => Type::Point,
"line" => Type::Line,
"lseg" => Type::Lseg,
"box" => Type::Box,
"path" => Type::Path,
"polygon" => Type::Polygon,
"circle" => Type::Circle,
"cidr" => Type::Cidr,
"inet" => Type::Inet,
"macaddr" => Type::MacAddr,
"macaddr8" => Type::MacAddr8,
"bit" => Type::Bit(BitAttr::default()),
"tsvector" => Type::TsVector,
"tsquery" => Type::TsQuery,
"uuid" => Type::Uuid,
"xml" => Type::Xml,
"json" => Type::Json,
"jsonb" => Type::JsonBinary,
// "" => Type::Composite,
"int4range" => Type::Int4Range,
"int8range" => Type::Int8Range,
"numrange" => Type::NumRange,
"tsrange" => Type::TsRange,
"tstzrange" => Type::TsTzRange,
"daterange" => Type::DateRange,
// "" => Type::Domain,
"pg_lsn" => Type::PgLsn,
"user-defined" => Type::Enum(EnumDef::default()),
"array" => Type::Array(ArrayDef::default()),
_ => Type::Unknown(name.to_owned()),
}
parse_type(name)
}
}

Expand Down Expand Up @@ -259,6 +253,14 @@ pub struct EnumDef {
pub typename: String,
}

/// Defines an enum for the PostgreSQL module
#[derive(Clone, Debug, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct ArrayDef {
/// Array type
pub col_type: Option<SeaRc<Type>>,
}

impl Type {
pub fn has_numeric_attr(&self) -> bool {
matches!(self, Type::Numeric(_) | Type::Decimal(_))
Expand Down Expand Up @@ -289,4 +291,8 @@ impl Type {
pub fn has_enum_attr(&self) -> bool {
matches!(self, Type::Enum(_))
}

pub fn has_array_attr(&self) -> bool {
matches!(self, Type::Array(_))
}
}
21 changes: 21 additions & 0 deletions src/postgres/parser/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::postgres::{def::*, parser::yes_or_no_to_bool, query::ColumnQueryResult};
use sea_query::SeaRc;
use std::{collections::HashMap, convert::TryFrom};

impl ColumnQueryResult {
Expand Down Expand Up @@ -44,6 +45,9 @@ pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType {
if ctype.has_enum_attr() {
ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype);
}
if ctype.has_array_attr() {
ctype = parse_array_attributes(result.udt_name_regtype.as_ref(), ctype);
}

ctype
}
Expand Down Expand Up @@ -183,6 +187,23 @@ pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -
ctype
}

pub fn parse_array_attributes(
udt_name_regtype: Option<&String>,
mut ctype: ColumnType,
) -> ColumnType {
match ctype {
Type::Array(ref mut def) => {
def.col_type = match udt_name_regtype {
None => panic!("parse_array_attributes(_) received an empty udt_name_regtype"),
Some(typename) => Some(SeaRc::new(Type::from_str(&typename.replacen("[]", "", 1)))),
};
}
_ => panic!("parse_array_attributes(_) received a type that does not have EnumDef"),
};

ctype
}

impl ColumnInfo {
pub fn parse_enum_variants(mut self, enums: &HashMap<String, Vec<String>>) -> Self {
if let Type::Enum(ref mut enum_def) = self.col_type {
Expand Down
13 changes: 8 additions & 5 deletions src/postgres/query/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub struct ColumnQueryResult {
pub interval_precision: Option<i32>,

pub udt_name: Option<String>,
pub udt_name_regtype: Option<String>,
}

impl SchemaQueryBuilder {
Expand All @@ -75,12 +76,9 @@ impl SchemaQueryBuilder {
table: SeaRc<dyn Iden>,
) -> SelectStatement {
Query::select()
.column(ColumnsField::ColumnName)
.expr(
Expr::expr(Expr::cust("udt_name::regtype").cast_as(Alias::new("text")))
.binary(BinOper::As, Expr::col(ColumnsField::DataType)),
)
.columns([
ColumnsField::ColumnName,
ColumnsField::DataType,
ColumnsField::ColumnDefault,
ColumnsField::GenerationExpression,
ColumnsField::IsNullable,
Expand All @@ -95,6 +93,10 @@ impl SchemaQueryBuilder {
ColumnsField::IntervalPrecision,
ColumnsField::UdtName,
])
.expr(
Expr::expr(Expr::cust("udt_name::regtype").cast_as(Alias::new("text")))
.binary(BinOper::As, Expr::col(Alias::new("udt_name_regtype"))),
)
.from((InformationSchema::Schema, InformationSchema::Columns))
.and_where(Expr::col(ColumnsField::TableSchema).eq(schema.to_string()))
.and_where(Expr::col(ColumnsField::TableName).eq(table.to_string()))
Expand Down Expand Up @@ -122,6 +124,7 @@ impl From<&PgRow> for ColumnQueryResult {
interval_type: row.get(12),
interval_precision: row.get(13),
udt_name: row.get(14),
udt_name_regtype: row.get(15),
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/postgres/writer/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ impl ColumnInfo {
.collect();
ColumnType::Enum { name, variants }
}
Type::Array(col_type) => {
ColumnType::Array(SeaRc::new(Box::new(write_type(col_type))))
}
Type::Array(array_def) => ColumnType::Array(SeaRc::new(Box::new(write_type(
array_def.col_type.as_ref().expect("Array type not defined"),
)))),
}
}
write_type(&self.col_type)
Expand Down

0 comments on commit b4a6836

Please sign in to comment.