From b70a394222bf209298026fd100f6b9498acf9fff Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 9 Feb 2024 08:21:15 -0800 Subject: [PATCH] fix: if doesn't support different types --- sqlglot/transforms.py | 5 ++++- tests/dialects/test_bigquery.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index caaa8acc7d..4777609551 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -214,12 +214,15 @@ def new_name(names: t.Set[str], name: str) -> str: explode_arg = explode.this if isinstance(explode, exp.ExplodeOuter): + bracket = explode_arg[0] + bracket.set("safe", True) + bracket.set("offset", True) explode_arg = exp.func( "IF", exp.func( "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) ).eq(0), - exp.array(exp.null(), copy=False), + exp.array(bracket, copy=False), explode_arg, ) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 0c90a80c41..bbc9a7f821 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -621,7 +621,7 @@ def test_bigquery(self): }, ) self.validate_all( - "SELECT IF(pos = pos_2, col, NULL) AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [NULL], []))) - 1)) AS pos CROSS JOIN UNNEST(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [NULL], [])) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [NULL], [])) - 1) AND pos_2 = (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [NULL], [])) - 1))", + "SELECT IF(pos = pos_2, col, NULL) AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], []))) - 1)) AS pos CROSS JOIN UNNEST(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) - 1) AND pos_2 = (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) - 1))", read={"spark": "select explode_outer([])"}, ) self.validate_all(