Skip to content

Commit

Permalink
[CELEBORN-1544][FOLLOWUP] ShuffleWriter needs to catch exception and …
Browse files Browse the repository at this point in the history
…call abort to avoid memory leaks

### What changes were proposed in this pull request?
This PR aims to fix a possible memory leak in ShuffleWriter.

Introduce a private abort method, which can be called to release memory when an exception occurs.

### Why are the changes needed?
#2661 Call the close method in the finally block, but the close method has `shuffleClient.mapperEnd`, which is dangerous for incomplete tasks, and the data may be inaccurate.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
GA

Closes #2663 from cxzl25/CELEBORN-1544-followup.

Authored-by: sychen <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
(cherry picked from commit bc3bd46)
Signed-off-by: zky.zhoukeyong <[email protected]>
  • Loading branch information
cxzl25 authored and waitinfuture committed Aug 21, 2024
1 parent 08716da commit f639be4
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,15 @@ public long getPushSortMemoryThreshold() {
return this.pushSortMemoryThreshold;
}

public void close() throws IOException {
public void close(boolean throwTaskKilledOnInterruption) throws IOException {
cleanupResources();
try {
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
if (throwTaskKilledOnInterruption) {
TaskInterruptedHelper.throwTaskKillException();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void testMemoryUsage() throws Exception {
!pusher.insertRecord(
row5k.getBaseObject(), row5k.getBaseOffset(), row5k.getSizeInBytes(), 0, true));

pusher.close();
pusher.close(true);

assertEquals(taskContext.taskMetrics().memoryBytesSpilled(), 2097152);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ public HashBasedShuffleWriter(

@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
Expand All @@ -174,13 +175,13 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
} else {
write0(records);
}
close();
needCleanupPusher = false;
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
} finally {
try {
close();
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
if (needCleanupPusher) {
cleanupPusher();
}
}
}
Expand Down Expand Up @@ -319,6 +320,15 @@ private void flushSendBuffer(int partitionId, byte[] buffer, int size)
writeMetrics.incWriteTime(System.nanoTime() - start);
}

private void cleanupPusher() throws IOException {
try {
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
}
}

private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by dataPusher thread
dataPusher.waitOnTermination();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ public SortBasedShuffleWriter(

@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
Expand All @@ -155,8 +156,12 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
} else {
write0(records);
}
} finally {
close();
needCleanupPusher = false;
} finally {
if (needCleanupPusher) {
cleanupPusher();
}
}
}

Expand Down Expand Up @@ -291,11 +296,17 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw
writeMetrics.incBytesWritten(bytesWritten);
}

private void cleanupPusher() throws IOException {
if (pusher != null) {
pusher.close(false);
}
}

private void close() throws IOException {
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
long pushStartTime = System.nanoTime();
pusher.pushData(false);
pusher.close();
pusher.close(true);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);

shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ public HashBasedShuffleWriter(

@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
Expand All @@ -170,13 +171,13 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
} else {
write0(records);
}
close();
needCleanupPusher = false;
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
} finally {
try {
close();
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
if (needCleanupPusher) {
cleanupPusher();
}
}
}
Expand Down Expand Up @@ -353,6 +354,15 @@ protected void mergeData(int partitionId, byte[] buffer, int offset, int length)
writeMetrics.incBytesWritten(bytesWritten);
}

private void cleanupPusher() throws IOException {
try {
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
}
}

private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by dataPusher thread
long pushMergedDataTime = System.nanoTime();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,15 @@ void doWrite(scala.collection.Iterator<Product2<K, V>> records) throws IOExcepti

@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
boolean needCleanupPusher = true;
try {
doWrite(records);
} finally {
close();
needCleanupPusher = false;
} finally {
if (needCleanupPusher) {
cleanupPusher();
}
}
}

Expand Down Expand Up @@ -354,11 +359,17 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw
writeMetrics.incBytesWritten(bytesWritten);
}

private void cleanupPusher() throws IOException {
if (pusher != null) {
pusher.close(false);
}
}

private void close() throws IOException {
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
long pushStartTime = System.nanoTime();
pusher.pushData(false);
pusher.close();
pusher.close(true);

shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
Expand Down

0 comments on commit f639be4

Please sign in to comment.