From 1fb5fc4cbea646588fababca5c50a8829e90348e Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 8 Aug 2023 09:49:45 -0700 Subject: [PATCH] Add test for selecting a single complex field array and its parent struct array [databricks] (#8744) * Added a test for schema pruning for orc * Added a test for schema pruning for parquet --------- Signed-off-by: Raza Jafri --- .../python/prune_partition_column_test.py | 85 ++++++++++++++++++- .../rapids/ExecutionPlanCaptureCallback.scala | 40 ++++++++- 2 files changed, 121 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/prune_partition_column_test.py b/integration_tests/src/main/python/prune_partition_column_test.py index f68c32f4d4b..6948cc72519 100644 --- a/integration_tests/src/main/python/prune_partition_column_test.py +++ b/integration_tests/src/main/python/prune_partition_column_test.py @@ -15,10 +15,12 @@ import os import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, run_with_cpu_and_gpu, assert_equal from data_gen import * from marks import * -from spark_session import with_cpu_session +from pyspark.sql.types import IntegerType +from spark_session import with_cpu_session, is_before_spark_320 +from conftest import spark_jvm # Several values to avoid generating too many folders for partitions. part1_gen = SetValuesGen(IntegerType(), [-10, -1, 0, 1, 10]) @@ -127,3 +129,82 @@ def test_prune_partition_column_when_filter_fallback_project(spark_tmp_path, pru filter_col, file_format): do_prune_partition_column_when_filter_project(spark_tmp_path, prune_part_enabled, file_format, filter_col, gpu_project_enabled=False) + +# This method creates two tables and saves them to partitioned Parquet/ORC files. The file is then +# read in using the read function that is passed in +def create_contacts_table_and_read(is_partitioned, format, data_path, expected_schemata, func, conf, table_name): + full_name_type = StructGen([('first', StringGen()), ('middle', StringGen()), ('last', StringGen())]) + name_type = StructGen([('first', StringGen()), ('last', StringGen())]) + contacts_data_gen = StructGen([ + ('id', IntegerGen()), + ('name', full_name_type), + ('address', StringGen()), + ('friends', ArrayGen(full_name_type, max_length=10, nullable=False))], nullable=False) + + brief_contacts_data_gen = StructGen([ + ('id', IntegerGen()), + ('name', name_type), + ('address', StringGen())], nullable=False) + + # We are adding the field 'p' twice just like it is being done in Spark tests + # https://github.com/apache/spark/blob/85e252e8503534009f4fb5ea005d44c9eda31447/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala#L193 + def contact_gen_df(spark, data_gen, partition): + gen = gen_df(spark, data_gen) + if is_partitioned: + return gen.withColumn('p', f.lit(partition)) + else: + return gen + + with_cpu_session(lambda spark: contact_gen_df(spark, contacts_data_gen, 1).write.format(format).save(data_path + f"/{table_name}/p=1")) + with_cpu_session(lambda spark: contact_gen_df(spark, brief_contacts_data_gen, 2).write.format(format).save(data_path + f"/{table_name}/p=2")) + + # Schema to read in. + read_schema = contacts_data_gen.data_type.add("p", IntegerType(), True) if is_partitioned else contacts_data_gen.data_type + + (from_cpu, cpu_df), (from_gpu, gpu_df) = run_with_cpu_and_gpu( + func(read_schema), + 'COLLECT_WITH_DATAFRAME', + conf=conf) + + jvm = spark_jvm() + jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertSchemataMatch(cpu_df._jdf, gpu_df._jdf, expected_schemata) + assert_equal(from_cpu, from_gpu) + +# https://github.com/NVIDIA/spark-rapids/issues/8712 +# https://github.com/NVIDIA/spark-rapids/issues/8713 +# https://github.com/NVIDIA/spark-rapids/issues/8714 +@pytest.mark.parametrize('query_and_expected_schemata', [("select friends.middle, friends from {} where p=1", "struct>>"), + pytest.param(("select name.middle, address from {} where p=2", "struct,address:string>"), marks=pytest.mark.skip(reason='https://github.com/NVIDIA/spark-rapids/issues/8788')), + ("select name.first from {} where name.first = 'Jane'", "struct>")]) +@pytest.mark.parametrize('is_partitioned', [True, False]) +@pytest.mark.parametrize('format', ["parquet", "orc"]) +def test_select_complex_field(format, spark_tmp_path, query_and_expected_schemata, is_partitioned, spark_tmp_table_factory): + table_name = spark_tmp_table_factory.get() + query, expected_schemata = query_and_expected_schemata + data_path = spark_tmp_path + "/DATA" + def read_temp_view(schema): + def do_it(spark): + spark.read.format(format).schema(schema).load(data_path + f"/{table_name}").createOrReplaceTempView(table_name) + return spark.sql(query.format(table_name)) + return do_it + conf={"spark.sql.parquet.enableVectorizedReader": "true"} + create_contacts_table_and_read(is_partitioned, format, data_path, expected_schemata, read_temp_view, conf, table_name) + +# https://github.com/NVIDIA/spark-rapids/issues/8715 +@pytest.mark.parametrize('select_and_expected_schemata', [("friend.First", "struct>>"), + ("friend.MIDDLE", "struct>>")]) +@pytest.mark.skipif(is_before_spark_320(), reason='https://issues.apache.org/jira/browse/SPARK-34638') +@pytest.mark.parametrize('is_partitioned', [True, False]) +@pytest.mark.parametrize('format', ["parquet", "orc"]) +def test_nested_column_prune_on_generator_output(format, spark_tmp_path, select_and_expected_schemata, is_partitioned, spark_tmp_table_factory): + table_name = spark_tmp_table_factory.get() + query, expected_schemata = select_and_expected_schemata + data_path = spark_tmp_path + "/DATA" + def read_temp_view(schema): + def do_it(spark): + spark.read.format(format).schema(schema).load(data_path + f"/{table_name}").createOrReplaceTempView(table_name) + return spark.table(table_name).select(f.explode(f.col("friends")).alias("friend")).select(query) + return do_it + conf = {"spark.sql.caseSensitive": "false", + "spark.sql.parquet.enableVectorizedReader": "true"} + create_contacts_table_and_read(is_partitioned, format, data_path, expected_schemata, read_temp_view, conf, table_name) \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala index fa85f7612af..4b6813d7469 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala @@ -25,11 +25,11 @@ import com.nvidia.spark.rapids.{PlanShims, PlanUtils} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.{ExecSubqueryExpression, QueryExecution, ReusedSubqueryExec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.util.QueryExecutionListener -object ExecutionPlanCaptureCallback { +object ExecutionPlanCaptureCallback extends AdaptiveSparkPlanHelper { private[this] var shouldCapture: Boolean = false private[this] val execPlans: ArrayBuffer[SparkPlan] = ArrayBuffer.empty @@ -83,6 +83,42 @@ object ExecutionPlanCaptureCallback { fallbackCpuClassList.foreach(fallbackCpuClass => assertDidFallBack(gpuPlans, fallbackCpuClass)) } + /** + * This method is used by the Python integration tests. + * The method checks the schemata used in the GPU and CPU executed plans and compares it to the + * expected schemata to make sure we are not reading more data than needed + */ + def assertSchemataMatch(cpuDf: DataFrame, gpuDf: DataFrame, expectedSchema: String): Unit = { + import org.apache.spark.sql.execution.FileSourceScanExec + import org.apache.spark.sql.types.StructType + import org.apache.spark.sql.catalyst.parser.CatalystSqlParser + + val cpuFileSourceScanSchemata = collect(cpuDf.queryExecution.executedPlan) { + case scan: FileSourceScanExec => scan.requiredSchema + } + val gpuFileSourceScanSchemata = collect(gpuDf.queryExecution.executedPlan) { + case scan: GpuFileSourceScanExec => scan.requiredSchema + } + assert(cpuFileSourceScanSchemata.size == gpuFileSourceScanSchemata.size, + s"Found ${cpuFileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected ${gpuFileSourceScanSchemata.size}") + + cpuFileSourceScanSchemata.zip(gpuFileSourceScanSchemata).foreach { + case (cpuScanSchema, gpuScanSchema) => + cpuScanSchema match { + case otherType: StructType => + assert(gpuScanSchema.sameType(otherType)) + val expectedStructType = CatalystSqlParser.parseDataType(expectedSchema) + assert(gpuScanSchema.sameType(expectedStructType), + s"Type GPU schema ${gpuScanSchema.toDDL} doesn't match $expectedSchema") + assert(cpuScanSchema.sameType(expectedStructType), + s"Type CPU schema ${cpuScanSchema.toDDL} doesn't match $expectedSchema") + case otherType => assert(false, s"The expected type $cpuScanSchema" + + s" doesn't match the actual type $otherType") + } + } + } + def assertCapturedAndGpuFellBack(fallbackCpuClass: String, timeoutMs: Long = 2000): Unit = { val gpuPlans = getResultsWithTimeout(timeoutMs = timeoutMs) assert(gpuPlans.nonEmpty, "Did not capture a plan")