From 04710f3f4c2fffdf17b09db058df9d5d23869c01 Mon Sep 17 00:00:00 2001 From: tobymao Date: Thu, 8 Feb 2024 15:34:50 -0800 Subject: [PATCH] fix: ignore nulls with order by closes #2932 --- sqlglot/expressions.py | 4 ++-- sqlglot/generator.py | 18 ++++++++++++------ tests/dialects/test_bigquery.py | 12 ++++++++++-- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index c26342b33a..055eb30ea3 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4284,11 +4284,11 @@ class Interval(TimeUnit): class IgnoreNulls(Expression): - pass + arg_types = {"this": True, "inline": False} class RespectNulls(Expression): - pass + arg_types = {"this": True, "inline": False} # Functions diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 81af56d8c8..c8afb8bb4e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2850,12 +2850,18 @@ def respectnulls_sql(self, expression: exp.RespectNulls) -> str: return self._embed_ignore_nulls(expression, "RESPECT NULLS") def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str: - if self.IGNORE_NULLS_IN_FUNC: - this = expression.find(exp.AggFunc) - if this: - sql = self.sql(this) - sql = sql[:-1] + f" {text})" - return sql + if self.IGNORE_NULLS_IN_FUNC and not expression.args.get("inline"): + for klass in (exp.Order, exp.Limit): + mod = expression.find(klass) + + if mod: + mod.this.replace(expression.__class__(this=mod.this.copy(), inline=True)) + return self.sql(expression.this) + + agg_func = expression.find(exp.AggFunc) + + if agg_func: + return self.sql(agg_func)[:-1] + f" {text})" return f"{self.sql(expression, 'this')} {text}" diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 340630c2e5..41b969804d 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -18,6 +18,11 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + self.validate_identity("ARRAY_AGG(x IGNORE NULLS LIMIT 1)") + self.validate_identity("ARRAY_AGG(x IGNORE NULLS ORDER BY x LIMIT 1)") + self.validate_identity("ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY x LIMIT 1)") + self.validate_identity("ARRAY_AGG(x IGNORE NULLS)") + self.validate_all( "SELECT SUM(x IGNORE NULLS) AS x", read={ @@ -55,6 +60,7 @@ def test_bigquery(self): self.validate_all( "SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", write={ + "bigquery": "SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", "duckdb": "SELECT QUANTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", "spark": "SELECT PERCENTILE_CONT(x, 0.5) RESPECT NULLS OVER ()", }, @@ -62,14 +68,16 @@ def test_bigquery(self): self.validate_all( "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x", write={ - "duckdb": "SELECT ARRAY_AGG(DISTINCT x ORDER BY a NULLS FIRST, b DESC LIMIT 10 IGNORE NULLS) AS x", + "bigquery": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x", + "duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 10) AS x", "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 10) IGNORE NULLS AS x", }, ) self.validate_all( "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x", write={ - "duckdb": "SELECT ARRAY_AGG(DISTINCT x ORDER BY a NULLS FIRST, b DESC LIMIT 1, 10 IGNORE NULLS) AS x", + "bigquery": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x", + "duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 1, 10) AS x", "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 1, 10) IGNORE NULLS AS x", }, )