Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent bad practice in python tests #482

Merged
merged 2 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,19 @@ def assert_gpu_and_cpu_are_equal_iterator(func, conf={}):
_assert_gpu_and_cpu_are_equal(func, False, conf=conf)


def assert_gpu_and_cpu_are_equal_sql(df, tableName, sql, conf=None):
def assert_gpu_and_cpu_are_equal_sql(df_fun, table_name, sql, conf=None):
"""
Assert that the specified SQL query produces equal results on CPU and GPU.
:param df: Input dataframe
:param tableName: Name of table to be created with the dataframe
:param df_fun: a function that will create the dataframe
:param table_name: Name of table to be created with the dataframe
:param sql: SQL query to be run on the specified table
:param conf: Any user-specified confs. Empty by default.
:return: Assertion failure, if results from CPU and GPU do not match.
"""
if conf is None:
conf = {}
df.createOrReplaceTempView(tableName)
assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.sql(sql), conf)
def do_it_all(spark):
df = df_fun(spark)
df.createOrReplaceTempView(table_name)
return spark.sql(sql)
assert_gpu_and_cpu_are_equal_collect(do_it_all, conf)
12 changes: 7 additions & 5 deletions integration_tests/src/main/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
import random
from spark_init_internal import spark
from spark_init_internal import get_spark_i_know_what_i_am_doing
from pyspark.sql.dataframe import DataFrame

_approximate_float_args = None
Expand Down Expand Up @@ -175,7 +175,7 @@ def spark_tmp_path(request):
ret = '/tmp/pyspark_tests/'
ret = ret + '/' + str(random.randint(0, 1000000)) + '/'
# Make sure it is there and accessible
sc = spark.sparkContext
sc = get_spark_i_know_what_i_am_doing().sparkContext
config = sc._jsc.hadoopConfiguration()
path = sc._jvm.org.apache.hadoop.fs.Path(ret)
fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config)
Expand All @@ -191,13 +191,13 @@ def _get_jvm(spark):
return spark.sparkContext._jvm

def spark_jvm():
return _get_jvm(spark)
return _get_jvm(get_spark_i_know_what_i_am_doing())

class TpchRunner:
def __init__(self, tpch_format, tpch_path):
self.tpch_format = tpch_format
self.tpch_path = tpch_path
self.setup(spark)
self.setup(get_spark_i_know_what_i_am_doing())

def setup(self, spark):
jvm_session = _get_jvm_session(spark)
Expand All @@ -210,6 +210,7 @@ def setup(self, spark):
formats.get(self.tpch_format)(jvm_session, self.tpch_path)

def do_test_query(self, query):
spark = get_spark_i_know_what_i_am_doing()
jvm_session = _get_jvm_session(spark)
jvm = _get_jvm(spark)
tests = {
Expand Down Expand Up @@ -256,7 +257,7 @@ class TpcxbbRunner:
def __init__(self, tpcxbb_format, tpcxbb_path):
self.tpcxbb_format = tpcxbb_format
self.tpcxbb_path = tpcxbb_path
self.setup(spark)
self.setup(get_spark_i_know_what_i_am_doing())

def setup(self, spark):
jvm_session = _get_jvm_session(spark)
Expand All @@ -269,6 +270,7 @@ def setup(self, spark):
formats.get(self.tpcxbb_format)(jvm_session,self.tpcxbb_path)

def do_test_query(self, query):
spark = get_spark_i_know_what_i_am_doing()
jvm_session = _get_jvm_session(spark)
jvm = _get_jvm(spark)
tests = {
Expand Down
10 changes: 5 additions & 5 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pyspark.sql.functions as f
import pytest
import random
from spark_session import spark, is_tz_utc
from spark_session import is_tz_utc
import sre_yield
import struct

Expand Down Expand Up @@ -489,8 +489,8 @@ def gen_array():
return [self._child_gen.gen() for _ in range(0, length)]
self._start(rand, gen_array)

def skip_if_not_utc(spark):
if (not is_tz_utc(spark)):
def skip_if_not_utc():
if (not is_tz_utc()):
pytest.skip('The java system time zone is not set to UTC')

def gen_df(spark, data_gen, length=2048, seed=0):
Expand All @@ -504,7 +504,7 @@ def gen_df(spark, data_gen, length=2048, seed=0):

# Before we get too far we need to verify that we can run with timestamps
if src.contains_ts():
skip_if_not_utc(spark)
skip_if_not_utc()

rand = random.Random(seed)
src.start(rand)
Expand All @@ -525,7 +525,7 @@ def _gen_scalars_common(data_gen, count, seed=0):

# Before we get too far we need to verify that we can run with timestamps
if src.contains_ts():
skip_if_not_utc(spark)
skip_if_not_utc()

rand = random.Random(seed)
src.start(rand)
Expand Down
10 changes: 5 additions & 5 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyspark.sql.types import *
from marks import *
import pyspark.sql.functions as f
from spark_session import with_cpu_session, with_spark_session
from spark_session import with_spark_session

_no_nans_float_conf = {'spark.rapids.sql.variableFloatAgg.enabled': 'true',
'spark.rapids.sql.hasNans': 'false',
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_hash_count_with_filter(data_gen, conf):
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_multiple_filters(data_gen, conf):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark : gen_df(spark, data_gen, length=100)),
lambda spark : gen_df(spark, data_gen, length=100),
"hash_agg_table",
'select count(a) filter (where c > 50),' +
'count(b) filter (where c > 100),' +
Expand All @@ -296,7 +296,7 @@ def test_hash_multiple_filters(data_gen, conf):
@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
def test_hash_multiple_filters_fail(data_gen):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark : gen_df(spark, data_gen, length=100)),
lambda spark : gen_df(spark, data_gen, length=100),
"hash_agg_table",
'select avg(b) filter (where b > 20) from hash_agg_table group by a',
_no_nans_float_conf_partial)
Expand All @@ -318,7 +318,7 @@ def test_hash_query_max_bug(data_gen):
_grpkey_doubles_with_nan_zero_grouping_keys], ids=idfn)
def test_hash_agg_with_nan_keys(data_gen):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark : gen_df(spark, data_gen, length=1024)),
lambda spark : gen_df(spark, data_gen, length=1024),
"hash_agg_table",
'select a, '
'count(*) as count_stars, '
Expand All @@ -342,7 +342,7 @@ def test_hash_agg_with_nan_keys(data_gen):
@pytest.mark.parametrize('data_gen', [ _grpkey_doubles_with_nan_zero_grouping_keys], ids=idfn)
def test_count_distinct_with_nan_floats(data_gen):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark : gen_df(spark, data_gen, length=1024)),
lambda spark : gen_df(spark, data_gen, length=1024),
"hash_agg_table",
'select a, count(distinct b) as count_distinct_bees from hash_agg_table group by a',
_no_nans_float_conf)
Expand Down
11 changes: 10 additions & 1 deletion integration_tests/src/main/python/spark_init_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,14 @@ def _spark__init():
_s.sparkContext.setLogLevel("WARN")
return _s

spark = _spark__init()
_spark = _spark__init()

def get_spark_i_know_what_i_am_doing():
"""
Get the current SparkSession.
This should almost never be called directly instead you should call
with_spark_session, with_cpu_session, or with_gpu_session for spark_session.
This is to guarantee that the session and it's config is setup in a repeatable way.
"""
return _spark

29 changes: 18 additions & 11 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

from conftest import is_allowing_any_non_gpu, get_non_gpu_allowed
from pyspark.sql import SparkSession
from spark_init_internal import spark as internal_spark
from pyspark.sql import SparkSession, DataFrame
from spark_init_internal import get_spark_i_know_what_i_am_doing

def _from_scala_map(scala_map):
ret = {}
Expand All @@ -25,12 +25,12 @@ def _from_scala_map(scala_map):
ret[key] = scala_map.get(key).get()
return ret

spark = internal_spark
_spark = get_spark_i_know_what_i_am_doing()
# Have to reach into a private member to get access to the API we need
_orig_conf = _from_scala_map(spark.conf._jconf.getAll())
_orig_conf = _from_scala_map(_spark.conf._jconf.getAll())
_orig_conf_keys = _orig_conf.keys()

def is_tz_utc(spark=spark):
def is_tz_utc(spark=_spark):
"""
true if the tz is UTC else false
"""
Expand All @@ -42,25 +42,32 @@ def is_tz_utc(spark=spark):

def _set_all_confs(conf):
for key, value in conf.items():
if spark.conf.get(key, None) != value:
spark.conf.set(key, value)
if _spark.conf.get(key, None) != value:
_spark.conf.set(key, value)

def reset_spark_session_conf():
"""Reset all of the configs for a given spark session."""
_set_all_confs(_orig_conf)
#We should clear the cache
spark.catalog.clearCache()
_spark.catalog.clearCache()
# Have to reach into a private member to get access to the API we need
current_keys = _from_scala_map(spark.conf._jconf.getAll()).keys()
current_keys = _from_scala_map(_spark.conf._jconf.getAll()).keys()
for key in current_keys:
if key not in _orig_conf_keys:
spark.conf.unset(key)
_spark.conf.unset(key)

def _check_for_proper_return_values(something):
"""We don't want to return an DataFrame or Dataset from a with_spark_session. You will not get what you expect"""
if (isinstance(something, DataFrame)):
raise RuntimeError("You should never return a DataFrame from a with_*_session, you will not get the results that you expect")

def with_spark_session(func, conf={}):
"""Run func that takes a spark session as input with the given configs set."""
reset_spark_session_conf()
_set_all_confs(conf)
return func(spark)
ret = func(_spark)
_check_for_proper_return_values(ret)
return ret

def with_cpu_session(func, conf={}):
"""Run func that takes a spark session as input with the given configs set on the CPU."""
Expand Down
9 changes: 4 additions & 5 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from data_gen import *
from pyspark.sql.types import *
from marks import *
from spark_session import with_cpu_session

_grpkey_longs_with_no_nulls = [
('a', RepeatSeqGen(LongGen(nullable=False), length=20)),
Expand Down Expand Up @@ -48,7 +47,7 @@
_grpkey_longs_with_nullable_timestamps], ids=idfn)
def test_window_aggs_for_rows(data_gen):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark : gen_df(spark, data_gen, length=2048)),
lambda spark : gen_df(spark, data_gen, length=2048),
"window_agg_table",
'select '
' sum(c) over '
Expand All @@ -72,7 +71,7 @@ def test_window_aggs_for_rows(data_gen):
_grpkey_longs_with_nullable_timestamps], ids=idfn)
def test_window_aggs_for_ranges(data_gen):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark: gen_df(spark, data_gen, length=2048)),
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'select '
' sum(c) over '
Expand Down Expand Up @@ -102,7 +101,7 @@ def test_window_aggs_for_ranges(data_gen):
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_timestamps], ids=idfn)
def test_window_aggs_for_ranges_of_dates(data_gen):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark: gen_df(spark, data_gen, length=2048)),
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'select '
' sum(c) over '
Expand All @@ -118,7 +117,7 @@ def test_window_aggs_for_ranges_of_dates(data_gen):
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls], ids=idfn)
def test_window_aggs_for_rows_count_non_null(data_gen):
assert_gpu_and_cpu_are_equal_sql(
with_cpu_session(lambda spark: gen_df(spark, data_gen, length=2048)),
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'select '
' count(c) over '
Expand Down