Skip to content

Commit

Permalink
Add cancel subcommand to measurements in MeasurementSystem CLI. (#1113)
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjayVas authored and ple13 committed Aug 16, 2024
1 parent 0588d4a commit e0f7515
Showing 1 changed file with 37 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventGroupEntry
import org.wfanet.measurement.api.v2alpha.activateAccountRequest
import org.wfanet.measurement.api.v2alpha.apiKey
import org.wfanet.measurement.api.v2alpha.authenticateRequest
import org.wfanet.measurement.api.v2alpha.cancelMeasurementRequest
import org.wfanet.measurement.api.v2alpha.certificate
import org.wfanet.measurement.api.v2alpha.createApiKeyRequest
import org.wfanet.measurement.api.v2alpha.createCertificateRequest
Expand Down Expand Up @@ -541,6 +542,7 @@ private class MeasurementConsumers {
CreateMeasurement::class,
ListMeasurements::class,
GetMeasurement::class,
CancelMeasurement::class,
]
)
private class Measurements {
Expand Down Expand Up @@ -568,6 +570,18 @@ private class Measurements {
val certificateStub: CertificatesCoroutineStub by lazy {
CertificatesCoroutineStub(parentCommand.kingdomChannel)
}

companion object {
fun printState(measurement: Measurement) {
if (measurement.state == Measurement.State.FAILED) {
println(
"State: FAILED - " + measurement.failure.reason + ": " + measurement.failure.message
)
} else {
println("State: ${measurement.state}")
}
}
}
}

@Command(name = "create", description = ["Creates a Single Measurement"])
Expand Down Expand Up @@ -1002,14 +1016,6 @@ class GetMeasurement : Runnable {

private val privateKeyHandle: PrivateKeyHandle by lazy { loadPrivateKey(privateKeyDerFile) }

private fun printMeasurementState(measurement: Measurement) {
if (measurement.state == Measurement.State.FAILED) {
println("State: FAILED - " + measurement.failure.reason + ": " + measurement.failure.message)
} else {
println("State: ${measurement.state}")
}
}

private fun getMeasurementResult(
resultPair: Measurement.ResultPair,
): Measurement.Result {
Expand Down Expand Up @@ -1065,7 +1071,7 @@ class GetMeasurement : Runnable {
.getMeasurement(getMeasurementRequest { name = measurementName })
}

printMeasurementState(measurement)
Measurements.printState(measurement)
if (measurement.state == Measurement.State.SUCCEEDED) {
measurement.resultsList.forEach {
val result = getMeasurementResult(it)
Expand All @@ -1075,6 +1081,28 @@ class GetMeasurement : Runnable {
}
}

@Command(name = "cancel", description = ["Cancels a Measurement"])
class CancelMeasurement : Runnable {
@ParentCommand private lateinit var parentCommand: Measurements

@Parameters(
index = "0",
description = ["API resource name of the Measurement"],
)
private lateinit var measurementName: String

override fun run() {
val measurement =
runBlocking(parentCommand.parentCommand.rpcDispatcher) {
parentCommand.measurementStub
.withAuthenticationKey(parentCommand.apiAuthenticationKey)
.cancelMeasurement(cancelMeasurementRequest { name = measurementName })
}

Measurements.printState(measurement)
}
}

@Command(
name = "data-providers",
subcommands = [CommandLine.HelpCommand::class],
Expand Down

0 comments on commit e0f7515

Please sign in to comment.