Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-569] CPU overhead on fine grain / concurrent off-heap acquire operations #590

Merged
merged 3 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,62 @@
import org.apache.arrow.memory.AllocationListener;

public class SparkManagedAllocationListener implements AllocationListener {
public static long BLOCK_SIZE = 8L * 1024 * 1024; // 8MB per block
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about make this configurable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea.

Although at this time I think a fixed size is enough for most cases: say 8M can already reduce Spark acquire calls by 99%, in my own measurement.

Also the similar logic was applied to C++ Arrow code so it might require further thinking to unify them somehow. On the whole I think we can open a individual topic for the suggestion. :)


private final NativeSQLMemoryConsumer consumer;
private final NativeSQLMemoryMetrics metrics;

private long bytesReserved = 0L;
private long blocksReserved = 0L;

public SparkManagedAllocationListener(NativeSQLMemoryConsumer consumer, NativeSQLMemoryMetrics metrics) {
this.consumer = consumer;
this.metrics = metrics;
}

@Override
public void onPreAllocation(long size) {
consumer.acquire(size);
metrics.inc(size);
long requiredBlocks = updateReservation(size);
if (requiredBlocks < 0) {
throw new IllegalStateException();
}
if (requiredBlocks == 0) {
return;
}
long toBeAcquired = requiredBlocks * BLOCK_SIZE;
consumer.acquire(toBeAcquired);
metrics.inc(toBeAcquired);
}

@Override
public void onRelease(long size) {
consumer.free(size);
metrics.inc(-size);
long requiredBlocks = updateReservation(-size);
if (requiredBlocks > 0) {
throw new IllegalStateException();
}
if (requiredBlocks == 0) {
return;
}
long toBeReleased = -requiredBlocks * BLOCK_SIZE;
consumer.free(toBeReleased);
metrics.inc(-toBeReleased);
}

public long updateReservation(long bytesToAdd) {
synchronized (this) {
long newBytesReserved = bytesReserved + bytesToAdd;
final long newBlocksReserved;
// ceiling
if (newBytesReserved == 0L) {
// 0 is the special case in ceiling algorithm
newBlocksReserved = 0L;
} else {
newBlocksReserved = (newBytesReserved - 1L) / BLOCK_SIZE + 1L;
}
long requiredBlocks = newBlocksReserved - blocksReserved;
bytesReserved = newBytesReserved;
blocksReserved = newBlocksReserved;
return requiredBlocks;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ object SparkMemoryUtils extends Logging {
sparkManagedAllocationListener
}

val allocListenerForBufferImport: AllocationListener = if (isArrowAutoReleaseEnabled) {
MemoryChunkCleaner.gcTrigger()
} else {
AllocationListener.NOOP
}

private def collectStackForDebug = {
if (DEBUG) {
val out = new ByteOutputStream()
Expand Down Expand Up @@ -99,6 +105,10 @@ object SparkMemoryUtils extends Logging {
alloc
}

val taskDefaultAllocatorForBufferImport: BufferAllocator = taskDefaultAllocator
.newChildAllocator("CHILD-ALLOC-BUFFER-IMPORT", allocListenerForBufferImport, 0L,
Long.MaxValue)

val defaultMemoryPool: NativeMemoryPoolWrapper = {
val rl = new SparkManagedReservationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP),
Expand Down Expand Up @@ -283,6 +293,13 @@ object SparkMemoryUtils extends Logging {
getTaskMemoryResources().taskDefaultAllocator
}

def contextAllocatorForBufferImport(): BufferAllocator = {
if (!inSparkTask()) {
return globalAllocator()
}
getTaskMemoryResources().taskDefaultAllocatorForBufferImport
}

def contextMemoryPool(): NativeMemoryPool = {
if (!inSparkTask()) {
return globalMemoryPool()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;

/** Parquet Reader Class. */
public class ParquetReader implements AutoCloseable {
Expand All @@ -41,7 +42,6 @@ public class ParquetReader implements AutoCloseable {
/** last readed length of a record batch. */
private long lastReadLength;

private BufferAllocator allocator;
private ParquetReaderJniWrapper jniWrapper;

/**
Expand All @@ -51,13 +51,11 @@ public class ParquetReader implements AutoCloseable {
* @param rowGroupIndices An array to indicate which rowGroup to read.
* @param columnIndices An array to indicate which columns to read.
* @param batchSize number of rows expected to be read in one batch.
* @param allocator A BufferAllocator reference.
* @throws IOException throws io exception in case of native failure.
*/
public ParquetReader(String path, int[] rowGroupIndices, int[] columnIndices,
long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException {
long batchSize, String tmp_dir) throws IOException {
this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir);
this.allocator = allocator;
this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize);
jniWrapper.nativeInitParquetReader(nativeInstanceId, columnIndices, rowGroupIndices);
}
Expand All @@ -76,7 +74,6 @@ public ParquetReader(String path, int[] rowGroupIndices, int[] columnIndices,
public ParquetReader(String path, long startPos, long endPos, int[] columnIndices,
long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException {
this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir);
this.allocator = allocator;
this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize);
jniWrapper.nativeInitParquetReader2(
nativeInstanceId, columnIndices, startPos, endPos);
Expand All @@ -93,7 +90,7 @@ public Schema getSchema() throws IOException {

try (MessageChannelReader schemaReader = new MessageChannelReader(
new ReadChannel(new ByteArrayReadableSeekableByteChannel(schemaBytes)),
allocator)) {
SparkMemoryUtils.contextAllocator())) {
MessageResult result = schemaReader.readNext();
if (result == null) {
throw new IOException("Unexpected end of input. Missing schema.");
Expand All @@ -115,8 +112,8 @@ public ArrowRecordBatch readNext() throws IOException {
if (serializedBatch == null) {
return null;
}
ArrowRecordBatch batch = UnsafeRecordBatchSerializer.deserializeUnsafe(allocator,
serializedBatch);
ArrowRecordBatch batch = UnsafeRecordBatchSerializer.deserializeUnsafe(
SparkMemoryUtils.contextAllocatorForBufferImport(), serializedBatch);
if (batch == null) {
throw new IllegalArgumentException("failed to build record batch");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public boolean hasNext() throws IOException {
}

public ArrowRecordBatch next() throws IOException {
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
if (nativeHandler == 0) {
return null;
}
Expand Down Expand Up @@ -132,7 +132,7 @@ public ArrowRecordBatch process(Schema schema, ArrowRecordBatch recordBatch,
if (nativeHandler == 0) {
return null;
}
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
byte[] serializedRecordBatch;
if (selectionVector != null) {
int selectionVectorRecordCount = selectionVector.getRecordCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public void evaluate(ColumnarNativeIterator batchItr)
public ArrowRecordBatch[] evaluate2(ArrowRecordBatch recordBatch) throws RuntimeException, IOException {
byte[] bytes = UnsafeRecordBatchSerializer.serializeUnsafe(recordBatch);
byte[][] serializedBatchArray = jniWrapper.nativeEvaluate2(nativeHandler, bytes);
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[serializedBatchArray.length];
for (int i = 0; i < serializedBatchArray.length; i++) {
if (serializedBatchArray[i] == null) {
Expand Down Expand Up @@ -191,7 +191,7 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVector
bufSizes[idx++] = bufLayout.getSize();
}

BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();

byte[][] serializedBatchArray;
if (selectionVector != null) {
Expand Down Expand Up @@ -237,7 +237,7 @@ public void SetMember(ArrowRecordBatch recordBatch) throws RuntimeException, IOE
}

public ArrowRecordBatch[] finish() throws RuntimeException, IOException {
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
byte[][] serializedBatchArray = jniWrapper.nativeFinish(nativeHandler);
ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[serializedBatchArray.length];
for (int i = 0; i < serializedBatchArray.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private class ArrowColumnarBatchSerializerInstance(
SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)

private val allocator: BufferAllocator = SparkMemoryUtils
.contextAllocator()
.contextAllocatorForBufferImport()
.newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue)

private var reader: ArrowStreamReader = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class PartitioningSuite extends QueryTest with SharedSparkSession {
val df = spark.sql("SELECT COUNT(*) AS cnt FROM ltab, rtab WHERE ltab.id = rtab.id")
df.explain(true)
df.show()
Thread.sleep(1000000)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
"ws_item_sk = i_item_sk LIMIT 10")
df.explain(true)
df.show()
Thread.sleep(1000000)
}
}

Expand All @@ -142,7 +141,6 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
"web_sales) LIMIT 10")
df.explain()
df.show()
Thread.sleep(1000000)
}
}

Expand All @@ -165,7 +163,6 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
)
df.explain(true)
df.show()
Thread.sleep(1000000)
}
}

Expand Down