Skip to content

Commit

Permalink
Update cross-media-measurement-api.
Browse files Browse the repository at this point in the history
This is a breaking change to the public API.
  • Loading branch information
SanjayVas committed May 25, 2023
1 parent 8324a9d commit acc05a6
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 158 deletions.
5 changes: 3 additions & 2 deletions build/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def wfa_measurement_system_repositories():

wfa_repo_archive(
name = "wfa_measurement_proto",
# DO_NOT_SUBMIT(world-federation-of-advertisers/cross-media-measurement-api#135): Use version.
commit = "3752010e2bee1ae4c003f20342f1d19993404aee",
repo = "cross-media-measurement-api",
sha256 = "644fa51594fa183b65dbc0a5a064ddfaa16c807b573ec66e171a9de63b0f2b03",
version = "0.31.0",
sha256 = "d358faffcf39201b6053f4fdb848f376615bd6375ca62486d5e2e5d6afd88ea4",
)

wfa_repo_archive(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
package org.wfanet.measurement.api.v2alpha.tools

import com.google.protobuf.ByteString
import com.google.protobuf.Timestamp
import com.google.protobuf.timestamp
import io.grpc.ManagedChannel
import java.io.File
import java.security.SecureRandom
Expand Down Expand Up @@ -447,13 +445,6 @@ private fun getDataProviderEntry(
}
}

private fun convertToTimestamp(instant: Instant): Timestamp {
return timestamp {
seconds = instant.epochSecond
nanos = instant.nano
}
}

private fun getReachAndFrequency(measurementTypeParams: MeasurementTypeParams): ReachAndFrequency {
return reachAndFrequency {
reachPrivacyParams = differentialPrivacyParams {
Expand Down Expand Up @@ -493,8 +484,7 @@ private fun getMeasurementResult(
privateKeyHandle: PrivateKeyHandle
): Measurement.Result {
val signedResult = decryptResult(resultPair.encryptedResult, privateKeyHandle)
val result = Measurement.Result.parseFrom(signedResult.data)
return result
return Measurement.Result.parseFrom(signedResult.data)
}

class Benchmark(
Expand Down Expand Up @@ -540,10 +530,10 @@ class Benchmark(
lateinit var result: Measurement.Result
}
/** List of tasks that have been submitted to the Kingdom. */
val taskList: MutableList<MeasurementTask> = Collections.synchronizedList(mutableListOf())
private val taskList: MutableList<MeasurementTask> = Collections.synchronizedList(mutableListOf())

/** List of tasks for which responses have been received or which have timed out. */
val completedTasks: MutableList<MeasurementTask> = mutableListOf()
private val completedTasks: MutableList<MeasurementTask> = mutableListOf()

/** Creates list of requests and sends them to the Kingdom. */
private fun generateRequests(
Expand Down Expand Up @@ -620,7 +610,12 @@ class Benchmark(
runBlocking(Dispatchers.IO) {
measurementStub
.withAuthenticationKey(apiAuthenticationKey)
.createMeasurement(createMeasurementRequest { this.measurement = measurement })
.createMeasurement(
createMeasurementRequest {
parent = measurementConsumer.name
this.measurement = measurement
}
)
}
println("Measurement Name: ${response.name}")

Expand All @@ -637,7 +632,7 @@ class Benchmark(
) {
var iTask = 0
while (iTask < taskList.size) {
val task = taskList.get(iTask)
val task = taskList[iTask]

print("${(Instant.now(clock).toEpochMilli() - firstInstant.toEpochMilli()) / 1000.0} ")
print("Trying to retrieve ${task.referenceId} ${task.measurementName}...")
Expand All @@ -660,16 +655,20 @@ class Benchmark(
(measurement.state == Measurement.State.FAILED) ||
timeoutOccurred
) {
if (measurement.state == Measurement.State.SUCCEEDED) {
val result = getMeasurementResult(measurement.resultsList.get(0), flags.privateKeyHandle)
task.result = result
// println ("Got result for task $iTask\n$measurement\n-----\n$result")
task.status = "success"
} else if (measurement.state == Measurement.State.FAILED) {
task.status = "failed"
task.errorMessage = measurement.failure.message
} else {
task.status = "timeout"
when (measurement.state) {
Measurement.State.SUCCEEDED -> {
val result = getMeasurementResult(measurement.resultsList[0], flags.privateKeyHandle)
task.result = result
// println ("Got result for task $iTask\n$measurement\n-----\n$result")
task.status = "success"
}
Measurement.State.FAILED -> {
task.status = "failed"
task.errorMessage = measurement.failure.message
}
else -> {
task.status = "timeout"
}
}

task.responseTime = Instant.now(clock)
Expand Down Expand Up @@ -712,7 +711,7 @@ class Benchmark(
reach = task.result.reach.value
}
out.print(reach)
var frequencies = arrayOf(0.0, 0.0, 0.0, 0.0, 0.0)
val frequencies = arrayOf(0.0, 0.0, 0.0, 0.0, 0.0)
if (task.status == "success" && task.result.hasFrequency()) {
task.result.frequency.relativeFrequencyDistributionMap.forEach {
if (it.key <= frequencies.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ class CreateMeasurement : Runnable {
.withAuthenticationKey(parentCommand.apiAuthenticationKey)
.createMeasurement(
createMeasurementRequest {
parent = measurementConsumer.name
this.measurement = measurement
requestId = this@CreateMeasurement.requestId
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,29 @@ class MeasurementsService(
override suspend fun createMeasurement(request: CreateMeasurementRequest): Measurement {
val authenticatedMeasurementConsumerKey = getAuthenticatedMeasurementConsumerKey()

val parentKey =
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"parent is either unspecified or invalid"
}

if (parentKey != authenticatedMeasurementConsumerKey) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot create a Measurement for another MeasurementConsumer"
}
}

val measurementConsumerCertificateKey =
grpcRequireNotNull(
MeasurementConsumerCertificateKey.fromName(
request.measurement.measurementConsumerCertificate
)
) {
"Measurement Consumer Certificate resource name is either unspecified or invalid"
"measurement_consumer_certificate is either unspecified or invalid"
}

if (
authenticatedMeasurementConsumerKey.measurementConsumerId !=
measurementConsumerCertificateKey.measurementConsumerId
grpcRequire(
measurementConsumerCertificateKey.measurementConsumerId == parentKey.measurementConsumerId
) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot create a Measurement for another MeasurementConsumer"
}
"measurement_consumer_certificate does not belong to ${request.parent}"
}

val measurementSpec = request.measurement.measurementSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ class FrontendSimulator(
}

val request = createMeasurementRequest {
parent = measurementConsumer.name
measurement = measurement {
measurementConsumerCertificate = measurementConsumer.certificate
measurementSpec =
Expand Down Expand Up @@ -474,7 +475,7 @@ class FrontendSimulator(
}

/** Gets the expected result of a [Measurement] using raw sketches. */
suspend fun getExpectedResult(
private suspend fun getExpectedResult(
measurementName: String,
protocolConfig: ProtocolConfig.LiquidLegionsV2
): Result {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1055,17 +1055,14 @@ class ReportsService(
apiAuthenticationKey: String,
signingConfig: SigningConfig,
): CreateMeasurementRequest {
grpcRequireNotNull(MeasurementConsumerKey.fromName(measurementConsumer.name)) {
"Invalid measurement consumer name [${measurementConsumer.name}]"
}

val measurementConsumerCertificate: X509Certificate =
readCertificate(signingConfig.signingCertificateDer)
val measurementConsumerSigningKey =
SigningKeyHandle(measurementConsumerCertificate, signingConfig.signingPrivateKey)
val measurementEncryptionPublicKey: ByteString = measurementConsumer.publicKey.data

return createMeasurementRequest {
parent = measurementConsumer.name
measurement = measurement {
this.measurementConsumerCertificate = signingConfig.signingCertificateName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,30 +327,29 @@ class MetricsService(
val measurementConsumerSigningKey = getMeasurementConsumerSigningKey(principal)
val measurementEncryptionPublicKey = measurementConsumer.publicKey.data

val measurement = measurement {
this.measurementConsumerCertificate = principal.config.signingCertificateName

dataProviders +=
buildDataProviderEntries(
eventGroupEntriesByDataProvider,
measurementEncryptionPublicKey,
measurementConsumerSigningKey,
principal.config.apiKey,
)

val unsignedMeasurementSpec: MeasurementSpec =
buildUnsignedMeasurementSpec(
measurementEncryptionPublicKey,
dataProviders.map { it.value.nonceHash },
metricSpec
)
return createMeasurementRequest {
parent = measurementConsumer.name
measurement = measurement {
measurementConsumerCertificate = principal.config.signingCertificateName

dataProviders +=
buildDataProviderEntries(
eventGroupEntriesByDataProvider,
measurementEncryptionPublicKey,
measurementConsumerSigningKey,
principal.config.apiKey,
)

measurementSpec =
signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey)
}
val unsignedMeasurementSpec: MeasurementSpec =
buildUnsignedMeasurementSpec(
measurementEncryptionPublicKey,
dataProviders.map { it.value.nonceHash },
metricSpec
)

return createMeasurementRequest {
this.measurement = measurement
measurementSpec =
signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey)
}
requestId = internalMeasurement.cmmsCreateMeasurementRequestId
}
}
Expand Down Expand Up @@ -649,7 +648,6 @@ class MetricsService(
var anyUpdate = false

for ((newState, measurementsList) in newStateToCmmsMeasurements) {
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null.
when (newState) {
Measurement.State.SUCCEEDED -> {
syncSucceededInternalMeasurements(measurementsList, apiAuthenticationKey, principal)
Expand Down
Loading

0 comments on commit acc05a6

Please sign in to comment.