Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework SQLite type mapping #117

Merged
merged 6 commits into from
Jan 31, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Rework SQLite type mapping
tyt2y3 authored and billy1624 committed Jan 23, 2024

Verified

This commit was signed with the committer’s verified signature.
billy1624 Billy Chan
commit 699711f6d018515ba9273cce41330cf18887fbc0
4 changes: 2 additions & 2 deletions src/sqlite/def/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{DefaultType, Type};
use super::{parse_type, DefaultType, Type};
use sea_query::{
foreign_key::ForeignKeyAction as SeaQueryForeignKeyAction, Alias, Index, IndexCreateStatement,
};
@@ -28,7 +28,7 @@ impl ColumnInfo {
Ok(ColumnInfo {
cid: row.get(0),
name: row.get(1),
r#type: Type::to_type(row.get(2))?,
r#type: parse_type(row.get(2))?,
not_null: col_not_null != 0,
default_value: if default_value == "NULL" {
DefaultType::Null
5 changes: 2 additions & 3 deletions src/sqlite/def/table.rs
Original file line number Diff line number Diff line change
@@ -224,7 +224,8 @@ impl TableDef {
new_table.table(Alias::new(&self.name));

self.columns.iter().for_each(|column_info| {
let mut new_column = ColumnDef::new(Alias::new(&column_info.name));
let mut new_column =
ColumnDef::new_with_type(Alias::new(&column_info.name), column_info.r#type.clone());
if column_info.not_null {
new_column.not_null();
}
@@ -235,8 +236,6 @@ impl TableDef {
primary_keys.push(column_info.name.clone());
}

column_info.r#type.write_type(&mut new_column);

match &column_info.default_value {
DefaultType::Integer(integer_value) => {
new_column.default(Value::Int(Some(*integer_value)));
205 changes: 49 additions & 156 deletions src/sqlite/def/types.rs
Original file line number Diff line number Diff line change
@@ -1,166 +1,59 @@
use sea_query::ColumnDef;
use sea_query::{BlobSize, ColumnType};
use std::num::ParseIntError;

/// A list of the offical SQLite types as outline at the official [SQLite Docs](https://www.sqlite.org/datatype3.html)
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Type {
Int,
Integer,
TinyInt,
SmallInt,
MediumInt,
BigInt,
UnsignedBigInt,
Int2,
Int8,
Character { length: u8 },
VarChar { length: u8 },
VaryingCharacter { length: u8 },
Nchar { length: u8 },
NativeCharacter { length: u8 },
NvarChar { length: u8 },
Text,
Clob,
Blob, //No datatype specified
Real,
Double,
DoublePrecision,
Float,
Numeric,
Decimal { integral: u8, fractional: u8 },
Boolean,
Date,
DateTime,
Timestamp,
}

impl Type {
/// Maps a string type from an `SqliteRow` into a [Type]
pub fn to_type(data_type: &str) -> Result<Type, ParseIntError> {
let data_type = data_type.to_uppercase();

let split_type: Vec<&str> = data_type.split('(').collect();
let type_result = match split_type[0] {
"INT" => Type::Int,
"INTEGER" => Type::Integer,
"TINY INT" | "TINYINT" => Type::TinyInt,
"SMALL INT" | "SMALLINT" => Type::SmallInt,
"MEDIUM INT" | "MEDIUMINT" => Type::MediumInt,
"BIG INT" | "BIGINT" => Type::BigInt,
"UNSIGNED INT" | "UNSIGNEDBIGINT" => Type::UnsignedBigInt,
"INT2" => Type::Int2,
"INT8" => Type::Int8,
"TEXT" => Type::Text,
"CLOB" => Type::Clob,
"BLOB" => Type::Blob,
"REAL" => Type::Real,
"DOUBLE" => Type::Double,
"DOUBLE PRECISION" => Type::DoublePrecision,
"FLOAT" => Type::Float,
"NUMERIC" => Type::Numeric,
"DECIMAL" => {
let decimals = split_type[1].chars().collect::<Vec<_>>();

let integral = decimals[0].to_string().parse::<u8>()?;
let fractional = decimals[2].to_string().parse::<u8>()?;
pub type Type = ColumnType;

Type::Decimal {
integral,
fractional,
pub fn parse_type(data_type: &str) -> Result<ColumnType, ParseIntError> {
let mut type_name = data_type;
let mut parts: Vec<u32> = Vec::new();
if let Some((prefix, suffix)) = data_type.split_once('(') {
if let Some(suffix) = suffix.strip_suffix(')') {
type_name = prefix;
for part in suffix.split(",") {
if let Ok(part) = part.trim().parse() {
parts.push(part);
} else {
break;
}
}
"BOOLEAN" => Type::Boolean,
"DATE" => Type::Date,
"DATETIME" => Type::DateTime,
"TIMESTAMP" => Type::Timestamp,
_ => Type::variable_types(&split_type)?,
};

Ok(type_result)
}

/// Write a [Type] to a [ColumnDef]
pub fn write_type(&self, column_def: &mut ColumnDef) {
match self {
Self::Int | Self::Integer | Self::MediumInt | Self::Int2 | Self::Int8 => {
column_def.integer();
}
Self::TinyInt => {
column_def.tiny_integer();
}
Self::SmallInt => {
column_def.small_integer();
}
Self::BigInt | Self::UnsignedBigInt => {
column_def.big_integer();
}
Self::Character { .. }
| Self::VarChar { .. }
| Self::VaryingCharacter { .. }
| Self::Nchar { .. }
| Self::NativeCharacter { .. }
| Self::NvarChar { .. }
| Self::Text
| Self::Clob => {
column_def.string();
}
Self::Blob => {
column_def.binary();
}
Self::Real | Self::Double | Self::DoublePrecision | Self::Float | Self::Numeric => {
column_def.double();
}
Self::Decimal {
integral,
fractional,
} => {
column_def.decimal_len((*integral) as u32, (*fractional) as u32);
}
Self::Boolean => {
column_def.boolean();
}
Self::Date => {
column_def.date();
}
Self::DateTime => {
column_def.date_time();
}
Self::Timestamp => {
column_def.timestamp();
}
}
}

#[allow(dead_code)]
fn concat_type(&self, type_name: &str, length: &u8) -> String {
let mut value = String::default();
value.push_str(type_name);
value.push('(');
value.push_str(&length.to_string());
value.push(')');

value
}

fn variable_types(split_type: &[&str]) -> Result<Type, ParseIntError> {
let length = if !split_type.len() == 1 {
let maybe_size = split_type[1].replace(')', "");
maybe_size.parse::<u8>()?
Ok(match type_name.to_lowercase().as_str() {
"char" => ColumnType::Char(parts.into_iter().next()),
"varchar" => ColumnType::String(parts.into_iter().next()),
"text" => ColumnType::Text,
"tinyint" => ColumnType::TinyInteger,
"smallint" => ColumnType::SmallInteger,
"integer" => ColumnType::Integer,
"bigint" => ColumnType::BigInteger,
"float" => ColumnType::Float,
"double" => ColumnType::Double,
"decimal" => ColumnType::Decimal(if parts.len() == 2 {
Some((parts[0], parts[1]))
} else {
255_u8
};

let type_result = match split_type[0] {
"VARCHAR" => Type::VarChar { length },
"CHARACTER" => Type::Character { length },
"VARYING CHARACTER" => Type::VaryingCharacter { length },
"NCHAR" => Type::Nchar { length },
"NATIVE CHARACTER" => Type::NativeCharacter { length },
"NVARCHAR" => Type::NvarChar { length },
_ => Type::Blob,
};
Ok(type_result)
}
None
}),
"datetime_text" => ColumnType::DateTime,
"timestamp_text" => ColumnType::Timestamp,
"timestamp_with_timezone_text" => ColumnType::TimestampWithTimeZone,
"time_text" => ColumnType::Time,
"date_text" => ColumnType::Date,
"tinyblob" => ColumnType::Binary(BlobSize::Tiny),
"mediumblob" => ColumnType::Binary(BlobSize::Medium),
"longblob" => ColumnType::Binary(BlobSize::Long),
"blob" => ColumnType::Binary(BlobSize::Blob(parts.into_iter().next())),
"varbinary_blob" if parts.len() == 1 => ColumnType::VarBinary(parts[0]),
"boolean" => ColumnType::Boolean,
"money" => ColumnType::Money(if parts.len() == 2 {
Some((parts[0], parts[1]))
} else {
None
}),
"json_text" => ColumnType::Json,
"jsonb_text" => ColumnType::JsonBinary,
"uuid_text" => ColumnType::Uuid,
_ => ColumnType::custom(data_type),
})
}

/// The default types for an SQLite `dflt_value`
@@ -170,6 +63,6 @@ pub enum DefaultType {
Float(f32),
String(String),
Null,
Unspecified, //FIXME For other types
Unspecified,
CurrentTimestamp,
}
92 changes: 57 additions & 35 deletions tests/live/sqlite/src/main.rs
Original file line number Diff line number Diff line change
@@ -4,8 +4,8 @@ use sqlx::SqlitePool;
use std::collections::HashMap;

use sea_schema::sea_query::{
Alias, ColumnDef, Expr, ForeignKey, ForeignKeyAction, ForeignKeyCreateStatement, Index, Query,
SqliteQueryBuilder, Table, TableCreateStatement, TableRef,
Alias, BlobSize, ColumnDef, Expr, ForeignKey, ForeignKeyAction, ForeignKeyCreateStatement,
Index, Query, SqliteQueryBuilder, Table, TableCreateStatement, TableRef,
};
use sea_schema::sqlite::{
def::TableDef,
@@ -15,10 +15,7 @@ use sea_schema::sqlite::{
#[cfg_attr(test, async_std::test)]
#[cfg_attr(not(test), async_std::main)]
async fn main() -> DiscoveryResult<()> {
// env_logger::builder()
// .filter_level(log::LevelFilter::Debug)
// .is_test(true)
// .init();
env_logger::init();

test_001().await?;
test_002().await?;
@@ -52,7 +49,7 @@ async fn test_001() -> DiscoveryResult<()> {
.table(Alias::new("Programming_Langs"))
.col(
ColumnDef::new(Alias::new("Name"))
.custom(Alias::new("INTEGER"))
.integer()
.not_null()
.auto_increment()
.primary_key(),
@@ -65,7 +62,7 @@ async fn test_001() -> DiscoveryResult<()> {
)
.col(
ColumnDef::new(Alias::new("SemVer"))
.custom(Alias::new("VARCHAR(255)"))
.string_len(255)
.not_null(),
)
.col(
@@ -126,17 +123,13 @@ async fn test_001() -> DiscoveryResult<()> {
// Tests foreign key discovery
let table_create_suppliers = Table::create()
.table(Alias::new("suppliers"))
.col(ColumnDef::new(Alias::new("supplier_id")).custom(Alias::new("INTEGER")))
.col(ColumnDef::new(Alias::new("supplier_id")).integer())
.col(
ColumnDef::new(Alias::new("supplier_name"))
.custom(Alias::new("TEXT"))
.not_null(),
)
.col(
ColumnDef::new(Alias::new("group_id"))
.custom(Alias::new("INTEGER"))
.text()
.not_null(),
)
.col(ColumnDef::new(Alias::new("group_id")).integer().not_null())
.primary_key(Index::create().col(Alias::new("supplier_id")))
.foreign_key(
ForeignKeyCreateStatement::new()
@@ -150,12 +143,8 @@ async fn test_001() -> DiscoveryResult<()> {

let table_create_supplier_groups = Table::create()
.table(Alias::new("supplier_groups"))
.col(ColumnDef::new(Alias::new("group_id")).custom(Alias::new("INTEGER")))
.col(
ColumnDef::new(Alias::new("group_name"))
.custom(Alias::new("TEXT"))
.not_null(),
)
.col(ColumnDef::new(Alias::new("group_id")).integer())
.col(ColumnDef::new(Alias::new("group_name")).text().not_null())
.primary_key(Index::create().col(Alias::new("group_id")))
.to_owned();

@@ -215,26 +204,18 @@ async fn test_001() -> DiscoveryResult<()> {

let schema = SchemaDiscovery::new(sqlite_pool.clone()).discover().await?;

let convert_column_types = |str: String| {
str.replace("INTEGER", "integer")
.replace("INT8", "integer")
.replace("TEXT", "text")
.replace("VARCHAR(255)", "text")
.replace("DATETIME", "text")
};
let expected_sql = [
create_table.to_string(SqliteQueryBuilder),
create_table_inventors.to_string(SqliteQueryBuilder),
table_create_supplier_groups.to_string(SqliteQueryBuilder),
table_create_suppliers.to_string(SqliteQueryBuilder),
]
.into_iter()
.map(convert_column_types)
.collect::<Vec<_>>();
assert_eq!(schema.tables.len(), expected_sql.len());

for (i, table) in schema.tables.into_iter().enumerate() {
let sql = convert_column_types(table.write().to_string(SqliteQueryBuilder));
let sql = table.write().to_string(SqliteQueryBuilder);
if sql == expected_sql[i] {
println!("[OK] {sql}");
}
@@ -273,6 +254,7 @@ async fn test_002() -> DiscoveryResult<()> {
create_lineitem_table(),
create_parent_table(),
create_child_table(),
create_strange_table(),
];

for tbl_create_stmt in tbl_create_stmts.iter() {
@@ -301,12 +283,12 @@ async fn test_002() -> DiscoveryResult<()> {
[
r#"CREATE TABLE "order" ("#,
r#""id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,"#,
r#""total" real,"#,
r#""total" decimal,"#,
r#""bakery_id" integer NOT NULL,"#,
r#""customer_id" integer NOT NULL,"#,
r#""placed_at" text NOT NULL DEFAULT CURRENT_TIMESTAMP,"#,
r#""updated" text NOT NULL DEFAULT '2023-06-07 16:24:00',"#,
r#""net_weight" real NOT NULL DEFAULT 10.05,"#,
r#""placed_at" datetime_text NOT NULL DEFAULT CURRENT_TIMESTAMP,"#,
r#""updated" datetime_text NOT NULL DEFAULT '2023-06-07 16:24:00',"#,
r#""net_weight" double NOT NULL DEFAULT 10.05,"#,
r#""priority" integer NOT NULL DEFAULT 5,"#,
r#"FOREIGN KEY ("customer_id") REFERENCES "customer" ("id") ON DELETE CASCADE ON UPDATE CASCADE,"#,
r#"FOREIGN KEY ("bakery_id") REFERENCES "bakery" ("id") ON DELETE CASCADE ON UPDATE CASCADE"#,
@@ -316,7 +298,7 @@ async fn test_002() -> DiscoveryResult<()> {
[
r#"CREATE TABLE "lineitem" ("#,
r#""id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,"#,
r#""price" real,"#,
r#""price" decimal,"#,
r#""quantity" integer,"#,
r#""order_id" integer NOT NULL,"#,
r#""cake_id" integer NOT NULL,"#,
@@ -566,3 +548,43 @@ fn create_child_table() -> TableCreateStatement {
)
.to_owned()
}

fn create_strange_table() -> TableCreateStatement {
Table::create()
.table(Alias::new("strange"))
.col(
ColumnDef::new(Alias::new("id"))
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(ColumnDef::new(Alias::new("int1")).integer())
.col(ColumnDef::new(Alias::new("int2")).tiny_integer())
.col(ColumnDef::new(Alias::new("int3")).small_integer())
.col(ColumnDef::new(Alias::new("int4")).big_integer())
.col(ColumnDef::new(Alias::new("string1")).string())
.col(ColumnDef::new(Alias::new("string2")).string_len(24))
.col(ColumnDef::new(Alias::new("char1")).char())
.col(ColumnDef::new(Alias::new("char2")).char_len(24))
.col(ColumnDef::new(Alias::new("text_col")).text())
.col(ColumnDef::new(Alias::new("json_col")).json())
.col(ColumnDef::new(Alias::new("uuid_col")).uuid())
.col(ColumnDef::new(Alias::new("decimal1")).decimal())
.col(ColumnDef::new(Alias::new("decimal2")).decimal_len(12, 4))
.col(ColumnDef::new(Alias::new("money1")).money())
.col(ColumnDef::new(Alias::new("money2")).money_len(12, 4))
.col(ColumnDef::new(Alias::new("float_col")).float())
.col(ColumnDef::new(Alias::new("double_col")).double())
.col(ColumnDef::new(Alias::new("date_col")).date())
.col(ColumnDef::new(Alias::new("time_col")).time())
.col(ColumnDef::new(Alias::new("datetime_col")).date_time())
.col(ColumnDef::new(Alias::new("boolean_col")).boolean())
.col(ColumnDef::new(Alias::new("binary1")).binary())
.col(ColumnDef::new(Alias::new("binary2")).binary_len(1024))
.col(ColumnDef::new(Alias::new("binary3")).var_binary(1024))
.col(ColumnDef::new(Alias::new("binary4")).blob(BlobSize::Tiny))
.col(ColumnDef::new(Alias::new("binary5")).blob(BlobSize::Medium))
.col(ColumnDef::new(Alias::new("binary6")).blob(BlobSize::Long))
.to_owned()
}