Skip to content

Commit

Permalink
Add test for selecting a single complex field array and its parent st…
Browse files Browse the repository at this point in the history
…ruct array [databricks] (NVIDIA#8744)

* Added a test for schema pruning for orc

* Added a test for schema pruning for parquet
---------

Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri authored Aug 8, 2023
1 parent 6c50e8d commit 1fb5fc4
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 4 deletions.
85 changes: 83 additions & 2 deletions integration_tests/src/main/python/prune_partition_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import os
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, run_with_cpu_and_gpu, assert_equal
from data_gen import *
from marks import *
from spark_session import with_cpu_session
from pyspark.sql.types import IntegerType
from spark_session import with_cpu_session, is_before_spark_320
from conftest import spark_jvm

# Several values to avoid generating too many folders for partitions.
part1_gen = SetValuesGen(IntegerType(), [-10, -1, 0, 1, 10])
Expand Down Expand Up @@ -127,3 +129,82 @@ def test_prune_partition_column_when_filter_fallback_project(spark_tmp_path, pru
filter_col, file_format):
do_prune_partition_column_when_filter_project(spark_tmp_path, prune_part_enabled, file_format,
filter_col, gpu_project_enabled=False)

# This method creates two tables and saves them to partitioned Parquet/ORC files. The file is then
# read in using the read function that is passed in
def create_contacts_table_and_read(is_partitioned, format, data_path, expected_schemata, func, conf, table_name):
full_name_type = StructGen([('first', StringGen()), ('middle', StringGen()), ('last', StringGen())])
name_type = StructGen([('first', StringGen()), ('last', StringGen())])
contacts_data_gen = StructGen([
('id', IntegerGen()),
('name', full_name_type),
('address', StringGen()),
('friends', ArrayGen(full_name_type, max_length=10, nullable=False))], nullable=False)

brief_contacts_data_gen = StructGen([
('id', IntegerGen()),
('name', name_type),
('address', StringGen())], nullable=False)

# We are adding the field 'p' twice just like it is being done in Spark tests
# https://github.com/apache/spark/blob/85e252e8503534009f4fb5ea005d44c9eda31447/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala#L193
def contact_gen_df(spark, data_gen, partition):
gen = gen_df(spark, data_gen)
if is_partitioned:
return gen.withColumn('p', f.lit(partition))
else:
return gen

with_cpu_session(lambda spark: contact_gen_df(spark, contacts_data_gen, 1).write.format(format).save(data_path + f"/{table_name}/p=1"))
with_cpu_session(lambda spark: contact_gen_df(spark, brief_contacts_data_gen, 2).write.format(format).save(data_path + f"/{table_name}/p=2"))

# Schema to read in.
read_schema = contacts_data_gen.data_type.add("p", IntegerType(), True) if is_partitioned else contacts_data_gen.data_type

(from_cpu, cpu_df), (from_gpu, gpu_df) = run_with_cpu_and_gpu(
func(read_schema),
'COLLECT_WITH_DATAFRAME',
conf=conf)

jvm = spark_jvm()
jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertSchemataMatch(cpu_df._jdf, gpu_df._jdf, expected_schemata)
assert_equal(from_cpu, from_gpu)

# https://github.com/NVIDIA/spark-rapids/issues/8712
# https://github.com/NVIDIA/spark-rapids/issues/8713
# https://github.com/NVIDIA/spark-rapids/issues/8714
@pytest.mark.parametrize('query_and_expected_schemata', [("select friends.middle, friends from {} where p=1", "struct<friends:array<struct<first:string,middle:string,last:string>>>"),
pytest.param(("select name.middle, address from {} where p=2", "struct<name:struct<middle:string>,address:string>"), marks=pytest.mark.skip(reason='https://github.com/NVIDIA/spark-rapids/issues/8788')),
("select name.first from {} where name.first = 'Jane'", "struct<name:struct<first:string>>")])
@pytest.mark.parametrize('is_partitioned', [True, False])
@pytest.mark.parametrize('format', ["parquet", "orc"])
def test_select_complex_field(format, spark_tmp_path, query_and_expected_schemata, is_partitioned, spark_tmp_table_factory):
table_name = spark_tmp_table_factory.get()
query, expected_schemata = query_and_expected_schemata
data_path = spark_tmp_path + "/DATA"
def read_temp_view(schema):
def do_it(spark):
spark.read.format(format).schema(schema).load(data_path + f"/{table_name}").createOrReplaceTempView(table_name)
return spark.sql(query.format(table_name))
return do_it
conf={"spark.sql.parquet.enableVectorizedReader": "true"}
create_contacts_table_and_read(is_partitioned, format, data_path, expected_schemata, read_temp_view, conf, table_name)

# https://github.com/NVIDIA/spark-rapids/issues/8715
@pytest.mark.parametrize('select_and_expected_schemata', [("friend.First", "struct<friends:array<struct<first:string>>>"),
("friend.MIDDLE", "struct<friends:array<struct<middle:string>>>")])
@pytest.mark.skipif(is_before_spark_320(), reason='https://issues.apache.org/jira/browse/SPARK-34638')
@pytest.mark.parametrize('is_partitioned', [True, False])
@pytest.mark.parametrize('format', ["parquet", "orc"])
def test_nested_column_prune_on_generator_output(format, spark_tmp_path, select_and_expected_schemata, is_partitioned, spark_tmp_table_factory):
table_name = spark_tmp_table_factory.get()
query, expected_schemata = select_and_expected_schemata
data_path = spark_tmp_path + "/DATA"
def read_temp_view(schema):
def do_it(spark):
spark.read.format(format).schema(schema).load(data_path + f"/{table_name}").createOrReplaceTempView(table_name)
return spark.table(table_name).select(f.explode(f.col("friends")).alias("friend")).select(query)
return do_it
conf = {"spark.sql.caseSensitive": "false",
"spark.sql.parquet.enableVectorizedReader": "true"}
create_contacts_table_and_read(is_partitioned, format, data_path, expected_schemata, read_temp_view, conf, table_name)
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ import com.nvidia.spark.rapids.{PlanShims, PlanUtils}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.{ExecSubqueryExpression, QueryExecution, ReusedSubqueryExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.util.QueryExecutionListener

object ExecutionPlanCaptureCallback {
object ExecutionPlanCaptureCallback extends AdaptiveSparkPlanHelper {
private[this] var shouldCapture: Boolean = false
private[this] val execPlans: ArrayBuffer[SparkPlan] = ArrayBuffer.empty

Expand Down Expand Up @@ -83,6 +83,42 @@ object ExecutionPlanCaptureCallback {
fallbackCpuClassList.foreach(fallbackCpuClass => assertDidFallBack(gpuPlans, fallbackCpuClass))
}

/**
* This method is used by the Python integration tests.
* The method checks the schemata used in the GPU and CPU executed plans and compares it to the
* expected schemata to make sure we are not reading more data than needed
*/
def assertSchemataMatch(cpuDf: DataFrame, gpuDf: DataFrame, expectedSchema: String): Unit = {
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser

val cpuFileSourceScanSchemata = collect(cpuDf.queryExecution.executedPlan) {
case scan: FileSourceScanExec => scan.requiredSchema
}
val gpuFileSourceScanSchemata = collect(gpuDf.queryExecution.executedPlan) {
case scan: GpuFileSourceScanExec => scan.requiredSchema
}
assert(cpuFileSourceScanSchemata.size == gpuFileSourceScanSchemata.size,
s"Found ${cpuFileSourceScanSchemata.size} file sources in dataframe, " +
s"but expected ${gpuFileSourceScanSchemata.size}")

cpuFileSourceScanSchemata.zip(gpuFileSourceScanSchemata).foreach {
case (cpuScanSchema, gpuScanSchema) =>
cpuScanSchema match {
case otherType: StructType =>
assert(gpuScanSchema.sameType(otherType))
val expectedStructType = CatalystSqlParser.parseDataType(expectedSchema)
assert(gpuScanSchema.sameType(expectedStructType),
s"Type GPU schema ${gpuScanSchema.toDDL} doesn't match $expectedSchema")
assert(cpuScanSchema.sameType(expectedStructType),
s"Type CPU schema ${cpuScanSchema.toDDL} doesn't match $expectedSchema")
case otherType => assert(false, s"The expected type $cpuScanSchema" +
s" doesn't match the actual type $otherType")
}
}
}

def assertCapturedAndGpuFellBack(fallbackCpuClass: String, timeoutMs: Long = 2000): Unit = {
val gpuPlans = getResultsWithTimeout(timeoutMs = timeoutMs)
assert(gpuPlans.nonEmpty, "Did not capture a plan")
Expand Down

0 comments on commit 1fb5fc4

Please sign in to comment.