From 728715276d99267ea4104fcfbee368afcafa93f2 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 1 Sep 2021 09:13:30 -0500 Subject: [PATCH] Add AST support for null literals Signed-off-by: Jason Lowe --- integration_tests/src/main/python/ast_test.py | 19 +++++++++---------- .../com/nvidia/spark/rapids/literals.scala | 7 ------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/integration_tests/src/main/python/ast_test.py b/integration_tests/src/main/python/ast_test.py index ee3cc9f4573..f60c1a0d844 100644 --- a/integration_tests/src/main/python/ast_test.py +++ b/integration_tests/src/main/python/ast_test.py @@ -75,11 +75,19 @@ def test_literal(spark_tmp_path, data_gen): # Write data to Parquet so Spark generates a plan using just the count of the data. data_path = spark_tmp_path + '/AST_TEST_DATA' with_cpu_session(lambda spark: gen_df(spark, [("a", IntegerGen())]).write.parquet(data_path)) - # AST does not support null literals until https://github.com/rapidsai/cudf/pull/9117 scalar = gen_scalar(data_gen, force_no_nulls=True) assert_gpu_ast(is_supported=True, func=lambda spark: spark.read.parquet(data_path).select(scalar)) +@pytest.mark.parametrize('data_gen', [boolean_gen, byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, timestamp_gen], ids=idfn) +def test_null_literal(spark_tmp_path, data_gen): + # Write data to Parquet so Spark generates a plan using just the count of the data. + data_path = spark_tmp_path + '/AST_TEST_DATA' + with_cpu_session(lambda spark: gen_df(spark, [("a", IntegerGen())]).write.parquet(data_path)) + data_type = data_gen.data_type + assert_gpu_ast(is_supported=True, + func=lambda spark: spark.read.parquet(data_path).select(f.lit(None).cast(data_type))) + @pytest.mark.parametrize('data_descr', ast_integral_descrs, ids=idfn) def test_bitwise_not(data_descr): assert_unary_ast(data_descr, lambda df: df.selectExpr('~a')) @@ -336,15 +344,6 @@ def test_scalar_pow(): 'pow(a, 7.0)', 'pow(-12.0, b)')) -@approximate_float -def test_scalar_pow_fallback(): - # AST null literals not supported until https://github.com/rapidsai/cudf/issues/8831 is fixed - data_gen = [('a', DoubleGen()),('b', DoubleGen().with_special_case(lambda rand: float(rand.randint(-16, 16)), weight=100.0))] - assert_gpu_ast(is_supported=False, - func=lambda spark: gen_df(spark, data_gen).selectExpr( - 'pow(cast(null as DOUBLE), a)', - 'pow(b, cast(null as DOUBLE))')) - @approximate_float @pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/89') @pytest.mark.parametrize('data_descr', ast_double_descr, ids=idfn) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 14af56cdb84..39ea6ff2f8c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -667,11 +667,4 @@ class LiteralExprMeta( super.print(append, depth, all) } } - - override protected def tagSelfForAst(): Unit = { - // Preclude null literals until https://github.com/rapidsai/cudf/issues/8831 is fixed. - if (lit.value == null) { - willNotWorkInAst("null literals are not supported") - } - } }