Skip to content

Commit

Permalink
Specify request_id in CreateMeasurement and CreateEventGroup requests.
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjayVas committed May 24, 2023
1 parent d4d03b2 commit 6c4025c
Show file tree
Hide file tree
Showing 7 changed files with 289 additions and 327 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,14 @@ class CreateMeasurement : Runnable {
)
private lateinit var measurementConsumer: String

@Option(
names = ["--request-id"],
description = ["ID of API request for idempotency"],
required = false,
defaultValue = "",
)
private lateinit var requestId: String

@Option(
names = ["--private-key-der-file"],
description = ["Private key for MeasurementConsumer"],
Expand All @@ -546,7 +554,7 @@ class CreateMeasurement : Runnable {
required = false,
defaultValue = ""
)
private lateinit var measurementIdempotencyKey: String
private lateinit var measurementReferenceId: String

@set:Option(
names = ["--vid-sampling-start"],
Expand Down Expand Up @@ -879,14 +887,19 @@ class CreateMeasurement : Runnable {

this.measurementSpec =
signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey)
measurementReferenceId = measurementIdempotencyKey
measurementReferenceId = this@CreateMeasurement.measurementReferenceId
}

val response =
runBlocking(parentCommand.parentCommand.rpcDispatcher) {
parentCommand.measurementStub
.withAuthenticationKey(parentCommand.apiAuthenticationKey)
.createMeasurement(createMeasurementRequest { this.measurement = measurement })
.createMeasurement(
createMeasurementRequest {
this.measurement = measurement
requestId = this@CreateMeasurement.requestId
}
)
}
println("Measurement Name: ${response.name}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,13 @@ class EdpSimulator(
throw Exception("Error creating EventGroupMetadataDescriptor", e)
}

val eventGroupReferenceId =
"${TestIdentifiers.EVENT_GROUP_REFERENCE_ID_PREFIX}-${edpData.displayName}"
val request = createEventGroupRequest {
parent = edpData.name
eventGroup = eventGroup {
this.measurementConsumer = measurementConsumerName
eventGroupReferenceId =
"${TestIdentifiers.EVENT_GROUP_REFERENCE_ID_PREFIX}-${edpData.displayName}"
this.eventGroupReferenceId = eventGroupReferenceId
eventTemplates += eventTemplateNames.map { eventTemplate { type = it } }
measurementConsumerCertificate = measurementConsumer.certificate
measurementConsumerPublicKey = measurementConsumer.publicKey
Expand All @@ -227,6 +228,7 @@ class EdpSimulator(
EncryptionPublicKey.parseFrom(measurementConsumer.publicKey.data)
)
}
requestId = eventGroupReferenceId
}
val eventGroup =
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ class ReportsService(
signingConfig: SigningConfig,
internalReportingSetMap: Map<Long, InternalReportingSet>
) = coroutineScope {
val deferred = mutableListOf<Deferred<Measurement>>()
val deferredMeasurements = mutableListOf<Deferred<Measurement>>()
for (metric in request.report.metricsList) {
val internalMetricDetails = buildInternalMetricDetails(metric)

Expand All @@ -481,34 +481,35 @@ class ReportsService(
val setOperationResult: SetOperationResult =
namedSetOperationResults[setOperationId] ?: continue

setOperationResult.weightedMeasurementInfoList.forEach { weightedMeasurementInfo ->
deferred.add(
for (weightedMeasurementInfo in setOperationResult.weightedMeasurementInfoList) {
deferredMeasurements.add(
async {
createMeasurement(
weightedMeasurementInfo,
reportInfo,
setOperationResult.internalMetricDetails,
measurementConsumer,
apiAuthenticationKey,
signingConfig,
internalReportingSetMap
)
weightedMeasurementInfo,
reportInfo,
setOperationResult.internalMetricDetails,
measurementConsumer,
apiAuthenticationKey,
signingConfig,
internalReportingSetMap
)
.also {
weightedMeasurementInfo.kingdomMeasurementId =
checkNotNull(MeasurementKey.fromName(it.name)).measurementId
}
}
)
}
}
}

// map of kingdom measurementReferenceId to kingdom apiId
val measurementMap = mutableMapOf<String, String>()
deferred.awaitAll().forEach { measurement ->
for (measurement in deferredMeasurements.awaitAll()) {
try {
val apiId = MeasurementKey.fromName(measurement.name)!!.measurementId
measurementMap[measurement.measurementReferenceId] = apiId
val measurementKey = checkNotNull(MeasurementKey.fromName(measurement.name))
internalMeasurementsStub.createMeasurement(
internalMeasurement {
this.measurementConsumerReferenceId = reportInfo.measurementConsumerReferenceId
this.measurementReferenceId = apiId
this.measurementConsumerReferenceId = measurementKey.measurementConsumerId
this.measurementReferenceId = measurementKey.measurementId
state = InternalMeasurement.State.PENDING
}
)
Expand All @@ -522,27 +523,6 @@ class ReportsService(
}
}
}

for (metric in request.report.metricsList) {
val internalMetricDetails = buildInternalMetricDetails(metric)

for (namedSetOperation in metric.setOperationsList) {
val setOperationId =
buildSetOperationId(
reportInfo.reportIdempotencyKey,
internalMetricDetails,
namedSetOperation.uniqueName,
)

val setOperationResult: SetOperationResult =
namedSetOperationResults[setOperationId] ?: continue

setOperationResult.weightedMeasurementInfoList.forEach { weightedMeasurementInfo ->
weightedMeasurementInfo.kingdomMeasurementId =
measurementMap[weightedMeasurementInfo.reportingMeasurementId]
}
}
}
}

/** Creates a kingdom measurement for a [WeightedMeasurement]. */
Expand Down Expand Up @@ -579,7 +559,7 @@ class ReportsService(
.createMeasurement(createMeasurementRequest)
} catch (e: StatusException) {
throw Exception(
"Unable to create the measurement [${createMeasurementRequest.measurement.name}].",
"Unable to create Measurement with request ID ${createMeasurementRequest.requestId}",
e
)
}
Expand Down Expand Up @@ -1050,13 +1030,14 @@ class ReportsService(
.forEach { (reportTimeInterval, weightedMeasurementInfos) ->
measurementCalculations.add(
InternalMetricKt.measurementCalculation {
this.timeInterval = reportTimeInterval.toInternal()
timeInterval = reportTimeInterval.toInternal()

weightedMeasurementInfos.forEach {
for (weightedMeasurementInfo in weightedMeasurementInfos) {
weightedMeasurements +=
MeasurementCalculationKt.weightedMeasurement {
this.measurementReferenceId = it.kingdomMeasurementId!!
coefficient = it.weightedMeasurement.coefficient
measurementReferenceId =
checkNotNull(weightedMeasurementInfo.kingdomMeasurementId)
coefficient = weightedMeasurementInfo.weightedMeasurement.coefficient
}
}
}
Expand All @@ -1070,50 +1051,44 @@ class ReportsService(
measurementConsumer: MeasurementConsumer,
eventGroupEntriesByDataProvider: Map<DataProviderKey, List<EventGroupEntry>>,
internalMetricDetails: InternalMetricDetails,
measurementReferenceId: String,
requestId: String,
apiAuthenticationKey: String,
signingConfig: SigningConfig,
): CreateMeasurementRequest {
val measurementConsumerReferenceId =
grpcRequireNotNull(MeasurementConsumerKey.fromName(measurementConsumer.name)) {
"Invalid measurement consumer name [${measurementConsumer.name}]"
}
.measurementConsumerId
grpcRequireNotNull(MeasurementConsumerKey.fromName(measurementConsumer.name)) {
"Invalid measurement consumer name [${measurementConsumer.name}]"
}

val measurementConsumerCertificate = readCertificate(signingConfig.signingCertificateDer)
val measurementConsumerCertificate: X509Certificate =
readCertificate(signingConfig.signingCertificateDer)
val measurementConsumerSigningKey =
SigningKeyHandle(measurementConsumerCertificate, signingConfig.signingPrivateKey)
val measurementEncryptionPublicKey = measurementConsumer.publicKey.data

val measurementResourceName =
MeasurementKey(measurementConsumerReferenceId, measurementReferenceId).toName()

val measurement = measurement {
name = measurementResourceName
this.measurementConsumerCertificate = signingConfig.signingCertificateName
val measurementEncryptionPublicKey: ByteString = measurementConsumer.publicKey.data

dataProviders +=
buildDataProviderEntries(
eventGroupEntriesByDataProvider,
measurementEncryptionPublicKey,
measurementConsumerSigningKey,
apiAuthenticationKey,
)
return createMeasurementRequest {
measurement = measurement {
this.measurementConsumerCertificate = signingConfig.signingCertificateName

val unsignedMeasurementSpec: MeasurementSpec =
buildUnsignedMeasurementSpec(
measurementEncryptionPublicKey,
dataProviders.map { it.value.nonceHash },
internalMetricDetails,
)
dataProviders +=
buildDataProviderEntries(
eventGroupEntriesByDataProvider,
measurementEncryptionPublicKey,
measurementConsumerSigningKey,
apiAuthenticationKey,
)

this.measurementSpec =
signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey)
val unsignedMeasurementSpec: MeasurementSpec =
buildUnsignedMeasurementSpec(
measurementEncryptionPublicKey,
dataProviders.map { it.value.nonceHash },
internalMetricDetails,
)

this.measurementReferenceId = measurementReferenceId
measurementSpec =
signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey)
}
this.requestId = requestId
}

return createMeasurementRequest { this.measurement = measurement }
}

/**
Expand Down Expand Up @@ -1733,10 +1708,7 @@ private fun buildDurationMeasurementSpec(

/** Converts a public [SetOperation.Type] to an [InternalSetOperation.Type]. */
private fun SetOperation.Type.toInternal(): InternalSetOperation.Type {
val source = this

@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null.
return when (source) {
return when (this) {
SetOperation.Type.UNION -> InternalSetOperation.Type.UNION
SetOperation.Type.INTERSECTION -> InternalSetOperation.Type.INTERSECTION
SetOperation.Type.DIFFERENCE -> InternalSetOperation.Type.DIFFERENCE
Expand Down Expand Up @@ -1984,7 +1956,6 @@ private fun InternalReport.toReport(): Report {

/** Converts an [InternalReport.State] to a public [Report.State]. */
private fun InternalReport.State.toState(): Report.State {
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null.
return when (this) {
InternalReport.State.RUNNING -> Report.State.RUNNING
InternalReport.State.SUCCEEDED -> Report.State.SUCCEEDED
Expand Down Expand Up @@ -2096,7 +2067,6 @@ private fun InternalOperand.toOperand(): Operand {

/** Converts an internal [InternalSetOperation.Type] to a public [SetOperation.Type]. */
private fun InternalSetOperation.Type.toType(): SetOperation.Type {
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null.
return when (this) {
InternalSetOperation.Type.UNION -> SetOperation.Type.UNION
InternalSetOperation.Type.INTERSECTION -> SetOperation.Type.INTERSECTION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,14 @@ class MetricsService(
metricSpec
)

this.measurementSpec =
measurementSpec =
signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey)

this.measurementReferenceId = internalMeasurement.cmmsCreateMeasurementRequestId
}

return createMeasurementRequest { this.measurement = measurement }
return createMeasurementRequest {
this.measurement = measurement
requestId = internalMeasurement.cmmsCreateMeasurementRequestId
}
}

/** Gets a [SigningKeyHandle] for a [MeasurementConsumerPrincipal]. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ class MeasurementSystemTest {

@Test
fun `measurements create calls CreateMeasurement with valid request`() {
val requestId = "foo"
val args =
commonArgs +
arrayOf(
Expand All @@ -574,6 +575,7 @@ class MeasurementSystemTest {
"--measurement-consumer=measurementConsumers/777",
"--private-key-der-file=$SECRETS_DIR/mc_cs_private.der",
"--measurement-ref-id=9999",
"--request-id=$requestId",
"--data-provider=dataProviders/1",
"--event-group=dataProviders/1/eventGroups/1",
"--event-filter=abcd",
Expand Down Expand Up @@ -606,6 +608,7 @@ class MeasurementSystemTest {
captureFirst<CreateMeasurementRequest> {
runBlocking { verify(measurementsServiceMock).createMeasurement(capture()) }
}
assertThat(request.requestId).isEqualTo(requestId)
val measurement = request.measurement
// measurementSpec matches
verifyMeasurementSpec(
Expand Down
Loading

0 comments on commit 6c4025c

Please sign in to comment.