Skip to content

Commit

Permalink
Unshim many expressions [databricks] (#5036)
Browse files Browse the repository at this point in the history
* Unshim GpuRegExpReplace

Signed-off-by: Jason Lowe <[email protected]>

* Unshim Lead and Lag

Signed-off-by: Jason Lowe <[email protected]>

* Unshim GpuGetArrayItem

Signed-off-by: Jason Lowe <[email protected]>

* Unshim GetMapValue

Signed-off-by: Jason Lowe <[email protected]>

* Unshim ElementAt

Signed-off-by: Jason Lowe <[email protected]>

* Remove GpuTimeSub

* Update generated docs

* Fix unused import in Spark31XdbShims
  • Loading branch information
jlowe authored Mar 25, 2022
1 parent fa4c63c commit e845b87
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 543 deletions.
2 changes: 1 addition & 1 deletion docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -5060,7 +5060,7 @@ are limited.
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>If it's map, only primitive key types supported.;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>If it's map, only primitive key types are supported.;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -297,139 +297,6 @@ abstract class Spark31XShims extends SparkShims with Spark31Xuntil33XShims with
(a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) {
// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false)
}),
GpuOverrides.expr[RegExpReplace](
"String replace using a regular expression pattern",
ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING),
ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("pos", TypeSig.lit(TypeEnum.INT)
.withPsNote(TypeEnum.INT, "only a value of 1 is supported"),
TypeSig.lit(TypeEnum.INT)))),
(a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)),
// Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}),
// Spark 3.1.1-specific LAG expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}
}),
GpuOverrides.expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT)),
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r){
override def convertToGpu(arr: Expression, ordinal: Expression): GpuExpression =
GpuGetArrayItem(arr, ordinal, in.failOnError)
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, in.failOnError)
}),
GpuOverrides.expr[ElementAt](
"Returns element of array at given(1-based) index in value if column is array. " +
"Returns value for the given key in value if column is map.",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL))
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT))
case _ => throw new IllegalStateException("Only Array or Map is supported as input.")
}
checks.tag(this)
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuElementAt(lhs, rhs, SQLConf.get.ansiEnabled)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.window.WindowExecBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.execution.python._
import org.apache.spark.sql.rapids.shims._
Expand Down Expand Up @@ -163,139 +162,6 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
(a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) {
// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false)
}),
GpuOverrides.expr[RegExpReplace](
"String replace using a regular expression pattern",
ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING),
ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("pos", TypeSig.lit(TypeEnum.INT)
.withPsNote(TypeEnum.INT, "only a value of 1 is supported"),
TypeSig.lit(TypeEnum.INT)))),
(a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)),
// Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}),
// Spark 3.1.1-specific LAG expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}
}),
GpuOverrides.expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT)),
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r){
override def convertToGpu(arr: Expression, ordinal: Expression): GpuExpression =
GpuGetArrayItem(arr, ordinal, in.failOnError)
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, in.failOnError)
}),
GpuOverrides.expr[ElementAt](
"Returns element of array at given(1-based) index in value if column is array. " +
"Returns value for the given key in value if column is map.",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL))
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT))
case _ => throw new IllegalStateException("Only Array or Map is supported as input.")
}
checks.tag(this)
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuElementAt(lhs, rhs, SQLConf.get.ansiEnabled)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

Expand Down
Loading

0 comments on commit e845b87

Please sign in to comment.