Skip to content

Commit

Permalink
init RSS framework (#551)
Browse files Browse the repository at this point in the history
Co-authored-by: zhangli20 <[email protected]>
  • Loading branch information
richox and zhangli20 authored Aug 21, 2024
1 parent 565c025 commit a69338b
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 7 deletions.
2 changes: 1 addition & 1 deletion native-engine/blaze-jni-bridge/src/jni_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ impl<'a> BlazeRssPartitionWriterBase<'_> {
Ok(BlazeRssPartitionWriterBase {
class,
method_write: env
.get_method_id(class, "write", "(ILjava/nio/ByteBuffer;I)V")
.get_method_id(class, "write", "(ILjava/nio/ByteBuffer;)V")
.unwrap(),
method_write_ret: ReturnType::Primitive(Primitive::Void),
method_flush: env.get_method_id(class, "flush", "()V").unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion native-engine/datafusion-ext-plans/src/shuffle/rss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Write for RssWriter {
let buf = jni_new_direct_byte_buffer!(&buf)?;
jni_call!(
BlazeRssPartitionWriterBase(self.rss_partition_writer.as_obj())
.write(self.partition_id as i32, buf.as_obj(), buf_len as i32) -> ()
.write(self.partition_id as i32, buf.as_obj()) -> ()
)?;
Ok(buf_len)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Copyright 2022 The Blaze Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.blaze.shuffle

import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkConf
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle._
import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle

import com.thoughtworks.enableIf

abstract class BlazeRssShuffleManagerBase(conf: SparkConf) extends ShuffleManager with Logging {
override def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle

override def unregisterShuffle(shuffleId: Int): Boolean

def getBlazeRssShuffleReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): BlazeRssShuffleReaderBase[K, C]

def getBlazeRssShuffleReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): BlazeRssShuffleReaderBase[K, C]

def getRssShuffleReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]

def getRssShuffleReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]

def getBlazeRssShuffleWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): BlazeRssShuffleWriterBase[K, V]

def getRssShuffleWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]

@enableIf(
Seq("spark320", "spark324", "spark333", "spark351").contains(
System.getProperty("blaze.shim")))
override def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {

if (isArrowShuffle(handle)) {
getBlazeRssShuffleReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics)
} else {
getRssShuffleReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics)
}
}

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {

if (isArrowShuffle(handle)) {
getBlazeRssShuffleReader(handle, startPartition, endPartition, context, metrics)
} else {
getRssShuffleReader(handle, startPartition, endPartition, context, metrics)
}
}

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
override def getReaderForRange[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {

if (isArrowShuffle(handle)) {
getBlazeRssShuffleReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics)
} else {
getRssShuffleReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics)
}
}

override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {

if (isArrowShuffle(handle)) {
getBlazeRssShuffleWriter(handle, mapId, context, metrics)
} else {
getRssShuffleWriter(handle, mapId, context, metrics)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,11 @@ object NativeConverters extends Logging {
}
val resultType = (lhs.dataType, rhs.dataType) match {
case (lhsType: DecimalType, rhsType: DecimalType) =>
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
resultDecimalType(
lhsType.precision,
lhsType.scale,
rhsType.precision,
rhsType.scale)
}

buildExprNode {
Expand Down Expand Up @@ -606,7 +610,11 @@ object NativeConverters extends Logging {
}
val resultType = (lhs.dataType, rhs.dataType) match {
case (lhsType: DecimalType, rhsType: DecimalType) =>
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
resultDecimalType(
lhsType.precision,
lhsType.scale,
rhsType.precision,
rhsType.scale)
}

buildExprNode {
Expand Down Expand Up @@ -642,7 +650,11 @@ object NativeConverters extends Logging {
}
val resultType = (lhs.dataType, rhs.dataType) match {
case (lhsType: DecimalType, rhsType: DecimalType) =>
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
resultDecimalType(
lhsType.precision,
lhsType.scale,
rhsType.precision,
rhsType.scale)
}

buildExprNode {
Expand Down Expand Up @@ -686,7 +698,11 @@ object NativeConverters extends Logging {
}
val resultType = (lhs.dataType, rhs.dataType) match {
case (lhsType: DecimalType, rhsType: DecimalType) =>
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
resultDecimalType(
lhsType.precision,
lhsType.scale,
rhsType.precision,
rhsType.scale)
}

buildExprNode {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2022 The Blaze Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.blaze.shuffle

import org.apache.spark.TaskContext
import org.apache.spark.shuffle.BaseShuffleHandle

abstract class BlazeRssShuffleReaderBase[K, C](
handle: BaseShuffleHandle[K, _, C],
context: TaskContext)
extends BlazeBlockStoreShuffleReaderBase[K, C](handle, context) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2022 The Blaze Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.blaze.shuffle

import java.util.UUID

import org.apache.spark.SparkEnv
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.sql.blaze.JniBridge
import org.apache.spark.sql.blaze.NativeHelper
import org.apache.spark.sql.blaze.NativeRDD
import org.apache.spark.sql.blaze.Shims
import org.apache.spark.Partition
import org.apache.spark.ShuffleDependency
import org.apache.spark.TaskContext
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.ShuffleHandle
import org.blaze.protobuf.PhysicalPlanNode
import org.blaze.protobuf.RssShuffleWriterExecNode

abstract class BlazeRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsReporter)
extends BlazeShuffleWriterBase[K, V](metrics) {

def getRssPartitionWriter(
handle: ShuffleHandle,
mapId: Int,
metrics: ShuffleWriteMetricsReporter,
numPartitions: Int): RssPartitionWriterBase

def nativeRssShuffleWrite(
nativeShuffleRDD: NativeRDD,
dep: ShuffleDependency[_, _, _],
mapId: Int,
context: TaskContext,
partition: Partition,
numPartitions: Int): MapStatus = {

val rssShuffleWriterObject =
getRssPartitionWriter(dep.shuffleHandle, mapId, metrics, numPartitions)
if (rssShuffleWriterObject == null) {
throw new RuntimeException("cannot get RssPartitionWriter")
}

try {
val jniResourceId = s"RssPartitionWriter:${UUID.randomUUID().toString}"
JniBridge.resourcesMap.put(jniResourceId, rssShuffleWriterObject)
val nativeRssShuffleWriterExec = PhysicalPlanNode
.newBuilder()
.setRssShuffleWriter(
RssShuffleWriterExecNode
.newBuilder(nativeShuffleRDD.nativePlan(partition, context).getRssShuffleWriter)
.setRssPartitionWriterResourceId(jniResourceId)
.build())
.build()

val iterator = NativeHelper.executeNativePlan(
nativeRssShuffleWriterExec,
nativeShuffleRDD.metrics,
partition,
Some(context))
assert(iterator.toArray.isEmpty)
} finally {
rssShuffleWriterObject.close()
}

val mapStatus = Shims.get.getMapStatus(
SparkEnv.get.blockManager.shuffleServerId,
rssShuffleWriterObject.getPartitionLengthMap,
mapId)
mapStatus
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.blaze.shuffle
import java.nio.ByteBuffer

trait RssPartitionWriterBase {
def write(partitionId: Int, buffer: ByteBuffer, length: Int): Unit
def write(partitionId: Int, buffer: ByteBuffer): Unit
def flush(): Unit
def close(): Unit
def getPartitionLengthMap: Array[Long]
Expand Down

0 comments on commit a69338b

Please sign in to comment.