Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parsing of Postgres user-defined types #84

Merged
merged 1 commit into from
Oct 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 66 additions & 60 deletions src/postgres/def/types.rs
Original file line number Diff line number Diff line change
@@ -111,7 +111,7 @@ pub enum Type {
JsonBinary,

/// Variable-length multidimensional array
Array(SeaRc<Type>),
Array(ArrayDef),

// TODO:
// /// The structure of a row or record; a list of field names and types
@@ -149,66 +149,60 @@ impl Type {
// TODO: Support more types
#[allow(clippy::should_implement_trait)]
pub fn from_str(name: &str) -> Type {
fn parse_type(name: &str) -> Type {
match name.to_lowercase().as_str() {
"smallint" | "int2" => Type::SmallInt,
"integer" | "int" | "int4" => Type::Integer,
"bigint" | "int8" => Type::BigInt,
"decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()),
"numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()),
"real" | "float4" => Type::Real,
"double precision" | "double" | "float8" => Type::DoublePrecision,
"smallserial" | "serial2" => Type::SmallSerial,
"serial" | "serial4" => Type::Serial,
"bigserial" | "serial8" => Type::BigSerial,
"money" => Type::Money,
"character varying" | "varchar" => Type::Varchar(StringAttr::default()),
"character" | "char" => Type::Char(StringAttr::default()),
"text" => Type::Text,
"bytea" => Type::Bytea,
"timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()),
"timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()),
"date" => Type::Date,
"time" | "time without time zone" => Type::Time(TimeAttr::default()),
"time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()),
"interval" => Type::Interval(IntervalAttr::default()),
"boolean" => Type::Boolean,
"point" => Type::Point,
"line" => Type::Line,
"lseg" => Type::Lseg,
"box" => Type::Box,
"path" => Type::Path,
"polygon" => Type::Polygon,
"circle" => Type::Circle,
"cidr" => Type::Cidr,
"inet" => Type::Inet,
"macaddr" => Type::MacAddr,
"macaddr8" => Type::MacAddr8,
"bit" => Type::Bit(BitAttr::default()),
"tsvector" => Type::TsVector,
"tsquery" => Type::TsQuery,
"uuid" => Type::Uuid,
"xml" => Type::Xml,
"json" => Type::Json,
"jsonb" => Type::JsonBinary,
// "" => Type::Composite,
"int4range" => Type::Int4Range,
"int8range" => Type::Int8Range,
"numrange" => Type::NumRange,
"tsrange" => Type::TsRange,
"tstzrange" => Type::TsTzRange,
"daterange" => Type::DateRange,
// "" => Type::Domain,
"pg_lsn" => Type::PgLsn,
"user-defined" => Type::Enum(EnumDef::default()),
_ if name.ends_with("[]") => {
let col_type = parse_type(&name.replacen("[]", "", 1));
Type::Array(SeaRc::new(col_type))
}
_ => Type::Unknown(name.to_owned()),
}
match name.to_lowercase().as_str() {
"smallint" | "int2" => Type::SmallInt,
"integer" | "int" | "int4" => Type::Integer,
"bigint" | "int8" => Type::BigInt,
"decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()),
"numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()),
"real" | "float4" => Type::Real,
"double precision" | "double" | "float8" => Type::DoublePrecision,
"smallserial" | "serial2" => Type::SmallSerial,
"serial" | "serial4" => Type::Serial,
"bigserial" | "serial8" => Type::BigSerial,
"money" => Type::Money,
"character varying" | "varchar" => Type::Varchar(StringAttr::default()),
"character" | "char" => Type::Char(StringAttr::default()),
"text" => Type::Text,
"bytea" => Type::Bytea,
"timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()),
"timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()),
"date" => Type::Date,
"time" | "time without time zone" => Type::Time(TimeAttr::default()),
"time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()),
"interval" => Type::Interval(IntervalAttr::default()),
"boolean" => Type::Boolean,
"point" => Type::Point,
"line" => Type::Line,
"lseg" => Type::Lseg,
"box" => Type::Box,
"path" => Type::Path,
"polygon" => Type::Polygon,
"circle" => Type::Circle,
"cidr" => Type::Cidr,
"inet" => Type::Inet,
"macaddr" => Type::MacAddr,
"macaddr8" => Type::MacAddr8,
"bit" => Type::Bit(BitAttr::default()),
"tsvector" => Type::TsVector,
"tsquery" => Type::TsQuery,
"uuid" => Type::Uuid,
"xml" => Type::Xml,
"json" => Type::Json,
"jsonb" => Type::JsonBinary,
// "" => Type::Composite,
"int4range" => Type::Int4Range,
"int8range" => Type::Int8Range,
"numrange" => Type::NumRange,
"tsrange" => Type::TsRange,
"tstzrange" => Type::TsTzRange,
"daterange" => Type::DateRange,
// "" => Type::Domain,
"pg_lsn" => Type::PgLsn,
"user-defined" => Type::Enum(EnumDef::default()),
"array" => Type::Array(ArrayDef::default()),
_ => Type::Unknown(name.to_owned()),
}
parse_type(name)
}
}

@@ -259,6 +253,14 @@ pub struct EnumDef {
pub typename: String,
}

/// Defines an enum for the PostgreSQL module
#[derive(Clone, Debug, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct ArrayDef {
/// Array type
pub col_type: Option<SeaRc<Type>>,
}

impl Type {
pub fn has_numeric_attr(&self) -> bool {
matches!(self, Type::Numeric(_) | Type::Decimal(_))
@@ -289,4 +291,8 @@ impl Type {
pub fn has_enum_attr(&self) -> bool {
matches!(self, Type::Enum(_))
}

pub fn has_array_attr(&self) -> bool {
matches!(self, Type::Array(_))
}
}
21 changes: 21 additions & 0 deletions src/postgres/parser/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::postgres::{def::*, parser::yes_or_no_to_bool, query::ColumnQueryResult};
use sea_query::SeaRc;
use std::{collections::HashMap, convert::TryFrom};

impl ColumnQueryResult {
@@ -44,6 +45,9 @@ pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType {
if ctype.has_enum_attr() {
ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype);
}
if ctype.has_array_attr() {
ctype = parse_array_attributes(result.udt_name_regtype.as_ref(), ctype);
}

ctype
}
@@ -183,6 +187,23 @@ pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -
ctype
}

pub fn parse_array_attributes(
udt_name_regtype: Option<&String>,
mut ctype: ColumnType,
) -> ColumnType {
match ctype {
Type::Array(ref mut def) => {
def.col_type = match udt_name_regtype {
None => panic!("parse_array_attributes(_) received an empty udt_name_regtype"),
Some(typename) => Some(SeaRc::new(Type::from_str(&typename.replacen("[]", "", 1)))),
};
}
_ => panic!("parse_array_attributes(_) received a type that does not have EnumDef"),
};

ctype
}

impl ColumnInfo {
pub fn parse_enum_variants(mut self, enums: &HashMap<String, Vec<String>>) -> Self {
if let Type::Enum(ref mut enum_def) = self.col_type {
13 changes: 8 additions & 5 deletions src/postgres/query/column.rs
Original file line number Diff line number Diff line change
@@ -66,6 +66,7 @@ pub struct ColumnQueryResult {
pub interval_precision: Option<i32>,

pub udt_name: Option<String>,
pub udt_name_regtype: Option<String>,
}

impl SchemaQueryBuilder {
@@ -75,12 +76,9 @@ impl SchemaQueryBuilder {
table: SeaRc<dyn Iden>,
) -> SelectStatement {
Query::select()
.column(ColumnsField::ColumnName)
.expr(
Expr::expr(Expr::cust("udt_name::regtype").cast_as(Alias::new("text")))
.binary(BinOper::As, Expr::col(ColumnsField::DataType)),
)
.columns([
ColumnsField::ColumnName,
ColumnsField::DataType,
ColumnsField::ColumnDefault,
ColumnsField::GenerationExpression,
ColumnsField::IsNullable,
@@ -95,6 +93,10 @@ impl SchemaQueryBuilder {
ColumnsField::IntervalPrecision,
ColumnsField::UdtName,
])
.expr(
Expr::expr(Expr::cust("udt_name::regtype").cast_as(Alias::new("text")))
.binary(BinOper::As, Expr::col(Alias::new("udt_name_regtype"))),
)
.from((InformationSchema::Schema, InformationSchema::Columns))
.and_where(Expr::col(ColumnsField::TableSchema).eq(schema.to_string()))
.and_where(Expr::col(ColumnsField::TableName).eq(table.to_string()))
@@ -122,6 +124,7 @@ impl From<&PgRow> for ColumnQueryResult {
interval_type: row.get(12),
interval_precision: row.get(13),
udt_name: row.get(14),
udt_name_regtype: row.get(15),
}
}
}
6 changes: 3 additions & 3 deletions src/postgres/writer/column.rs
Original file line number Diff line number Diff line change
@@ -146,9 +146,9 @@ impl ColumnInfo {
.collect();
ColumnType::Enum { name, variants }
}
Type::Array(col_type) => {
ColumnType::Array(SeaRc::new(Box::new(write_type(col_type))))
}
Type::Array(array_def) => ColumnType::Array(SeaRc::new(Box::new(write_type(
array_def.col_type.as_ref().expect("Array type not defined"),
)))),
}
}
write_type(&self.col_type)