Skip to content

Commit

Permalink
Add getMetric and batchGetMetric. (#959)
Browse files Browse the repository at this point in the history
  • Loading branch information
riemanli authored May 1, 2023
1 parent a99bf24 commit cee629b
Show file tree
Hide file tree
Showing 2 changed files with 1,184 additions and 519 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ import org.wfanet.measurement.internal.reporting.v2.metricSpec as internalMetric
import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.CreateMetricRequest
import org.wfanet.measurement.reporting.v2alpha.GetMetricRequest
import org.wfanet.measurement.reporting.v2alpha.ListMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.ListMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.Metric
Expand All @@ -137,6 +140,7 @@ import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.watchDurationResu
import org.wfanet.measurement.reporting.v2alpha.MetricSpec
import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineImplBase
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.listMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.metric
import org.wfanet.measurement.reporting.v2alpha.metricResult
Expand All @@ -148,19 +152,19 @@ private const val MAX_PAGE_SIZE = 1000
private const val NANOS_PER_SECOND = 1_000_000_000

class MetricsService(
private val metricSpecConfig: MetricSpecConfig,
private val internalReportingSetsStub: InternalReportingSetsCoroutineStub,
private val internalMeasurementsStub: InternalMeasurementsCoroutineStub,
private val internalMetricsStub: InternalMetricsCoroutineStub,
private val dataProvidersStub: DataProvidersCoroutineStub,
private val measurementsStub: MeasurementsCoroutineStub,
private val certificatesStub: CertificatesCoroutineStub,
private val measurementConsumersStub: MeasurementConsumersCoroutineStub,
private val encryptionKeyPairStore: EncryptionKeyPairStore,
private val secureRandom: SecureRandom,
private val signingPrivateKeyDir: File,
private val trustedCertificates: Map<ByteString, X509Certificate>,
private val metricSpecConfig: MetricSpecConfig,
private val coroutineContext: @BlockingExecutor CoroutineContext = Dispatchers.IO,
internalMeasurementsStub: InternalMeasurementsCoroutineStub,
dataProvidersStub: DataProvidersCoroutineStub,
measurementsStub: MeasurementsCoroutineStub,
certificatesStub: CertificatesCoroutineStub,
measurementConsumersStub: MeasurementConsumersCoroutineStub,
encryptionKeyPairStore: EncryptionKeyPairStore,
secureRandom: SecureRandom,
signingPrivateKeyDir: File,
trustedCertificates: Map<ByteString, X509Certificate>,
coroutineContext: @BlockingExecutor CoroutineContext = Dispatchers.IO,
) : MetricsCoroutineImplBase() {

private val measurementSupplier =
Expand Down Expand Up @@ -294,7 +298,21 @@ class MetricsService(
.withAuthenticationKey(principal.config.apiKey)
.createMeasurement(createMeasurementRequest)
} catch (e: StatusException) {
throw Exception("Unable to create a CMMS measurement.", e)
throw when (e.status.code) {
Status.Code.INVALID_ARGUMENT ->
Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.")
Status.Code.PERMISSION_DENIED ->
Status.PERMISSION_DENIED.withDescription(
"Cannot create a CMMS Measurement for another MeasurementConsumer."
)
Status.Code.FAILED_PRECONDITION ->
Status.FAILED_PRECONDITION.withDescription("Failed precondition.")
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${measurementConsumer.name} is not found.")
else -> Status.UNKNOWN.withDescription("Unable to create a CMMS measurement.")
}
.withCause(e)
.asRuntimeException()
}
}

Expand Down Expand Up @@ -405,8 +423,8 @@ class MetricsService(
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found")
else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName")
Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found.")
else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName.")
}
.withCause(e)
.asRuntimeException()
Expand All @@ -418,7 +436,16 @@ class MetricsService(
.withAuthenticationKey(apiAuthenticationKey)
.getCertificate(getCertificateRequest { name = dataProvider.certificate })
} catch (e: StatusException) {
throw Exception("Unable to retrieve Certificate ${dataProvider.certificate}", e)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${dataProvider.certificate} not found.")
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve Certificate ${dataProvider.certificate}."
)
}
.withCause(e)
.asRuntimeException()
}
if (
certificate.revocationState != Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED
Expand Down Expand Up @@ -535,10 +562,16 @@ class MetricsService(
getMeasurementConsumerRequest { name = principal.resourceKey.toName() }
)
} catch (e: StatusException) {
throw Exception(
"Unable to retrieve the measurement consumer " + "[${principal.resourceKey.toName()}].",
e
)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.")
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}]."
)
}
.withCause(e)
.asRuntimeException()
}
}

Expand Down Expand Up @@ -580,11 +613,19 @@ class MetricsService(
.getCertificate(getCertificateRequest { name = principal.config.signingCertificateName })
.x509Der
} catch (e: StatusException) {
throw Exception(
"Unable to retrieve the signing certificate for the measurement consumer " +
"[$principal.config.signingCertificateName].",
e
)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription(
"${principal.config.signingCertificateName} not found."
)
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the signing certificate " +
"[${principal.config.signingCertificateName}] for the measurement consumer."
)
}
.withCause(e)
.asRuntimeException()
}
}

Expand Down Expand Up @@ -720,7 +761,20 @@ class MetricsService(
.withAuthenticationKey(apiAuthenticationKey)
.getMeasurement(getMeasurementRequest { name = measurementResourceName })
} catch (e: StatusException) {
throw Exception("Unable to retrieve the measurement [$measurementResourceName].", e)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("$measurementResourceName not found.")
Status.Code.PERMISSION_DENIED ->
Status.PERMISSION_DENIED.withDescription(
"Doesn't have permission to get $measurementResourceName."
)
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the measurement [$measurementResourceName]."
)
}
.withCause(e)
.asRuntimeException()
}
}
}
Expand Down Expand Up @@ -766,10 +820,17 @@ class MetricsService(
.withAuthenticationKey(apiAuthenticationKey)
.getCertificate(getCertificateRequest { name = measurementResultPair.certificate })
} catch (e: StatusException) {
throw Exception(
"Unable to retrieve the certificate [${measurementResultPair.certificate}].",
e
)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${measurementResultPair.certificate} not found.")
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the certificate " +
"[${measurementResultPair.certificate}] for the measurement consumer."
)
}
.withCause(e)
.asRuntimeException()
}

val signedResult =
Expand Down Expand Up @@ -853,6 +914,119 @@ class MetricsService(
}
}

override suspend fun getMetric(request: GetMetricRequest): Metric {
val metricKey =
grpcRequireNotNull(MetricKey.fromName(request.name)) {
"Metric name is either unspecified or invalid."
}

val principal: ReportingPrincipal = principalFromCurrentContext
when (principal) {
is MeasurementConsumerPrincipal -> {
if (metricKey.cmmsMeasurementConsumerId != principal.resourceKey.measurementConsumerId) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot get a Metric for another MeasurementConsumer."
}
}
}
}

val internalMetric: InternalMetric =
getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId))

// Early exit when the metric is at a terminal state.
if (internalMetric.state != Metric.State.RUNNING) {
return internalMetric.toMetric()
}

// Only syncs pending measurements which can only be in metrics that are still running.
val toBeSyncedInternalMeasurements: List<InternalMeasurement> =
internalMetric.weightedMeasurementsList
.map { weightedMeasurement -> weightedMeasurement.measurement }
.filter { internalMeasurement ->
internalMeasurement.state == InternalMeasurement.State.PENDING
}

val anyMeasurementUpdated: Boolean =
measurementSupplier.syncInternalMeasurements(
toBeSyncedInternalMeasurements,
principal.config.apiKey,
principal,
)

return if (anyMeasurementUpdated) {
getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId))
.toMetric()
} else {
internalMetric.toMetric()
}
}

override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse {
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"Parent is either unspecified or invalid."
}

val principal: ReportingPrincipal = principalFromCurrentContext

when (principal) {
is MeasurementConsumerPrincipal -> {
if (request.parent != principal.resourceKey.toName()) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot get Metrics for another MeasurementConsumer."
}
}
}
}

grpcRequire(request.namesList.isNotEmpty()) { "No metric name is provided." }
grpcRequire(request.namesList.size <= MAX_BATCH_SIZE) {
"At most $MAX_BATCH_SIZE metrics can be supported in a batch."
}

val externalMetricIds: List<Long> =
request.namesList.map { metricName ->
val metricKey =
grpcRequireNotNull(MetricKey.fromName(metricName)) {
"Metric name is either unspecified or invalid."
}
apiIdToExternalId(metricKey.metricId)
}

val internalMetrics: List<InternalMetric> =
batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds)

// Only syncs pending measurements which can only be in metrics that are still running.
val toBeSyncedInternalMeasurements: List<InternalMeasurement> =
internalMetrics
.filter { internalMetric -> internalMetric.state == Metric.State.RUNNING }
.flatMap { internalMetric -> internalMetric.weightedMeasurementsList }
.map { weightedMeasurement -> weightedMeasurement.measurement }
.filter { internalMeasurement ->
internalMeasurement.state == InternalMeasurement.State.PENDING
}

val anyMeasurementUpdated: Boolean =
measurementSupplier.syncInternalMeasurements(
toBeSyncedInternalMeasurements,
principal.config.apiKey,
principal,
)

return batchGetMetricsResponse {
metrics +=
/**
* TODO(@riemanli): a potential improvement can be done by only getting the metrics whose
* measurements are updated. Re-evaluate when a load-test is ready after deployment.
*/
if (anyMeasurementUpdated) {
batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds)
.map { it.toMetric() }
} else {
internalMetrics.map { it.toMetric() }
}
}
}
override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse {
val listMetricsPageToken: ListMetricsPageToken = request.toListMetricsPageToken()

Expand Down Expand Up @@ -914,8 +1088,13 @@ class MetricsService(
principal,
)

// If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise,
// use the original list.
/**
* If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise,
* use the original list.
*
* TODO(@riemanli): a potential improvement can be done by only getting the metrics whose
* measurements are updated. Re-evaluate when a load-test is ready after deployment.
*/
val internalMetrics: List<InternalMetric> =
if (anyMeasurementUpdated) {
batchGetInternalMetrics(
Expand All @@ -935,7 +1114,7 @@ class MetricsService(
}
}

/** Gets a batch of [InternalMetric]. */
/** Gets a batch of [InternalMetric]s. */
private suspend fun batchGetInternalMetrics(
cmmsMeasurementConsumerId: String,
externalMetricIds: List<Long>,
Expand All @@ -952,6 +1131,23 @@ class MetricsService(
}
}

/** Gets an [InternalMetric]. */
private suspend fun getInternalMetric(
cmmsMeasurementConsumerId: String,
externalMetricId: Long,
): InternalMetric {
return try {
batchGetInternalMetrics(cmmsMeasurementConsumerId, listOf(externalMetricId)).first()
} catch (e: StatusException) {
val metricName =
MetricKey(cmmsMeasurementConsumerId, externalIdToApiId(externalMetricId)).toName()
throw Exception(
"Unable to get the metric with name = [${metricName}] from the reporting database.",
e
)
}
}

override suspend fun createMetric(request: CreateMetricRequest): Metric {
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"Parent is either unspecified or invalid."
Expand Down
Loading

0 comments on commit cee629b

Please sign in to comment.