diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java index 01d8520179d..8cd2a4874aa 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java @@ -501,13 +501,15 @@ public long getPushSortMemoryThreshold() { return this.pushSortMemoryThreshold; } - public void close() throws IOException { + public void close(boolean throwTaskKilledOnInterruption) throws IOException { cleanupResources(); try { dataPusher.waitOnTermination(); sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue()); } catch (InterruptedException e) { - TaskInterruptedHelper.throwTaskKillException(); + if (throwTaskKilledOnInterruption) { + TaskInterruptedHelper.throwTaskKillException(); + } } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java index 0962c98c4aa..73c15bb70d6 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java @@ -127,7 +127,7 @@ public void testMemoryUsage() throws Exception { !pusher.insertRecord( row5k.getBaseObject(), row5k.getBaseOffset(), row5k.getSizeInBytes(), 0, true)); - pusher.close(); + pusher.close(true); assertEquals(taskContext.taskMetrics().memoryBytesSpilled(), 2097152); } diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 0986c23ef28..fd83ec2e1c5 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -162,6 +162,7 @@ public HashBasedShuffleWriter( @Override public void write(scala.collection.Iterator> records) throws IOException { + boolean needCleanupPusher = true; try { if (canUseFastWrite()) { fastWrite0(records); @@ -174,13 +175,13 @@ public void write(scala.collection.Iterator> records) throws IOEx } else { write0(records); } + close(); + needCleanupPusher = false; } catch (InterruptedException e) { TaskInterruptedHelper.throwTaskKillException(); } finally { - try { - close(); - } catch (InterruptedException e) { - TaskInterruptedHelper.throwTaskKillException(); + if (needCleanupPusher) { + cleanupPusher(); } } } @@ -319,6 +320,15 @@ private void flushSendBuffer(int partitionId, byte[] buffer, int size) writeMetrics.incWriteTime(System.nanoTime() - start); } + private void cleanupPusher() throws IOException { + try { + dataPusher.waitOnTermination(); + sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue()); + } catch (InterruptedException e) { + TaskInterruptedHelper.throwTaskKillException(); + } + } + private void close() throws IOException, InterruptedException { // here we wait for all the in-flight batches to return which sent by dataPusher thread dataPusher.waitOnTermination(); diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 5ecf1edba5c..7b8baaf0601 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -143,6 +143,7 @@ public SortBasedShuffleWriter( @Override public void write(scala.collection.Iterator> records) throws IOException { + boolean needCleanupPusher = true; try { if (canUseFastWrite()) { fastWrite0(records); @@ -155,8 +156,12 @@ public void write(scala.collection.Iterator> records) throws IOEx } else { write0(records); } - } finally { close(); + needCleanupPusher = false; + } finally { + if (needCleanupPusher) { + cleanupPusher(); + } } } @@ -291,11 +296,17 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw writeMetrics.incBytesWritten(bytesWritten); } + private void cleanupPusher() throws IOException { + if (pusher != null) { + pusher.close(false); + } + } + private void close() throws IOException { logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed())); long pushStartTime = System.nanoTime(); pusher.pushData(false); - pusher.close(); + pusher.close(true); writeMetrics.incWriteTime(System.nanoTime() - pushStartTime); shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber()); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 5a9b0455293..bd1eeb16c28 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -158,6 +158,7 @@ public HashBasedShuffleWriter( @Override public void write(scala.collection.Iterator> records) throws IOException { + boolean needCleanupPusher = true; try { if (canUseFastWrite()) { fastWrite0(records); @@ -170,13 +171,13 @@ public void write(scala.collection.Iterator> records) throws IOEx } else { write0(records); } + close(); + needCleanupPusher = false; } catch (InterruptedException e) { TaskInterruptedHelper.throwTaskKillException(); } finally { - try { - close(); - } catch (InterruptedException e) { - TaskInterruptedHelper.throwTaskKillException(); + if (needCleanupPusher) { + cleanupPusher(); } } } @@ -353,6 +354,15 @@ protected void mergeData(int partitionId, byte[] buffer, int offset, int length) writeMetrics.incBytesWritten(bytesWritten); } + private void cleanupPusher() throws IOException { + try { + dataPusher.waitOnTermination(); + sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue()); + } catch (InterruptedException e) { + TaskInterruptedHelper.throwTaskKillException(); + } + } + private void close() throws IOException, InterruptedException { // here we wait for all the in-flight batches to return which sent by dataPusher thread long pushMergedDataTime = System.nanoTime(); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index f3b856394d5..1c8900bb22d 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -226,10 +226,15 @@ void doWrite(scala.collection.Iterator> records) throws IOExcepti @Override public void write(scala.collection.Iterator> records) throws IOException { + boolean needCleanupPusher = true; try { doWrite(records); - } finally { close(); + needCleanupPusher = false; + } finally { + if (needCleanupPusher) { + cleanupPusher(); + } } } @@ -354,11 +359,17 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw writeMetrics.incBytesWritten(bytesWritten); } + private void cleanupPusher() throws IOException { + if (pusher != null) { + pusher.close(false); + } + } + private void close() throws IOException { logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed())); long pushStartTime = System.nanoTime(); pusher.pushData(false); - pusher.close(); + pusher.close(true); shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber()); writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);