diff --git a/src/postgres/def/types.rs b/src/postgres/def/types.rs index af9ec07c..3e59330b 100644 --- a/src/postgres/def/types.rs +++ b/src/postgres/def/types.rs @@ -148,8 +148,8 @@ pub enum Type { impl Type { // TODO: Support more types #[allow(clippy::should_implement_trait)] - pub fn from_str(name: &str) -> Type { - match name.to_lowercase().as_str() { + pub fn from_str(column_type: &str, udt_name: Option<&str>, is_enum: bool) -> Type { + match column_type.to_lowercase().as_str() { "smallint" | "int2" => Type::SmallInt, "integer" | "int" | "int4" => Type::Integer, "bigint" | "int8" => Type::BigInt, @@ -199,9 +199,12 @@ impl Type { "daterange" => Type::DateRange, // "" => Type::Domain, "pg_lsn" => Type::PgLsn, - "user-defined" => Type::Enum(EnumDef::default()), + "user-defined" if is_enum => Type::Enum(EnumDef::default()), + "user-defined" if !is_enum && udt_name.is_some() => { + Type::Unknown(udt_name.unwrap().to_owned()) + } "array" => Type::Array(ArrayDef::default()), - _ => Type::Unknown(name.to_owned()), + _ => Type::Unknown(column_type.to_owned()), } } } diff --git a/src/postgres/discovery/mod.rs b/src/postgres/discovery/mod.rs index e560999f..be7ea507 100644 --- a/src/postgres/discovery/mod.rs +++ b/src/postgres/discovery/mod.rs @@ -14,6 +14,8 @@ use std::collections::HashMap; mod executor; pub use executor::*; +pub(crate) type EnumVariantMap = HashMap>; + pub struct SchemaDiscovery { pub query: SchemaQueryBuilder, pub executor: Executor, @@ -33,30 +35,20 @@ impl SchemaDiscovery { } pub async fn discover(&self) -> Schema { - let enums: HashMap<_, _> = self + let enums: EnumVariantMap = self .discover_enums() .await .into_iter() .map(|enum_def| (enum_def.typename, enum_def.values)) .collect(); - let tables = self.discover_tables().await; let tables = future::join_all( - tables + self.discover_tables() + .await .into_iter() - .map(|t| (self, t)) + .map(|t| (self, t, &enums)) .map(Self::discover_table_static), ) - .await - .into_iter() - .map(|mut table| { - table.columns = table - .columns - .into_iter() - .map(|col| col.parse_enum_variants(&enums)) - .collect(); - table - }) - .collect(); + .await; Schema { schema: self.schema.to_string(), @@ -84,16 +76,17 @@ impl SchemaDiscovery { tables } - async fn discover_table_static(params: (&Self, TableInfo)) -> TableDef { + async fn discover_table_static(params: (&Self, TableInfo, &EnumVariantMap)) -> TableDef { let this = params.0; let info = params.1; - Self::discover_table(this, info).await + let enums = params.2; + Self::discover_table(this, info, enums).await } - pub async fn discover_table(&self, info: TableInfo) -> TableDef { + pub async fn discover_table(&self, info: TableInfo, enums: &EnumVariantMap) -> TableDef { let table = SeaRc::new(Alias::new(info.name.as_str())); let columns = self - .discover_columns(self.schema.clone(), table.clone()) + .discover_columns(self.schema.clone(), table.clone(), enums) .await; let constraints = self .discover_constraints(self.schema.clone(), table.clone()) @@ -143,6 +136,7 @@ impl SchemaDiscovery { &self, schema: SeaRc, table: SeaRc, + enums: &EnumVariantMap, ) -> Vec { let rows = self .executor @@ -153,7 +147,7 @@ impl SchemaDiscovery { .map(|row| { let result: ColumnQueryResult = (&row).into(); debug_print!("{:?}", result); - let column = result.parse(); + let column = result.parse(enums); debug_print!("{:?}", column); column }) diff --git a/src/postgres/parser/column.rs b/src/postgres/parser/column.rs index 59b23219..e7787bdb 100644 --- a/src/postgres/parser/column.rs +++ b/src/postgres/parser/column.rs @@ -1,17 +1,18 @@ -use crate::postgres::{def::*, parser::yes_or_no_to_bool, query::ColumnQueryResult}; +use crate::postgres::{ + def::*, discovery::EnumVariantMap, parser::yes_or_no_to_bool, query::ColumnQueryResult, +}; use sea_query::SeaRc; -use std::{collections::HashMap, convert::TryFrom}; impl ColumnQueryResult { - pub fn parse(self) -> ColumnInfo { - parse_column_query_result(self) + pub fn parse(self, enums: &EnumVariantMap) -> ColumnInfo { + parse_column_query_result(self, enums) } } -pub fn parse_column_query_result(result: ColumnQueryResult) -> ColumnInfo { +pub fn parse_column_query_result(result: ColumnQueryResult, enums: &EnumVariantMap) -> ColumnInfo { ColumnInfo { name: result.column_name.clone(), - col_type: parse_column_type(&result), + col_type: parse_column_type(&result, enums), default: ColumnExpression::from_option_string(result.column_default), generated: ColumnExpression::from_option_string(result.column_generated), not_null: NotNull::from_bool(!yes_or_no_to_bool(&result.is_nullable)), @@ -19,8 +20,16 @@ pub fn parse_column_query_result(result: ColumnQueryResult) -> ColumnInfo { } } -pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType { - let mut ctype = Type::from_str(result.column_type.as_str()); +pub fn parse_column_type(result: &ColumnQueryResult, enums: &EnumVariantMap) -> ColumnType { + let is_enum = result + .udt_name + .as_ref() + .map_or(false, |udt_name| enums.contains_key(udt_name)); + let mut ctype = Type::from_str( + result.column_type.as_str(), + result.udt_name.as_deref(), + is_enum, + ); if ctype.has_numeric_attr() { ctype = parse_numeric_attributes( @@ -43,10 +52,10 @@ pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType { ctype = parse_bit_attributes(result.character_maximum_length, ctype); } if ctype.has_enum_attr() { - ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype); + ctype = parse_enum_attributes(result.udt_name.as_deref(), ctype, enums); } if ctype.has_array_attr() { - ctype = parse_array_attributes(result.udt_name_regtype.as_ref(), ctype); + ctype = parse_array_attributes(result.udt_name_regtype.as_deref(), ctype, enums); } ctype @@ -173,13 +182,20 @@ pub fn parse_bit_attributes( ctype } -pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -> ColumnType { +pub fn parse_enum_attributes( + udt_name: Option<&str>, + mut ctype: ColumnType, + enums: &EnumVariantMap, +) -> ColumnType { match ctype { Type::Enum(ref mut def) => { def.typename = match udt_name { None => panic!("parse_enum_attributes(_) received an empty udt_name"), Some(typename) => typename.to_string(), }; + if let Some(variants) = enums.get(&def.typename) { + def.values = variants.clone() + } } _ => panic!("parse_enum_attributes(_) received a type that does not have EnumDef"), }; @@ -188,14 +204,23 @@ pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) - } pub fn parse_array_attributes( - udt_name_regtype: Option<&String>, + udt_name_regtype: Option<&str>, mut ctype: ColumnType, + enums: &EnumVariantMap, ) -> ColumnType { match ctype { Type::Array(ref mut def) => { def.col_type = match udt_name_regtype { None => panic!("parse_array_attributes(_) received an empty udt_name_regtype"), - Some(typename) => Some(SeaRc::new(Type::from_str(&typename.replacen("[]", "", 1)))), + Some(typename) => { + let typename = typename.replacen("[]", "", 1); + let is_enum = enums.contains_key(&typename); + Some(SeaRc::new(Type::from_str( + &typename, + Some(&typename), + is_enum, + ))) + } }; } _ => panic!("parse_array_attributes(_) received a type that does not have EnumDef"), @@ -203,14 +228,3 @@ pub fn parse_array_attributes( ctype } - -impl ColumnInfo { - pub fn parse_enum_variants(mut self, enums: &HashMap>) -> Self { - if let Type::Enum(ref mut enum_def) = self.col_type { - if let Some(def) = enums.get(&enum_def.typename) { - enum_def.values = def.clone() - } - } - self - } -} diff --git a/src/postgres/parser/table.rs b/src/postgres/parser/table.rs index fea64824..9e2ddc20 100644 --- a/src/postgres/parser/table.rs +++ b/src/postgres/parser/table.rs @@ -12,6 +12,6 @@ pub fn parse_table_query_result(table_query: TableQueryResult) -> TableInfo { name: table_query.table_name, of_type: table_query .user_defined_type_name - .map(|type_name| Type::from_str(&type_name)), + .map(|type_name| Type::from_str(&type_name, Some(&type_name), false)), } } diff --git a/tests/live/postgres/src/main.rs b/tests/live/postgres/src/main.rs index c0cfce78..9583234d 100644 --- a/tests/live/postgres/src/main.rs +++ b/tests/live/postgres/src/main.rs @@ -22,6 +22,11 @@ async fn main() { let connection = setup(&url, "sea-schema").await; let mut executor = connection.acquire().await.unwrap(); + sqlx::query("CREATE EXTENSION IF NOT EXISTS citext") + .execute(&mut executor) + .await + .unwrap(); + let create_enum_stmt = Type::create() .as_enum(Alias::new("crazy_enum")) .values(vec![ @@ -337,5 +342,10 @@ fn create_collection_table() -> TableCreateStatement { .not_null(), ) .col(ColumnDef::new(Alias::new("integers_opt")).array(ColumnType::Integer)) + .col( + ColumnDef::new(Alias::new("case_insensitive_text")) + .custom(Alias::new("citext")) + .not_null(), + ) .to_owned() }