diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index bc564e8fa3f..8e5100f1d62 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -301,16 +301,19 @@ def assert_gpu_and_cpu_are_equal_iterator(func, conf={}): _assert_gpu_and_cpu_are_equal(func, False, conf=conf) -def assert_gpu_and_cpu_are_equal_sql(df, tableName, sql, conf=None): +def assert_gpu_and_cpu_are_equal_sql(df_fun, table_name, sql, conf=None): """ Assert that the specified SQL query produces equal results on CPU and GPU. - :param df: Input dataframe - :param tableName: Name of table to be created with the dataframe + :param df_fun: a function that will create the dataframe + :param table_name: Name of table to be created with the dataframe :param sql: SQL query to be run on the specified table :param conf: Any user-specified confs. Empty by default. :return: Assertion failure, if results from CPU and GPU do not match. """ if conf is None: conf = {} - df.createOrReplaceTempView(tableName) - assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.sql(sql), conf) + def do_it_all(spark): + df = df_fun(spark) + df.createOrReplaceTempView(table_name) + return spark.sql(sql) + assert_gpu_and_cpu_are_equal_collect(do_it_all, conf) diff --git a/integration_tests/src/main/python/conftest.py b/integration_tests/src/main/python/conftest.py index 506e8be8be5..4bd04049f8b 100644 --- a/integration_tests/src/main/python/conftest.py +++ b/integration_tests/src/main/python/conftest.py @@ -14,7 +14,7 @@ import pytest import random -from spark_init_internal import spark +from spark_init_internal import get_spark_i_know_what_i_am_doing from pyspark.sql.dataframe import DataFrame _approximate_float_args = None @@ -175,7 +175,7 @@ def spark_tmp_path(request): ret = '/tmp/pyspark_tests/' ret = ret + '/' + str(random.randint(0, 1000000)) + '/' # Make sure it is there and accessible - sc = spark.sparkContext + sc = get_spark_i_know_what_i_am_doing().sparkContext config = sc._jsc.hadoopConfiguration() path = sc._jvm.org.apache.hadoop.fs.Path(ret) fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config) @@ -191,13 +191,13 @@ def _get_jvm(spark): return spark.sparkContext._jvm def spark_jvm(): - return _get_jvm(spark) + return _get_jvm(get_spark_i_know_what_i_am_doing()) class TpchRunner: def __init__(self, tpch_format, tpch_path): self.tpch_format = tpch_format self.tpch_path = tpch_path - self.setup(spark) + self.setup(get_spark_i_know_what_i_am_doing()) def setup(self, spark): jvm_session = _get_jvm_session(spark) @@ -210,6 +210,7 @@ def setup(self, spark): formats.get(self.tpch_format)(jvm_session, self.tpch_path) def do_test_query(self, query): + spark = get_spark_i_know_what_i_am_doing() jvm_session = _get_jvm_session(spark) jvm = _get_jvm(spark) tests = { @@ -256,7 +257,7 @@ class TpcxbbRunner: def __init__(self, tpcxbb_format, tpcxbb_path): self.tpcxbb_format = tpcxbb_format self.tpcxbb_path = tpcxbb_path - self.setup(spark) + self.setup(get_spark_i_know_what_i_am_doing()) def setup(self, spark): jvm_session = _get_jvm_session(spark) @@ -269,6 +270,7 @@ def setup(self, spark): formats.get(self.tpcxbb_format)(jvm_session,self.tpcxbb_path) def do_test_query(self, query): + spark = get_spark_i_know_what_i_am_doing() jvm_session = _get_jvm_session(spark) jvm = _get_jvm(spark) tests = { diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index b8230dde452..c9d11b2da07 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -19,7 +19,7 @@ import pyspark.sql.functions as f import pytest import random -from spark_session import spark, is_tz_utc +from spark_session import is_tz_utc import sre_yield import struct @@ -489,8 +489,8 @@ def gen_array(): return [self._child_gen.gen() for _ in range(0, length)] self._start(rand, gen_array) -def skip_if_not_utc(spark): - if (not is_tz_utc(spark)): +def skip_if_not_utc(): + if (not is_tz_utc()): pytest.skip('The java system time zone is not set to UTC') def gen_df(spark, data_gen, length=2048, seed=0): @@ -504,7 +504,7 @@ def gen_df(spark, data_gen, length=2048, seed=0): # Before we get too far we need to verify that we can run with timestamps if src.contains_ts(): - skip_if_not_utc(spark) + skip_if_not_utc() rand = random.Random(seed) src.start(rand) @@ -525,7 +525,7 @@ def _gen_scalars_common(data_gen, count, seed=0): # Before we get too far we need to verify that we can run with timestamps if src.contains_ts(): - skip_if_not_utc(spark) + skip_if_not_utc() rand = random.Random(seed) src.start(rand) diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 2cad2ee7cc1..e0864a77377 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -19,7 +19,7 @@ from pyspark.sql.types import * from marks import * import pyspark.sql.functions as f -from spark_session import with_cpu_session, with_spark_session +from spark_session import with_spark_session _no_nans_float_conf = {'spark.rapids.sql.variableFloatAgg.enabled': 'true', 'spark.rapids.sql.hasNans': 'false', @@ -276,7 +276,7 @@ def test_hash_count_with_filter(data_gen, conf): @pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn) def test_hash_multiple_filters(data_gen, conf): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark : gen_df(spark, data_gen, length=100)), + lambda spark : gen_df(spark, data_gen, length=100), "hash_agg_table", 'select count(a) filter (where c > 50),' + 'count(b) filter (where c > 100),' + @@ -296,7 +296,7 @@ def test_hash_multiple_filters(data_gen, conf): @pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn) def test_hash_multiple_filters_fail(data_gen): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark : gen_df(spark, data_gen, length=100)), + lambda spark : gen_df(spark, data_gen, length=100), "hash_agg_table", 'select avg(b) filter (where b > 20) from hash_agg_table group by a', _no_nans_float_conf_partial) @@ -318,7 +318,7 @@ def test_hash_query_max_bug(data_gen): _grpkey_doubles_with_nan_zero_grouping_keys], ids=idfn) def test_hash_agg_with_nan_keys(data_gen): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark : gen_df(spark, data_gen, length=1024)), + lambda spark : gen_df(spark, data_gen, length=1024), "hash_agg_table", 'select a, ' 'count(*) as count_stars, ' @@ -342,7 +342,7 @@ def test_hash_agg_with_nan_keys(data_gen): @pytest.mark.parametrize('data_gen', [ _grpkey_doubles_with_nan_zero_grouping_keys], ids=idfn) def test_count_distinct_with_nan_floats(data_gen): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark : gen_df(spark, data_gen, length=1024)), + lambda spark : gen_df(spark, data_gen, length=1024), "hash_agg_table", 'select a, count(distinct b) as count_distinct_bees from hash_agg_table group by a', _no_nans_float_conf) diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py index 08e965f7446..61fba592f11 100644 --- a/integration_tests/src/main/python/spark_init_internal.py +++ b/integration_tests/src/main/python/spark_init_internal.py @@ -28,5 +28,14 @@ def _spark__init(): _s.sparkContext.setLogLevel("WARN") return _s -spark = _spark__init() +_spark = _spark__init() + +def get_spark_i_know_what_i_am_doing(): + """ + Get the current SparkSession. + This should almost never be called directly instead you should call + with_spark_session, with_cpu_session, or with_gpu_session for spark_session. + This is to guarantee that the session and it's config is setup in a repeatable way. + """ + return _spark diff --git a/integration_tests/src/main/python/spark_session.py b/integration_tests/src/main/python/spark_session.py index cbaa74fa01c..1a4b8a7ba29 100644 --- a/integration_tests/src/main/python/spark_session.py +++ b/integration_tests/src/main/python/spark_session.py @@ -13,8 +13,8 @@ # limitations under the License. from conftest import is_allowing_any_non_gpu, get_non_gpu_allowed -from pyspark.sql import SparkSession -from spark_init_internal import spark as internal_spark +from pyspark.sql import SparkSession, DataFrame +from spark_init_internal import get_spark_i_know_what_i_am_doing def _from_scala_map(scala_map): ret = {} @@ -25,12 +25,12 @@ def _from_scala_map(scala_map): ret[key] = scala_map.get(key).get() return ret -spark = internal_spark +_spark = get_spark_i_know_what_i_am_doing() # Have to reach into a private member to get access to the API we need -_orig_conf = _from_scala_map(spark.conf._jconf.getAll()) +_orig_conf = _from_scala_map(_spark.conf._jconf.getAll()) _orig_conf_keys = _orig_conf.keys() -def is_tz_utc(spark=spark): +def is_tz_utc(spark=_spark): """ true if the tz is UTC else false """ @@ -42,25 +42,32 @@ def is_tz_utc(spark=spark): def _set_all_confs(conf): for key, value in conf.items(): - if spark.conf.get(key, None) != value: - spark.conf.set(key, value) + if _spark.conf.get(key, None) != value: + _spark.conf.set(key, value) def reset_spark_session_conf(): """Reset all of the configs for a given spark session.""" _set_all_confs(_orig_conf) #We should clear the cache - spark.catalog.clearCache() + _spark.catalog.clearCache() # Have to reach into a private member to get access to the API we need - current_keys = _from_scala_map(spark.conf._jconf.getAll()).keys() + current_keys = _from_scala_map(_spark.conf._jconf.getAll()).keys() for key in current_keys: if key not in _orig_conf_keys: - spark.conf.unset(key) + _spark.conf.unset(key) + +def _check_for_proper_return_values(something): + """We don't want to return an DataFrame or Dataset from a with_spark_session. You will not get what you expect""" + if (isinstance(something, DataFrame)): + raise RuntimeError("You should never return a DataFrame from a with_*_session, you will not get the results that you expect") def with_spark_session(func, conf={}): """Run func that takes a spark session as input with the given configs set.""" reset_spark_session_conf() _set_all_confs(conf) - return func(spark) + ret = func(_spark) + _check_for_proper_return_values(ret) + return ret def with_cpu_session(func, conf={}): """Run func that takes a spark session as input with the given configs set on the CPU.""" diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index d3e52021460..360e0cd1781 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -18,7 +18,6 @@ from data_gen import * from pyspark.sql.types import * from marks import * -from spark_session import with_cpu_session _grpkey_longs_with_no_nulls = [ ('a', RepeatSeqGen(LongGen(nullable=False), length=20)), @@ -48,7 +47,7 @@ _grpkey_longs_with_nullable_timestamps], ids=idfn) def test_window_aggs_for_rows(data_gen): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark : gen_df(spark, data_gen, length=2048)), + lambda spark : gen_df(spark, data_gen, length=2048), "window_agg_table", 'select ' ' sum(c) over ' @@ -72,7 +71,7 @@ def test_window_aggs_for_rows(data_gen): _grpkey_longs_with_nullable_timestamps], ids=idfn) def test_window_aggs_for_ranges(data_gen): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark: gen_df(spark, data_gen, length=2048)), + lambda spark: gen_df(spark, data_gen, length=2048), "window_agg_table", 'select ' ' sum(c) over ' @@ -102,7 +101,7 @@ def test_window_aggs_for_ranges(data_gen): @pytest.mark.parametrize('data_gen', [_grpkey_longs_with_timestamps], ids=idfn) def test_window_aggs_for_ranges_of_dates(data_gen): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark: gen_df(spark, data_gen, length=2048)), + lambda spark: gen_df(spark, data_gen, length=2048), "window_agg_table", 'select ' ' sum(c) over ' @@ -118,7 +117,7 @@ def test_window_aggs_for_ranges_of_dates(data_gen): @pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls], ids=idfn) def test_window_aggs_for_rows_count_non_null(data_gen): assert_gpu_and_cpu_are_equal_sql( - with_cpu_session(lambda spark: gen_df(spark, data_gen, length=2048)), + lambda spark: gen_df(spark, data_gen, length=2048), "window_agg_table", 'select ' ' count(c) over '