Skip to content

Commit

Permalink
fix: ignore nulls with order by closes #2932
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Feb 8, 2024
1 parent b827626 commit c4fb2fa
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
4 changes: 2 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4284,11 +4284,11 @@ class Interval(TimeUnit):


class IgnoreNulls(Expression):
pass
arg_types = {"this": True, "inline": False}


class RespectNulls(Expression):
pass
arg_types = {"this": True, "inline": False}


# Functions
Expand Down
18 changes: 12 additions & 6 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,12 +2850,18 @@ def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return self._embed_ignore_nulls(expression, "RESPECT NULLS")

def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC:
this = expression.find(exp.AggFunc)
if this:
sql = self.sql(this)
sql = sql[:-1] + f" {text})"
return sql
if self.IGNORE_NULLS_IN_FUNC and not expression.args.get("inline"):
for klass in (exp.Order, exp.Limit):
mod = expression.find(klass)

if mod:
mod.this.replace(expression.__class__(this=mod.this.copy(), inline=True))
return self.sql(expression.this)

agg_func = expression.find(exp.AggFunc)

if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"

return f"{self.sql(expression, 'this')} {text}"

Expand Down
9 changes: 7 additions & 2 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ class TestBigQuery(Validator):
maxDiff = None

def test_bigquery(self):
self.validate_identity("ARRAY_AGG(x IGNORE NULLS LIMIT 1)")
self.validate_identity("ARRAY_AGG(x IGNORE NULLS ORDER BY x LIMIT 1)")
self.validate_identity("ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY x LIMIT 1)")
self.validate_identity("ARRAY_AGG(x IGNORE NULLS)")

self.validate_all(
"SELECT SUM(x IGNORE NULLS) AS x",
read={
Expand Down Expand Up @@ -62,14 +67,14 @@ def test_bigquery(self):
self.validate_all(
"SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x",
write={
"duckdb": "SELECT ARRAY_AGG(DISTINCT x ORDER BY a NULLS FIRST, b DESC LIMIT 10 IGNORE NULLS) AS x",
"duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 10) AS x",
"spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 10) IGNORE NULLS AS x",
},
)
self.validate_all(
"SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x",
write={
"duckdb": "SELECT ARRAY_AGG(DISTINCT x ORDER BY a NULLS FIRST, b DESC LIMIT 1, 10 IGNORE NULLS) AS x",
"duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 1, 10) AS x",
"spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 1, 10) IGNORE NULLS AS x",
},
)
Expand Down

0 comments on commit c4fb2fa

Please sign in to comment.