Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-762] Add complex types support for ColumnarSortExec #763

Merged
merged 10 commits into from
Mar 22, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ public static ArrowWritableColumnVector[] allocateColumns(
int capacity, StructType schema) {
String timeZoneId = SparkSchemaUtils.getLocalTimezoneID();
Schema arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId);
System.setProperty("arrow.struct.conflict.policy", AbstractStructVector.ConflictPolicy.CONFLICT_APPEND.toString());
VectorSchemaRoot new_root =
VectorSchemaRoot.create(arrowSchema, SparkMemoryUtils.contextAllocator());

List<FieldVector> fieldVectors = new_root.getFieldVectors();
ArrowWritableColumnVector[] vectors =
new ArrowWritableColumnVector[fieldVectors.size()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ object RowToColumnConverter {
case StringType => StringConverter
case BinaryType => BinaryConverter
case CalendarIntervalType => CalendarConverter
case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType, nullable))
case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType, at.containsNull))
case st: StructType => new StructConverter(st.fields.map(
(f) => getConverterForType(f.dataType, f.nullable)))
case dt: DecimalType => new DecimalConverter(dt)
case mt: MapType => new MapConverter(getConverterForType(mt.keyType, nullable),
getConverterForType(mt.valueType, nullable))
case mt: MapType => new MapConverter(getConverterForType(mt.keyType, false),
getConverterForType(mt.valueType, mt.valueContainsNull))
case unknown => throw new UnsupportedOperationException(
s"Type $unknown not supported")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils, Utils}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types._

import scala.util.control.Breaks.{break, breakable}

Expand Down Expand Up @@ -92,19 +92,32 @@ case class ColumnarSortExec(
buildCheck()

def buildCheck(): Unit = {
val columnarConf: GazellePluginConfig = GazellePluginConfig.getSessionConf
// check types
for (attr <- output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
if (!columnarConf.enableComplexType) {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} else {
ConverterUtils.checkIfComplexTypeSupported(attr.dataType)
}
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarSorter.")
s"${attr.dataType} is not supported in ColumnarSortExec.")
}
}
// check expr
sortOrder.toList.map(expr => {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr.child)
val attr = ConverterUtils.getAttrFromExpr(expr.child, true)
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarSortExec keys.")
}
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class ColumnarSorter(
.toArray)
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr, true)
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
attr.dataType)
})
val arrowSchema = new Schema(outputFieldList.asJava)
var sort_iterator: BatchIterator = _
Expand Down Expand Up @@ -180,8 +180,8 @@ object ColumnarSorter extends Logging {
def checkIfKeyFound(sortOrder: Seq[SortOrder], outputAttributes: Seq[Attribute]): Unit = {
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr, true)
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
attr.dataType)
})
sortOrder.toList.foreach(sort => {
val attr = ConverterUtils.getAttrFromExpr(sort.child, true)
Expand Down Expand Up @@ -346,8 +346,8 @@ object ColumnarSorter extends Logging {
_sparkConf: SparkConf): (ExpressionTree, Schema) = {
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr, true)
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
ConverterUtils.createArrowField(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
attr.dataType)
})
val retType = Field.nullable("res", new ArrowType.Int(32, true))
val sort_node =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,25 @@ object ConverterUtils extends Logging {
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def checkIfComplexTypeSupported(dt: DataType): Unit = dt match {
case d: ArrayType =>
case d: StructType =>
case d: MapType =>
case d: BooleanType =>
case d: ByteType =>
case d: ShortType =>
case d: IntegerType =>
case d: LongType =>
case d: FloatType =>
case d: DoubleType =>
case d: StringType =>
case d: DateType =>
case d: DecimalType =>
case d: TimestampType =>
case _ =>
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def checkIfNestTypeSupported(dt: DataType): Unit = dt match {
case d: ArrayType => checkIfTypeSupported(d.elementType)
case d: StructType =>
Expand Down Expand Up @@ -546,7 +565,7 @@ object ConverterUtils extends Logging {
new Field(
name,
new FieldType(nullable, ArrowType.List.INSTANCE, null),
Lists.newArrayList(createArrowField(s"${name}_${dt}", at.elementType)))
Lists.newArrayList(createArrowField("element", at.elementType)))
case st: StructType =>
val fieldlist = new util.ArrayList[Field]
var structField = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,13 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] {
if (columnarConf.enableArrowColumnarToRow) {
val child = replaceWithColumnarPlan(plan.child)
logDebug(s"ColumnarPostOverrides ArrowColumnarToRowExec(${child.getClass})")
new ArrowColumnarToRowExec(child)
try {
new ArrowColumnarToRowExec(child)
} catch {
case _: Throwable =>
logInfo("ArrowColumnarToRowExec: Falling back to ColumnarToRow...")
ColumnarToRowExec(child)
}
} else {
val children = plan.children.map(replaceWithColumnarPlan)
plan.withNewChildren(children)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.execution

import java.nio.file.Files

import com.intel.oap.tpc.util.TPCRunner
import org.apache.log4j.{Level, LogManager}
import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.test.SharedSparkSession

class PayloadSuite extends QueryTest with SharedSparkSession {

private val MAX_DIRECT_MEMORY = "5000m"
private var runner: TPCRunner = _

private var lPath: String = _
private var rPath: String = _
private val scale = 100

override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
conf.set("spark.memory.offHeap.size", String.valueOf(MAX_DIRECT_MEMORY))
.set("spark.plugins", "com.intel.oap.GazellePlugin")
.set("spark.sql.codegen.wholeStage", "false")
.set("spark.sql.sources.useV1SourceList", "")
.set("spark.oap.sql.columnar.tmp_dir", "/tmp/")
.set("spark.sql.columnar.sort.broadcastJoin", "true")
.set("spark.storage.blockManagerSlaveTimeoutMs", "3600000")
.set("spark.executor.heartbeatInterval", "3600000")
.set("spark.network.timeout", "3601s")
.set("spark.oap.sql.columnar.preferColumnar", "true")
.set("spark.oap.sql.columnar.sortmergejoin", "true")
.set("spark.sql.columnar.codegen.hashAggregate", "false")
.set("spark.sql.columnar.sort", "true")
.set("spark.sql.columnar.window", "true")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.unsafe.exceptionOnMemoryLeak", "false")
.set("spark.network.io.preferDirectBufs", "false")
.set("spark.sql.sources.useV1SourceList", "arrow,parquet")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.oap.sql.columnar.sortmergejoin.lazyread", "true")
.set("spark.oap.sql.columnar.autorelease", "false")
.set("spark.sql.shuffle.partitions", "50")
.set("spark.sql.adaptive.coalescePartitions.initialPartitionNum", "5")
.set("spark.oap.sql.columnar.shuffledhashjoin.buildsizelimit", "200m")
// .set("spark.oap.sql.columnar.rowtocolumnar", "false")
// .set("spark.oap.sql.columnar.columnartorow", "false")
return conf
}

override def beforeAll(): Unit = {
super.beforeAll()
LogManager.getRootLogger.setLevel(Level.WARN)

val lfile = Files.createTempFile("", ".parquet").toFile
lfile.deleteOnExit()
lPath = lfile.getAbsolutePath
spark.range(2).select(col("id"), expr("1").as("kind"),
expr("1").as("key"),
expr("array(1, 2)").as("arr_field"),
expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"),
expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"),
expr("array(map(1, 2), map(3,4))").as("arr_map_field"),
expr("struct(1, 2)").as("struct_field"),
expr("struct(1, struct(1, 2))").as("struct_struct_field"),
expr("struct(1, array(1, 2))").as("struct_array_field"),
expr("map(1, 2)").as("map_field"),
expr("map(1, map(3,4))").as("map_map_field"),
expr("map(1, array(1, 2))").as("map_arr_field"),
expr("map(struct(1, 2), 2)").as("map_struct_field"))
.coalesce(1)
.write
.format("parquet")
.mode("overwrite")
.parquet(lPath)

val rfile = Files.createTempFile("", ".parquet").toFile
rfile.deleteOnExit()
rPath = rfile.getAbsolutePath
spark.range(2).select(col("id"), expr("id % 2").as("kind"),
expr("id % 2").as("key"),
expr("array(1, 2)").as("arr_field"),
expr("struct(1, 2)").as("struct_field"))
.coalesce(1)
.write
.format("parquet")
.mode("overwrite")
.parquet(rPath)

spark.catalog.createTable("ltab", lPath, "arrow")
spark.catalog.createTable("rtab", rPath, "arrow")
}

test("Test Array in Sort") {
// spark.sql("SELECT * FROM ltab").printSchema()
val df = spark.sql("SELECT ltab.arr_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count == 2)
}

test("Test Nest Array in Sort") {
val df = spark.sql("SELECT ltab.arr_arr_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count == 2)
}

test("Test Nest Array in multi-keys Sort") {
val df = spark.sql("SELECT ltab.arr_arr_field FROM ltab order by ltab.kind, ltab.key")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count == 2)
}

test("Test Struct in Sort") {
val df = spark.sql("SELECT ltab.struct_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Nest Struct in Sort") {
val df = spark.sql("SELECT ltab.struct_struct_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Struct_Array in Sort") {
val df = spark.sql("SELECT ltab.struct_array_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map in Sort") {
val df = spark.sql("SELECT ltab.map_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Nest Map in Sort") {
val df = spark.sql("SELECT ltab.map_map_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map_Array in Sort") {
val df = spark.sql("SELECT ltab.map_arr_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map_Struct in Sort") {
val df = spark.sql("SELECT ltab.map_struct_field FROM ltab order by ltab.kind")
df.explain(false)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarSortExec]).isDefined)
assert(df.count() == 2)
}

override def afterAll(): Unit = {
super.afterAll()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
lPath = lfile.getAbsolutePath
spark.range(2).select(col("id"), expr("1").as("kind"),
expr("array(1, 2)").as("arr_field"),
expr("array(\"hello\", \"world\")").as("arr_str_field"),
expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"),
expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"),
expr("array(map(1, 2), map(3,4))").as("arr_map_field"),
Expand Down Expand Up @@ -202,6 +203,15 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined)
}

test("Test Array String in Shuffle split") {
val df = spark.sql("SELECT ltab.arr_str_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.printSchema()
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count == 2)
}

override def afterAll(): Unit = {
super.afterAll()
}
Expand Down
Loading