-
Notifications
You must be signed in to change notification settings - Fork 236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support float case of format_number with format_float kernel #9790
Changes from 10 commits
ada8d7a
238c061
bc08d57
a375433
92845cc
40450e8
7140c9f
fcca63c
6cac1a9
4fb4691
7729c32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) 2020-2023, NVIDIA CORPORATION. | ||
# Copyright (c) 2020-2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
|
@@ -823,24 +823,19 @@ def test_format_number_supported(data_gen): | |
) | ||
|
||
float_format_number_conf = {'spark.rapids.sql.formatNumberFloat.enabled': 'true'} | ||
format_number_float_gens = [DoubleGen(min_exp=-300, max_exp=15)] | ||
|
||
@pytest.mark.parametrize('data_gen', format_number_float_gens, ids=idfn) | ||
def test_format_number_float_limited(data_gen): | ||
format_number_float_gens = [(DoubleGen(), 0.05), (FloatGen(), 0.5), | ||
(SetValuesGen(FloatType(), [float('nan'), float('inf'), float('-inf'), 0.0, -0.0]), 1.0), | ||
(SetValuesGen(DoubleType(), [float('nan'), float('inf'), float('-inf'), 0.0, -0.0]), 1.0)] | ||
# The actual error rate is 2% for double and 42% for float | ||
# set threshold to 5% and 50% to avoid bad luck | ||
|
||
@pytest.mark.parametrize('data_gen,max_err', format_number_float_gens, ids=idfn) | ||
def test_format_number_float_limited(data_gen, max_err): | ||
gen = data_gen | ||
assert_gpu_and_cpu_are_equal_collect( | ||
lambda spark: unary_op_df(spark, gen).selectExpr( | ||
'format_number(a, 5)'), | ||
conf = float_format_number_conf | ||
) | ||
|
||
# format_number for float/double is disabled by default due to compatibility issue | ||
# GPU will generate result with less precision than CPU | ||
@allow_non_gpu('ProjectExec') | ||
@pytest.mark.parametrize('data_gen', [float_gen, double_gen], ids=idfn) | ||
def test_format_number_float_fallback(data_gen): | ||
assert_gpu_fallback_collect( | ||
lambda spark: unary_op_df(spark, data_gen).selectExpr( | ||
'format_number(a, 5)'), | ||
'FormatNumber' | ||
) | ||
cpu = with_cpu_session(lambda spark: unary_op_df(spark, gen).selectExpr('*', | ||
'format_number(a, 5)').collect(), conf = float_format_number_conf) | ||
gpu = with_gpu_session(lambda spark: unary_op_df(spark, gen).selectExpr('*', | ||
'format_number(a, 5)').collect(), conf = float_format_number_conf) | ||
mismatched = sum(x[0] != x[1] for x in zip(cpu, gpu)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I preferred the version that checked that when we parsed them back to a float the numbers were within the error bounds instead of saying that we cannot be wrong more than some set percentage. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
all_values = len(cpu) | ||
assert mismatched / all_values <= max_err |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: also uses
ryu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done