-
Notifications
You must be signed in to change notification settings - Fork 236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AnsiMode support for GetArrayItem GetMapValue and ElementAt for Spark 3.1.1 #2350
Conversation
Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
- doc refine Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
- for ansi_mode=true Signed-off-by: Allen Xu <[email protected]>
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Outdated
Show resolved
Hide resolved
Signed-off-by: Allen Xu <[email protected]>
@@ -17,13 +17,11 @@ | |||
package com.nvidia.spark.rapids.shims.spark311 | |||
|
|||
import java.nio.ByteBuffer | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scala style should be returning errors for these being removed,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
} | ||
|
||
/** | ||
* Returns the field at `ordinal` in the Array `child`. | ||
* | ||
* We need to do type checking here as `ordinal` expression maybe unresolved. | ||
*/ | ||
case class GpuGetArrayItem(child: Expression, ordinal: Expression) | ||
case class GpuGetArrayItem(child: Expression, ordinal: Expression, failOnError: Boolean = false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I personally would prefer to not have a default value for failOnError
, just so we are explicit about it everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -87,15 +87,15 @@ class GpuGetArrayItemMeta( | |||
override def convertToGpu( | |||
arr: Expression, | |||
ordinal: Expression): GpuExpression = | |||
GpuGetArrayItem(arr, ordinal) | |||
GpuGetArrayItem(arr, ordinal, SQLConf.get.ansiEnabled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wrong. In version prior to 3.1.1 the default value should be false, not based off of the ansiEnabled
config. Otherwise we will fail when spark does not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set to false and add comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more bugs that I saw when I took a closer look
if (ordinal.isValid && ordinal.getInt >= 0) { | ||
lhs.getBase.extractListElement(ordinal.getInt) | ||
if (ordinal.isValid) { | ||
val minNumElements = lhs.getBase.countElements.min.getInt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leaks a ColumnVector and a Scalar. The result of countElements
must be closed and so much the result of min
lhs.getBase.extractListElement(ordinal.getInt) | ||
if (ordinal.isValid) { | ||
val minNumElements = lhs.getBase.countElements.min.getInt | ||
if ( (ordinal.getInt < 0 || minNumElements < ordinal.getInt + 1) && failOnError) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is supposed to happen with a null
array? countElements
will return a null
for a null
array and min skips over nulls unless all of them are null. So is null[1]
an error in ansi mode or not? If it is then this code will completely miss it. If it is not an error, then we will get an exception, or possibly data corruption when we try to get the int value from the result of min
if the batch is all null arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a simple test about this problem:
// row data is like:
+------------------------+
|col_1 |
+------------------------+
|null |
|[Java, Scala, C++, a, b]|
+------------------------+
df.select(col("col_1")[2]).show()
+--------+
|col_1[2]|
+--------+
| null|
| C++|
+--------+
CPU and GPU will return the same result for this case.
But like you said, error occurs when the column contains all null arrays.
the countElements
works well, but the min
will return 0 for this case and will throw the exception here.
(CPU will still return null for them)
For the all_null case, I plan to compare getNullCount
with getRowCount
. But the getNullCount says it's a very expensive op. Do you think we should apply this method here?
@revans2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking better. My main concern right now is that I don't see how we are setting failOnError
properly for Spark3.1.1+ I think we need to either check the version number in the meta, which is brittle, or preferably put the rule into the Shim layer so Spark 3.1.1 can override the behavior.
'spark.sql.legacy.allowNegativeScaleOfDecimal': True}, | ||
error_message='java.lang.ArrayIndexOutOfBoundsException') | ||
|
||
@pytest.mark.skipif(not is_before_spark_311(), reason="This will throw exception only in Spark 3.1.1+") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason looks like a copy and paste. Under this test is in not clear what it means. It might be nice to update both to say something like "In Spark 3.1.1+ ANSI mode array index throws on out of range indexes"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking better. My main concern right now is that I don't see how we are setting
failOnError
properly for Spark3.1.1+ I think we need to either check the version number in the meta, which is brittle, or preferably put the rule into the Shim layer so Spark 3.1.1 can override the behavior.
Agree, currently I set failOnError
to false by default in the convertToGpu for all Spark before 311, because ANSI mode has no effects on them. It has the same behavior as Spark311+ANSI mode=false.
For Spark311, I put real ANSI mode config in the convertToGpu to change the behavior for ANSI=true or false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed that. Looks good then all I have is this nit.
Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
Signed-off-by: Allen Xu <[email protected]>
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
Outdated
Show resolved
Hide resolved
Signed-off-by: Allen Xu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, only some nits.
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
Outdated
Show resolved
Hide resolved
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
Outdated
Show resolved
Hide resolved
build |
build |
… 3.1.1 (NVIDIA#2350) To match the behavior of GetArrayItem, GetMapValue and ElementAt with CPU in Spark 3.1.1. Signed-off-by: Allen Xu <[email protected]>
… 3.1.1 (NVIDIA#2350) To match the behavior of GetArrayItem, GetMapValue and ElementAt with CPU in Spark 3.1.1. Signed-off-by: Allen Xu <[email protected]>
Fix #2272 and #2276
This PR adds shim support for GetArrayItem, GetMapValue and ElementAt to match the CPU behavior on Spark 3.1.1.
This relies on rapidsai/cudf#8209 and #2260.
More:
This adds an parameter
all_null
for ArrayGen in the data_gen part in integration_test.This parameter is used to create null array instead of empty array.
Null array is used to create a corner case for GetArrayItem:
For a dataframe like:
df.select(col("col_1")[2]).show()
will return without exception: