From bf30d6f77ea461e490554c3630564bb153883518 Mon Sep 17 00:00:00 2001 From: Ryan Lee Date: Wed, 16 Oct 2024 17:07:26 -0700 Subject: [PATCH] spark 4 parquet writer test initial fixes Signed-off-by: Ryan Lee --- integration_tests/run_pyspark_from_build.sh | 4 +-- .../src/main/python/parquet_write_test.py | 32 +++++++++++-------- .../com/nvidia/spark/rapids/ShimLoader.scala | 3 ++ 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index 22a23349791..9bd72b2ada0 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -28,10 +28,10 @@ else [[ ! -x "$(command -v zip)" ]] && { echo "fail to find zip command in $PATH"; exit 1; } PY4J_TMP=("${SPARK_HOME}"/python/lib/py4j-*-src.zip) PY4J_FILE=${PY4J_TMP[0]} - # PySpark uses ".dev0" for "-SNAPSHOT", ".dev" for "preview" + # PySpark uses ".dev0" for "-SNAPSHOT" and either ".dev" for "preview" or ".devN" for "previewN" # https://github.com/apache/spark/blob/66f25e314032d562567620806057fcecc8b71f08/dev/create-release/release-build.sh#L267 VERSION_STRING=$(PYTHONPATH=${SPARK_HOME}/python:${PY4J_FILE} python -c \ - "import pyspark, re; print(re.sub('\.dev0?$', '', pyspark.__version__))" + "import pyspark, re; print(re.sub('\.dev[012]?$', '', pyspark.__version__))" ) SCALA_VERSION=`$SPARK_HOME/bin/pyspark --version 2>&1| grep Scala | awk '{split($4,v,"."); printf "%s.%s", v[1], v[2]}'` diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 805a0b8137c..ed317f2e282 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -37,8 +37,11 @@ reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf, coalesce_parquet_file_reader_conf] parquet_decimal_struct_gen= StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_gens)]) -writer_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'} +legacy_parquet_datetimeRebaseModeInWrite='spark.sql.parquet.datetimeRebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite' +legacy_parquet_int96RebaseModeInWrite='spark.sql.parquet.int96RebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInWrite' +legacy_parquet_int96RebaseModeInRead='spark.sql.parquet.int96RebaseModeInRead' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInRead' +writer_confs={legacy_parquet_datetimeRebaseModeInWrite: 'CORRECTED', + legacy_parquet_int96RebaseModeInWrite: 'CORRECTED'} parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, TimestampGen(), binary_gen] @@ -158,8 +161,8 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): lambda spark, path: unary_op_df(spark, gen).write.parquet(path), lambda spark, path: spark.read.parquet(path), data_path, - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, + conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase, + legacy_parquet_int96RebaseModeInWrite: ts_rebase, 'spark.sql.parquet.outputTimestampType': ts_type}) @@ -285,8 +288,8 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, spark_tmp_table_fact def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write): spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write) - spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', datetime_rebase) - spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', int96_rebase) # for spark 310 + spark.conf.set(legacy_parquet_datetimeRebaseModeInWrite, datetime_rebase) + spark.conf.set(legacy_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310 with pytest.raises(Exception) as e_info: df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get()) assert e_info.match(r".*SparkUpgradeException.*") @@ -544,8 +547,8 @@ def generate_map_with_empty_validity(spark, path): def test_parquet_write_fails_legacy_datetime(spark_tmp_path, data_gen, ts_write, ts_rebase_write): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write} + legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write, + legacy_parquet_int96RebaseModeInWrite: ts_rebase_write} def writeParquetCatchException(spark, data_gen, data_path): with pytest.raises(Exception) as e_info: unary_op_df(spark, data_gen).coalesce(1).write.parquet(data_path) @@ -563,12 +566,12 @@ def test_parquet_write_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, dat ts_rebase_write, ts_rebase_read): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write[0], - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write[1], + legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0], + legacy_parquet_int96RebaseModeInWrite: ts_rebase_write[1], # The rebase modes in read configs should be ignored and overridden by the same # modes in write configs, which are retrieved from the written files. 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0], - 'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read[1]} + legacy_parquet_int96RebaseModeInRead: ts_rebase_read[1]} assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path), lambda spark, path: spark.read.parquet(path), @@ -597,7 +600,8 @@ def test_it(spark): spark.sql("CREATE TABLE {} LOCATION '{}/ctas' AS SELECT * FROM {}".format( ctas_with_existing_name, data_path, src_name)) except pyspark.sql.utils.AnalysisException as e: - if allow_non_empty or e.desc.find('non-empty directory') == -1: + description= e._desc if is_spark_400_or_later() else e.desc + if allow_non_empty or description.find('non-empty directory') == -1: raise e with_gpu_session(test_it, conf) @@ -825,8 +829,8 @@ def write_partitions(spark, table_path): ) def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func): - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase} + conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase, + legacy_parquet_int96RebaseModeInWrite: ts_rebase} def create_table(spark, path): tmp_table = spark_tmp_table_factory.get() diff --git a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala index bc35dad5372..699a5b5ee8b 100644 --- a/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala +++ b/sql-plugin-api/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala @@ -310,6 +310,9 @@ object ShimLoader { // hack for databricks, try to find something more reliable? if (SPARK_BUILD_USER.equals("Databricks")) { SPARK_VERSION + "-databricks" + // hack for preview versions to act as the base version + } else if (SPARK_VERSION matches """\d+\.\d+\.\d+-preview\d+""") { + SPARK_VERSION.split('-')(0) } else { SPARK_VERSION }