Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
Support AQE for ColumnarWriter (#526)
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi authored Sep 29, 2021
1 parent 5aea527 commit 4a0dfbe
Showing 1 changed file with 51 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ package com.intel.oap.spark.sql
import com.intel.oap.spark.sql.ArrowWriteExtension.ArrowWritePostRule
import com.intel.oap.spark.sql.ArrowWriteExtension.DummyRule
import com.intel.oap.spark.sql.ArrowWriteExtension.SimpleColumnarRule
import com.intel.oap.spark.sql.ArrowWriteExtension.SimpleStrategy
import com.intel.oap.spark.sql.execution.datasources.arrow.ArrowFileFormat
import com.intel.oap.sql.execution.RowToArrowColumnarExec

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.OrderPreservingUnaryNode

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.MapData
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.CodegenSupport
import org.apache.spark.sql.execution.ColumnarRule
import org.apache.spark.sql.execution.ColumnarToRowExec
Expand All @@ -48,6 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String
class ArrowWriteExtension extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectColumnar(session => SimpleColumnarRule(DummyRule, ArrowWritePostRule(session)))
e.injectPlannerStrategy(session => SimpleStrategy())
}
}

Expand All @@ -68,7 +72,7 @@ object ArrowWriteExtension {
cmd match {
case command: InsertIntoHadoopFsRelationCommand =>
if (command.fileFormat
.isInstanceOf[ArrowFileFormat]) {
.isInstanceOf[ArrowFileFormat]) {
rc.withNewChildren(Array(ColumnarToFakeRowAdaptor(child)))
} else {
plan.withNewChildren(plan.children.map(apply))
Expand All @@ -79,8 +83,20 @@ object ArrowWriteExtension {
cmd match {
case command: InsertIntoHadoopFsRelationCommand =>
if (command.fileFormat
.isInstanceOf[ArrowFileFormat]) {
rc.withNewChildren(Array(ColumnarToFakeRowAdaptor(RowToArrowColumnarExec(child))))
.isInstanceOf[ArrowFileFormat]) {
child match {
case c: AdaptiveSparkPlanExec =>
rc.withNewChildren(
Array(
AdaptiveSparkPlanExec(
ColumnarToFakeRowAdaptor(c.inputPlan),
c.context,
c.preprocessingRules,
c.isSubquery)))
case other =>
rc.withNewChildren(
Array(ColumnarToFakeRowAdaptor(RowToArrowColumnarExec(child))))
}
} else {
plan.withNewChildren(plan.children.map(apply))
}
Expand All @@ -90,18 +106,6 @@ object ArrowWriteExtension {
}
}

private case class ColumnarToFakeRowAdaptor(child: SparkPlan) extends ColumnarToRowTransition {
assert(child.supportsColumnar)

override protected def doExecute(): RDD[InternalRow] = {
child.executeColumnar().map { cb =>
new FakeRow(cb)
}
}

override def output: Seq[Attribute] = child.output
}

class FakeRow(val batch: ColumnarBatch) extends InternalRow {
override def numFields: Int = throw new UnsupportedOperationException()
override def setNullAt(i: Int): Unit = throw new UnsupportedOperationException()
Expand All @@ -117,7 +121,8 @@ object ArrowWriteExtension {
override def getDouble(ordinal: Int): Double = throw new UnsupportedOperationException()
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
throw new UnsupportedOperationException()
override def getUTF8String(ordinal: Int): UTF8String = throw new UnsupportedOperationException()
override def getUTF8String(ordinal: Int): UTF8String =
throw new UnsupportedOperationException()
override def getBinary(ordinal: Int): Array[Byte] = throw new UnsupportedOperationException()
override def getInterval(ordinal: Int): CalendarInterval =
throw new UnsupportedOperationException()
Expand All @@ -128,4 +133,31 @@ object ArrowWriteExtension {
override def get(ordinal: Int, dataType: DataType): AnyRef =
throw new UnsupportedOperationException()
}

private case class ColumnarToFakeRowLogicAdaptor(child: LogicalPlan)
extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
}

private case class ColumnarToFakeRowAdaptor(child: SparkPlan) extends ColumnarToRowTransition {
if (!child.logicalLink.isEmpty) {
setLogicalLink(ColumnarToFakeRowLogicAdaptor(child.logicalLink.get))
}

override protected def doExecute(): RDD[InternalRow] = {
child.executeColumnar().map { cb => new FakeRow(cb) }
}

override def output: Seq[Attribute] = child.output
}

case class SimpleStrategy() extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ColumnarToFakeRowLogicAdaptor(child: LogicalPlan) =>
Seq(ColumnarToFakeRowAdaptor(planLater(child)))
case other =>
Nil
}
}

}

0 comments on commit 4a0dfbe

Please sign in to comment.