Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21657][SQL] optimize explode quadratic memory consumpation #19683

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ce7c369
[SPARK-21657][SQL] optimize explode quadratic memory consumpation
uzadude Nov 7, 2017
76aa258
fixed "File line length exceeds 100 characters"
uzadude Nov 7, 2017
7a9bc96
fixed "File line length exceeds 100 characters"
uzadude Nov 7, 2017
a3050e9
Fixed ClassCastException
uzadude Nov 7, 2017
b8b5960
fixed ColumnPruningSuite join plan test
uzadude Nov 8, 2017
b825d6b
changed indentation
uzadude Nov 19, 2017
04b5814
Merge branch 'master' into optimize_explode
uzadude Dec 7, 2017
7cb9454
nit change
uzadude Dec 7, 2017
ccc78e9
changes according to @cloud-fan's comments
uzadude Dec 22, 2017
8ef78af
Merge branch 'optimize_explode' of https://github.com/uzadude/spark i…
uzadude Dec 22, 2017
272a059
a working version
uzadude Dec 22, 2017
f9e69a4
join+omitGeneratorReferences -> requiredChildOutput
uzadude Dec 24, 2017
42aa32d
style fix
uzadude Dec 24, 2017
93816b6
fix "No space after token ="
uzadude Dec 24, 2017
6caa0d5
fixing test
uzadude Dec 25, 2017
11867e2
changed requiredChildOutput to unrequiredChildOutput for easier initi…
uzadude Dec 26, 2017
e00ecaf
fixed "There should at least one a single empty line separating group…
uzadude Dec 26, 2017
c3183d0
fixed tests
uzadude Dec 27, 2017
b6b8694
Fixes after last review by @cloud-fan
uzadude Dec 27, 2017
09e6d05
nothing - just to re-run the tests.
uzadude Dec 27, 2017
227c7af
more review changes - mainly consolidating the logic in the Optimizer
uzadude Dec 27, 2017
17db21e
more review fixes, mainly allowing more Optimizer cases after p @ Pro…
uzadude Dec 28, 2017
9edd864
scala style check fix
uzadude Dec 28, 2017
6c07d2b
Merge commit '28778174208664327b75915e83ae5e611360eef3' into optimize…
uzadude Dec 28, 2017
f68bd2d
fix after merge
uzadude Dec 28, 2017
f92ec11
fix resolve check
uzadude Dec 28, 2017
288aa73
unrequiredChildIndex: Seq[Attribute] -> unrequiredChildIndex: Seq[Int]
uzadude Dec 29, 2017
283340f
fixed indentation
uzadude Dec 29, 2017
8f06dda
last review changes.
uzadude Dec 29, 2017
1c6626a
small fixes
uzadude Dec 29, 2017
4edd884
fixed comments by @viirya
uzadude Dec 29, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ class Analyzer(
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))

case oldVersion: Generate
if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? AttributeSet.intersect is special.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, the implementation was identical in class Generate:
def generatedSet: AttributeSet = AttributeSet(generatorOutput)
override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see

val newOutput = oldVersion.generatorOutput.map(_.newInstance())
(oldVersion, oldVersion.copy(generatorOutput = newOutput))

Expand Down Expand Up @@ -1138,7 +1138,7 @@ class Analyzer(
case g: Generate =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, g))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child)
(newExprs, g.copy(join = true, child = newChild))
(newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild))

// For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes
// via its children.
Expand Down Expand Up @@ -1578,7 +1578,7 @@ class Analyzer(
resolvedGenerator =
Generate(
generator,
join = projectList.size > 1, // Only join if there are other expressions in SELECT.
unrequiredChildIndex = Nil,
outer = outer,
qualifier = None,
generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ trait CheckAnalysis extends PredicateHelper {
// allows to have correlation under it
// but must not host any outer references.
// Note:
// Generator with join=false is treated as Category 4.
case g: Generate if g.join =>
// Generator with requiredChildOutput.isEmpty is treated as Category 4.
case g: Generate if g.requiredChildOutput.nonEmpty =>
failOnInvalidOuterReference(g)

// Category 4: Any other operators not in the above 3 categories
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ package object dsl {

def generate(
generator: Generator,
join: Boolean = false,
unrequiredChildIndex: Seq[Int] = Nil,
outer: Boolean = false,
alias: Option[String] = None,
outputNames: Seq[String] = Nil): LogicalPlan =
Generate(generator, join = join, outer = outer, alias,
outputNames.map(UnresolvedAttribute(_)), logicalPlan)
Generate(generator, unrequiredChildIndex, outer,
alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan)

def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,15 @@ object ColumnPruning extends Rule[LogicalPlan] {
f.copy(child = prunedChild(child, f.references))
case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
e.copy(child = prunedChild(child, e.references))
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
g.copy(child = prunedChild(g.child, g.references))

// Turn off `join` for Generate if no column from it's child is used
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
// prune unrequired references
case p @ Project(_, g: Generate) if p.references != g.outputSet =>
val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references
val newChild = prunedChild(g.child, requiredAttrs)
val unrequired = g.generator.references -- p.references
val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1))
.map(_._2)
p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices))

// Eliminate unneeded attributes from right side of a Left Existence Join.
case j @ Join(_, right, LeftExistence(_), _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val expressions = expressionList(ctx.expression)
Generate(
UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
join = true,
unrequiredChildIndex = Nil,
outer = ctx.OUTER != null,
Some(ctx.tblName.getText.toLowerCase),
ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* their output.
*
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer.
* It's used as an optimization for omitting data generation that will
* be discarded next by a projection.
* A common use case is when we explode(array(..)) and are interested
* only in the exploded data and not in the original array. before this
* optimization the array got duplicated for each of its elements,
* causing O(n^^2) memory consumption. (see [SPARK-21657])
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty.
* @param qualifier Qualifier for the attributes of generator(UDTF)
Expand All @@ -83,15 +88,17 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
*/
case class Generate(
generator: Generator,
join: Boolean,
unrequiredChildIndex: Seq[Int],
outer: Boolean,
qualifier: Option[String],
generatorOutput: Seq[Attribute],
child: LogicalPlan)
extends UnaryNode {

/** The set of all attributes produced by this node. */
def generatedSet: AttributeSet = AttributeSet(generatorOutput)
lazy val requiredChildOutput: Seq[Attribute] = {
val unrequiredSet = unrequiredChildIndex.toSet
child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1)
}

override lazy val resolved: Boolean = {
generator.resolved &&
Expand All @@ -114,9 +121,7 @@ case class Generate(
nullableOutput
}

def output: Seq[Attribute] = {
if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput
}
def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput
}

case class Filter(condition: Expression, child: LogicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,54 +38,64 @@ class ColumnPruningSuite extends PlanTest {
CollapseProject) :: Nil
}

test("Column pruning for Generate when Generate.join = false") {
val input = LocalRelation('a.int, 'b.array(StringType))
test("Column pruning for Generate when Generate.unrequiredChildIndex = child.output") {
val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))

val query = input.generate(Explode('b), join = false).analyze
val query =
input
.generate(Explode('c), outputNames = "explode" :: Nil)
.select('c, 'explode)
.analyze

val optimized = Optimize.execute(query)

val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze
val correctAnswer =
input
.select('c)
.generate(Explode('c), outputNames = "explode" :: Nil)
.analyze

comparePlans(optimized, correctAnswer)
}

test("Column pruning for Generate when Generate.join = true") {
val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
test("Fill Generate.unrequiredChildIndex if possible") {
val input = LocalRelation('b.array(StringType))

val query =
input
.generate(Explode('c), join = true, outputNames = "explode" :: Nil)
.select('a, 'explode)
.generate(Explode('b), outputNames = "explode" :: Nil)
.select(('explode + 1).as("result"))
.analyze

val optimized = Optimize.execute(query)

val correctAnswer =
input
.select('a, 'c)
.generate(Explode('c), join = true, outputNames = "explode" :: Nil)
.select('a, 'explode)
.generate(Explode('b), unrequiredChildIndex = input.output.zipWithIndex.map(_._2),
outputNames = "explode" :: Nil)
.select(('explode + 1).as("result"))
.analyze

comparePlans(optimized, correctAnswer)
}

test("Turn Generate.join to false if possible") {
val input = LocalRelation('b.array(StringType))
test("Another fill Generate.unrequiredChildIndex if possible") {
val input = LocalRelation('a.int, 'b.int, 'c1.string, 'c2.string)

val query =
input
.generate(Explode('b), join = true, outputNames = "explode" :: Nil)
.select(('explode + 1).as("result"))
.generate(Explode(CreateArray(Seq('c1, 'c2))), outputNames = "explode" :: Nil)
.select('a, 'c1, 'explode)
.analyze

val optimized = Optimize.execute(query)

val correctAnswer =
input
.generate(Explode('b), join = false, outputNames = "explode" :: Nil)
.select(('explode + 1).as("result"))
.select('a, 'c1, 'c2)
.generate(Explode(CreateArray(Seq('c1, 'c2))),
unrequiredChildIndex = Seq(2),
outputNames = "explode" :: Nil)
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,14 @@ class FilterPushdownSuite extends PlanTest {
test("generate: predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode('c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), alias = Some("arr"))
.where(('b >= 5) && ('a > 6))
}
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where(('b >= 5) && ('a > 6))
.generate(Explode('c_arr), true, false, Some("arr")).analyze
.generate(Explode('c_arr), alias = Some("arr")).analyze
}

comparePlans(optimized, correctAnswer)
Expand All @@ -640,14 +640,14 @@ class FilterPushdownSuite extends PlanTest {
test("generate: non-deterministic predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode('c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), alias = Some("arr"))
.where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6))
}
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where('b >= 5)
.generate(Explode('c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), alias = Some("arr"))
.where('a + Rand(10).as("rnd") > 6 && 'col > 6)
.analyze
}
Expand All @@ -659,14 +659,14 @@ class FilterPushdownSuite extends PlanTest {
val generator = Explode('c_arr)
val originalQuery = {
testRelationWithArrayType
.generate(generator, true, false, Some("arr"))
.generate(generator, alias = Some("arr"))
.where(('b >= 5) && ('c > 6))
}
val optimized = Optimize.execute(originalQuery.analyze)
val referenceResult = {
testRelationWithArrayType
.where('b >= 5)
.generate(generator, true, false, Some("arr"))
.generate(generator, alias = Some("arr"))
.where('c > 6).analyze
}

Expand All @@ -687,7 +687,7 @@ class FilterPushdownSuite extends PlanTest {
test("generate: all conjuncts referenced generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode('c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), alias = Some("arr"))
.where(('col > 6) || ('b > 5)).analyze
}
val optimized = Optimize.execute(originalQuery)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select * from t lateral view explode(x) expl as x",
table("t")
.generate(explode, join = true, outer = false, Some("expl"), Seq("x"))
.generate(explode, alias = Some("expl"), outputNames = Seq("x"))
.select(star()))

// Multiple lateral views
Expand All @@ -286,12 +286,12 @@ class PlanParserSuite extends AnalysisTest {
|lateral view explode(x) expl
|lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
table("t")
.generate(explode, join = true, outer = false, Some("expl"), Seq.empty)
.generate(jsonTuple, join = true, outer = true, Some("jtup"), Seq("q", "z"))
.generate(explode, alias = Some("expl"))
.generate(jsonTuple, outer = true, alias = Some("jtup"), outputNames = Seq("q", "z"))
.select(star()))

// Multi-Insert lateral views.
val from = table("t1").generate(explode, join = true, outer = false, Some("expl"), Seq("x"))
val from = table("t1").generate(explode, alias = Some("expl"), outputNames = Seq("x"))
assertEqual(
"""from t1
|lateral view explode(x) expl as x
Expand All @@ -303,7 +303,7 @@ class PlanParserSuite extends AnalysisTest {
|where s < 10
""".stripMargin,
Union(from
.generate(jsonTuple, join = true, outer = false, Some("jtup"), Seq("q", "z"))
.generate(jsonTuple, alias = Some("jtup"), outputNames = Seq("q", "z"))
.select(star())
.insertInto("t2"),
from.where('s < 10).select(star()).insertInto("t3")))
Expand All @@ -312,10 +312,8 @@ class PlanParserSuite extends AnalysisTest {
val expected = table("t")
.generate(
UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)),
join = true,
outer = false,
Some("posexpl"),
Seq("x", "y"))
alias = Some("posexpl"),
outputNames = Seq("x", "y"))
.select(star())
assertEqual(
"select * from t lateral view posexplode(x) posexpl as x, y",
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,7 @@ class Dataset[T] private[sql](
val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr))

withPlan {
Generate(generator, join = true, outer = false,
Generate(generator, unrequiredChildIndex = Nil, outer = false,
qualifier = None, generatorOutput = Nil, planWithBarrier)
}
}
Expand Down Expand Up @@ -2136,7 +2136,7 @@ class Dataset[T] private[sql](
val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil)

withPlan {
Generate(generator, join = true, outer = false,
Generate(generator, unrequiredChildIndex = Nil, outer = false,
qualifier = None, generatorOutput = Nil, planWithBarrier)
}
}
Expand Down
Loading