diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/IteratorWrapper.java b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/IteratorWrapper.java new file mode 100644 index 000000000..8f24b38a5 --- /dev/null +++ b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/IteratorWrapper.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.oap.vectorized; + + +import scala.collection.convert.Wrappers; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; + +public class IteratorWrapper { + + private Iterator> in; + + public IteratorWrapper(Iterator> in) { + this.in = in; + } + + public boolean hasNext() { + return in.hasNext(); + } + + public List next() { + return in.next(); + } + +} diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java index 93d5f3223..2e47f538c 100644 --- a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java +++ b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java @@ -83,6 +83,28 @@ public native long nativeMake( long memoryPoolId, boolean writeSchema); + public long make( + NativePartitioning part, + long offheapPerTask, + int bufferSize) { + return initSplit( + part.getShortName(), + part.getNumPartitions(), + part.getSchema(), + part.getExprList(), + offheapPerTask, + bufferSize + ); + } + + public native long initSplit( + String shortName, + int numPartitions, + byte[] schema, + byte[] exprList, + long offheapPerTask, + int bufferSize); + /** * * Spill partition data to disk. @@ -113,6 +135,11 @@ public native long split( long splitterId, int numRows, long[] bufAddrs, long[] bufSizes, boolean firstRecordBatch) throws IOException; + /** + * Collect the record batch after splitting. + */ + public native void collect(long splitterId, int numRows) throws IOException; + /** * Update the compress type. */ @@ -127,6 +154,15 @@ public native long split( */ public native SplitResult stop(long splitterId) throws IOException; + /** + * Clear the buffer. And stop processing splitting + * + * @param splitterId splitter instance id + * @return SplitResult + */ + public native SplitResult clear(long splitterId) throws IOException; + + /** * Release resources associated with designated splitter instance. * diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/SplitIterator.java b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/SplitIterator.java new file mode 100644 index 000000000..b758e6bc7 --- /dev/null +++ b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/SplitIterator.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.oap.vectorized; + + +import com.intel.oap.expression.ConverterUtils; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.ipc.message.ArrowBuffer; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Iterator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SplitIterator implements Iterator{ + + private static final Logger logger = LoggerFactory.getLogger(SplitIterator.class); + + public static class IteratorOptions implements Serializable { + private static final long serialVersionUID = -1L; + + private int partitionNum; + + private String name; + + private long offheapPerTask; + + private int bufferSize; + + private String expr; + + public NativePartitioning getNativePartitioning() { + return nativePartitioning; + } + + public void setNativePartitioning(NativePartitioning nativePartitioning) { + this.nativePartitioning = nativePartitioning; + } + + NativePartitioning nativePartitioning; + + public int getPartitionNum() { + return partitionNum; + } + + public void setPartitionNum(int partitionNum) { + this.partitionNum = partitionNum; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public long getOffheapPerTask() { + return offheapPerTask; + } + + public void setOffheapPerTask(long offheapPerTask) { + this.offheapPerTask = offheapPerTask; + } + + public int getBufferSize() { + return bufferSize; + } + + public void setBufferSize(int bufferSize) { + this.bufferSize = bufferSize; + } + + public String getExpr() { + return expr; + } + + public void setExpr(String expr) { + this.expr = expr; + } + + } + + ShuffleSplitterJniWrapper jniWrapper = null; + + private long nativeSplitter = 0; + private final Iterator iterator; + private final IteratorOptions options; + + private ColumnarBatch cb = null; + + public SplitIterator(Iterator iterator, IteratorOptions options) { + this.iterator = iterator; + this.options = options; + } + + private void nativeCreateInstance() { + for (int i = 0; i < cb.numCols(); i++) { + ArrowWritableColumnVector vector = (ArrowWritableColumnVector)(cb.column(i)); + vector.getValueVector().setValueCount(cb.numRows()); + } + ArrowRecordBatch recordBatch = ConverterUtils.createArrowRecordBatch(cb); + try { + if (jniWrapper == null) { + jniWrapper = new ShuffleSplitterJniWrapper(); + } + if (nativeSplitter != 0) { + jniWrapper.clear(nativeSplitter); + nativeSplitter = 0; + // throw new Exception("NativeSplitter is not clear."); + } + nativeSplitter = jniWrapper.make( + options.getNativePartitioning(), + options.getOffheapPerTask(), + options.getBufferSize()); + long[] bufAddrs = new long[recordBatch.getBuffers().size()]; + long[] bufSizes = new long[recordBatch.getBuffersLayout().size()]; + int i = 0, j = 0; + for (ArrowBuf buffer: recordBatch.getBuffers()) { + bufAddrs[i++] = buffer.memoryAddress(); + } + for (ArrowBuffer buffer: recordBatch.getBuffersLayout()) { + bufSizes[j++] = buffer.getSize(); + } + if (i != j || i < 1) { + logger.warn("bufAddrs and BuffersLayout have different lengths, and buffer sizes is " + i + " -- " + j); + } + jniWrapper.split(nativeSplitter, cb.numRows(), bufAddrs, bufSizes, false); + jniWrapper.collect(nativeSplitter, cb.numRows()); + } catch (Exception e) { + if (nativeSplitter != 0) { + try { + jniWrapper.clear(nativeSplitter); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + nativeSplitter = 0; + } + throw new RuntimeException(e); + } finally { + ConverterUtils.releaseArrowRecordBatch(recordBatch); + // cb.close(); + } + + } + + private native boolean nativeHasNext(long instance); + + public boolean hasRecordBatch(){ + while (iterator.hasNext()) { + cb = iterator.next(); + if (cb.numRows() != 0 && cb.numCols() != 0) { + nativeCreateInstance(); + return true; + } + } + if (nativeSplitter != 0) { + try { + jniWrapper.clear(nativeSplitter); + nativeSplitter = 0; + } catch (IOException e) { + throw new RuntimeException(e); + } finally { +// jniWrapper.close(nativeSplitter); + } + } + return false; + } + + @Override + public boolean hasNext() { + // 1. Init the native splitter + if (nativeSplitter == 0) { + return hasRecordBatch() && nativeHasNext(nativeSplitter); + } + // 2. Call native hasNext + if (nativeHasNext(nativeSplitter)) { + return true; + } else { + return hasRecordBatch() && nativeHasNext(nativeSplitter); + } + } + + private native byte[] nativeNext(long instance); + + @Override + public ColumnarBatch next() { + byte[] serializedRecordBatch = nativeNext(nativeSplitter); + return ConverterUtils.createRecordBatch(serializedRecordBatch, + options.getNativePartitioning().getSchema()); + } + + private native int nativeNextPartitionId(long nativeSplitter); + + public int nextPartitionId() { + return nativeNextPartitionId(nativeSplitter); + } + + @Override + protected void finalize() throws Throwable { + try { + if (nativeSplitter != 0) { + logger.error("NativeSplitter is not clear."); + jniWrapper.clear(nativeSplitter); + nativeSplitter = 0; + } + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + jniWrapper.close(nativeSplitter); + } + } + +} diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/GazellePluginConfig.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/GazellePluginConfig.scala index 408c70e0e..67c4a4b1c 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/GazellePluginConfig.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/GazellePluginConfig.scala @@ -83,6 +83,11 @@ class GazellePluginConfig(conf: SQLConf) extends Logging { val enableColumnarShuffledHashJoin: Boolean = conf.getConfString("spark.oap.sql.columnar.shuffledhashjoin", "true").toBoolean && enableCpu + // enable or disable fallback shuffle manager + val enableFallbackShuffle: Boolean = conf + .getConfString("spark.oap.sql.columnar.enableFallbackShuffle", "false") + .equals("true") && enableCpu + val enableArrowColumnarToRow: Boolean = conf.getConfString("spark.oap.sql.columnar.columnartorow", "true").toBoolean && enableCpu diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala index 273889bb4..3e287fd9f 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala @@ -21,7 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, IOException} import java.nio.channels.Channels import java.nio.ByteBuffer import java.util.ArrayList - import com.intel.oap.vectorized.ArrowWritableColumnVector import io.netty.buffer.{ByteBufAllocator, ByteBufOutputStream} import org.apache.arrow.memory.ArrowBuf @@ -50,13 +49,14 @@ import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import io.netty.buffer.{ByteBuf, ByteBufAllocator, ByteBufOutputStream} -import java.nio.channels.{Channels, WritableByteChannel} +import java.nio.channels.{Channels, WritableByteChannel} import com.google.common.collect.Lists +import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer + import java.io.{InputStream, OutputStream} import java.util import java.util.concurrent.TimeUnit.SECONDS - import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.types.TimeUnit import org.apache.arrow.vector.types.pojo.ArrowType @@ -64,14 +64,27 @@ import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision} import org.apache.spark.sql.catalyst.util.{DateTimeConstants, DateTimeUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND -import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils -import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils +import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils, SparkVectorUtils} object ConverterUtils extends Logging { def calcuateEstimatedSize(columnarBatch: ColumnarBatch): Long = { SparkVectorUtils.estimateSize(columnarBatch) } + def createRecordBatch(serializedRecordBatch: Array[Byte], serializedSchema: Array[Byte]): ColumnarBatch = { + val schema = ConverterUtils.getSchemaFromBytesBuf(serializedSchema); + val allocator = SparkMemoryUtils.contextAllocatorForBufferImport + val resultBatch = UnsafeRecordBatchSerializer.deserializeUnsafe(allocator, serializedRecordBatch) + if (resultBatch == null) { + throw new Exception("Error from SerializedRecordBatch to ColumnarBatch.") + } else { + val resultColumnVectorList = fromArrowRecordBatch(schema, resultBatch) + val length = resultBatch.getLength + ConverterUtils.releaseArrowRecordBatch(resultBatch) + new ColumnarBatch(resultColumnVectorList.map(v => v.asInstanceOf[ColumnVector]), length) + } + } + def createArrowRecordBatch(columnarBatch: ColumnarBatch): ArrowRecordBatch = { SparkVectorUtils.toArrowRecordBatch(columnarBatch) } @@ -369,6 +382,19 @@ object ConverterUtils extends Logging { } } + def getShortAttributeName(attr: Attribute): String = { + val index = attr.name.indexOf("(") + if (index != -1) { + attr.name.substring(0, index) + } else { + attr.name + } + } + + def genColumnNameWithExprId(attr: Attribute): String = { + getShortAttributeName(attr) + "#" + attr.exprId.id + } + def getResultAttrFromExpr( fieldExpr: Expression, name: String = "None", diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala index 461b55503..4f454a5b0 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala @@ -182,7 +182,7 @@ case class ColumnarPreOverrides(session: SparkSession) extends Rule[SparkPlan] { case plan: ShuffleExchangeExec => val child = replaceWithColumnarPlan(plan.child) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - if ((child.supportsColumnar || columnarConf.enablePreferColumnar) && columnarConf.enableColumnarShuffle) { + if ((child.supportsColumnar || columnarConf.enablePreferColumnar) && (columnarConf.enableColumnarShuffle || columnarConf.enableFallbackShuffle)) { if (isSupportAdaptive) { new ColumnarShuffleExchangeAdaptor( plan.outputPartitioning, @@ -289,7 +289,7 @@ case class ColumnarPreOverrides(session: SparkSession) extends Rule[SparkPlan] { case plan if (SparkShimLoader.getSparkShims.isCustomShuffleReaderExec(plan) - && columnarConf.enableColumnarShuffle) => + && (columnarConf.enableColumnarShuffle || columnarConf.enableFallbackShuffle)) => val child = SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(plan) val partitionSpecs = SparkShimLoader.getSparkShims.getPartitionSpecsOfCustomShuffleReaderExec(plan) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala index 16cb31216..e933a96b4 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala @@ -56,6 +56,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { val preferColumnar = columnarConf.enablePreferColumnar val optimizeLevel = columnarConf.joinOptimizationThrottle val enableColumnarShuffle = columnarConf.enableColumnarShuffle + val enableFallbackShuffle = columnarConf.enableFallbackShuffle val enableColumnarSort = columnarConf.enableColumnarSort val enableColumnarWindow = columnarConf.enableColumnarWindow val enableColumnarSortMergeJoin = columnarConf.enableColumnarSortMergeJoin @@ -133,7 +134,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { if (!enableColumnarSort) return false new ColumnarSortExec(plan.sortOrder, plan.global, plan.child, plan.testSpillFrequency) case plan: ShuffleExchangeExec => - if (!enableColumnarShuffle) return false + if (!enableColumnarShuffle && !enableFallbackShuffle) return false new ColumnarShuffleExchangeExec( plan.outputPartitioning, plan.child) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala index a5e4d814b..b66ea08ce 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala @@ -19,22 +19,19 @@ package com.intel.oap.vectorized import java.io._ import java.nio.ByteBuffer - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.reflect.ClassTag - import com.intel.oap.GazellePluginConfig import com.intel.oap.expression.ConverterUtils import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer import org.apache.arrow.memory.ArrowBuf import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector.ipc.ArrowStreamReader -import org.apache.arrow.vector.VectorLoader -import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.{ArrowStreamReader, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} +import org.apache.arrow.vector.{FieldVector, VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.types.pojo.Schema - import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.serializer.DeserializationStream @@ -43,6 +40,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.serializer.SerializerInstance import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils +import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils.toArrowRecordBatch import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -50,6 +48,8 @@ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.ColumnVector import org.apache.spark.sql.vectorized.ColumnarBatch +import java.nio.channels.Channels + class ArrowColumnarBatchSerializer( schema: StructType, readBatchNumRows: SQLMetric, numOutputRows: SQLMetric) extends Serializer with Serializable { @@ -252,9 +252,48 @@ private class ArrowColumnarBatchSerializerInstance( } } - // Columnar shuffle write process don't need this. - override def serializeStream(s: OutputStream): SerializationStream = - throw new UnsupportedOperationException + override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { + + override def writeKey[T: ClassTag](key: T): SerializationStream = { + // The key is only needed on the map side when computing partition ids. + // It does not need to be shuffled. + assert(null == key || key.isInstanceOf[Int]) + this + } + + override def writeValue[T: ClassTag](value: T): SerializationStream = { + val cb = value.asInstanceOf[ColumnarBatch] + val recordBatch = ConverterUtils.createArrowRecordBatch(cb) + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), recordBatch) + } catch { + case e: Exception => + logError("Failed to serialize current RecordBatch", e) + } finally { + ConverterUtils.releaseArrowRecordBatch(recordBatch) + // cb.close + } + this + } + + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { + // This method is never called by shuffle code + throw new UnsupportedOperationException + } + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + // This method is never called by shuffle code + throw new UnsupportedOperationException + } + + override def flush(): Unit = { + out.flush() + } + + override def close(): Unit = { + out.close() + } + } // These methods are never called by shuffle code. override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/vectorized/CloseablePartitionedBatchIterator.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/vectorized/CloseablePartitionedBatchIterator.scala new file mode 100644 index 000000000..fc87a4b47 --- /dev/null +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/vectorized/CloseablePartitionedBatchIterator.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.oap.vectorized + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * An Iterator that insures that the batches [[ColumnarBatch]]s it iterates over are all closed + * properly. + */ +class CloseablePartitionedBatchIterator(itr: Iterator[Product2[Int, ColumnarBatch]]) + extends Iterator[Product2[Int, ColumnarBatch]] + with Logging { + var cb: ColumnarBatch = null + + private def closeCurrentBatch(): Unit = { + if (cb != null) { + cb.close() + cb = null + } + } + + + TaskContext.get().addTaskCompletionListener[Unit] { _ => + closeCurrentBatch() + } + + override def hasNext: Boolean = { + itr.hasNext + } + + override def next(): Product2[Int, ColumnarBatch] = { + closeCurrentBatch() + val value = itr.next() + cb = value._2 + value + } + +} diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 54dcfebf3..c43adebcd 100644 --- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution import com.google.common.collect.Lists import com.intel.oap.expression.{CodeGeneration, ColumnarExpression, ColumnarExpressionConverter, ConverterUtils} import com.intel.oap.GazellePluginConfig -import com.intel.oap.vectorized.{ArrowColumnarBatchSerializer, ArrowWritableColumnVector, NativePartitioning} +import com.intel.oap.vectorized.SplitIterator.IteratorOptions +import com.intel.oap.vectorized.{ArrowColumnarBatchSerializer, ArrowWritableColumnVector, CloseablePartitionedBatchIterator, NativePartitioning, ShuffleSplitterJniWrapper, SplitIterator, SplitResult} import org.apache.arrow.gandiva.expression.TreeBuilder import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} @@ -34,10 +35,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.CoalesceExec.EmptyPartition import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -45,11 +44,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{MutablePair, Utils} + import scala.collection.JavaConverters._ import scala.concurrent.Future -import org.apache.spark.sql.util.ArrowUtils - case class ColumnarShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, @@ -339,6 +337,16 @@ object ColumnarShuffleExchangeExec extends Logging { } } + private val conf = SparkEnv.get.conf + private val offHeadSize = conf.getSizeAsBytes("spark.memory.offHeap.size", 0) + private val executorNum = conf.getInt("spark.executor.cores", 1) + private val offheapPerTask = offHeadSize / executorNum + private val nativeBufferSize = GazellePluginConfig.getSessionConf.shuffleSplitDefaultSize + + private val jniWrapper = new ShuffleSplitterJniWrapper() + private var splitResult: SplitResult = _ + private var firstRecordBatch: Boolean = true + def prepareShuffleDependency( rdd: RDD[ColumnarBatch], outputAttributes: Seq[Attribute], @@ -455,39 +463,132 @@ object ColumnarShuffleExchangeExec extends Logging { // Thus in Columnar Shuffle we never use the "key" part. val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition - val rddWithDummyKey: RDD[Product2[Int, ColumnarBatch]] = newPartitioning match { - case RangePartitioning(sortingExpressions, _) => - rdd.mapPartitionsWithIndexInternal((_, cbIter) => { - val partitionKeyExtractor: InternalRow => Any = { - val projection = - UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) - row => projection(row) - } - val newIter = computeAndAddPartitionId(cbIter, partitionKeyExtractor) + val rddWithPartitionKey: RDD[Product2[Int, ColumnarBatch]] = + if (GazellePluginConfig.getSessionConf.enableColumnarShuffle) { + newPartitioning match { + case RangePartitioning(sortingExpressions, _) => + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + val partitionKeyExtractor: InternalRow => Any = { + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + row => projection(row) + } + val newIter = computeAndAddPartitionId(cbIter, partitionKeyExtractor) - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit] { _ => - newIter.closeAppendedVector() - } + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit] { _ => + newIter.closeAppendedVector() + } - newIter - }, isOrderSensitive = isOrderSensitive) - case _ => - rdd.mapPartitionsWithIndexInternal( - (_, cbIter) => - cbIter.map { cb => - (0 until cb.numCols).foreach( - cb.column(_) - .asInstanceOf[ArrowWritableColumnVector] - .getValueVector - .setValueCount(cb.numRows)) - (0, cb) + newIter + }, isOrderSensitive = isOrderSensitive) + case _ => + rdd.mapPartitionsWithIndexInternal( + (_, cbIter) => + cbIter.map { cb => + (0 until cb.numCols).foreach( + cb.column(_) + .asInstanceOf[ArrowWritableColumnVector] + .getValueVector + .setValueCount(cb.numRows)) + (0, cb) + }, + isOrderSensitive = isOrderSensitive) + } + } else { + val options = new IteratorOptions + options.setExpr("") + options.setOffheapPerTask(offheapPerTask) + options.setBufferSize(nativeBufferSize) + options.setNativePartitioning(nativePartitioning) + newPartitioning match { + case HashPartitioning(exprs, n) => + rdd.mapPartitionsWithIndexInternal( + (_, cbIter) => { + options.setPartitionNum(n) + val fields = exprs.zipWithIndex.map { + case (expr, i) => + val attribute = ConverterUtils.getAttrFromExpr(expr) + ConverterUtils.genColumnNameWithExprId(attribute) + } + options.setExpr(fields.mkString(",")) + options.setName("hash") + // ColumnarBatch Iterator + val iter = new Iterator[Product2[Int, ColumnarBatch]] { + val splitIterator = new SplitIterator(cbIter.asJava, options) + + override def hasNext: Boolean = splitIterator.hasNext + + override def next(): Product2[Int, ColumnarBatch] = + (splitIterator.nextPartitionId(), splitIterator.next()); + } + new CloseablePartitionedBatchIterator(iter) }, - isOrderSensitive = isOrderSensitive) - } + isOrderSensitive = isOrderSensitive + ) + case RoundRobinPartitioning(n) => + rdd.mapPartitionsWithIndexInternal( + (_, cbIter) => { + options.setPartitionNum(n) + options.setName("rr") + // ColumnarBatch Iterator + val iter = new Iterator[Product2[Int, ColumnarBatch]] { + val splitIterator = new SplitIterator(cbIter.asJava, options) + + override def hasNext: Boolean = splitIterator.hasNext + + override def next(): Product2[Int, ColumnarBatch] = + (splitIterator.nextPartitionId(), splitIterator.next()); + } + new CloseablePartitionedBatchIterator(iter) + }, + isOrderSensitive = isOrderSensitive + ) + case SinglePartition => + rdd.mapPartitionsWithIndexInternal( + (_, cbIter) => + cbIter.map { cb => + (0 until cb.numCols).foreach( + cb.column(_) + .asInstanceOf[ArrowWritableColumnVector] + .getValueVector + .setValueCount(cb.numRows)) + (0, cb) + }, + isOrderSensitive = isOrderSensitive + ) + case _ => + logError("Unsupported operations: " + nativePartitioning.getShortName) + rdd.mapPartitionsWithIndexInternal( + (_, cbIter) => + cbIter.map { cb => + (0 until cb.numCols).foreach( + cb.column(_) + .asInstanceOf[ArrowWritableColumnVector] + .getValueVector + .setValueCount(cb.numRows)) + (0, cb) + }, + isOrderSensitive = isOrderSensitive + ) +// rdd.mapPartitionsWithIndexInternal( +// (_, cbIter) => { +// val iter = new Iterator[Product2[Int, ColumnarBatch]] { +// val splitIterator = new SplitIterator(cbIter.asJava, options) +// +// override def hasNext: Boolean = splitIterator.hasNext +// +// override def next(): Product2[Int, ColumnarBatch] = +// (splitIterator.nextPartitionId(), splitIterator.next()); +// } +// new CloseablePartitionedBatchIterator(iter) +// }, +// isOrderSensitive = isOrderSensitive +// ) + }} val dependency = new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( - rddWithDummyKey, + rddWithPartitionKey, new PartitionIdPassthrough(newPartitioning.numPartitions), serializer, shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), diff --git a/native-sql-engine/cpp/src/jni/jni_wrapper.cc b/native-sql-engine/cpp/src/jni/jni_wrapper.cc index 52c4b61dc..4eb50f5c9 100644 --- a/native-sql-engine/cpp/src/jni/jni_wrapper.cc +++ b/native-sql-engine/cpp/src/jni/jni_wrapper.cc @@ -1028,6 +1028,122 @@ Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeEvaluate2( JNI_METHOD_END(nullptr) } +JNIEXPORT jboolean JNICALL +Java_com_intel_oap_vectorized_SplitIterator_nativeHasNext( + JNIEnv* env, jobject, jlong splitter_id) { + JNI_METHOD_START + auto splitter_ = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter_) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + JniThrow(error_message); + } + return splitter_->hasNext(); + JNI_METHOD_END(false) +} + +JNIEXPORT jobject JNICALL +Java_com_intel_oap_vectorized_SplitIterator_nativeNext( + JNIEnv* env, jobject, jlong splitter_id) { + JNI_METHOD_START + auto splitter_ = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter_) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + JniThrow(error_message); + } + jbyteArray serialized_record_batch = JniGetOrThrow( + ToBytes(env, splitter_->nextBatch()), "Error deserializing message"); + return serialized_record_batch; + + JNI_METHOD_END(nullptr) +} + +JNIEXPORT jlong JNICALL +Java_com_intel_oap_vectorized_SplitIterator_nativeNextPartitionId( + JNIEnv* env, jobject, jlong splitter_id) { + JNI_METHOD_START + auto splitter_ = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter_) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + JniThrow(error_message); + } + return splitter_->nextPartitionId(); + JNI_METHOD_END(-1L) +} + +JNIEXPORT jlong JNICALL +Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_initSplit( + JNIEnv* env, jobject, jstring partitioning_name_jstr, jint num_partitions, + jbyteArray schema_arr, jbyteArray expr_arr, jlong offheap_per_task, jint buffer_size) { + JNI_METHOD_START + if (partitioning_name_jstr == NULL) { + JniThrow(std::string("Short partitioning name can't be null")); + return 0; + } + if (schema_arr == NULL) { + JniThrow(std::string("Make splitter schema can't be null")); + } + + auto partitioning_name_c = env->GetStringUTFChars(partitioning_name_jstr, JNI_FALSE); + auto partitioning_name = std::string(partitioning_name_c); + env->ReleaseStringUTFChars(partitioning_name_jstr, partitioning_name_c); + + auto splitOptions = SplitOptions::Defaults(); + splitOptions.prefer_spill = false; + splitOptions.buffered_write = true; + if (buffer_size > 0) { + splitOptions.buffer_size = buffer_size; + } + splitOptions.offheap_per_task = offheap_per_task; + +// auto* pool = reinterpret_cast(memory_pool_id); +// if (pool == nullptr) { +// JniThrow("Memory pool does not exist or has been closed"); +// } +// splitOptions.memory_pool = pool; + + std::shared_ptr schema; + // ValueOrDie in MakeSchema + MakeSchema(env, schema_arr, &schema); + + gandiva::ExpressionVector expr_vector = {}; + if (expr_arr != NULL) { + gandiva::FieldVector ret_types; + JniAssertOkOrThrow(MakeExprVector(env, expr_arr, &expr_vector, &ret_types), + "Failed to parse expressions protobuf"); + } + + jclass cls = env->FindClass("java/lang/Thread"); + jmethodID mid = env->GetStaticMethodID(cls, "currentThread", "()Ljava/lang/Thread;"); + jobject thread = env->CallStaticObjectMethod(cls, mid); + if (thread == NULL) { + std::cout << "Thread.currentThread() return NULL" << std::endl; + } else { + jmethodID mid_getid = env->GetMethodID(cls, "getId", "()J"); + jlong sid = env->CallLongMethod(thread, mid_getid); + splitOptions.thread_id = (int64_t)sid; + } + + jclass tc_cls = env->FindClass("org/apache/spark/TaskContext"); + jmethodID get_tc_mid = + env->GetStaticMethodID(tc_cls, "get", "()Lorg/apache/spark/TaskContext;"); + jobject tc_obj = env->CallStaticObjectMethod(tc_cls, get_tc_mid); + if (tc_obj == NULL) { + std::cout << "TaskContext.get() return NULL" << std::endl; + } else { + jmethodID get_tsk_attmpt_mid = env->GetMethodID(tc_cls, "taskAttemptId", "()J"); + jlong attmpt_id = env->CallLongMethod(tc_obj, get_tsk_attmpt_mid); + splitOptions.task_attempt_id = (int64_t)attmpt_id; + } + + auto splitter = + JniGetOrThrow(Splitter::Make(partitioning_name, std::move(schema), num_partitions, + expr_vector, std::move(splitOptions)), + "Failed create native shuffle splitter"); + return shuffle_splitter_holder_.Insert(std::shared_ptr(splitter)); + + JNI_METHOD_END(-1L) +} + JNIEXPORT jlong JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_nativeMake( JNIEnv* env, jobject, jstring partitioning_name_jstr, jint num_partitions, @@ -1194,6 +1310,48 @@ JNIEXPORT jlong JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_ JNI_METHOD_END(-1L) } +JNIEXPORT void JNICALL +Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_collect( + JNIEnv* env, jobject obj, jlong splitter_id, jint num_rows) { + JNI_METHOD_START + auto splitter_ = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter_) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + JniThrow(error_message); + } + JniAssertOkOrThrow(splitter_->Collect(), "Native split: splitter collect failed"); + JNI_METHOD_END() +} + +JNIEXPORT jobject JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_clear( + JNIEnv* env, jobject, jlong splitter_id) { + JNI_METHOD_START + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + JniThrow(error_message); + } + + JniAssertOkOrThrow(splitter->Clear(), "Native split:: splitter close failed"); + const auto& partition_lengths = splitter->PartitionLengths(); + auto partition_length_arr = env->NewLongArray(partition_lengths.size()); + auto src = reinterpret_cast(partition_lengths.data()); + env->SetLongArrayRegion(partition_length_arr, 0, partition_lengths.size(), src); + + const auto& raw_partition_lengths = splitter->RawPartitionLengths(); + auto raw_partition_length_arr = env->NewLongArray(raw_partition_lengths.size()); + auto raw_src = reinterpret_cast(raw_partition_lengths.data()); + env->SetLongArrayRegion(raw_partition_length_arr, 0, raw_partition_lengths.size(), raw_src); + + jobject split_result = env->NewObject(split_result_class, split_result_constructor, splitter->TotalComputePidTime(), + splitter->TotalWriteTime(), splitter->TotalSpillTime(), + splitter->TotalCompressTime(), splitter->TotalBytesWritten(), + splitter->TotalBytesSpilled(), partition_length_arr, raw_partition_length_arr + ); + return split_result; + JNI_METHOD_END(nullptr) +} + JNIEXPORT jobject JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_stop( JNIEnv* env, jobject, jlong splitter_id) { JNI_METHOD_START diff --git a/native-sql-engine/cpp/src/shuffle/splitter.cc b/native-sql-engine/cpp/src/shuffle/splitter.cc index 4fe7dd7b6..2dbb25fb1 100644 --- a/native-sql-engine/cpp/src/shuffle/splitter.cc +++ b/native-sql-engine/cpp/src/shuffle/splitter.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include "shuffle/utils.h" #include "utils/macros.h" @@ -223,8 +224,6 @@ arrow::Status Splitter::Init() { const auto& fields = schema_->fields(); ARROW_ASSIGN_OR_RAISE(column_type_id_, ToSplitterTypeId(schema_->fields())); - partition_writer_.resize(num_partitions_); - // pre-computed row count for each partition after the record batch split partition_id_cnt_.resize(num_partitions_); // pre-allocated buffer size for each partition, unit is row count @@ -235,8 +234,6 @@ arrow::Status Splitter::Init() { // the offset of each partition during record batch split partition_buffer_idx_offset_.resize(num_partitions_); - partition_cached_recordbatch_.resize(num_partitions_); - partition_cached_recordbatch_size_.resize(num_partitions_); partition_lengths_.resize(num_partitions_); raw_partition_lengths_.resize(num_partitions_); reducer_offset_offset_.resize(num_partitions_ + 1); @@ -294,38 +291,42 @@ arrow::Status Splitter::Init() { partition_list_builders_[i].resize(num_partitions_); } - ARROW_ASSIGN_OR_RAISE(configured_dirs_, GetConfiguredLocalDirs()); - sub_dir_selection_.assign(configured_dirs_.size(), 0); - - // Both data_file and shuffle_index_file should be set through jni. - // For test purpose, Create a temporary subdirectory in the system temporary - // dir with prefix "columnar-shuffle" - if (options_.data_file.length() == 0) { - ARROW_ASSIGN_OR_RAISE(options_.data_file, CreateTempShuffleFile(configured_dirs_[0])); - } + if (!options_.data_file.empty()) { + partition_cached_recordbatch_.resize(num_partitions_); + partition_cached_recordbatch_size_.resize(num_partitions_); + partition_writer_.resize(num_partitions_); - auto& ipc_write_options = options_.ipc_write_options; - ipc_write_options.memory_pool = options_.memory_pool; - ipc_write_options.use_threads = false; + ARROW_ASSIGN_OR_RAISE(configured_dirs_, GetConfiguredLocalDirs()); + sub_dir_selection_.assign(configured_dirs_.size(), 0); - if (options_.compression_type == arrow::Compression::FASTPFOR) { - ARROW_ASSIGN_OR_RAISE(ipc_write_options.codec, - arrow::util::Codec::CreateInt32(arrow::Compression::FASTPFOR)); + // Both data_file and shuffle_index_file should be set through jni. + // For test purpose, Create a temporary subdirectory in the system temporary + // dir with prefix "columnar-shuffle" + if (options_.data_file.length() == 0) { + ARROW_ASSIGN_OR_RAISE(options_.data_file, CreateTempShuffleFile(configured_dirs_[0])); + } + auto& ipc_write_options = options_.ipc_write_options; + ipc_write_options.memory_pool = options_.memory_pool; + ipc_write_options.use_threads = false; + if (options_.compression_type == arrow::Compression::FASTPFOR) { + ARROW_ASSIGN_OR_RAISE(ipc_write_options.codec, + arrow::util::Codec::CreateInt32(arrow::Compression::FASTPFOR)); + + } else if (options_.compression_type == arrow::Compression::LZ4_FRAME) { + ARROW_ASSIGN_OR_RAISE(ipc_write_options.codec, + arrow::util::Codec::Create(arrow::Compression::LZ4_FRAME)); + } else { + ARROW_ASSIGN_OR_RAISE(ipc_write_options.codec, arrow::util::Codec::CreateInt32( + arrow::Compression::UNCOMPRESSED)); + } - } else if (options_.compression_type == arrow::Compression::LZ4_FRAME) { - ARROW_ASSIGN_OR_RAISE(ipc_write_options.codec, - arrow::util::Codec::Create(arrow::Compression::LZ4_FRAME)); - } else { - ARROW_ASSIGN_OR_RAISE(ipc_write_options.codec, arrow::util::Codec::CreateInt32( - arrow::Compression::UNCOMPRESSED)); + // initialize tiny batch write options + tiny_bach_write_options_ = ipc_write_options; + ARROW_ASSIGN_OR_RAISE( + tiny_bach_write_options_.codec, + arrow::util::Codec::CreateInt32(arrow::Compression::UNCOMPRESSED)); } - // initialize tiny batch write options - tiny_bach_write_options_ = ipc_write_options; - ARROW_ASSIGN_OR_RAISE( - tiny_bach_write_options_.codec, - arrow::util::Codec::CreateInt32(arrow::Compression::UNCOMPRESSED)); - // Allocate first buffer for split reducer ARROW_ASSIGN_OR_RAISE(combine_buffer_, arrow::AllocateResizableBuffer(0, options_.memory_pool)); @@ -392,6 +393,63 @@ arrow::Status Splitter::Split(const arrow::RecordBatch& rb) { return arrow::Status::OK(); } +bool Splitter::hasNext() { + if (!output_rb_.empty()){ + next_partition_id = output_rb_.top().first; + next_batch = output_rb_.top().second; + } + return !output_rb_.empty(); +} + +std::shared_ptr Splitter::nextBatch() { + if (!output_rb_.empty()) { + output_rb_.pop(); + } +// #ifndef DEBUG +// std::cout << "Output partitionid is: " << next_partition_id << +// ", output_batch_rows: " << next_batch->num_rows() << std::endl; +// #endif + return next_batch; +} + +int32_t Splitter::nextPartitionId() { + return next_partition_id; +} + +/** +* Collect the rb after splitting. +*/ +arrow::Status Splitter::Collect() { + EVAL_START("close", options_.thread_id) + // collect buffers and collect metrics + for (auto pid = 0; pid < num_partitions_; ++pid) { + if (partition_buffer_idx_base_[pid] > 0) { + RETURN_NOT_OK(CacheRecordBatch(pid, true)); + } + } + EVAL_END("close", options_.thread_id, options_.task_attempt_id) + return arrow::Status::OK(); +} + + +arrow::Status Splitter::Clear() { + EVAL_START("close", options_.thread_id) + next_batch = nullptr; + for (auto pid = 0; pid < num_partitions_; ++pid) { + partition_lengths_[pid] = 0; + raw_partition_lengths_[pid] = 0; + } + if (output_rb_.size() > 0) { + std::cerr << "Dirty stack output_rb_" << std::endl; + output_rb_ = std::stack>>(); + } + this -> combine_buffer_.reset(); + this -> schema_payload_.reset(); + partition_buffers_.clear(); + EVAL_END("close", options_.thread_id, options_.task_attempt_id) + return arrow::Status::OK(); +} + arrow::Status Splitter::Stop() { EVAL_START("write", options_.thread_id) // open data file output stream @@ -582,26 +640,31 @@ arrow::Status Splitter::CacheRecordBatch(int32_t partition_id, bool reset_buffer int64_t raw_size = batch_nbytes(batch); raw_partition_lengths_[partition_id] += raw_size; - auto payload = std::make_shared(); + + if (!options_.data_file.empty()) { + auto payload = std::make_shared(); #ifndef SKIPCOMPRESS - if (num_rows <= options_.batch_compress_threshold) { - TIME_NANO_OR_RAISE(total_compress_time_, - arrow::ipc::GetRecordBatchPayload( - *batch, tiny_bach_write_options_, payload.get())); - } else { - TIME_NANO_OR_RAISE(total_compress_time_, - arrow::ipc::GetRecordBatchPayload( - *batch, options_.ipc_write_options, payload.get())); - } + if (num_rows <= options_.batch_compress_threshold) { + TIME_NANO_OR_RAISE(total_compress_time_, + arrow::ipc::GetRecordBatchPayload( + *batch, tiny_bach_write_options_, payload.get())); + } else { + TIME_NANO_OR_RAISE(total_compress_time_, + arrow::ipc::GetRecordBatchPayload( + *batch, options_.ipc_write_options, payload.get())); + } #else - // for test reason - TIME_NANO_OR_RAISE(total_compress_time_, - arrow::ipc::GetRecordBatchPayload(*batch, tiny_bach_write_options_, - payload.get())); + // for test reason + TIME_NANO_OR_RAISE(total_compress_time_, + arrow::ipc::GetRecordBatchPayload(*batch, tiny_bach_write_options_, + payload.get())); #endif - - partition_cached_recordbatch_size_[partition_id] += payload->body_length; - partition_cached_recordbatch_[partition_id].push_back(std::move(payload)); + partition_cached_recordbatch_size_[partition_id] += payload->body_length; + partition_cached_recordbatch_[partition_id].push_back(std::move(payload)); + } + std::pair> part_batch = std::make_pair(partition_id, batch); + output_rb_.emplace(part_batch); + // partition_cached_arb_[partition_id].push_back(batch); partition_buffer_idx_base_[partition_id] = 0; } return arrow::Status::OK(); diff --git a/native-sql-engine/cpp/src/shuffle/splitter.h b/native-sql-engine/cpp/src/shuffle/splitter.h index f27c061bf..9c4dfaa86 100644 --- a/native-sql-engine/cpp/src/shuffle/splitter.h +++ b/native-sql-engine/cpp/src/shuffle/splitter.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "shuffle/type.h" #include "shuffle/utils.h" @@ -72,11 +73,38 @@ class Splitter { */ virtual arrow::Status Split(const arrow::RecordBatch&); + /** + * Iterator for splitting rb + */ + virtual bool hasNext(); + + /** + * Iterator for splitting rb + */ + virtual std::shared_ptr nextBatch(); + + /** + * Iterator for splitting rb + */ + virtual int32_t nextPartitionId(); + /** * Compute the compresse size of record batch. */ virtual int64_t CompressedSize(const arrow::RecordBatch&); + /** + * Collect the rb. + */ + arrow::Status Collect(); + + + /** + * Clear the buffer. And stop processing splitting + */ + arrow::Status Clear(); + + /** * For each partition, merge spilled file into shuffle data file and write any * cached record batch to shuffle data file. Close all resources and collect @@ -231,6 +259,12 @@ class Splitter { // page std::shared_ptr combine_buffer_; + + int32_t next_partition_id = -1; + std::shared_ptr next_batch = nullptr; + + std::stack>> output_rb_; + // partid std::vector>> partition_cached_recordbatch_; diff --git a/pom.xml b/pom.xml index 81e68f0f8..1b4f5d6ab 100644 --- a/pom.xml +++ b/pom.xml @@ -75,6 +75,12 @@ 3.2.0 + + hadoop-3.3 + + 3.3.1 + + dataproc-2.0