diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarGuardRule.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarGuardRule.scala index cda4b8731..39bd2465c 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarGuardRule.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarGuardRule.scala @@ -73,10 +73,7 @@ case class ColumnarGuardRule(conf: SparkConf) extends Rule[SparkPlan] { } plan case plan: InMemoryTableScanExec => - if (plan.supportsColumnar) { - return false - } - plan + new ColumnarInMemoryTableScanExec(plan.attributes, plan.predicates, plan.relation) case plan: ProjectExec => if(!enableColumnarProjFilter) return false new ColumnarConditionProjectExec(null, plan.projectList, plan.child) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPlugin.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPlugin.scala index 46ec91e19..34bcd55e2 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPlugin.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPlugin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.joins._ @@ -60,6 +61,9 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] { case plan: BatchScanExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") new ColumnarBatchScanExec(plan.output, plan.scan) + case plan: InMemoryTableScanExec => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarInMemoryTableScanExec(plan.attributes, plan.predicates, plan.relation) case plan: ProjectExec => val columnarChild = replaceWithColumnarPlan(plan.child) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarInMemoryRelation.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarInMemoryRelation.scala new file mode 100644 index 000000000..d4e375e70 --- /dev/null +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarInMemoryRelation.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.oap.execution + +import java.io._ +import org.apache.commons.lang3.StringUtils + +import com.intel.oap.expression._ +import com.intel.oap.vectorized.ArrowWritableColumnVector +import com.intel.oap.vectorized.CloseableColumnBatchIterator +import org.apache.arrow.memory.ArrowBuf +import org.apache.spark.TaskContext +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.columnar.{ + CachedBatch, + CachedBatchSerializer, + SimpleMetricsCachedBatch, + SimpleMetricsCachedBatchSerializer +} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.vectorized.{WritableColumnVector} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{LongAccumulator, Utils} +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import sun.misc.Cleaner + +private class Deallocator(var arrowColumnarBatch: Array[ColumnarBatch]) extends Runnable { + + override def run(): Unit = { + try { + Option(arrowColumnarBatch) match { + case Some(buffer) => + //System.out.println(s"ArrowCachedBatch released in DeAllocator, First buffer name is ${buffer(0)}") + buffer.foreach(_.close) + case other => + } + } catch { + case e: Exception => + // We should suppress all possible errors in Cleaner to prevent JVM from being shut down + //System.err.println("ArrowCachedBatch-Deallocator: Error running deallocator") + e.printStackTrace() + } + } +} + +/** + * The default implementation of CachedBatch. + * + * @param numRows The total number of rows in this batch + * @param buffers The buffers for serialized columns + * @param stats The stat of columns + */ +case class ArrowCachedBatch( + var numRows: Int, + var buffer: Array[ColumnarBatch], + stats: InternalRow) + extends SimpleMetricsCachedBatch + with Externalizable { + if (buffer != null) { + //System.out.println(s"ArrowCachedBatch constructed First buffer name is ${buffer(0)}") + Cleaner.create(this, new Deallocator(buffer)) + } + def this() = { + this(0, null, null) + } + def release() = { + //System.out.println(s"ArrowCachedBatch released by clear cache, First buffer name is ${buffer(0)}") + buffer.foreach(_.close) + } + lazy val estimatedSize: Long = { + var size: Long = 0 + buffer.foreach(batch => { + size += ConverterUtils.calcuateEstimatedSize(batch) + }) + //System.out.println(s"ArrowCachedBatch${buffer(0)} estimated size is ${size}") + size + } + override def sizeInBytes: Long = estimatedSize + override def writeExternal(out: ObjectOutput): Unit = { + // System.out.println(s"writeExternal for $this") + val rawArrowData = ConverterUtils.convertToNetty(buffer) + out.writeObject(rawArrowData) + buffer.foreach(_.close) + } + + override def readExternal(in: ObjectInput): Unit = { + numRows = 0 + val rawArrowData = in.readObject().asInstanceOf[Array[Byte]] + buffer = ConverterUtils.convertFromNetty(null, new ByteArrayInputStream(rawArrowData)).toArray + //System.out.println(s"ArrowCachedBatch constructed by deserilizer, First buffer name is ${buffer(0)}") + Cleaner.create(this, new Deallocator(buffer)) + } +} + +/** + * The default implementation of CachedBatchSerializer. + */ +class ArrowColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = true + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + val batchSize = conf.columnBatchSize + val useCompression = conf.useCompression + convertForCacheInternal(input, schema, batchSize, useCompression) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = + throw new IllegalStateException("InternalRow input is not supported") + + def convertForCacheInternal( + input: RDD[ColumnarBatch], + output: Seq[Attribute], + batchSize: Int, + useCompression: Boolean): RDD[CachedBatch] = { + input.mapPartitions { iter => + var processed = false + new Iterator[ArrowCachedBatch] { + def next(): ArrowCachedBatch = { + processed = true + var _numRows: Int = 0 + val _input = new ArrayBuffer[ColumnarBatch]() + while (iter.hasNext) { + val batch = iter.next + if (batch.numRows > 0) { + (0 until batch.numCols).foreach(i => + batch.column(i).asInstanceOf[ArrowWritableColumnVector].retain()) + _numRows += batch.numRows + _input += batch + } + } + // To avoid mem copy, we only save columnVector reference here + val res = ArrowCachedBatch(_numRows, _input.toArray, null) + // System.out.println(s"convertForCacheInternal cachedBatch is ${res}") + res + } + + def hasNext: Boolean = !processed + } + } + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + val columnIndices = + selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray + def createAndDecompressColumn(cachedIter: Iterator[CachedBatch]): Iterator[ColumnarBatch] = { + val res = new Iterator[ColumnarBatch] { + var iter: Iterator[ColumnarBatch] = null + if (cachedIter.hasNext) { + val cachedColumnarBatch: ArrowCachedBatch = + cachedIter.next.asInstanceOf[ArrowCachedBatch] + // System.out.println( + // s"convertCachedBatchToColumnarBatch cachedBatch is ${cachedColumnarBatch}") + val rawData = cachedColumnarBatch.buffer + + iter = new Iterator[ColumnarBatch] { + val numBatches = rawData.size + var batchIdx = 0 + override def hasNext: Boolean = batchIdx < numBatches + override def next(): ColumnarBatch = { + val vectors = columnIndices.map(i => rawData(batchIdx).column(i)) + vectors.foreach(v => v.asInstanceOf[ArrowWritableColumnVector].retain()) + val numRows = rawData(batchIdx).numRows + batchIdx += 1 + new ColumnarBatch(vectors, numRows) + } + } + } + def next(): ColumnarBatch = + if (iter != null) { + iter.next + } else { + val resultStructType = StructType(selectedAttributes.map(a => + StructField(a.name, a.dataType, a.nullable, a.metadata))) + val resultColumnVectors = + ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray + new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0) + } + def hasNext: Boolean = iter.hasNext + } + new CloseableColumnBatchIterator(res) + } + input.mapPartitions(createAndDecompressColumn) + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + // Find the ordinals and data types of the requested columns. + val columnarBatchRdd = + convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) + columnarBatchRdd.mapPartitions { batches => + val toUnsafe = UnsafeProjection.create(selectedAttributes, selectedAttributes) + batches.flatMap { batch => batch.rowIterator().asScala.map(toUnsafe) } + } + } + + override def supportsColumnarOutput(schema: StructType): Boolean = true + + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = + Option(Seq.fill(attributes.length)(classOf[ArrowWritableColumnVector].getName)) + +} diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarInMemoryTableScanExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarInMemoryTableScanExec.scala new file mode 100644 index 000000000..6a3452bba --- /dev/null +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarInMemoryTableScanExec.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.oap.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.columnar.CachedBatch +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class ColumnarInMemoryTableScanExec( + attributes: Seq[Attribute], + predicates: Seq[Expression], + @transient relation: InMemoryRelation) + extends LeafExecNode { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override val nodeName: String = { + relation.cacheBuilder.tableName match { + case Some(_) => + "Scan " + relation.cacheBuilder.cachedName + case _ => + super.nodeName + } + } + + override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + + override def doCanonicalize(): SparkPlan = + copy( + attributes = attributes.map(QueryPlan.normalizeExpressions(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExpressions(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + + override def vectorTypes: Option[Seq[String]] = + relation.cacheBuilder.serializer.vectorTypes(attributes, conf) + + /** + * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. + * If false, get data from UnsafeRow build from CachedBatch + */ + override val supportsColumnar: Boolean = true + + private lazy val columnarInputRDD: RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val buffers = filteredCachedBatches() + relation.cacheBuilder.serializer + .convertCachedBatchToColumnarBatch(buffers, relation.output, attributes, conf) + .map { cb => + numOutputRows += cb.numRows() + cb + } + } + + private lazy val inputRDD: RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput = relation.output + val serializer = relation.cacheBuilder.serializer + + // update SQL metrics + val withMetrics = + filteredCachedBatches().mapPartitions { iter => + iter.map { batch => + numOutputRows += batch.numRows + batch + } + } + serializer.convertCachedBatchToInternalRow(withMetrics, relOutput, attributes, conf) + } + + override def output: Seq[Attribute] = attributes + + private def updateAttribute(expr: Expression): Expression = { + // attributes can be pruned so using relation's output. + // E.g., relation.output is [id, item] but this scan's output can be [item] only. + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) + expr.transform { + case attr: Attribute => attrMap.getOrElse(attr, attr) + } + } + + // The cached version does not change the outputPartitioning of the original SparkPlan. + // But the cached version could alias output, so we need to replace output. + override def outputPartitioning: Partitioning = { + relation.cachedPlan.outputPartitioning match { + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other + } + } + + // The cached version does not change the outputOrdering of the original SparkPlan. + // But the cached version could alias output, so we need to replace output. + override def outputOrdering: Seq[SortOrder] = + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + + // Accumulators used for testing purposes + lazy val readPartitions = sparkContext.longAccumulator + lazy val readBatches = sparkContext.longAccumulator + + private def filteredCachedBatches(): RDD[CachedBatch] = { + relation.cacheBuilder.cachedColumnBuffers + } + + protected override def doExecute(): RDD[InternalRow] = { + inputRDD + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + columnarInputRDD + } +} diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConditionProjector.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConditionProjector.scala index 8016bdfd2..3c04ecd1b 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConditionProjector.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConditionProjector.scala @@ -22,7 +22,11 @@ import java.util.Objects import java.util.concurrent.TimeUnit import com.google.common.collect.Lists -import com.intel.oap.expression.ColumnarConditionProjector.{FieldOptimizedProjector, FilterProjector, ProjectorWrapper} +import com.intel.oap.expression.ColumnarConditionProjector.{ + FieldOptimizedProjector, + FilterProjector, + ProjectorWrapper +} import com.intel.oap.vectorized.ArrowWritableColumnVector import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ @@ -106,7 +110,11 @@ class ColumnarConditionProjector( val resultSchema = ArrowUtils.fromArrowSchema(resultArrowSchema) if (skip) { logWarning( - s"Will do skip!!!\nconditionArrowSchema is ${conditionArrowSchema}, conditionOrdinalList is ${conditionOrdinalList}, \nprojectionArrowSchema is ${projectionArrowSchema}, projectionOrinalList is ${projectOrdinalList}, \nresult schema is ${resultArrowSchema}") + s"Will do skip!!!\nconditionArrowSchema is ${conditionArrowSchema}," + + s" conditionOrdinalList is ${conditionOrdinalList}, " + + s"\nprojectionArrowSchema is ${projectionArrowSchema}, " + + s"projectionOrinalList is ${projectOrdinalList}, " + + s"\nresult schema is ${resultArrowSchema}") } val conditioner = if (skip == false && condPrepareList != null) { @@ -149,7 +157,11 @@ class ColumnarConditionProjector( val fieldNodesList = prepareList.map(_._1).toList.asJava try { if (withCond) { - new FilterProjector(projectionSchema, resultSchema, fieldNodesList, SelectionVectorType.SV_INT16) + new FilterProjector( + projectionSchema, + resultSchema, + fieldNodesList, + SelectionVectorType.SV_INT16) } else { new FieldOptimizedProjector(projectionSchema, resultSchema, fieldNodesList) } @@ -157,7 +169,8 @@ class ColumnarConditionProjector( case e => logError( s"\noriginalInputAttributes is ${originalInputAttributes} ${originalInputAttributes.map( - _.dataType)}, \nprojectionSchema is ${projectionSchema}, \nresultSchema is ${resultSchema}, \nProjection is ${prepareList.map(_._1.toProtobuf)}") + _.dataType)}, \nprojectionSchema is ${projectionSchema}, \nresultSchema is ${resultSchema}, \nProjection is ${prepareList + .map(_._1.toProtobuf)}") throw e } } @@ -451,7 +464,10 @@ object ColumnarConditionProjector extends Logging { throw new UnsupportedOperationException } - def evaluate(recordBatch: ArrowRecordBatch, numRows: Int, selectionVector: SelectionVector): ColumnarBatch = { + def evaluate( + recordBatch: ArrowRecordBatch, + numRows: Int, + selectionVector: SelectionVector): ColumnarBatch = { throw new UnsupportedOperationException } @@ -461,8 +477,11 @@ object ColumnarConditionProjector extends Logging { /** * Proxy projector that is optimized for field projections. */ - class FieldOptimizedProjector(projectionSchema: Schema, resultSchema: Schema, - exprs: java.util.List[ExpressionTree]) extends ProjectorWrapper { + class FieldOptimizedProjector( + projectionSchema: Schema, + resultSchema: Schema, + exprs: java.util.List[ExpressionTree]) + extends ProjectorWrapper { val fieldExprs = ListBuffer[(ExpressionTree, Int)]() val fieldExprNames = new util.HashSet[String]() @@ -484,17 +503,15 @@ object ColumnarConditionProjector extends Logging { } } - val fieldResultSchema = new Schema( - fieldExprs.map { - case (_, i) => - resultSchema.getFields.get(i) - }.asJava) + val fieldResultSchema = new Schema(fieldExprs.map { + case (_, i) => + resultSchema.getFields.get(i) + }.asJava) - val nonFieldResultSchema = new Schema( - nonFieldExprs.map { - case (_, i) => - resultSchema.getFields.get(i) - }.asJava) + val nonFieldResultSchema = new Schema(nonFieldExprs.map { + case (_, i) => + resultSchema.getFields.get(i) + }.asJava) val nonFieldProjector: Option[Projector] = if (nonFieldExprs.isEmpty) { @@ -502,9 +519,13 @@ object ColumnarConditionProjector extends Logging { } else { Some( Projector.make( - projectionSchema, nonFieldExprs.map { - case (e, _) => e - }.toList.asJava)) + projectionSchema, + nonFieldExprs + .map { + case (e, _) => e + } + .toList + .asJava)) } override def evaluate(recordBatch: ArrowRecordBatch): ColumnarBatch = { @@ -513,15 +534,16 @@ object ColumnarConditionProjector extends Logging { // Execute expression-based projections val nonFieldResultColumnVectors = - ArrowWritableColumnVector.allocateColumns(numRows, + ArrowWritableColumnVector.allocateColumns( + numRows, ArrowUtils.fromArrowSchema(nonFieldResultSchema)) val outputVectors = nonFieldResultColumnVectors - .map(columnVector => { - columnVector.getValueVector - }) - .toList - .asJava + .map(columnVector => { + columnVector.getValueVector + }) + .toList + .asJava nonFieldProjector.foreach { _.evaluate(recordBatch, outputVectors) @@ -564,11 +586,10 @@ object ColumnarConditionProjector extends Logging { inAVs.foreach(_.close()) // Projected vector count check - projectedAVs.foreach { - arrowVector => - if (arrowVector == null) { - throw new IllegalStateException() - } + projectedAVs.foreach { arrowVector => + if (arrowVector == null) { + throw new IllegalStateException() + } } val outputBatch = @@ -582,22 +603,29 @@ object ColumnarConditionProjector extends Logging { } } - class FilterProjector(projectionSchema: Schema, resultSchema: Schema, + class FilterProjector( + projectionSchema: Schema, + resultSchema: Schema, exprs: java.util.List[ExpressionTree], - selectionVectorType: GandivaTypes.SelectionVectorType) extends ProjectorWrapper { + selectionVectorType: GandivaTypes.SelectionVectorType) + extends ProjectorWrapper { val projector = Projector.make(projectionSchema, exprs, selectionVectorType) - override def evaluate(recordBatch: ArrowRecordBatch, numRows: Int, + override def evaluate( + recordBatch: ArrowRecordBatch, + numRows: Int, selectionVector: SelectionVector): ColumnarBatch = { val resultColumnVectors = - ArrowWritableColumnVector.allocateColumns(numRows, ArrowUtils.fromArrowSchema(resultSchema)) + ArrowWritableColumnVector.allocateColumns( + numRows, + ArrowUtils.fromArrowSchema(resultSchema)) val outputVectors = resultColumnVectors - .map(columnVector => { - columnVector.getValueVector - }) - .toList - .asJava + .map(columnVector => { + columnVector.getValueVector + }) + .toList + .asJava projector.evaluate(recordBatch, selectionVector, outputVectors) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala index 68dad0aa6..f43a044f1 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala @@ -74,29 +74,21 @@ import java.io.{InputStream, OutputStream} import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision} object ConverterUtils extends Logging { + def calcuateEstimatedSize(columnarBatch: ColumnarBatch): Long = { + val cols = (0 until columnarBatch.numCols).toList.map(i => + columnarBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector()) + val nodes = new java.util.ArrayList[ArrowFieldNode]() + val buffers = new java.util.ArrayList[ArrowBuf]() + cols.foreach(vector => { + appendNodes(vector.asInstanceOf[FieldVector], nodes, buffers); + }) + buffers.asScala.map(_.getPossibleMemoryConsumed()).sum + } def createArrowRecordBatch(columnarBatch: ColumnarBatch): ArrowRecordBatch = { val numRowsInBatch = columnarBatch.numRows() val cols = (0 until columnarBatch.numCols).toList.map(i => columnarBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector()) createArrowRecordBatch(numRowsInBatch, cols) - - /*val fieldNodes = new ListBuffer[ArrowFieldNode]() - val inputData = new ListBuffer[ArrowBuf]() - for (i <- 0 until columnarBatch.numCols()) { - val inputVector = - columnarBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector() - fieldNodes += new ArrowFieldNode(numRowsInBatch, inputVector.getNullCount()) - //FIXME for projection + in test - //fieldNodes += new ArrowFieldNode(numRowsInBatch, inputVector.getNullCount()) - inputData += inputVector.getValidityBuffer() - if (inputVector.isInstanceOf[VarCharVector]) { - inputData += inputVector.getOffsetBuffer() - } - inputData += inputVector.getDataBuffer() - //FIXME for projection + in test - //inputData += inputVector.getValidityBuffer() - } - new ArrowRecordBatch(numRowsInBatch, fieldNodes.toList.asJava, inputData.toList.asJava)*/ } def createArrowRecordBatch(numRowsInBatch: Int, cols: List[ValueVector]): ArrowRecordBatch = { @@ -225,13 +217,21 @@ object ConverterUtils extends Logging { def convertFromNetty( attributes: Seq[Attribute], - data: Array[Array[Byte]]): Iterator[ColumnarBatch] = { + data: Array[Array[Byte]], + columnIndices: Array[Int] = null): Iterator[ColumnarBatch] = { if (data.size == 0) { return new Iterator[ColumnarBatch] { override def hasNext: Boolean = false override def next(): ColumnarBatch = { - val resultStructType = StructType( - attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + val resultStructType = if (columnIndices == null) { + StructType( + attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + } else { + StructType( + columnIndices + .map(i => attributes(i)) + .map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + } val resultColumnVectors = ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0) @@ -306,7 +306,14 @@ object ConverterUtils extends Logging { val vectors = fromArrowRecordBatch(schema, batch, allocator) val length = batch.getLength batch.close - new ColumnarBatch(vectors.map(_.asInstanceOf[ColumnVector]), length) + if (columnIndices == null) { + new ColumnarBatch(vectors.map(_.asInstanceOf[ColumnVector]), length) + } else { + new ColumnarBatch( + columnIndices.map(i => vectors(i).asInstanceOf[ColumnVector]), + length) + } + } catch { case e: Throwable => messageReader.close diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala new file mode 100644 index 000000000..f14e7a811 --- /dev/null +++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BindReferences, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, Length, LessThan, LessThanOrEqual, Literal, Or, Predicate, StartsWith} +import org.apache.spark.sql.execution.columnar.{ColumnStatisticsSchema, PartitionStatistics} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{AtomicType, BinaryType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.StorageLevel + +/** + * Basic interface that all cached batches of data must support. This is primarily to allow + * for metrics to be handled outside of the encoding and decoding steps in a standard way. + */ +@DeveloperApi +@Since("3.1.0") +trait CachedBatch { + def numRows: Int + def sizeInBytes: Long +} + +/** + * Provides APIs that handle transformations of SQL data associated with the cache/persist APIs. + */ +@DeveloperApi +@Since("3.1.0") +trait CachedBatchSerializer extends Serializable { + /** + * Can `convertColumnarBatchToCachedBatch()` be called instead of + * `convertInternalRowToCachedBatch()` for this given schema? True if it can and false if it + * cannot. Columnar input is only supported if the plan could produce columnar output. Currently + * this is mostly supported by input formats like parquet and orc, but more operations are likely + * to be supported soon. + * @param schema the schema of the data being stored. + * @return True if columnar input can be supported, else false. + */ + def supportsColumnarInput(schema: Seq[Attribute]): Boolean + + /** + * Convert an `RDD[InternalRow]` into an `RDD[CachedBatch]` in preparation for caching the data. + * @param input the input `RDD` to be converted. + * @param schema the schema of the data being stored. + * @param storageLevel where the data will be stored. + * @param conf the config for the query. + * @return The data converted into a format more suitable for caching. + */ + def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] + + /** + * Convert an `RDD[ColumnarBatch]` into an `RDD[CachedBatch]` in preparation for caching the data. + * This will only be called if `supportsColumnarInput()` returned true for the given schema and + * the plan up to this point would could produce columnar output without modifying it. + * @param input the input `RDD` to be converted. + * @param schema the schema of the data being stored. + * @param storageLevel where the data will be stored. + * @param conf the config for the query. + * @return The data converted into a format more suitable for caching. + */ + def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] + + /** + * Builds a function that can be used to filter batches prior to being decompressed. + * In most cases extending [[SimpleMetricsCachedBatchSerializer]] will provide the filter logic + * necessary. You will need to provide metrics for this to work. [[SimpleMetricsCachedBatch]] + * provides the APIs to hold those metrics and explains the metrics used, really just min and max. + * Note that this is intended to skip batches that are not needed, and the actual filtering of + * individual rows is handled later. + * @param predicates the set of expressions to use for filtering. + * @param cachedAttributes the schema/attributes of the data that is cached. This can be helpful + * if you don't store it with the data. + * @return a function that takes the partition id and the iterator of batches in the partition. + * It returns an iterator of batches that should be decompressed. + */ + def buildFilter( + predicates: Seq[Expression], + cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] + + /** + * Can `convertCachedBatchToColumnarBatch()` be called instead of + * `convertCachedBatchToInternalRow()` for this given schema? True if it can and false if it + * cannot. Columnar output is typically preferred because it is more efficient. Note that + * `convertCachedBatchToInternalRow()` must always be supported as there are other checks that + * can force row based output. + * @param schema the schema of the data being checked. + * @return true if columnar output should be used for this schema, else false. + */ + def supportsColumnarOutput(schema: StructType): Boolean + + /** + * The exact java types of the columns that are output in columnar processing mode. This + * is a performance optimization for code generation and is optional. + * @param attributes the attributes to be output. + * @param conf the config for the query that will read the data. + */ + def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = None + + /** + * Convert the cached data into a ColumnarBatch. This currently is only used if + * `supportsColumnarOutput()` returns true for the associated schema, but there are other checks + * that can force row based output. One of the main advantages of doing columnar output over row + * based output is that the code generation is more standard and can be combined with code + * generation for downstream operations. + * @param input the cached batches that should be converted. + * @param cacheAttributes the attributes of the data in the batch. + * @param selectedAttributes the fields that should be loaded from the data and the order they + * should appear in the output batch. + * @param conf the configuration for the job. + * @return an RDD of the input cached batches transformed into the ColumnarBatch format. + */ + def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] + + /** + * Convert the cached batch into `InternalRow`s. If you want this to be performant, code + * generation is advised. + * @param input the cached batches that should be converted. + * @param cacheAttributes the attributes of the data in the batch. + * @param selectedAttributes the field that should be loaded from the data and the order they + * should appear in the output rows. + * @param conf the configuration for the job. + * @return RDD of the rows that were stored in the cached batches. + */ + def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] +} + +/** + * A [[CachedBatch]] that stores some simple metrics that can be used for filtering of batches with + * the [[SimpleMetricsCachedBatchSerializer]]. + * The metrics are returned by the stats value. For each column in the batch 5 columns of metadata + * are needed in the row. + */ +@DeveloperApi +@Since("3.1.0") +trait SimpleMetricsCachedBatch extends CachedBatch { + /** + * Holds stats for each cached column. The optional `upperBound` and `lowerBound` should be + * of the same type as the original column. If they are null, then it is assumed that they + * are not provided, and will not be used for filtering. + * <ul> + * <li>`upperBound` (optional)</li> + * <li>`lowerBound` (Optional)</li> + * <li>`nullCount`: `Int`</li> + * <li>`rowCount`: `Int`</li> + * <li>`sizeInBytes`: `Long`</li> + * </ul> + * These are repeated for each column in the original cached data. + */ + val stats: InternalRow + override def sizeInBytes: Long = + Range.apply(4, stats.numFields, 5).map(stats.getLong).sum +} + +// Currently, uses statistics for all atomic types that are not `BinaryType`. +private object ExtractableLiteral { + def unapply(expr: Expression): Option[Literal] = expr match { + case lit: Literal => lit.dataType match { + case BinaryType => None + case _: AtomicType => Some(lit) + case _ => None + } + case _ => None + } +} + +/** + * Provides basic filtering for [[CachedBatchSerializer]] implementations. + * The requirement to extend this is that all of the batches produced by your serializer are + * instances of [[SimpleMetricsCachedBatch]]. + * This does not calculate the metrics needed to be stored in the batches. That is up to each + * implementation. The metrics required are really just min and max values and those are optional + * especially for complex types. Because those metrics are simple and it is likely that compression + * will also be done on the data we thought it best to let each implementation decide on the most + * efficient way to calculate the metrics, possibly combining them with compression passes that + * might also be done across the data. + */ +@DeveloperApi +@Since("3.1.0") +abstract class SimpleMetricsCachedBatchSerializer extends CachedBatchSerializer with Logging { + override def buildFilter( + predicates: Seq[Expression], + cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] = { + throw new UnsupportedOperationException("buildFilter is not yet supported") + } + /*override def buildFilter( + predicates: Seq[Expression], + cachedAttributes: Seq[Attribute]): (Int, Iterator[CachedBatch]) => Iterator[CachedBatch] = { + val stats = new PartitionStatistics(cachedAttributes) + val statsSchema = stats.schema + + def statsFor(a: Attribute): ColumnStatisticsSchema = { + stats.forAttribute(a) + } + + // Returned filter predicate should return false iff it is impossible for the input expression + // to evaluate to `true` based on statistics collected about this partition batch. + @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { + case And(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => + (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) + + case Or(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => + buildFilter(lhs) || buildFilter(rhs) + + case EqualTo(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound + case EqualTo(ExtractableLiteral(l), a: AttributeReference) => + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound + + case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound + case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) => + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound + + case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l + case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound + + case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound <= l + case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + l <= statsFor(a).upperBound + + case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound + case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l + + case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + l <= statsFor(a).upperBound + case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + statsFor(a).lowerBound <= l + + case IsNull(a: Attribute) => statsFor(a).nullCount > 0 + case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 + + case In(a: AttributeReference, list: Seq[Expression]) + if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => + list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && + l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) + // This is an example to explain how it works, imagine that the id column stored as follows: + // __________________________________________ + // | Partition ID | lowerBound | upperBound | + // |--------------|------------|------------| + // | p1 | '1' | '9' | + // | p2 | '10' | '19' | + // | p3 | '20' | '29' | + // | p4 | '30' | '39' | + // | p5 | '40' | '49' | + // |______________|____________|____________| + // + // A filter: df.filter($"id".startsWith("2")). + // In this case it substr lowerBound and upperBound: + // ________________________________________________________________________________________ + // | Partition ID | lowerBound.substr(0, Length("2")) | upperBound.substr(0, Length("2")) | + // |--------------|-----------------------------------|-----------------------------------| + // | p1 | '1' | '9' | + // | p2 | '1' | '1' | + // | p3 | '2' | '2' | + // | p4 | '3' | '3' | + // | p5 | '4' | '4' | + // |______________|___________________________________|___________________________________| + // + // We can see that we only need to read p1 and p3. + case StartsWith(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound.substr(0, Length(l)) <= l && + l <= statsFor(a).upperBound.substr(0, Length(l)) + } + + // When we bind the filters we need to do it against the stats schema + val partitionFilters: Seq[Expression] = { + predicates.flatMap { p => + val filter = buildFilter.lift(p) + val boundFilter = + filter.map( + BindReferences.bindReference( + _, + statsSchema, + allowFailures = true)) + + boundFilter.foreach(_ => + filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f"))) + + // If the filter can't be resolved then we are missing required statistics. + boundFilter.filter(_.resolved) + } + } + + def ret(index: Int, cachedBatchIterator: Iterator[CachedBatch]): Iterator[CachedBatch] = { + val partitionFilter = Predicate.create( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + cachedAttributes) + + partitionFilter.initialize(index) + val schemaIndex = cachedAttributes.zipWithIndex + + cachedBatchIterator.filter { cb => + val cachedBatch = cb.asInstanceOf[SimpleMetricsCachedBatch] + if (!partitionFilter.eval(cachedBatch.stats)) { + logDebug { + val statsString = schemaIndex.map { case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") + s"Skipping partition based on stats $statsString" + } + false + } else { + true + } + } + } + ret + }*/ +} diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala new file mode 100644 index 000000000..3d7efe4ef --- /dev/null +++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.TaskContext +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.columnar.{ + CachedBatch, + CachedBatchSerializer, + SimpleMetricsCachedBatch, + SimpleMetricsCachedBatchSerializer +} +import org.apache.spark.sql.execution.{ + InputAdapter, + QueryExecution, + SparkPlan, + WholeStageCodegenExec, + ColumnarToRowExec +} +import org.apache.spark.sql.execution.vectorized.{ + OffHeapColumnVector, + OnHeapColumnVector, + WritableColumnVector +} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.types.{ + BooleanType, + ByteType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StructType, + UserDefinedType +} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{LongAccumulator, Utils} + +/** + * The default implementation of CachedBatch. + * + * @param numRows The total number of rows in this batch + * @param buffers The buffers for serialized columns + * @param stats The stat of columns + */ +case class DefaultCachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) + extends SimpleMetricsCachedBatch + +private[sql] case class CachedRDDBuilder( + serializer: CachedBatchSerializer, + storageLevel: StorageLevel, + @transient cachedPlan: SparkPlan, + tableName: Option[String]) { + + @transient @volatile private var _cachedColumnBuffers + : RDD[org.apache.spark.sql.columnar.CachedBatch] = null + + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator + val rowCountStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator + + val cachedName = tableName + .map(n => s"In-memory table $n") + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024)) + + def cachedColumnBuffers: RDD[org.apache.spark.sql.columnar.CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } + } + _cachedColumnBuffers + } + + def clearCache(blocking: Boolean = false): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.foreach(buffer => buffer match { + case b: com.intel.oap.execution.ArrowCachedBatch => + b.release + case other => + }) + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } + } + } + + def isCachedColumnBuffersLoaded: Boolean = { + _cachedColumnBuffers != null + } + + private def buildBuffers(): RDD[org.apache.spark.sql.columnar.CachedBatch] = { + val cb = serializer.convertColumnarBatchToCachedBatch( + cachedPlan.executeColumnar(), + cachedPlan.output, + storageLevel, + cachedPlan.conf) + + val cached = cb + .map { batch => + sizeInBytesStats.add(batch.sizeInBytes) + rowCountStats.add(batch.numRows) + batch + } + .persist(storageLevel) + cached.setName(cachedName) + cached + } +} + +object InMemoryRelation { + + private[this] var ser: Option[CachedBatchSerializer] = None + private[this] def getSerializer(sqlConf: SQLConf): CachedBatchSerializer = synchronized { + if (ser.isEmpty) { + val serClass = + Utils.classForName("com.intel.oap.execution.ArrowColumnarCachedBatchSerializer") + val instance = serClass.getConstructor().newInstance().asInstanceOf[CachedBatchSerializer] + ser = Some(instance) + } + ser.get + } + + /* Visible for testing */ + private[columnar] def clearSerializer(): Unit = synchronized { ser = None } + + def convertToColumnarIfPossible(plan: SparkPlan): SparkPlan = plan match { + case gen: WholeStageCodegenExec => + gen.child match { + case c2r: ColumnarToRowExec => + c2r.child match { + case ia: InputAdapter => ia.child + case _ => plan + } + case _ => plan + } + case c2r: ColumnarToRowExec => // This matches when whole stage code gen is disabled. + c2r.child + case _ => plan + } + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + optimizedPlan: LogicalPlan): InMemoryRelation = { + val serializer = getSerializer(optimizedPlan.conf) + val columnarChild = convertToColumnarIfPossible(child) + val cacheBuilder = CachedRDDBuilder(serializer, storageLevel, columnarChild, tableName) + val relation = + new InMemoryRelation(columnarChild.output, cacheBuilder, optimizedPlan.outputOrdering) + relation.statsOfPlanToCache = optimizedPlan.stats + relation + } + + /** + * This API is intended only to be used for testing. + */ + def apply( + serializer: CachedBatchSerializer, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + optimizedPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(serializer, storageLevel, child, tableName) + val relation = new InMemoryRelation(child.output, cacheBuilder, optimizedPlan.outputOrdering) + relation.statsOfPlanToCache = optimizedPlan.stats + relation + } + + def apply(cacheBuilder: CachedRDDBuilder, qe: QueryExecution): InMemoryRelation = { + val optimizedPlan = qe.optimizedPlan + val newBuilder = if (cacheBuilder.serializer.supportsColumnarInput(optimizedPlan.output)) { + cacheBuilder.copy(cachedPlan = convertToColumnarIfPossible(qe.executedPlan)) + } else { + cacheBuilder.copy(cachedPlan = qe.executedPlan) + } + val relation = + new InMemoryRelation(newBuilder.cachedPlan.output, newBuilder, optimizedPlan.outputOrdering) + relation.statsOfPlanToCache = optimizedPlan.stats + relation + } + + def apply( + output: Seq[Attribute], + cacheBuilder: CachedRDDBuilder, + outputOrdering: Seq[SortOrder], + statsOfPlanToCache: Statistics): InMemoryRelation = { + val relation = InMemoryRelation(output, cacheBuilder, outputOrdering) + relation.statsOfPlanToCache = statsOfPlanToCache + relation + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode + with MultiInstanceRelation { + + @volatile var statsOfPlanToCache: Statistics = null + + override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy( + output = output.map(QueryPlan.normalizeExpressions(_, cachedPlan.output)), + cacheBuilder, + outputOrdering) + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + private[sql] def updateStats(rowCount: Long, newColStats: Map[Attribute, ColumnStat]): Unit = + this.synchronized { + val newStats = statsOfPlanToCache.copy( + rowCount = Some(rowCount), + attributeStats = AttributeMap((statsOfPlanToCache.attributeStats ++ newColStats).toSeq)) + statsOfPlanToCache = newStats + } + + override def computeStats(): Statistics = { + if (!cacheBuilder.isCachedColumnBuffersLoaded) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + statsOfPlanToCache + } else { + statsOfPlanToCache.copy( + sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue, + rowCount = Some(cacheBuilder.rowCountStats.value.longValue)) + } + } + + def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = + InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache) + + override def newInstance(): this.type = { + InMemoryRelation( + output.map(_.newInstance()), + cacheBuilder, + outputOrdering, + statsOfPlanToCache).asInstanceOf[this.type] + } + + // override `clone` since the default implementation won't carry over mutable states. + override def clone(): LogicalPlan = { + val cloned = this.copy() + cloned.statsOfPlanToCache = this.statsOfPlanToCache + cloned + } + + override def simpleString(maxFields: Int): String = + s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}" +}