Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse Postgres citext as Type::Unknown #94

Merged
merged 2 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/postgres/def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ pub enum Type {
impl Type {
// TODO: Support more types
#[allow(clippy::should_implement_trait)]
pub fn from_str(name: &str) -> Type {
match name.to_lowercase().as_str() {
pub fn from_str(column_type: &str, udt_name: Option<&String>, is_enum: bool) -> Type {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think this can be udt_name: Option<&str> as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, let me check

match column_type.to_lowercase().as_str() {
"smallint" | "int2" => Type::SmallInt,
"integer" | "int" | "int4" => Type::Integer,
"bigint" | "int8" => Type::BigInt,
Expand Down Expand Up @@ -199,9 +199,12 @@ impl Type {
"daterange" => Type::DateRange,
// "" => Type::Domain,
"pg_lsn" => Type::PgLsn,
"user-defined" => Type::Enum(EnumDef::default()),
"user-defined" if is_enum => Type::Enum(EnumDef::default()),
"user-defined" if !is_enum && udt_name.is_some() => {
Type::Unknown(udt_name.unwrap().to_owned())
}
"array" => Type::Array(ArrayDef::default()),
_ => Type::Unknown(name.to_owned()),
_ => Type::Unknown(column_type.to_owned()),
}
}
}
Expand Down
34 changes: 14 additions & 20 deletions src/postgres/discovery/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use std::collections::HashMap;
mod executor;
pub use executor::*;

pub(crate) type EnumVariantMap = HashMap<String, Vec<String>>;

pub struct SchemaDiscovery {
pub query: SchemaQueryBuilder,
pub executor: Executor,
Expand All @@ -33,30 +35,20 @@ impl SchemaDiscovery {
}

pub async fn discover(&self) -> Schema {
let enums: HashMap<_, _> = self
let enums: EnumVariantMap = self
.discover_enums()
.await
.into_iter()
.map(|enum_def| (enum_def.typename, enum_def.values))
.collect();
let tables = self.discover_tables().await;
let tables = future::join_all(
tables
self.discover_tables()
.await
.into_iter()
.map(|t| (self, t))
.map(|t| (self, t, &enums))
.map(Self::discover_table_static),
)
.await
.into_iter()
.map(|mut table| {
table.columns = table
.columns
.into_iter()
.map(|col| col.parse_enum_variants(&enums))
.collect();
table
})
.collect();
.await;

Schema {
schema: self.schema.to_string(),
Expand Down Expand Up @@ -84,16 +76,17 @@ impl SchemaDiscovery {
tables
}

async fn discover_table_static(params: (&Self, TableInfo)) -> TableDef {
async fn discover_table_static(params: (&Self, TableInfo, &EnumVariantMap)) -> TableDef {
let this = params.0;
let info = params.1;
Self::discover_table(this, info).await
let enums = params.2;
Self::discover_table(this, info, enums).await
}

pub async fn discover_table(&self, info: TableInfo) -> TableDef {
pub async fn discover_table(&self, info: TableInfo, enums: &EnumVariantMap) -> TableDef {
let table = SeaRc::new(Alias::new(info.name.as_str()));
let columns = self
.discover_columns(self.schema.clone(), table.clone())
.discover_columns(self.schema.clone(), table.clone(), enums)
.await;
let constraints = self
.discover_constraints(self.schema.clone(), table.clone())
Expand Down Expand Up @@ -143,6 +136,7 @@ impl SchemaDiscovery {
&self,
schema: SeaRc<dyn Iden>,
table: SeaRc<dyn Iden>,
enums: &EnumVariantMap,
) -> Vec<ColumnInfo> {
let rows = self
.executor
Expand All @@ -153,7 +147,7 @@ impl SchemaDiscovery {
.map(|row| {
let result: ColumnQueryResult = (&row).into();
debug_print!("{:?}", result);
let column = result.parse();
let column = result.parse(enums);
debug_print!("{:?}", column);
column
})
Expand Down
60 changes: 37 additions & 23 deletions src/postgres/parser/column.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
use crate::postgres::{def::*, parser::yes_or_no_to_bool, query::ColumnQueryResult};
use crate::postgres::{
def::*, discovery::EnumVariantMap, parser::yes_or_no_to_bool, query::ColumnQueryResult,
};
use sea_query::SeaRc;
use std::{collections::HashMap, convert::TryFrom};

impl ColumnQueryResult {
pub fn parse(self) -> ColumnInfo {
parse_column_query_result(self)
pub fn parse(self, enums: &EnumVariantMap) -> ColumnInfo {
parse_column_query_result(self, enums)
}
}

pub fn parse_column_query_result(result: ColumnQueryResult) -> ColumnInfo {
pub fn parse_column_query_result(result: ColumnQueryResult, enums: &EnumVariantMap) -> ColumnInfo {
ColumnInfo {
name: result.column_name.clone(),
col_type: parse_column_type(&result),
col_type: parse_column_type(&result, enums),
default: ColumnExpression::from_option_string(result.column_default),
generated: ColumnExpression::from_option_string(result.column_generated),
not_null: NotNull::from_bool(!yes_or_no_to_bool(&result.is_nullable)),
is_identity: yes_or_no_to_bool(&result.is_identity),
}
}

pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType {
let mut ctype = Type::from_str(result.column_type.as_str());
pub fn parse_column_type(result: &ColumnQueryResult, enums: &EnumVariantMap) -> ColumnType {
let is_enum = result
.udt_name
.as_ref()
.map_or(false, |udt_name| enums.contains_key(udt_name));
let mut ctype = Type::from_str(
result.column_type.as_str(),
result.udt_name.as_ref(),
is_enum,
);

if ctype.has_numeric_attr() {
ctype = parse_numeric_attributes(
Expand All @@ -43,10 +52,10 @@ pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType {
ctype = parse_bit_attributes(result.character_maximum_length, ctype);
}
if ctype.has_enum_attr() {
ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype);
ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype, enums);
}
if ctype.has_array_attr() {
ctype = parse_array_attributes(result.udt_name_regtype.as_ref(), ctype);
ctype = parse_array_attributes(result.udt_name_regtype.as_ref(), ctype, enums);
}

ctype
Expand Down Expand Up @@ -173,13 +182,20 @@ pub fn parse_bit_attributes(
ctype
}

pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -> ColumnType {
pub fn parse_enum_attributes(
udt_name: Option<&String>,
mut ctype: ColumnType,
enums: &EnumVariantMap,
) -> ColumnType {
match ctype {
Type::Enum(ref mut def) => {
def.typename = match udt_name {
None => panic!("parse_enum_attributes(_) received an empty udt_name"),
Some(typename) => typename.to_string(),
};
if let Some(variants) = enums.get(&def.typename) {
def.values = variants.clone()
}
}
_ => panic!("parse_enum_attributes(_) received a type that does not have EnumDef"),
};
Expand All @@ -190,27 +206,25 @@ pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -
pub fn parse_array_attributes(
udt_name_regtype: Option<&String>,
mut ctype: ColumnType,
enums: &EnumVariantMap,
) -> ColumnType {
match ctype {
Type::Array(ref mut def) => {
def.col_type = match udt_name_regtype {
None => panic!("parse_array_attributes(_) received an empty udt_name_regtype"),
Some(typename) => Some(SeaRc::new(Type::from_str(&typename.replacen("[]", "", 1)))),
Some(typename) => {
let typename = typename.replacen("[]", "", 1);
let is_enum = enums.contains_key(&typename);
Some(SeaRc::new(Type::from_str(
&typename,
Some(&typename),
is_enum,
)))
}
};
}
_ => panic!("parse_array_attributes(_) received a type that does not have EnumDef"),
};

ctype
}

impl ColumnInfo {
pub fn parse_enum_variants(mut self, enums: &HashMap<String, Vec<String>>) -> Self {
if let Type::Enum(ref mut enum_def) = self.col_type {
if let Some(def) = enums.get(&enum_def.typename) {
enum_def.values = def.clone()
}
}
self
}
}
2 changes: 1 addition & 1 deletion src/postgres/parser/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ pub fn parse_table_query_result(table_query: TableQueryResult) -> TableInfo {
name: table_query.table_name,
of_type: table_query
.user_defined_type_name
.map(|type_name| Type::from_str(&type_name)),
.map(|type_name| Type::from_str(&type_name, Some(&type_name), false)),
}
}
10 changes: 10 additions & 0 deletions tests/live/postgres/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ async fn main() {
let connection = setup(&url, "sea-schema").await;
let mut executor = connection.acquire().await.unwrap();

sqlx::query("CREATE EXTENSION IF NOT EXISTS citext")
.execute(&mut executor)
.await
.unwrap();

let create_enum_stmt = Type::create()
.as_enum(Alias::new("crazy_enum"))
.values(vec![
Expand Down Expand Up @@ -337,5 +342,10 @@ fn create_collection_table() -> TableCreateStatement {
.not_null(),
)
.col(ColumnDef::new(Alias::new("integers_opt")).array(ColumnType::Integer))
.col(
ColumnDef::new(Alias::new("case_insensitive_text"))
.custom(Alias::new("citext"))
.not_null(),
)
.to_owned()
}