Skip to content

Commit

Permalink
[CELEBORN-1601] Support revise lost shuffles
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
To support revising lost shuffle IDs in a long-running job such as flink batch jobs.

### Why are the changes needed?
1. To support revise lost shuffles.
2. To add an HTTP endpoint to revise lost shuffles manually.

### Does this PR introduce _any_ user-facing change?
NO.

### How was this patch tested?
Cluster tests.

Closes apache#2746 from FMX/b1600.

Lead-authored-by: mingji <[email protected]>
Co-authored-by: Ethan Feng <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
  • Loading branch information
FMX authored and zaynt4606 committed Oct 21, 2024
1 parent 8f70daf commit aa97292
Show file tree
Hide file tree
Showing 32 changed files with 652 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@ class CommonOptions {
paramLabel = "username",
description = Array("The username of the TENANT_USER level."))
private[cli] var configName: String = _

@Option(
names = Array("--apps"),
paramLabel = "appId",
description = Array("The application Id list seperated by comma."))
private[cli] var apps: String = _

}
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,14 @@ final class MasterOptions {
names = Array("--remove-workers-unavailable-info"),
description = Array("Remove the workers unavailable info from the master."))
private[master] var removeWorkersUnavailableInfo: Boolean = _

@Option(
names = Array("--revise-lost-shuffles"),
description = Array("Revise lost shuffles or remove shuffles for an application."))
private[master] var reviseLostShuffles: Boolean = _

@Option(
names = Array("--delete-apps"),
description = Array("Delete resource of an application."))
private[master] var deleteApps: Boolean = _
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ trait MasterSubcommand extends CliLogging {
@ArgGroup(exclusive = true, multiplicity = "1")
private[master] var masterOptions: MasterOptions = _

@ArgGroup(exclusive = false)
private[master] var reviseLostShuffleOptions: ReviseLostShuffleOptions = _

@Mixin
private[master] var commonOptions: CommonOptions = _

Expand Down Expand Up @@ -110,4 +113,8 @@ trait MasterSubcommand extends CliLogging {

private[master] def runShowThreadDump: ThreadStackResponse

private[master] def reviseLostShuffles: HandleResponse

private[master] def deleteApps: HandleResponse

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class MasterSubcommandImpl extends Runnable with MasterSubcommand {
if (masterOptions.showContainerInfo) log(runShowContainerInfo)
if (masterOptions.showDynamicConf) log(runShowDynamicConf)
if (masterOptions.showThreadDump) log(runShowThreadDump)
if (masterOptions.reviseLostShuffles) log(reviseLostShuffles)
if (masterOptions.deleteApps) log(deleteApps)
if (masterOptions.addClusterAlias != null && masterOptions.addClusterAlias.nonEmpty)
runAddClusterAlias
if (masterOptions.removeClusterAlias != null && masterOptions.removeClusterAlias.nonEmpty)
Expand Down Expand Up @@ -220,4 +222,20 @@ class MasterSubcommandImpl extends Runnable with MasterSubcommand {
}

private[master] def runShowContainerInfo: ContainerInfo = defaultApi.getContainerInfo

override private[master] def reviseLostShuffles: HandleResponse = {
val app = commonOptions.apps
if (app.contains(",")) {
throw new ParameterException(
spec.commandLine(),
"Only one application id can be provided for this command.")
}
val shuffleIds = reviseLostShuffleOptions.shuffleIds
applicationApi.reviseLostShuffles(app, shuffleIds)
}

override private[master] def deleteApps: HandleResponse = {
val apps = commonOptions.apps
applicationApi.deleteApps(apps)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.cli.master

import picocli.CommandLine.Option

final class ReviseLostShuffleOptions {

@Option(
names = Array("--shuffleIds"),
description = Array("The shuffle ids to manipulate."))
private[master] var shuffleIds: String = _

}
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,32 @@ class TestCelebornCliCommands extends CelebornFunSuite with MiniClusterFeature {
captureOutputAndValidateResponse(args, "success: true")
}

test("master --delete-apps case1") {
val args = prepareMasterArgs() ++ Array(
"--delete-apps",
"--apps",
"app1")
captureOutputAndValidateResponse(args, "success: true")
}

test("master --delete-apps case2") {
val args = prepareMasterArgs() ++ Array(
"--delete-apps",
"--apps",
"app1,app2")
captureOutputAndValidateResponse(args, "success: true")
}

test("master --revise-lost-shuffles case1") {
val args = prepareMasterArgs() ++ Array(
"--revise-lost-shuffles",
"--apps",
"app1",
"--shuffleIds",
"1,2,3,4,5,6")
captureOutputAndValidateResponse(args, "success: true")
}

private def prepareMasterArgs(): Array[String] = {
Array(
"master",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@

package org.apache.celeborn.client

import java.util.concurrent.{ScheduledFuture, TimeUnit}
import java.util
import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit}
import java.util.function.Consumer

import scala.collection.JavaConverters._

import org.apache.commons.lang3.StringUtils

import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.client.MasterClient
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, ApplicationLostResponse, HeartbeatFromApplication, HeartbeatFromApplicationResponse, ZERO_UUID}
import org.apache.celeborn.common.protocol.PbReviseLostShufflesResponse
import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, ApplicationLostResponse, HeartbeatFromApplication, HeartbeatFromApplicationResponse, ReviseLostShuffles, ZERO_UUID}
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.{ThreadUtils, Utils}

Expand All @@ -33,9 +38,11 @@ class ApplicationHeartbeater(
conf: CelebornConf,
masterClient: MasterClient,
shuffleMetrics: () => (Long, Long),
workerStatusTracker: WorkerStatusTracker) extends Logging {
workerStatusTracker: WorkerStatusTracker,
registeredShuffles: ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]) extends Logging {

private var stopped = false
private val reviseLostShuffles = conf.reviseLostShufflesEnabled

// Use independent app heartbeat threads to avoid being blocked by other operations.
private val appHeartbeatIntervalMs = conf.appHeartbeatIntervalMs
Expand Down Expand Up @@ -70,6 +77,30 @@ class ApplicationHeartbeater(
if (response.statusCode == StatusCode.SUCCESS) {
logDebug("Successfully send app heartbeat.")
workerStatusTracker.handleHeartbeatResponse(response)
// revise shuffle id if there are lost shuffles
if (reviseLostShuffles) {
val masterRecordedShuffleIds = response.registeredShuffles
val localOnlyShuffles = new util.ArrayList[Integer]()
registeredShuffles.forEach(new Consumer[Int] {
override def accept(key: Int): Unit = {
localOnlyShuffles.add(key)
}
})
localOnlyShuffles.removeAll(masterRecordedShuffleIds)
if (!localOnlyShuffles.isEmpty) {
logWarning(
s"There are lost shuffle found ${StringUtils.join(localOnlyShuffles, ",")}, revise lost shuffles.")
val reviseLostShufflesResponse = masterClient.askSync(
ReviseLostShuffles.apply(appId, localOnlyShuffles, MasterClient.genRequestId()),
classOf[PbReviseLostShufflesResponse])
if (!reviseLostShufflesResponse.getSuccess) {
logWarning(
s"Revise lost shuffles failed. Error message :${reviseLostShufflesResponse.getMessage}")
} else {
logInfo("Revise lost shuffles succeed.")
}
}
}
}
} catch {
case it: InterruptedException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
conf,
masterClient,
() => commitManager.commitMetrics(),
workerStatusTracker)
workerStatusTracker,
registeredShuffle)
private val changePartitionManager = new ChangePartitionManager(conf, this)
private val releasePartitionManager = new ReleasePartitionManager(conf, this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,97 +23,24 @@ import org.junit.Assert

import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf.{APPLICATION_HEARTBEAT_WITH_AVAILABLE_WORKERS_ENABLE, CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT}
import org.apache.celeborn.common.CelebornConf.CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT
import org.apache.celeborn.common.meta.WorkerInfo
import org.apache.celeborn.common.protocol.message.ControlMessages.HeartbeatFromApplicationResponse
import org.apache.celeborn.common.protocol.message.StatusCode

class WorkerStatusTrackerSuite extends CelebornFunSuite {
test("handleHeartbeatResponse without availableWorkers") {
val celebornConf = new CelebornConf()
celebornConf.set(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 2000L)
celebornConf.set(APPLICATION_HEARTBEAT_WITH_AVAILABLE_WORKERS_ENABLE, false)
val statusTracker = new WorkerStatusTracker(celebornConf, null)

val registerTime = System.currentTimeMillis()
statusTracker.excludedWorkers.put(mock("host1"), (StatusCode.WORKER_UNKNOWN, registerTime))
statusTracker.excludedWorkers.put(mock("host2"), (StatusCode.WORKER_SHUTDOWN, registerTime))

// test reserve (only statusCode list in handleHeartbeatResponse)
val empty = buildResponse(Array.empty, Array.empty, Array.empty, Array.empty)
statusTracker.handleHeartbeatResponse(empty)

// only reserve host1
Assert.assertEquals(
statusTracker.excludedWorkers.get(mock("host1")),
(StatusCode.WORKER_UNKNOWN, registerTime))
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host2")))

// add shutdown/excluded worker
val response1 =
buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"), Array.empty)
statusTracker.handleHeartbeatResponse(response1)

// test keep Unknown register time
Assert.assertEquals(
statusTracker.excludedWorkers.get(mock("host1")),
(StatusCode.WORKER_UNKNOWN, registerTime))

// test new added shutdown/excluded workers
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host0")))
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host3")))
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))

// test re heartbeat with shutdown workers
val response2 = buildResponse(Array.empty, Array.empty, Array("host4"), Array.empty)
statusTracker.handleHeartbeatResponse(response2)
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))

// test remove
val workers = new util.HashSet[WorkerInfo]
workers.add(mock("host3"))
statusTracker.removeFromExcludedWorkers(workers)
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host3")))

// test register time elapsed
Thread.sleep(3000)
val response3 = buildResponse(Array.empty, Array("host5", "host6"), Array.empty, Array.empty)
statusTracker.handleHeartbeatResponse(response3)
Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))

// test available workers
Assert.assertEquals(statusTracker.availableWorkers.size(), 0)
val response4 = buildResponse(
Array.empty,
Array.empty,
Array.empty,
Array("host5", "host6", "host7", "host8"))
statusTracker.handleHeartbeatResponse(response4)

// availableWorkers wont update through heartbeat
// when APPLICATION_HEARTBEAT_WITH_AVAILABLE_WORKERS_ENABLE set to false
Assert.assertEquals(statusTracker.availableWorkers.size(), 0)
// available workers won't overwrite excluded workers
Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))
}

test("handleHeartbeatResponse with availableWorkers") {
test("handleHeartbeatResponse") {
val celebornConf = new CelebornConf()
celebornConf.set(CLIENT_EXCLUDED_WORKER_EXPIRE_TIMEOUT, 2000L)
celebornConf.set(APPLICATION_HEARTBEAT_WITH_AVAILABLE_WORKERS_ENABLE, true)
val statusTracker = new WorkerStatusTracker(celebornConf, null)

val registerTime = System.currentTimeMillis()
statusTracker.excludedWorkers.put(mock("host1"), (StatusCode.WORKER_UNKNOWN, registerTime))
statusTracker.excludedWorkers.put(mock("host2"), (StatusCode.WORKER_SHUTDOWN, registerTime))

// test reserve (only statusCode list in handleHeartbeatResponse)
val empty = buildResponse(Array.empty, Array.empty, Array.empty, Array.empty)
val empty = buildResponse(Array.empty, Array.empty, Array.empty)
statusTracker.handleHeartbeatResponse(empty)

// only reserve host1
Expand All @@ -123,23 +50,23 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host2")))

// add shutdown/excluded worker
val response1 =
buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"), Array.empty)
val response1 = buildResponse(Array("host0"), Array("host1", "host3"), Array("host4"))
statusTracker.handleHeartbeatResponse(response1)

// test keep Unknown register time
Assert.assertEquals(
statusTracker.excludedWorkers.get(mock("host1")),
(StatusCode.WORKER_UNKNOWN, registerTime))

// test new added workers
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host0")))
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host3")))
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))

// test re heartbeat with shutdown workers
val response2 = buildResponse(Array.empty, Array.empty, Array("host4"), Array.empty)
statusTracker.handleHeartbeatResponse(response2)
val response3 = buildResponse(Array.empty, Array.empty, Array("host4"))
statusTracker.handleHeartbeatResponse(response3)
Assert.assertTrue(!statusTracker.excludedWorkers.containsKey(mock("host4")))
Assert.assertTrue(statusTracker.shuttingWorkers.contains(mock("host4")))

Expand All @@ -151,49 +78,25 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {

// test register time elapsed
Thread.sleep(3000)
val response3 = buildResponse(Array.empty, Array("host5", "host6"), Array.empty, Array.empty)
statusTracker.handleHeartbeatResponse(response3)
val response2 = buildResponse(Array.empty, Array("host5", "host6"), Array.empty)
statusTracker.handleHeartbeatResponse(response2)
Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
Assert.assertFalse(statusTracker.excludedWorkers.containsKey(mock("host1")))

// test available workers
Assert.assertEquals(statusTracker.availableWorkers.size(), 0)
val response4 = buildResponse(
Array.empty,
Array.empty,
Array.empty,
Array("host5", "host6", "host7", "host8"))
statusTracker.handleHeartbeatResponse(response4)
Assert.assertEquals(statusTracker.availableWorkers.size(), 2)
// available workers won't overwrite excluded workers
Assert.assertEquals(statusTracker.excludedWorkers.size(), 2)
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host5")))
Assert.assertTrue(statusTracker.excludedWorkers.containsKey(mock("host6")))

// test re heartbeat with available workers
val response5 = buildResponse(Array.empty, Array.empty, Array.empty, Array("host8", "host9"))
statusTracker.handleHeartbeatResponse(response5)
Assert.assertEquals(statusTracker.availableWorkers.size(), 2)
Assert.assertFalse(statusTracker.availableWorkers.contains(mock("host7")))
Assert.assertTrue(statusTracker.availableWorkers.contains(mock("host8")))
Assert.assertTrue(statusTracker.availableWorkers.contains(mock("host9")))
}

private def buildResponse(
excludedWorkerHosts: Array[String],
unknownWorkerHosts: Array[String],
shuttingWorkerHosts: Array[String],
availableWorkerHosts: Array[String]): HeartbeatFromApplicationResponse = {
shuttingWorkerHosts: Array[String]): HeartbeatFromApplicationResponse = {
val excludedWorkers = mockWorkers(excludedWorkerHosts)
val unknownWorkers = mockWorkers(unknownWorkerHosts)
val shuttingWorkers = mockWorkers(shuttingWorkerHosts)
val availableWorkers = mockWorkers(availableWorkerHosts)
HeartbeatFromApplicationResponse(
StatusCode.SUCCESS,
excludedWorkers,
unknownWorkers,
shuttingWorkers,
availableWorkers)
new util.ArrayList[Integer]())
}

private def mockWorkers(workerHosts: Array[String]): util.ArrayList[WorkerInfo] = {
Expand Down
Loading

0 comments on commit aa97292

Please sign in to comment.