Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unshim many expressions [databricks] #5036

Merged
merged 8 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -5060,7 +5060,7 @@ are limited.
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>If it's map, only primitive key types supported.;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>If it's map, only primitive key types are supported.;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -297,139 +297,6 @@ abstract class Spark31XShims extends SparkShims with Spark31Xuntil33XShims with
(a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) {
// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false)
}),
GpuOverrides.expr[RegExpReplace](
"String replace using a regular expression pattern",
ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING),
ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("pos", TypeSig.lit(TypeEnum.INT)
.withPsNote(TypeEnum.INT, "only a value of 1 is supported"),
TypeSig.lit(TypeEnum.INT)))),
(a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)),
// Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}),
// Spark 3.1.1-specific LAG expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}
}),
GpuOverrides.expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT)),
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r){
override def convertToGpu(arr: Expression, ordinal: Expression): GpuExpression =
GpuGetArrayItem(arr, ordinal, in.failOnError)
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, in.failOnError)
}),
GpuOverrides.expr[ElementAt](
"Returns element of array at given(1-based) index in value if column is array. " +
"Returns value for the given key in value if column is map.",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL))
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT))
case _ => throw new IllegalStateException("Only Array or Map is supported as input.")
}
checks.tag(this)
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuElementAt(lhs, rhs, SQLConf.get.ansiEnabled)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.window.WindowExecBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.execution.python._
import org.apache.spark.sql.rapids.shims._
Expand Down Expand Up @@ -163,139 +162,6 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
(a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) {
// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false)
}),
GpuOverrides.expr[RegExpReplace](
"String replace using a regular expression pattern",
ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("str", TypeSig.STRING, TypeSig.STRING),
ParamCheck("regex", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("rep", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
ParamCheck("pos", TypeSig.lit(TypeEnum.INT)
.withPsNote(TypeEnum.INT, "only a value of 1 is supported"),
TypeSig.lit(TypeEnum.INT)))),
(a, conf, p, r) => new GpuRegExpReplaceMeta(a, conf, p, r)),
// Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}),
// Spark 3.1.1-specific LAG expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
),
(lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}
}),
GpuOverrides.expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT)),
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r){
override def convertToGpu(arr: Expression, ordinal: Expression): GpuExpression =
GpuGetArrayItem(arr, ordinal, in.failOnError)
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, in.failOnError)
}),
GpuOverrides.expr[ElementAt](
"Returns element of array at given(1-based) index in value if column is array. " +
"Returns value for the given key in value if column is map.",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL))
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.INT, TypeSig.INT))
case _ => throw new IllegalStateException("Only Array or Map is supported as input.")
}
checks.tag(this)
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuElementAt(lhs, rhs, SQLConf.get.ansiEnabled)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

Expand Down
Loading