Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V] Remove complex type fallback for parquet #6712

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,55 +98,15 @@ object VeloxBackendSettings extends BackendSettingsApi {
}
}

val parquetTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = {
case StructField(_, arrayType: ArrayType, _, _) =>
arrayType.simpleString + " is forced to fallback."
case StructField(_, mapType: MapType, _, _) =>
mapType.simpleString + " is forced to fallback."
case StructField(_, structType: StructType, _, _) =>
structType.simpleString + " is forced to fallback."
case StructField(_, timestampType: TimestampType, _, _)
if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled =>
timestampType.simpleString + " is forced to fallback."
}
val orcTypeValidatorWithComplexTypeFallback: PartialFunction[StructField, String] = {
case StructField(_, arrayType: ArrayType, _, _) =>
arrayType.simpleString + " is forced to fallback."
case StructField(_, mapType: MapType, _, _) =>
mapType.simpleString + " is forced to fallback."
case StructField(_, structType: StructType, _, _) =>
structType.simpleString + " is forced to fallback."
case StructField(_, stringType: StringType, _, metadata)
if isCharType(stringType, metadata) =>
CharVarcharUtils.getRawTypeString(metadata) + " not support"
case StructField(_, TimestampType, _, _) => "TimestampType not support"
}
format match {
case ParquetReadFormat =>
val typeValidator: PartialFunction[StructField, String] = {
// Parquet scan of nested array with struct/array as element type is unsupported in Velox.
case StructField(_, arrayType: ArrayType, _, _)
if arrayType.elementType.isInstanceOf[StructType] =>
"StructType as element in ArrayType"
case StructField(_, arrayType: ArrayType, _, _)
if arrayType.elementType.isInstanceOf[ArrayType] =>
"ArrayType as element in ArrayType"
// Parquet scan of nested map with struct as key type,
// or array type as value type is not supported in Velox.
case StructField(_, mapType: MapType, _, _) if mapType.keyType.isInstanceOf[StructType] =>
"StructType as Key in MapType"
case StructField(_, mapType: MapType, _, _)
if mapType.valueType.isInstanceOf[ArrayType] =>
"ArrayType as Value in MapType"
// Parquet timestamp is not fully supported yet
case StructField(_, TimestampType, _, _)
if GlutenConfig.getConf.forceParquetTimestampTypeScanFallbackEnabled =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the int64 timestamp in the parquet file does not seem to be supported yet.
facebookincubator/velox#8325

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I keep this check and we may have a try by setting forceParquetTimestampTypeScanFallbackEnabled to false once the related support merged.

"TimestampType"
}
if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
validateTypes(typeValidator)
} else {
validateTypes(parquetTypeValidatorWithComplexTypeFallback)
}
validateTypes(typeValidator)
case DwrfReadFormat => ValidationResult.succeeded
case OrcReadFormat =>
if (!GlutenConfig.getConf.veloxOrcScanEnabled) {
Expand All @@ -170,11 +130,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
CharVarcharUtils.getRawTypeString(metadata) + " not support"
case StructField(_, TimestampType, _, _) => "TimestampType not support"
}
if (!GlutenConfig.getConf.forceComplexTypeScanFallbackEnabled) {
validateTypes(typeValidator)
} else {
validateTypes(orcTypeValidatorWithComplexTypeFallback)
}
validateTypes(typeValidator)
}
case _ => ValidationResult.failed(s"Unsupported file format for $format.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
childTransformer: ExpressionTransformer,
ordinal: Int,
original: GetStructField): ExpressionTransformer = {
VeloxGetStructFieldTransformer(substraitExprName, childTransformer, original)
VeloxGetStructFieldTransformer(substraitExprName, childTransformer, ordinal, original)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,20 @@ case class VeloxNamedStructTransformer(
case class VeloxGetStructFieldTransformer(
substraitExprName: String,
child: ExpressionTransformer,
ordinal: Int,
original: GetStructField)
extends UnaryExpressionTransformer {
override def doTransform(args: Object): ExpressionNode = {
val childNode = child.doTransform(args)
childNode match {
case node: StructLiteralNode =>
node.getFieldLiteral(original.ordinal)
node.getFieldLiteral(ordinal)
case node: SelectionNode =>
// Append the nested index to selection node.
node.addNestedChildIdx(JInteger.valueOf(original.ordinal))
node.addNestedChildIdx(JInteger.valueOf(ordinal))
case node: NullLiteralNode =>
val nodeType =
node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(original.ordinal)
node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(ordinal)
ExpressionBuilder.makeNullLiteral(nodeType)
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa

sql("CREATE TABLE t2(id INT, l ARRAY<STRUCT<k: INT, v: INT>>) USING PARQUET")
sql("INSERT INTO t2 VALUES(1, ARRAY(STRUCT(1, 100))), (2, ARRAY(STRUCT(2, 200)))")
runQueryAndCompare("SELECT first(l) FROM t2")(df => checkFallbackOperators(df, 1))
runQueryAndCompare("SELECT first(l) FROM t2")(df => checkFallbackOperators(df, 0))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,22 +427,6 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit
}
}

test("Force complex type scan fallback") {
withSQLConf(("spark.gluten.sql.complexType.scan.fallback.enabled", "true")) {
val df = spark.sql("select struct from type1")
val executedPlan = getExecutedPlan(df)
assert(!executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}

test("Force timestamp type scan fallback") {
withSQLConf(("spark.gluten.sql.parquet.timestampType.scan.fallback.enabled", "true")) {
val df = spark.sql("select timestamp from type1")
val executedPlan = getExecutedPlan(df)
assert(!executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
}
}

test("Decimal type") {
// Validation: BatchScan Project Aggregate Expand Sort Limit
runQueryAndCompare(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import scala.collection.mutable.ArrayBuffer

trait Transformable {
def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer
}
Expand Down Expand Up @@ -345,12 +347,23 @@ object ExpressionConverter extends SQLConfHelper with Logging {
expr => replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap)),
m)
case getStructField: GetStructField =>
// Different backends may have different result.
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
substraitExprName,
replaceWithExpressionTransformer0(getStructField.child, attributeSeq, expressionsMap),
getStructField.ordinal,
getStructField)
try {
val bindRef =
bindGetStructField(getStructField, attributeSeq)
// Different backends may have different result.
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
substraitExprName,
replaceWithExpressionTransformer0(getStructField.child, attributeSeq, expressionsMap),
bindRef.ordinal,
getStructField)
} catch {
case e: IllegalStateException =>
// This situation may need developers to fix, although we just throw the below
// exception to let the corresponding operator fall back.
throw new UnsupportedOperationException(
s"Failed to bind reference for $getStructField: ${e.getMessage}")
}

case getArrayStructFields: GetArrayStructFields =>
GenericExpressionTransformer(
substraitExprName,
Expand Down Expand Up @@ -693,4 +706,49 @@ object ExpressionConverter extends SQLConfHelper with Logging {
}
substraitExprName
}

private def bindGetStructField(
structField: GetStructField,
input: AttributeSeq): BoundReference = {
// get the new ordinal base input
var newOrdinal: Int = -1
val names = new ArrayBuffer[String]
var root: Expression = structField
while (root.isInstanceOf[GetStructField]) {
val curField = root.asInstanceOf[GetStructField]
val name = curField.childSchema.fields(curField.ordinal).name
names += name
root = root.asInstanceOf[GetStructField].child
}
// For map/array type, the reference is correct no matter NESTED_SCHEMA_PRUNING_ENABLED or not
if (!root.isInstanceOf[AttributeReference]) {
return BoundReference(structField.ordinal, structField.dataType, structField.nullable)
}
names += root.asInstanceOf[AttributeReference].name
input.attrs.foreach(
attribute => {
var level = names.size - 1
if (names(level) == attribute.name) {
var candidateFields: Array[StructField] = null
var dtType = attribute.dataType
while (dtType.isInstanceOf[StructType] && level >= 1) {
candidateFields = dtType.asInstanceOf[StructType].fields
level -= 1
val curName = names(level)
for (i <- 0 until candidateFields.length) {
if (candidateFields(i).name == curName) {
dtType = candidateFields(i).dataType
newOrdinal = i
}
}
}
}
})
if (newOrdinal == -1) {
throw new IllegalStateException(
s"Couldn't find $structField in ${input.attrs.mkString("[", ",", "]")}")
} else {
BoundReference(newOrdinal, structField.dataType, structField.nullable)
}
}
}
Loading