Skip to content

Commit

Permalink
Parse Postgres citext as Type::Unknown (#94)
Browse files Browse the repository at this point in the history
* Parse Postgres citext as Type::Unknown

* refactoring
  • Loading branch information
billy1624 authored Jan 5, 2023
1 parent b5bbc55 commit c86710c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 49 deletions.
11 changes: 7 additions & 4 deletions src/postgres/def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ pub enum Type {
impl Type {
// TODO: Support more types
#[allow(clippy::should_implement_trait)]
pub fn from_str(name: &str) -> Type {
match name.to_lowercase().as_str() {
pub fn from_str(column_type: &str, udt_name: Option<&str>, is_enum: bool) -> Type {
match column_type.to_lowercase().as_str() {
"smallint" | "int2" => Type::SmallInt,
"integer" | "int" | "int4" => Type::Integer,
"bigint" | "int8" => Type::BigInt,
Expand Down Expand Up @@ -199,9 +199,12 @@ impl Type {
"daterange" => Type::DateRange,
// "" => Type::Domain,
"pg_lsn" => Type::PgLsn,
"user-defined" => Type::Enum(EnumDef::default()),
"user-defined" if is_enum => Type::Enum(EnumDef::default()),
"user-defined" if !is_enum && udt_name.is_some() => {
Type::Unknown(udt_name.unwrap().to_owned())
}
"array" => Type::Array(ArrayDef::default()),
_ => Type::Unknown(name.to_owned()),
_ => Type::Unknown(column_type.to_owned()),
}
}
}
Expand Down
34 changes: 14 additions & 20 deletions src/postgres/discovery/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use std::collections::HashMap;
mod executor;
pub use executor::*;

pub(crate) type EnumVariantMap = HashMap<String, Vec<String>>;

pub struct SchemaDiscovery {
pub query: SchemaQueryBuilder,
pub executor: Executor,
Expand All @@ -33,30 +35,20 @@ impl SchemaDiscovery {
}

pub async fn discover(&self) -> Schema {
let enums: HashMap<_, _> = self
let enums: EnumVariantMap = self
.discover_enums()
.await
.into_iter()
.map(|enum_def| (enum_def.typename, enum_def.values))
.collect();
let tables = self.discover_tables().await;
let tables = future::join_all(
tables
self.discover_tables()
.await
.into_iter()
.map(|t| (self, t))
.map(|t| (self, t, &enums))
.map(Self::discover_table_static),
)
.await
.into_iter()
.map(|mut table| {
table.columns = table
.columns
.into_iter()
.map(|col| col.parse_enum_variants(&enums))
.collect();
table
})
.collect();
.await;

Schema {
schema: self.schema.to_string(),
Expand Down Expand Up @@ -84,16 +76,17 @@ impl SchemaDiscovery {
tables
}

async fn discover_table_static(params: (&Self, TableInfo)) -> TableDef {
async fn discover_table_static(params: (&Self, TableInfo, &EnumVariantMap)) -> TableDef {
let this = params.0;
let info = params.1;
Self::discover_table(this, info).await
let enums = params.2;
Self::discover_table(this, info, enums).await
}

pub async fn discover_table(&self, info: TableInfo) -> TableDef {
pub async fn discover_table(&self, info: TableInfo, enums: &EnumVariantMap) -> TableDef {
let table = SeaRc::new(Alias::new(info.name.as_str()));
let columns = self
.discover_columns(self.schema.clone(), table.clone())
.discover_columns(self.schema.clone(), table.clone(), enums)
.await;
let constraints = self
.discover_constraints(self.schema.clone(), table.clone())
Expand Down Expand Up @@ -143,6 +136,7 @@ impl SchemaDiscovery {
&self,
schema: SeaRc<dyn Iden>,
table: SeaRc<dyn Iden>,
enums: &EnumVariantMap,
) -> Vec<ColumnInfo> {
let rows = self
.executor
Expand All @@ -153,7 +147,7 @@ impl SchemaDiscovery {
.map(|row| {
let result: ColumnQueryResult = (&row).into();
debug_print!("{:?}", result);
let column = result.parse();
let column = result.parse(enums);
debug_print!("{:?}", column);
column
})
Expand Down
62 changes: 38 additions & 24 deletions src/postgres/parser/column.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
use crate::postgres::{def::*, parser::yes_or_no_to_bool, query::ColumnQueryResult};
use crate::postgres::{
def::*, discovery::EnumVariantMap, parser::yes_or_no_to_bool, query::ColumnQueryResult,
};
use sea_query::SeaRc;
use std::{collections::HashMap, convert::TryFrom};

impl ColumnQueryResult {
pub fn parse(self) -> ColumnInfo {
parse_column_query_result(self)
pub fn parse(self, enums: &EnumVariantMap) -> ColumnInfo {
parse_column_query_result(self, enums)
}
}

pub fn parse_column_query_result(result: ColumnQueryResult) -> ColumnInfo {
pub fn parse_column_query_result(result: ColumnQueryResult, enums: &EnumVariantMap) -> ColumnInfo {
ColumnInfo {
name: result.column_name.clone(),
col_type: parse_column_type(&result),
col_type: parse_column_type(&result, enums),
default: ColumnExpression::from_option_string(result.column_default),
generated: ColumnExpression::from_option_string(result.column_generated),
not_null: NotNull::from_bool(!yes_or_no_to_bool(&result.is_nullable)),
is_identity: yes_or_no_to_bool(&result.is_identity),
}
}

pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType {
let mut ctype = Type::from_str(result.column_type.as_str());
pub fn parse_column_type(result: &ColumnQueryResult, enums: &EnumVariantMap) -> ColumnType {
let is_enum = result
.udt_name
.as_ref()
.map_or(false, |udt_name| enums.contains_key(udt_name));
let mut ctype = Type::from_str(
result.column_type.as_str(),
result.udt_name.as_deref(),
is_enum,
);

if ctype.has_numeric_attr() {
ctype = parse_numeric_attributes(
Expand All @@ -43,10 +52,10 @@ pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType {
ctype = parse_bit_attributes(result.character_maximum_length, ctype);
}
if ctype.has_enum_attr() {
ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype);
ctype = parse_enum_attributes(result.udt_name.as_deref(), ctype, enums);
}
if ctype.has_array_attr() {
ctype = parse_array_attributes(result.udt_name_regtype.as_ref(), ctype);
ctype = parse_array_attributes(result.udt_name_regtype.as_deref(), ctype, enums);
}

ctype
Expand Down Expand Up @@ -173,13 +182,20 @@ pub fn parse_bit_attributes(
ctype
}

pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -> ColumnType {
pub fn parse_enum_attributes(
udt_name: Option<&str>,
mut ctype: ColumnType,
enums: &EnumVariantMap,
) -> ColumnType {
match ctype {
Type::Enum(ref mut def) => {
def.typename = match udt_name {
None => panic!("parse_enum_attributes(_) received an empty udt_name"),
Some(typename) => typename.to_string(),
};
if let Some(variants) = enums.get(&def.typename) {
def.values = variants.clone()
}
}
_ => panic!("parse_enum_attributes(_) received a type that does not have EnumDef"),
};
Expand All @@ -188,29 +204,27 @@ pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -
}

pub fn parse_array_attributes(
udt_name_regtype: Option<&String>,
udt_name_regtype: Option<&str>,
mut ctype: ColumnType,
enums: &EnumVariantMap,
) -> ColumnType {
match ctype {
Type::Array(ref mut def) => {
def.col_type = match udt_name_regtype {
None => panic!("parse_array_attributes(_) received an empty udt_name_regtype"),
Some(typename) => Some(SeaRc::new(Type::from_str(&typename.replacen("[]", "", 1)))),
Some(typename) => {
let typename = typename.replacen("[]", "", 1);
let is_enum = enums.contains_key(&typename);
Some(SeaRc::new(Type::from_str(
&typename,
Some(&typename),
is_enum,
)))
}
};
}
_ => panic!("parse_array_attributes(_) received a type that does not have EnumDef"),
};

ctype
}

impl ColumnInfo {
pub fn parse_enum_variants(mut self, enums: &HashMap<String, Vec<String>>) -> Self {
if let Type::Enum(ref mut enum_def) = self.col_type {
if let Some(def) = enums.get(&enum_def.typename) {
enum_def.values = def.clone()
}
}
self
}
}
2 changes: 1 addition & 1 deletion src/postgres/parser/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ pub fn parse_table_query_result(table_query: TableQueryResult) -> TableInfo {
name: table_query.table_name,
of_type: table_query
.user_defined_type_name
.map(|type_name| Type::from_str(&type_name)),
.map(|type_name| Type::from_str(&type_name, Some(&type_name), false)),
}
}
10 changes: 10 additions & 0 deletions tests/live/postgres/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ async fn main() {
let connection = setup(&url, "sea-schema").await;
let mut executor = connection.acquire().await.unwrap();

sqlx::query("CREATE EXTENSION IF NOT EXISTS citext")
.execute(&mut executor)
.await
.unwrap();

let create_enum_stmt = Type::create()
.as_enum(Alias::new("crazy_enum"))
.values(vec![
Expand Down Expand Up @@ -337,5 +342,10 @@ fn create_collection_table() -> TableCreateStatement {
.not_null(),
)
.col(ColumnDef::new(Alias::new("integers_opt")).array(ColumnType::Integer))
.col(
ColumnDef::new(Alias::new("case_insensitive_text"))
.custom(Alias::new("citext"))
.not_null(),
)
.to_owned()
}

0 comments on commit c86710c

Please sign in to comment.