Skip to content

Commit

Permalink
Use planner for in-memory scans.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Jun 13, 2014
1 parent 1c04652 commit 8757c8e
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 35 deletions.
14 changes: 7 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.columnar.InMemoryRelation

import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.SparkStrategies
Expand Down Expand Up @@ -166,22 +166,21 @@ class SQLContext(@transient val sparkContext: SparkContext)
val useCompression =
sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false)
val asInMemoryRelation =
InMemoryColumnarTableScan(
currentTable.output, executePlan(currentTable).executedPlan, useCompression)
InMemoryRelation(useCompression, executePlan(currentTable).executedPlan)

catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation))
catalog.registerTable(None, tableName, asInMemoryRelation)
}

/** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = {
EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match {
// This is kind of a hack to make sure that if this was just an RDD registered as a table,
// we reregister the RDD as a table.
case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd, _)) =>
case inMem @ InMemoryRelation(_, _, e: ExistingRdd) =>
inMem.cachedColumnBuffers.unpersist()
catalog.unregisterTable(None, tableName)
catalog.registerTable(None, tableName, SparkLogicalPlan(e))
case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) =>
case inMem: InMemoryRelation =>
inMem.cachedColumnBuffers.unpersist()
catalog.unregisterTable(None, tableName)
case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan")
Expand All @@ -192,7 +191,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
def isCached(tableName: String): Boolean = {
val relation = catalog.lookupRelation(None, tableName)
EliminateAnalysisOperators(relation) match {
case SparkLogicalPlan(_: InMemoryColumnarTableScan) => true
case _: InMemoryRelation => true
case _ => false
}
}
Expand All @@ -208,6 +207,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
PartialAggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
ParquetOperations ::
BasicOperators ::
CartesianProduct ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,29 @@

package org.apache.spark.sql.columnar

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, LeafNode}
import org.apache.spark.sql.Row
import org.apache.spark.SparkConf

private[sql] case class InMemoryColumnarTableScan(
attributes: Seq[Attribute],
child: SparkPlan,
useCompression: Boolean)
extends LeafNode {
object InMemoryRelation {
def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation =
new InMemoryRelation(child.output, useCompression, child)
}

override def output: Seq[Attribute] = attributes
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
useCompression: Boolean,
child: SparkPlan)
extends LogicalPlan with MultiInstanceRelation {

override def children = Seq.empty
override def references = Set.empty

override def newInstance() =
new InMemoryRelation(output.map(_.newInstance), useCompression, child).asInstanceOf[this.type]

lazy val cachedColumnBuffers = {
val output = child.output
Expand All @@ -55,14 +66,26 @@ private[sql] case class InMemoryColumnarTableScan(
cached.count()
cached
}
}

private[sql] case class InMemoryColumnarTableScan(
attributes: Seq[Attribute],
relation: InMemoryRelation)
extends LeafNode {

override def output: Seq[Attribute] = attributes

override def execute() = {
cachedColumnBuffers.mapPartitions { iterator =>
relation.cachedColumnBuffers.mapPartitions { iterator =>
val columnBuffers = iterator.next()
assert(!iterator.hasNext)

new Iterator[Row] {
val columnAccessors = columnBuffers.map(ColumnAccessor(_))
// Find the ordinals of the requested columns. If none are requested, use the first.
val requestedColumns =
if (attributes.isEmpty) Seq(0) else attributes.map(relation.output.indexOf(_))

val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_))
val nextRow = new GenericMutableRow(columnAccessors.length)

override def next() = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)
SparkLogicalPlan(
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
case scan @ InMemoryColumnarTableScan(output, _, _) =>
scan.copy(attributes = output.map(_.newInstance))
case _ => sys.error("Multiple instance of the same relation detected.")
}).asInstanceOf[this.type]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.parquet._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}

private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
Expand Down Expand Up @@ -191,6 +192,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

object InMemoryScans extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
pruneFilterProject(
projectList,
filters,
identity[Seq[Expression]], // No filters are pushed down.
InMemoryColumnarTableScan(_, mem)) :: Nil
case _ => Nil
}
}

// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def numPartitions = self.numPartitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
package org.apache.spark.sql

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.test.TestSQLContext

class CachedTableSuite extends QueryTest {
Expand All @@ -34,7 +33,7 @@ class CachedTableSuite extends QueryTest {
)

TestSQLContext.table("testData").queryExecution.analyzed match {
case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
case _ : InMemoryRelation => // Found evidence of caching
case noCache => fail(s"No cache node found in plan $noCache")
}

Expand All @@ -46,7 +45,7 @@ class CachedTableSuite extends QueryTest {
)

TestSQLContext.table("testData").queryExecution.analyzed match {
case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
case cachePlan: InMemoryRelation =>
fail(s"Table still cached after uncache: $cachePlan")
case noCache => // Table uncached successfully
}
Expand All @@ -61,13 +60,17 @@ class CachedTableSuite extends QueryTest {
test("SELECT Star Cached Table") {
TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar")
TestSQLContext.cacheTable("selectStar")
TestSQLContext.sql("SELECT * FROM selectStar")
TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect()
TestSQLContext.uncacheTable("selectStar")
}

test("Self-join cached") {
val unCachedAnswer =
TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
TestSQLContext.cacheTable("testData")
TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key")
checkAnswer(
TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
unCachedAnswer.toSeq)
TestSQLContext.uncacheTable("testData")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest {

test("simple columnar query") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true))
val scan = InMemoryRelation(useCompression = true, plan)

checkAnswer(scan, testData.collect().toSeq)
}

test("projection") {
val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true))
val scan = InMemoryRelation(useCompression = true, plan)

checkAnswer(scan, testData.collect().map {
case Row(key: Int, value: String) => value -> key
Expand All @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {

test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true))
val scan = InMemoryRelation(useCompression = true, plan)

checkAnswer(scan, testData.collect().toSeq)
checkAnswer(scan, testData.collect().toSeq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
CommandStrategy(self),
TakeOrdered,
ParquetOperations,
InMemoryScans,
HiveTableScans,
DataSinks,
Scripts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.hive.execution.{HiveTableScan, InsertIntoHiveTable}
import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}

/* Implicit conversions */
import scala.collection.JavaConversions._
Expand Down Expand Up @@ -130,8 +130,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
castChildOutput(p, table, child)

case p @ logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan(
_, HiveTableScan(_, table, _), _)), _, child, _) =>
case p @ logical.InsertIntoTable(
InMemoryRelation(_, _,
HiveTableScan(_, table, _)), _, child, _) =>
castChildOutput(p, table, child)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.columnar.InMemoryRelation

private[hive] trait HiveStrategies {
// Possibly being too clever with types here... or not clever enough.
Expand All @@ -44,8 +44,9 @@ private[hive] trait HiveStrategies {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
case logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan(
_, HiveTableScan(_, table, _), _)), partition, child, overwrite) =>
case logical.InsertIntoTable(
InMemoryRelation(_, _,
HiveTableScan(_, table, _)), partition, child, overwrite) =>
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
case _ => Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.hive

import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.hive.execution.HiveComparisonTest
import org.apache.spark.sql.hive.test.TestHive

Expand All @@ -34,7 +34,7 @@ class CachedTableSuite extends HiveComparisonTest {

test("check that table is cached and uncache") {
TestHive.table("src").queryExecution.analyzed match {
case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
case _ : InMemoryRelation => // Found evidence of caching
case noCache => fail(s"No cache node found in plan $noCache")
}
TestHive.uncacheTable("src")
Expand All @@ -45,7 +45,7 @@ class CachedTableSuite extends HiveComparisonTest {

test("make sure table is uncached") {
TestHive.table("src").queryExecution.analyzed match {
case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
case cachePlan: InMemoryRelation =>
fail(s"Table still cached after uncache: $cachePlan")
case noCache => // Table uncached successfully
}
Expand Down

0 comments on commit 8757c8e

Please sign in to comment.