Skip to content

Commit

Permalink
Introduce --shard-split and --shard-all (#1955)
Browse files Browse the repository at this point in the history
Also deprecate `--shards

---------

Co-authored-by: Bartek Pacia <[email protected]>
  • Loading branch information
tokou and bartekpacia authored Sep 3, 2024
1 parent 022840f commit 9ad6e39
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions maestro-cli/src/main/java/maestro/cli/command/TestCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,22 @@ class TestCommand : Callable<Int> {

@Option(
names = ["-s", "--shards"],
description = ["Number of parallel shards to distribute tests across"]
description = ["Number of parallel shards to distribute tests across"],
)
private var shards: Int = 1
@Deprecated("Use --shard-split or --shard-all instead")
private var legacyShardCount: Int? = null

@Option(
names = ["--shard-split"],
description = ["Splits the tests across N connected devices"],
)
private var shardSplit: Int? = null

@Option(
names = ["--shard-all"],
description = ["Replicates all the tests across N connected devices"],
)
private var shardAll: Int? = null

@Option(names = ["-c", "--continuous"])
private var continuous: Boolean = false
Expand Down Expand Up @@ -152,6 +165,14 @@ class TestCommand : Callable<Int> {
}

override fun call(): Int {
if (shardSplit != null && shardAll != null) {
throw CliError("Options --shard-split and --shard-all are mutually exclusive.")
}

if (legacyShardCount != null) {
PrintUtils.warn("--shards option is deprecated and will be removed in the next Maestro version. Use --shard-split or --shard-all instead.")
shardSplit = legacyShardCount
}
val executionPlan = try {
WorkspaceExecutionPlanner.plan(
flowFile.toPath().toAbsolutePath(),
Expand All @@ -177,7 +198,6 @@ class TestCommand : Callable<Int> {
}

private fun handleSessions(debugOutputPath: Path, plan: ExecutionPlan): Int = runBlocking(Dispatchers.IO) {
val sharded = shards > 1

runCatching {
val deviceIds = (if (isWebFlow())
Expand All @@ -193,10 +213,18 @@ class TestCommand : Callable<Int> {
initialActiveDevices.addAll(DeviceService.listConnectedDevices().map {
it.instanceId
}.toMutableSet())

val shards = shardSplit ?: shardAll ?: 1

val availableDevices = if (deviceIds.isNotEmpty()) deviceIds.size else initialActiveDevices.size
val effectiveShards = shards.coerceAtMost(plan.flowsToRun.size)
val chunkPlans = plan.flowsToRun
val sharded = effectiveShards > 1

val chunkPlans =
if (shardAll != null) (0 until effectiveShards).map { plan.copy() }
else plan.flowsToRun
.withIndex()
.groupBy { it.index % shards }
.groupBy { it.index % effectiveShards }
.map { (shardIndex, files) ->
ExecutionPlan(
files.map { it.value },
Expand All @@ -208,12 +236,12 @@ class TestCommand : Callable<Int> {
}

// Collect device configurations for missing shards, if any
val missing = effectiveShards - if (deviceIds.isNotEmpty()) deviceIds.size else initialActiveDevices.size
val allDeviceConfigs = (0 until missing).map { shardIndex ->
val missing = effectiveShards - availableDevices
val allDeviceConfigs = if (shardAll == null) (0 until missing).map { shardIndex ->
PrintUtils.message("------------------ Shard ${shardIndex + 1} ------------------")
// Collect device configurations here, one per shard
PickDeviceView.requestDeviceOptions()
}.toMutableList()
}.toMutableList() else mutableListOf()

val barrier = CountDownLatch(effectiveShards)

Expand Down

0 comments on commit 9ad6e39

Please sign in to comment.