Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-762] Add complex types support for ColumnarSortExec (#763)
Browse files Browse the repository at this point in the history
* Add complex types support for ColumnarSortExec

* Add struct/map support

* Fix clang format

* Add Key check and fix Attr parse issue

* Fix remaining UTs

* Correct typo

* Correct Typos

* Add enable/disable config in ColumnarSort

* Correct Error message
  • Loading branch information
zhixingheyi-tian authored Mar 22, 2022
1 parent 1f25067 commit df1da54
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ public static ArrowWritableColumnVector[] allocateColumns(
int capacity, StructType schema) {
String timeZoneId = SparkSchemaUtils.getLocalTimezoneID();
Schema arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId);
System.setProperty("arrow.struct.conflict.policy", AbstractStructVector.ConflictPolicy.CONFLICT_APPEND.toString());
VectorSchemaRoot new_root =
VectorSchemaRoot.create(arrowSchema, SparkMemoryUtils.contextAllocator());

List<FieldVector> fieldVectors = new_root.getFieldVectors();
ArrowWritableColumnVector[] vectors =
new ArrowWritableColumnVector[fieldVectors.size()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ object RowToColumnConverter {
case StringType => StringConverter
case BinaryType => BinaryConverter
case CalendarIntervalType => CalendarConverter
case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType, nullable))
case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType, at.containsNull))
case st: StructType => new StructConverter(st.fields.map(
(f) => getConverterForType(f.dataType, f.nullable)))
case dt: DecimalType => new DecimalConverter(dt)
case mt: MapType => new MapConverter(getConverterForType(mt.keyType, nullable),
getConverterForType(mt.valueType, nullable))
case mt: MapType => new MapConverter(getConverterForType(mt.keyType, false),
getConverterForType(mt.valueType, mt.valueContainsNull))
case unknown => throw new UnsupportedOperationException(
s"Type $unknown not supported")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils, Utils}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types._

import scala.util.control.Breaks.{break, breakable}

Expand Down Expand Up @@ -92,19 +92,32 @@ case class ColumnarSortExec(
buildCheck()

def buildCheck(): Unit = {
val columnarConf: GazellePluginConfig = GazellePluginConfig.getSessionConf
// check types
for (attr <- output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
if (!columnarConf.enableComplexType) {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} else {
ConverterUtils.checkIfComplexTypeSupported(attr.dataType)
}
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarSorter.")
s"${attr.dataType} is not supported in ColumnarSortExec.")
}
}
// check expr
sortOrder.toList.map(expr => {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr.child)
val attr = ConverterUtils.getAttrFromExpr(expr.child, true)
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarSortExec keys.")
}
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class ColumnarSorter(
.toArray)
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr, true)
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
attr.dataType)
})
val arrowSchema = new Schema(outputFieldList.asJava)
var sort_iterator: BatchIterator = _
Expand Down Expand Up @@ -180,8 +180,8 @@ object ColumnarSorter extends Logging {
def checkIfKeyFound(sortOrder: Seq[SortOrder], outputAttributes: Seq[Attribute]): Unit = {
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr, true)
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
attr.dataType)
})
sortOrder.toList.foreach(sort => {
val attr = ConverterUtils.getAttrFromExpr(sort.child, true)
Expand Down Expand Up @@ -346,8 +346,8 @@ object ColumnarSorter extends Logging {
_sparkConf: SparkConf): (ExpressionTree, Schema) = {
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr, true)
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
attr.dataType)
})
val retType = Field.nullable("res", new ArrowType.Int(32, true))
val sort_node =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,25 @@ object ConverterUtils extends Logging {
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def checkIfComplexTypeSupported(dt: DataType): Unit = dt match {
case d: ArrayType =>
case d: StructType =>
case d: MapType =>
case d: BooleanType =>
case d: ByteType =>
case d: ShortType =>
case d: IntegerType =>
case d: LongType =>
case d: FloatType =>
case d: DoubleType =>
case d: StringType =>
case d: DateType =>
case d: DecimalType =>
case d: TimestampType =>
case _ =>
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def checkIfNestTypeSupported(dt: DataType): Unit = dt match {
case d: ArrayType => checkIfTypeSupported(d.elementType)
case d: StructType =>
Expand Down Expand Up @@ -546,7 +565,7 @@ object ConverterUtils extends Logging {
new Field(
name,
new FieldType(nullable, ArrowType.List.INSTANCE, null),
Lists.newArrayList(createArrowField(s"${name}_${dt}", at.elementType)))
Lists.newArrayList(createArrowField("element", at.elementType)))
case st: StructType =>
val fieldlist = new util.ArrayList[Field]
var structField = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,13 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] {
if (columnarConf.enableArrowColumnarToRow) {
val child = replaceWithColumnarPlan(plan.child)
logDebug(s"ColumnarPostOverrides ArrowColumnarToRowExec(${child.getClass})")
new ArrowColumnarToRowExec(child)
try {
new ArrowColumnarToRowExec(child)
} catch {
case _: Throwable =>
logInfo("ArrowColumnarToRowExec: Falling back to ColumnarToRow...")
ColumnarToRowExec(child)
}
} else {
val children = plan.children.map(replaceWithColumnarPlan)
plan.withNewChildren(children)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.execution

import java.nio.file.Files

import com.intel.oap.tpc.util.TPCRunner
import org.apache.log4j.{Level, LogManager}
import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.test.SharedSparkSession

class PayloadSuite extends QueryTest with SharedSparkSession {

private val MAX_DIRECT_MEMORY = "5000m"
private var runner: TPCRunner = _

private var lPath: String = _
private var rPath: String = _
private val scale = 100

override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
conf.set("spark.memory.offHeap.size", String.valueOf(MAX_DIRECT_MEMORY))
.set("spark.plugins", "com.intel.oap.GazellePlugin")
.set("spark.sql.codegen.wholeStage", "false")
.set("spark.sql.sources.useV1SourceList", "")
.set("spark.oap.sql.columnar.tmp_dir", "/tmp/")
.set("spark.sql.columnar.sort.broadcastJoin", "true")
.set("spark.storage.blockManagerSlaveTimeoutMs", "3600000")
.set("spark.executor.heartbeatInterval", "3600000")
.set("spark.network.timeout", "3601s")
.set("spark.oap.sql.columnar.preferColumnar", "true")
.set("spark.oap.sql.columnar.sortmergejoin", "true")
.set("spark.sql.columnar.codegen.hashAggregate", "false")
.set("spark.sql.columnar.sort", "true")
.set("spark.sql.columnar.window", "true")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.unsafe.exceptionOnMemoryLeak", "false")
.set("spark.network.io.preferDirectBufs", "false")
.set("spark.sql.sources.useV1SourceList", "arrow,parquet")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.oap.sql.columnar.sortmergejoin.lazyread", "true")
.set("spark.oap.sql.columnar.autorelease", "false")
.set("spark.sql.shuffle.partitions", "50")
.set("spark.sql.adaptive.coalescePartitions.initialPartitionNum", "5")
.set("spark.oap.sql.columnar.shuffledhashjoin.buildsizelimit", "200m")
// .set("spark.oap.sql.columnar.rowtocolumnar", "false")
// .set("spark.oap.sql.columnar.columnartorow", "false")
return conf
}

override def beforeAll(): Unit = {
super.beforeAll()
LogManager.getRootLogger.setLevel(Level.WARN)

val lfile = Files.createTempFile("", ".parquet").toFile
lfile.deleteOnExit()
lPath = lfile.getAbsolutePath
spark.range(2).select(col("id"), expr("1").as("kind"),
expr("1").as("key"),
expr("array(1, 2)").as("arr_field"),
expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"),
expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"),
expr("array(map(1, 2), map(3,4))").as("arr_map_field"),
expr("struct(1, 2)").as("struct_field"),
expr("struct(1, struct(1, 2))").as("struct_struct_field"),
expr("struct(1, array(1, 2))").as("struct_array_field"),
expr("map(1, 2)").as("map_field"),
expr("map(1, map(3,4))").as("map_map_field"),
expr("map(1, array(1, 2))").as("map_arr_field"),
expr("map(struct(1, 2), 2)").as("map_struct_field"))
.coalesce(1)
.write
.format("parquet")
.mode("overwrite")
.parquet(lPath)

val rfile = Files.createTempFile("", ".parquet").toFile
rfile.deleteOnExit()
rPath = rfile.getAbsolutePath
spark.range(2).select(col("id"), expr("id % 2").as("kind"),
expr("id % 2").as("key"),
expr("array(1, 2)").as("arr_field"),
expr("struct(1, 2)").as("struct_field"))
.coalesce(1)
.write
.format("parquet")
.mode("overwrite")
.parquet(rPath)

spark.catalog.createTable("ltab", lPath, "arrow")
spark.catalog.createTable("rtab", rPath, "arrow")
}

test("Test Array in Sort") {
// spark.sql("SELECT * FROM ltab").printSchema()
val df = spark.sql("SELECT ltab.arr_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count == 2)
}

test("Test Nest Array in Sort") {
val df = spark.sql("SELECT ltab.arr_arr_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count == 2)
}

test("Test Nest Array in multi-keys Sort") {
val df = spark.sql("SELECT ltab.arr_arr_field FROM ltab order by ltab.kind, ltab.key")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count == 2)
}

test("Test Struct in Sort") {
val df = spark.sql("SELECT ltab.struct_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Nest Struct in Sort") {
val df = spark.sql("SELECT ltab.struct_struct_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Struct_Array in Sort") {
val df = spark.sql("SELECT ltab.struct_array_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map in Sort") {
val df = spark.sql("SELECT ltab.map_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Nest Map in Sort") {
val df = spark.sql("SELECT ltab.map_map_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map_Array in Sort") {
val df = spark.sql("SELECT ltab.map_arr_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map_Struct in Sort") {
val df = spark.sql("SELECT ltab.map_struct_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

override def afterAll(): Unit = {
super.afterAll()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
lPath = lfile.getAbsolutePath
spark.range(2).select(col("id"), expr("1").as("kind"),
expr("array(1, 2)").as("arr_field"),
expr("array(\"hello\", \"world\")").as("arr_str_field"),
expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"),
expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"),
expr("array(map(1, 2), map(3,4))").as("arr_map_field"),
Expand Down Expand Up @@ -202,6 +203,15 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined)
}

test("Test Array String in Shuffle split") {
val df = spark.sql("SELECT ltab.arr_str_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.printSchema()
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count == 2)
}

override def afterAll(): Unit = {
super.afterAll()
}
Expand Down
Loading

0 comments on commit df1da54

Please sign in to comment.