Skip to content

Commit

Permalink
[BUG]: sql functions case sensitivit (#3063)
Browse files Browse the repository at this point in the history
closes #3056
  • Loading branch information
universalmind303 authored Oct 16, 2024
1 parent 271ec7c commit d243cee
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ impl SQLPlanner {

// lookup function variant(s) by name
let fns = &SQL_FUNCTIONS;
let fn_name = func.name.to_string();
// SQL function names are case-insensitive
let fn_name = func.name.to_string().to_lowercase();
let fn_match = match fns.get(&fn_name) {
Some(func) => func,
None => unsupported_sql_err!("Function `{}` not found", fn_name),
Expand Down
4 changes: 4 additions & 0 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ mod tests {
#[case::ilike("select utf8 ilike 'a' as ilike from tbl1")]
#[case::datestring("select DATE '2021-08-01' as dt from tbl1")]
#[case::datetime("select DATETIME '2021-08-01 00:00:00' as dt from tbl1")]
#[case::countstar("select COUNT(*) as count from tbl1")]
#[case::countstarlower("select COUNT(*) as count from tbl1")]
#[case::count("select COUNT(i32) as count from tbl1")]
#[case::countcasing("select CoUnT(i32) as count from tbl1")]
// #[case::to_datetime("select to_datetime(utf8, 'YYYY-MM-DD') as to_datetime from tbl1")]
fn test_compiles_funcs(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> {
let plan = planner.plan_sql(query);
Expand Down
12 changes: 12 additions & 0 deletions tests/sql/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def test_hash_exprs():
daft.sql("SELECT minhash(a) as hash_a FROM df").collect()


def test_count_star():
df = daft.from_pydict(
{
"a": [1, 2, 3, 4],
}
)

actual = daft.sql("SELECT COUNT(*) FROM df").collect()
expected = df.agg(daft.col("*").count().alias("count")).collect()
assert actual.to_pydict() == expected.to_pydict()


def test_between():
df = daft.from_pydict(
{
Expand Down

0 comments on commit d243cee

Please sign in to comment.