Skip to content

Commit

Permalink
[SPARK-15214][SQL] Code-generation for Generate
Browse files Browse the repository at this point in the history
This is a backport of apache@7ca7a63.

## What changes were proposed in this pull request?

This PR adds code generation to `Generate`. It supports two code paths:
- General `TraversableOnce` based iteration. This used for regular `Generator` (code generation supporting) expressions. This code path expects the expression to return a `TraversableOnce[InternalRow]` and it will iterate over the returned collection. This PR adds code generation for the `stack` generator.
- Specialized `ArrayData/MapData` based iteration. This is used for the `explode`, `posexplode` & `inline` functions and operates directly on the `ArrayData`/`MapData` result that the child of the generator returns.

### Benchmarks
I have added some benchmarks and it seems we can create a nice speedup for explode:
#### Environment
```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
Intel(R) Core(TM) i7-4980HQ CPU  2.80GHz
```
#### Explode Array
##### Before
```
generate explode array:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode array wholestage off         7377 / 7607          2.3         439.7       1.0X
generate explode array wholestage on          6055 / 6086          2.8         360.9       1.2X
```
##### After
```
generate explode array:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode array wholestage off         7432 / 7696          2.3         443.0       1.0X
generate explode array wholestage on           631 /  646         26.6          37.6      11.8X
```
#### Explode Map
##### Before
```
generate explode map:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode map wholestage off         12792 / 12848          1.3         762.5       1.0X
generate explode map wholestage on          11181 / 11237          1.5         666.5       1.1X
```
##### After
```
generate explode map:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate explode map wholestage off         10949 / 10972          1.5         652.6       1.0X
generate explode map wholestage on             870 /  913         19.3          51.9      12.6X
```
#### Posexplode
##### Before
```
generate posexplode array:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate posexplode array wholestage off      7547 / 7580          2.2         449.8       1.0X
generate posexplode array wholestage on       5786 / 5838          2.9         344.9       1.3X
```
##### After
```
generate posexplode array:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate posexplode array wholestage off      7535 / 7548          2.2         449.1       1.0X
generate posexplode array wholestage on        620 /  624         27.1          37.0      12.1X
```
#### Inline
##### Before
```
generate inline array:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate inline array wholestage off          6935 / 6978          2.4         413.3       1.0X
generate inline array wholestage on           6360 / 6400          2.6         379.1       1.1X
```
##### After
```
generate inline array:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate inline array wholestage off          6940 / 6966          2.4         413.6       1.0X
generate inline array wholestage on           1002 / 1012         16.7          59.7       6.9X
```
#### Stack
##### Before
```
generate stack:                          Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate stack wholestage off               12980 / 13104          1.3         773.7       1.0X
generate stack wholestage on                11566 / 11580          1.5         689.4       1.1X
```
##### After
```
generate stack:                          Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
generate stack wholestage off               12875 / 12949          1.3         767.4       1.0X
generate stack wholestage on                   840 /  845         20.0          50.0      15.3X
```
## How was this patch tested?

Existing tests.

Author: Herman van Hovell <[email protected]>
Author: Kousuke Saruta <[email protected]>

Closes apache#230 from hvanhovell/SPARK-15214.
  • Loading branch information
hvanhovell committed Feb 17, 2017
1 parent 2a5eab5 commit ed284c0
Show file tree
Hide file tree
Showing 7 changed files with 485 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -60,6 +62,26 @@ trait Generator extends Expression {
* rows can be made here.
*/
def terminate(): TraversableOnce[InternalRow] = Nil

/**
* Check if this generator supports code generation.
*/
def supportCodegen: Boolean = !isInstanceOf[CodegenFallback]
}

/**
* A collection producing [[Generator]]. This trait provides a different path for code generation,
* by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object.
*/
trait CollectionGenerator extends Generator {
/** The position of an element within the collection should also be returned. */
def position: Boolean

/** Rows will be inlined during generation. */
def inline: Boolean

/** The type of the returned collection object. */
def collectionType: DataType = dataType
}

/**
Expand All @@ -77,7 +99,9 @@ case class UserDefinedGenerator(
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
val inputSchema = StructType(children.map { e =>
StructField(e.simpleString, e.dataType, nullable = true)
})
CatalystTypeConverters.createToScalaConverter(inputSchema)
}.asInstanceOf[InternalRow => Row]
}
Expand Down Expand Up @@ -109,8 +133,7 @@ case class UserDefinedGenerator(
1 2
3 NULL
""")
case class Stack(children: Seq[Expression])
extends Expression with Generator with CodegenFallback {
case class Stack(children: Seq[Expression]) extends Generator {

private lazy val numRows = children.head.eval().asInstanceOf[Int]
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
Expand Down Expand Up @@ -149,29 +172,58 @@ case class Stack(children: Seq[Expression])
InternalRow(fields: _*)
}
}


/**
* Only support code generation when stack produces 50 rows or less.
*/
override def supportCodegen: Boolean = numRows <= 50

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Rows - we write these into an array.
val rowData = ctx.freshName("rows")
ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
val values = children.tail
val dataTypes = values.take(numFields).map(_.dataType)
val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
val fields = Seq.tabulate(numFields) { col =>
val index = row * numFields + col
if (index < values.length) values(index) else Literal(null, dataTypes(col))
}
val eval = CreateStruct(fields).genCode(ctx)
s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
})

// Create the collection.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
ctx.addMutableState(
s"$wrapperClass<InternalRow>",
ev.value,
s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
ev.copy(code = code, isNull = "false")
}
}

/**
* A base class for Explode and PosExplode
* A base class for [[Explode]] and [[PosExplode]].
*/
abstract class ExplodeBase(child: Expression, position: Boolean)
extends UnaryExpression with Generator with CodegenFallback with Serializable {
abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable {
override val inline: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case _: ArrayType | _: MapType =>
TypeCheckResult.TypeCheckSuccess
} else {
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function explode should be array or map type, not ${child.dataType}")
}
}

// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
override def elementSchema: StructType = child.dataType match {
case ArrayType(et, containsNull) =>
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("pos", IntegerType, nullable = false)
.add("col", et, containsNull)
} else {
new StructType()
Expand All @@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
case MapType(kt, vt, valueContainsNull) =>
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("key", kt, false)
.add("pos", IntegerType, nullable = false)
.add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
} else {
new StructType()
.add("key", kt, false)
.add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
}
}
Expand Down Expand Up @@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
}
}
}

override def collectionType: DataType = child.dataType

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx)
}
}

/**
Expand All @@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
20
""")
// scalastyle:on line.size.limit
case class Explode(child: Expression) extends ExplodeBase(child, position = false)
case class Explode(child: Expression) extends ExplodeBase {
override val position: Boolean = false
}

/**
* Given an input array produces a sequence of rows for each position and value in the array.
Expand All @@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
1 20
""")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
case class PosExplode(child: Expression) extends ExplodeBase {
override val position = true
}

/**
* Explodes an array of structs into a table.
Expand All @@ -273,20 +335,24 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t
1 a
2 b
""")
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator {
override val inline: Boolean = true
override val position: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
case ArrayType(st: StructType, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should be array of struct type, not ${child.dataType}")
}

override def elementSchema: StructType = child.dataType match {
case ArrayType(et : StructType, _) => et
case ArrayType(st: StructType, _) => st
}

override def collectionType: DataType = child.dataType

private lazy val numFields = elementSchema.fields.length

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
Expand All @@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
yield inputArray.getStruct(i, numFields)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.{DataType, IntegerType}

class SubexpressionEliminationSuite extends SparkFunSuite {
test("Semantic equals and hash") {
Expand Down Expand Up @@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
test("Children of CodegenFallback") {
val one = Literal(1)
val two = Add(one, one)
val explode = Explode(two)
val add = Add(two, explode)
val fallback = CodegenFallbackExpression(two)
val add = Add(two, fallback)

var equivalence = new EquivalentExpressions
val equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
// the `two` inside `explode` should not be added
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}
}

case class CodegenFallbackExpression(child: Expression)
extends UnaryExpression with CodegenFallback {
override def dataType: DataType = child.dataType
}
Loading

0 comments on commit ed284c0

Please sign in to comment.