diff --git a/src/postgres/def/types.rs b/src/postgres/def/types.rs index f644ffdd..fcd957cd 100644 --- a/src/postgres/def/types.rs +++ b/src/postgres/def/types.rs @@ -1,3 +1,4 @@ +use sea_query::SeaRc; #[cfg(feature = "with-serde")] use serde::{Deserialize, Serialize}; @@ -110,7 +111,7 @@ pub enum Type { JsonBinary, /// Variable-length multidimensional array - Array, + Array(SeaRc), // TODO: // /// The structure of a row or record; a list of field names and types @@ -148,61 +149,66 @@ impl Type { // TODO: Support more types #[allow(clippy::should_implement_trait)] pub fn from_str(name: &str) -> Type { - match name.to_lowercase().as_str() { - "smallint" | "int2" => Type::SmallInt, - "integer" | "int" | "int4" => Type::Integer, - "bigint" | "int8" => Type::BigInt, - "decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()), - "numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()), - "real" | "float4" => Type::Real, - "double precision" | "double" | "float8" => Type::DoublePrecision, - "smallserial" | "serial2" => Type::SmallSerial, - "serial" | "serial4" => Type::Serial, - "bigserial" | "serial8" => Type::BigSerial, - "money" => Type::Money, - "character varying" | "varchar" => Type::Varchar(StringAttr::default()), - "character" | "char" => Type::Char(StringAttr::default()), - "text" => Type::Text, - "bytea" => Type::Bytea, - "timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()), - "timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()), - "date" => Type::Date, - "time" | "time without time zone" => Type::Time(TimeAttr::default()), - "time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()), - "interval" => Type::Interval(IntervalAttr::default()), - "boolean" => Type::Boolean, - "point" => Type::Point, - "line" => Type::Line, - "lseg" => Type::Lseg, - "box" => Type::Box, - "path" => Type::Path, - "polygon" => Type::Polygon, - "circle" => Type::Circle, - "cidr" => Type::Cidr, - "inet" => Type::Inet, - "macaddr" => Type::MacAddr, - "macaddr8" => Type::MacAddr8, - "bit" => Type::Bit(BitAttr::default()), - "tsvector" => Type::TsVector, - "tsquery" => Type::TsQuery, - "uuid" => Type::Uuid, - "xml" => Type::Xml, - "json" => Type::Json, - "jsonb" => Type::JsonBinary, - "array" => Type::Array, - // "" => Type::Composite, - "int4range" => Type::Int4Range, - "int8range" => Type::Int8Range, - "numrange" => Type::NumRange, - "tsrange" => Type::TsRange, - "tstzrange" => Type::TsTzRange, - "daterange" => Type::DateRange, - // "" => Type::Domain, - "pg_lsn" => Type::PgLsn, - "user-defined" => Type::Enum(EnumDef::default()), - - _ => Type::Unknown(name.to_owned()), + fn parse_type(name: &str) -> Type { + match name.to_lowercase().as_str() { + "smallint" | "int2" => Type::SmallInt, + "integer" | "int" | "int4" => Type::Integer, + "bigint" | "int8" => Type::BigInt, + "decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()), + "numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()), + "real" | "float4" => Type::Real, + "double precision" | "double" | "float8" => Type::DoublePrecision, + "smallserial" | "serial2" => Type::SmallSerial, + "serial" | "serial4" => Type::Serial, + "bigserial" | "serial8" => Type::BigSerial, + "money" => Type::Money, + "character varying" | "varchar" => Type::Varchar(StringAttr::default()), + "character" | "char" => Type::Char(StringAttr::default()), + "text" => Type::Text, + "bytea" => Type::Bytea, + "timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()), + "timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()), + "date" => Type::Date, + "time" | "time without time zone" => Type::Time(TimeAttr::default()), + "time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()), + "interval" => Type::Interval(IntervalAttr::default()), + "boolean" => Type::Boolean, + "point" => Type::Point, + "line" => Type::Line, + "lseg" => Type::Lseg, + "box" => Type::Box, + "path" => Type::Path, + "polygon" => Type::Polygon, + "circle" => Type::Circle, + "cidr" => Type::Cidr, + "inet" => Type::Inet, + "macaddr" => Type::MacAddr, + "macaddr8" => Type::MacAddr8, + "bit" => Type::Bit(BitAttr::default()), + "tsvector" => Type::TsVector, + "tsquery" => Type::TsQuery, + "uuid" => Type::Uuid, + "xml" => Type::Xml, + "json" => Type::Json, + "jsonb" => Type::JsonBinary, + // "" => Type::Composite, + "int4range" => Type::Int4Range, + "int8range" => Type::Int8Range, + "numrange" => Type::NumRange, + "tsrange" => Type::TsRange, + "tstzrange" => Type::TsTzRange, + "daterange" => Type::DateRange, + // "" => Type::Domain, + "pg_lsn" => Type::PgLsn, + "user-defined" => Type::Enum(EnumDef::default()), + _ if name.ends_with("[]") => { + let col_type = parse_type(&name.replacen("[]", "", 1)); + Type::Array(SeaRc::new(col_type)) + } + _ => Type::Unknown(name.to_owned()), + } } + parse_type(name) } } diff --git a/src/postgres/query/column.rs b/src/postgres/query/column.rs index 998fc0d6..dfb95f8a 100644 --- a/src/postgres/query/column.rs +++ b/src/postgres/query/column.rs @@ -1,6 +1,6 @@ use super::{InformationSchema, SchemaQueryBuilder}; use crate::sqlx_types::postgres::PgRow; -use sea_query::{Expr, Iden, Query, SeaRc, SelectStatement}; +use sea_query::{Alias, BinOper, Expr, Iden, Query, SeaRc, SelectStatement}; #[derive(Debug, sea_query::Iden)] /// Ref: https://www.postgresql.org/docs/13/infoschema-columns.html @@ -75,9 +75,12 @@ impl SchemaQueryBuilder { table: SeaRc, ) -> SelectStatement { Query::select() - .columns(vec![ - ColumnsField::ColumnName, - ColumnsField::DataType, + .column(ColumnsField::ColumnName) + .expr( + Expr::expr(Expr::cust("udt_name::regtype").cast_as(Alias::new("text"))) + .binary(BinOper::As, Expr::col(ColumnsField::DataType)), + ) + .columns([ ColumnsField::ColumnDefault, ColumnsField::GenerationExpression, ColumnsField::IsNullable, diff --git a/src/postgres/writer/column.rs b/src/postgres/writer/column.rs index 102ad7f0..41d60b65 100644 --- a/src/postgres/writer/column.rs +++ b/src/postgres/writer/column.rs @@ -1,11 +1,10 @@ use crate::postgres::def::{ColumnInfo, Type}; -use sea_query::{Alias, ColumnDef, DynIden, IntoIden, PgInterval}; +use sea_query::{Alias, BlobSize, ColumnDef, ColumnType, DynIden, IntoIden, PgInterval, SeaRc}; use std::{convert::TryFrom, fmt::Write}; impl ColumnInfo { pub fn write(&self) -> ColumnDef { let mut col_info = self.clone(); - let mut col_def = ColumnDef::new(Alias::new(self.name.as_str())); let mut extras: Vec = Vec::new(); if let Some(default) = self.default.as_ref() { if default.0.starts_with("nextval") { @@ -16,10 +15,17 @@ impl ColumnInfo { extras.push(string); } } + let col_type = col_info.write_col_type(); + let mut col_def = ColumnDef::new_with_type(Alias::new(self.name.as_str()), col_type); if self.is_identity { col_info = Self::convert_to_serial(col_info); } - col_def = col_info.write_col_type(col_def); + if matches!( + col_info.col_type, + Type::SmallSerial | Type::Serial | Type::BigSerial + ) { + col_def.auto_increment(); + } if self.not_null.is_some() { col_def.not_null(); } @@ -45,203 +51,106 @@ impl ColumnInfo { col_info } - pub fn write_col_type(&self, mut col_def: ColumnDef) -> ColumnDef { - match &self.col_type { - Type::SmallInt => { - col_def.small_integer(); - } - Type::Integer => { - col_def.integer(); - } - Type::BigInt => { - col_def.big_integer(); - } - Type::Decimal(num_attr) | Type::Numeric(num_attr) => { - if num_attr.precision.is_none() & num_attr.scale.is_none() { - col_def.decimal(); - } else { - col_def.decimal_len( - num_attr.precision.unwrap_or(0) as u32, - num_attr.scale.unwrap_or(0) as u32, - ); + pub fn write_col_type(&self) -> ColumnType { + fn write_type(col_type: &Type) -> ColumnType { + match col_type { + Type::SmallInt => ColumnType::SmallInteger(None), + Type::Integer => ColumnType::Integer(None), + Type::BigInt => ColumnType::BigInteger(None), + Type::Decimal(num_attr) | Type::Numeric(num_attr) => { + match (num_attr.precision, num_attr.scale) { + (None, None) => ColumnType::Decimal(None), + (precision, scale) => ColumnType::Decimal(Some(( + precision.unwrap_or(0).into(), + scale.unwrap_or(0).into(), + ))), + } } - } - Type::Real => { - col_def.float(); - } - Type::DoublePrecision => { - col_def.double(); - } - Type::SmallSerial => { - col_def.small_integer().auto_increment(); - } - Type::Serial => { - col_def.integer().auto_increment(); - } - Type::BigSerial => { - col_def.big_integer().auto_increment(); - } - Type::Money => { - col_def.money(); - } - Type::Varchar(string_attr) => { - match string_attr.length { - Some(length) => col_def.string_len(length.into()), - None => col_def.string(), - }; - } - Type::Char(string_attr) => { - match string_attr.length { - Some(length) => col_def.char_len(length.into()), - None => col_def.char(), - }; - } - Type::Text => { - col_def.text(); - } - Type::Bytea => { - col_def.binary(); - } - Type::Timestamp(time_attr) => { + Type::Real => ColumnType::Float(None), + Type::DoublePrecision => ColumnType::Double(None), + Type::SmallSerial => ColumnType::SmallInteger(None), + Type::Serial => ColumnType::Integer(None), + Type::BigSerial => ColumnType::BigInteger(None), + Type::Money => ColumnType::Money(None), + Type::Varchar(string_attr) => { + ColumnType::String(string_attr.length.map(Into::into)) + } + Type::Char(string_attr) => ColumnType::Char(string_attr.length.map(Into::into)), + Type::Text => ColumnType::Text, + Type::Bytea => ColumnType::Binary(BlobSize::Blob(None)), // The SQL standard requires that writing just timestamp be equivalent to timestamp without time zone, // and PostgreSQL honors that behavior. (https://www.postgresql.org/docs/current/datatype-datetime.html) - match time_attr.precision { - Some(precision) => col_def.date_time_len(precision.into()), - None => col_def.date_time(), - }; - } - Type::TimestampWithTimeZone(time_attr) => { - match time_attr.precision { - Some(precision) => col_def.timestamp_with_time_zone_len(precision.into()), - None => col_def.timestamp_with_time_zone(), - }; - } - Type::Date => { - col_def.date(); - } - Type::Time(time_attr) => { - match time_attr.precision { - Some(precision) => col_def.time_len(precision.into()), - None => col_def.time(), - }; - } - Type::TimeWithTimeZone(time_attr) => { - match time_attr.precision { - Some(precision) => col_def.time_len(precision.into()), - None => col_def.time(), - }; - } - Type::Interval(interval_attr) => { - let field = match &interval_attr.field { - Some(field) => PgInterval::try_from(field).ok(), - None => None, - }; - let precision = interval_attr.precision.map(Into::into); - col_def.interval(field, precision); - } - Type::Boolean => { - col_def.boolean(); - } - Type::Point => { - col_def.custom(Alias::new("point")); - } - Type::Line => { - col_def.custom(Alias::new("line")); - } - Type::Lseg => { - col_def.custom(Alias::new("lseg")); - } - Type::Box => { - col_def.custom(Alias::new("box")); - } - Type::Path => { - col_def.custom(Alias::new("path")); - } - Type::Polygon => { - col_def.custom(Alias::new("polygon")); - } - Type::Circle => { - col_def.custom(Alias::new("circle")); - } - Type::Cidr => { - col_def.custom(Alias::new("cidr")); - } - Type::Inet => { - col_def.custom(Alias::new("inet")); - } - Type::MacAddr => { - col_def.custom(Alias::new("macaddr")); - } - Type::MacAddr8 => { - col_def.custom(Alias::new("macaddr8")); - } - Type::Bit(bit_attr) => { - let mut str = String::new(); - write!(str, "bit").unwrap(); - if bit_attr.length.is_some() { - write!(str, "(").unwrap(); - if let Some(length) = bit_attr.length { - write!(str, "{}", length).unwrap(); + Type::Timestamp(time_attr) => { + ColumnType::DateTime(time_attr.precision.map(Into::into)) + } + Type::TimestampWithTimeZone(time_attr) => { + ColumnType::TimestampWithTimeZone(time_attr.precision.map(Into::into)) + } + Type::Date => ColumnType::Date, + Type::Time(time_attr) => ColumnType::Time(time_attr.precision.map(Into::into)), + Type::TimeWithTimeZone(time_attr) => { + ColumnType::Time(time_attr.precision.map(Into::into)) + } + Type::Interval(interval_attr) => { + let field = match &interval_attr.field { + Some(field) => PgInterval::try_from(field).ok(), + None => None, + }; + let precision = interval_attr.precision.map(Into::into); + ColumnType::Interval(field, precision) + } + Type::Boolean => ColumnType::Boolean, + Type::Point => ColumnType::Custom(Alias::new("point").into_iden()), + Type::Line => ColumnType::Custom(Alias::new("line").into_iden()), + Type::Lseg => ColumnType::Custom(Alias::new("lseg").into_iden()), + Type::Box => ColumnType::Custom(Alias::new("box").into_iden()), + Type::Path => ColumnType::Custom(Alias::new("path").into_iden()), + Type::Polygon => ColumnType::Custom(Alias::new("polygon").into_iden()), + Type::Circle => ColumnType::Custom(Alias::new("circle").into_iden()), + Type::Cidr => ColumnType::Custom(Alias::new("cidr").into_iden()), + Type::Inet => ColumnType::Custom(Alias::new("inet").into_iden()), + Type::MacAddr => ColumnType::Custom(Alias::new("macaddr").into_iden()), + Type::MacAddr8 => ColumnType::Custom(Alias::new("macaddr8").into_iden()), + Type::Bit(bit_attr) => { + let mut str = String::new(); + write!(str, "bit").unwrap(); + if bit_attr.length.is_some() { + write!(str, "(").unwrap(); + if let Some(length) = bit_attr.length { + write!(str, "{}", length).unwrap(); + } + write!(str, ")").unwrap(); } - write!(str, ")").unwrap(); + ColumnType::Custom(Alias::new(&str).into_iden()) + } + Type::TsVector => ColumnType::Custom(Alias::new("tsvector").into_iden()), + Type::TsQuery => ColumnType::Custom(Alias::new("tsquery").into_iden()), + Type::Uuid => ColumnType::Uuid, + Type::Xml => ColumnType::Custom(Alias::new("xml").into_iden()), + Type::Json => ColumnType::Json, + Type::JsonBinary => ColumnType::JsonBinary, + Type::Int4Range => ColumnType::Custom(Alias::new("int4range").into_iden()), + Type::Int8Range => ColumnType::Custom(Alias::new("int8range").into_iden()), + Type::NumRange => ColumnType::Custom(Alias::new("numrange").into_iden()), + Type::TsRange => ColumnType::Custom(Alias::new("tsrange").into_iden()), + Type::TsTzRange => ColumnType::Custom(Alias::new("tstzrange").into_iden()), + Type::DateRange => ColumnType::Custom(Alias::new("daterange").into_iden()), + Type::PgLsn => ColumnType::Custom(Alias::new("pg_lsn").into_iden()), + Type::Unknown(s) => ColumnType::Custom(Alias::new(s).into_iden()), + Type::Enum(enum_def) => { + let name = Alias::new(&enum_def.typename).into_iden(); + let variants: Vec = enum_def + .values + .iter() + .map(|variant| Alias::new(variant).into_iden()) + .collect(); + ColumnType::Enum { name, variants } + } + Type::Array(col_type) => { + ColumnType::Array(SeaRc::new(Box::new(write_type(col_type)))) } - col_def.custom(Alias::new(&str)); - } - Type::TsVector => { - col_def.custom(Alias::new("tsvector")); - } - Type::TsQuery => { - col_def.custom(Alias::new("tsquery")); - } - Type::Uuid => { - col_def.uuid(); - } - Type::Xml => { - col_def.custom(Alias::new("xml")); - } - Type::Json => { - col_def.json(); - } - Type::JsonBinary => { - col_def.json_binary(); - } - Type::Array => { - col_def.custom(Alias::new("array")); - } - Type::Int4Range => { - col_def.custom(Alias::new("int4range")); - } - Type::Int8Range => { - col_def.custom(Alias::new("int8range")); - } - Type::NumRange => { - col_def.custom(Alias::new("numrange")); - } - Type::TsRange => { - col_def.custom(Alias::new("tsrange")); - } - Type::TsTzRange => { - col_def.custom(Alias::new("tstzrange")); - } - Type::DateRange => { - col_def.custom(Alias::new("daterange")); - } - Type::PgLsn => { - col_def.custom(Alias::new("pg_lsn")); - } - Type::Unknown(s) => { - col_def.custom(Alias::new(s)); - } - Type::Enum(enum_def) => { - let name = Alias::new(&enum_def.typename); - let variants: Vec = enum_def - .values - .iter() - .map(|variant| Alias::new(variant).into_iden()) - .collect(); - col_def.enumeration(name, variants); } - }; - col_def + } + write_type(&self.col_type) } } diff --git a/tests/live/postgres/src/main.rs b/tests/live/postgres/src/main.rs index 7d0dde71..1cbdaf8f 100644 --- a/tests/live/postgres/src/main.rs +++ b/tests/live/postgres/src/main.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use sea_schema::postgres::{def::TableDef, discovery::SchemaDiscovery}; use sea_schema::sea_query::TableRef; use sea_schema::sea_query::{ - extension::postgres::Type, Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, + extension::postgres::Type, Alias, ColumnDef, ColumnType, ForeignKey, ForeignKeyAction, Index, PostgresQueryBuilder, Table, TableCreateStatement, }; use sqlx::{PgPool, Pool, Postgres}; @@ -46,6 +46,7 @@ async fn main() { create_cake_table(), create_cakes_bakers_table(), create_lineitem_table(), + create_collection_table(), ]; for tbl_create_stmt in tbl_create_stmts.iter() { @@ -320,3 +321,21 @@ fn create_cake_table() -> TableCreateStatement { ) .to_owned() } + +fn create_collection_table() -> TableCreateStatement { + Table::create() + .table(Alias::new("collection")) + .col( + ColumnDef::new(Alias::new("id")) + .integer() + .not_null() + .auto_increment(), + ) + .col( + ColumnDef::new(Alias::new("integers")) + .array(ColumnType::Integer(None)) + .not_null(), + ) + .col(ColumnDef::new(Alias::new("integers_opt")).array(ColumnType::Integer(None))) + .to_owned() +}