Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getMetric and batchGetMetric. #959

Merged
merged 12 commits into from
May 1, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ import org.wfanet.measurement.internal.reporting.v2.metricSpec as internalMetric
import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.CreateMetricRequest
import org.wfanet.measurement.reporting.v2alpha.GetMetricRequest
import org.wfanet.measurement.reporting.v2alpha.ListMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.ListMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.Metric
Expand All @@ -137,6 +140,7 @@ import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.watchDurationResu
import org.wfanet.measurement.reporting.v2alpha.MetricSpec
import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineImplBase
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.listMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.metric
import org.wfanet.measurement.reporting.v2alpha.metricResult
Expand All @@ -148,19 +152,19 @@ private const val MAX_PAGE_SIZE = 1000
private const val NANOS_PER_SECOND = 1_000_000_000

class MetricsService(
private val metricSpecConfig: MetricSpecConfig,
private val internalReportingSetsStub: InternalReportingSetsCoroutineStub,
private val internalMeasurementsStub: InternalMeasurementsCoroutineStub,
private val internalMetricsStub: InternalMetricsCoroutineStub,
private val dataProvidersStub: DataProvidersCoroutineStub,
private val measurementsStub: MeasurementsCoroutineStub,
private val certificatesStub: CertificatesCoroutineStub,
private val measurementConsumersStub: MeasurementConsumersCoroutineStub,
private val encryptionKeyPairStore: EncryptionKeyPairStore,
private val secureRandom: SecureRandom,
private val signingPrivateKeyDir: File,
private val trustedCertificates: Map<ByteString, X509Certificate>,
private val metricSpecConfig: MetricSpecConfig,
private val coroutineContext: @BlockingExecutor CoroutineContext = Dispatchers.IO,
internalMeasurementsStub: InternalMeasurementsCoroutineStub,
dataProvidersStub: DataProvidersCoroutineStub,
measurementsStub: MeasurementsCoroutineStub,
certificatesStub: CertificatesCoroutineStub,
measurementConsumersStub: MeasurementConsumersCoroutineStub,
encryptionKeyPairStore: EncryptionKeyPairStore,
secureRandom: SecureRandom,
signingPrivateKeyDir: File,
trustedCertificates: Map<ByteString, X509Certificate>,
coroutineContext: @BlockingExecutor CoroutineContext = Dispatchers.IO,
) : MetricsCoroutineImplBase() {

private val measurementSupplier =
Expand Down Expand Up @@ -294,7 +298,21 @@ class MetricsService(
.withAuthenticationKey(principal.config.apiKey)
.createMeasurement(createMeasurementRequest)
} catch (e: StatusException) {
throw Exception("Unable to create a CMMS measurement.", e)
throw when (e.status.code) {
Status.Code.INVALID_ARGUMENT ->
Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.")
Status.Code.PERMISSION_DENIED ->
Status.PERMISSION_DENIED.withDescription(
"Cannot create a CMMS Measurement for another MeasurementConsumer."
)
Status.Code.FAILED_PRECONDITION ->
Status.FAILED_PRECONDITION.withDescription("Failed precondition.")
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${measurementConsumer.name} is not found.")
else -> Status.UNKNOWN.withDescription("Unable to create a CMMS measurement.")
}
.withCause(e)
.asRuntimeException()
}
}

Expand Down Expand Up @@ -405,8 +423,8 @@ class MetricsService(
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found")
else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName")
Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found.")
else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName.")
}
.withCause(e)
.asRuntimeException()
Expand All @@ -418,7 +436,16 @@ class MetricsService(
.withAuthenticationKey(apiAuthenticationKey)
.getCertificate(getCertificateRequest { name = dataProvider.certificate })
} catch (e: StatusException) {
throw Exception("Unable to retrieve Certificate ${dataProvider.certificate}", e)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${dataProvider.certificate} not found.")
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve Certificate ${dataProvider.certificate}."
)
}
.withCause(e)
.asRuntimeException()
}
if (
certificate.revocationState != Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED
Expand Down Expand Up @@ -535,10 +562,16 @@ class MetricsService(
getMeasurementConsumerRequest { name = principal.resourceKey.toName() }
)
} catch (e: StatusException) {
throw Exception(
"Unable to retrieve the measurement consumer " + "[${principal.resourceKey.toName()}].",
e
)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.")
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}]."
)
}
.withCause(e)
.asRuntimeException()
}
}

Expand Down Expand Up @@ -580,11 +613,19 @@ class MetricsService(
.getCertificate(getCertificateRequest { name = principal.config.signingCertificateName })
.x509Der
} catch (e: StatusException) {
throw Exception(
"Unable to retrieve the signing certificate for the measurement consumer " +
"[$principal.config.signingCertificateName].",
e
)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription(
"${principal.config.signingCertificateName} not found."
)
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the signing certificate " +
"[${principal.config.signingCertificateName}] for the measurement consumer."
)
}
.withCause(e)
.asRuntimeException()
}
}

Expand Down Expand Up @@ -720,7 +761,20 @@ class MetricsService(
.withAuthenticationKey(apiAuthenticationKey)
.getMeasurement(getMeasurementRequest { name = measurementResourceName })
} catch (e: StatusException) {
throw Exception("Unable to retrieve the measurement [$measurementResourceName].", e)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("$measurementResourceName not found.")
Status.Code.PERMISSION_DENIED ->
Status.PERMISSION_DENIED.withDescription(
"Doesn't have permission to get $measurementResourceName."
)
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the measurement [$measurementResourceName]."
)
}
.withCause(e)
.asRuntimeException()
}
}
}
Expand Down Expand Up @@ -766,10 +820,17 @@ class MetricsService(
.withAuthenticationKey(apiAuthenticationKey)
.getCertificate(getCertificateRequest { name = measurementResultPair.certificate })
} catch (e: StatusException) {
throw Exception(
"Unable to retrieve the certificate [${measurementResultPair.certificate}].",
e
)
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("${measurementResultPair.certificate} not found.")
else ->
Status.UNKNOWN.withDescription(
"Unable to retrieve the certificate " +
"[${measurementResultPair.certificate}] for the measurement consumer."
)
}
.withCause(e)
.asRuntimeException()
}

val signedResult =
Expand Down Expand Up @@ -853,6 +914,119 @@ class MetricsService(
}
}

override suspend fun getMetric(request: GetMetricRequest): Metric {
val metricKey =
grpcRequireNotNull(MetricKey.fromName(request.name)) {
"Metric name is either unspecified or invalid."
}

val principal: ReportingPrincipal = principalFromCurrentContext
when (principal) {
is MeasurementConsumerPrincipal -> {
if (metricKey.cmmsMeasurementConsumerId != principal.resourceKey.measurementConsumerId) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot get a Metric for another MeasurementConsumer."
}
}
}
}

val internalMetric: InternalMetric =
getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId))

// Early exit when the metric is at a terminal state.
if (internalMetric.state != Metric.State.RUNNING) {
return internalMetric.toMetric()
}

// Only syncs pending measurements which can only be in metrics that are still running.
val toBeSyncedInternalMeasurements: List<InternalMeasurement> =
internalMetric.weightedMeasurementsList
.map { weightedMeasurement -> weightedMeasurement.measurement }
.filter { internalMeasurement ->
internalMeasurement.state == InternalMeasurement.State.PENDING
}

val anyMeasurementUpdated: Boolean =
measurementSupplier.syncInternalMeasurements(
toBeSyncedInternalMeasurements,
principal.config.apiKey,
principal,
)

return if (anyMeasurementUpdated) {
getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId))
.toMetric()
} else {
internalMetric.toMetric()
}
}

override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse {
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"Parent is either unspecified or invalid."
}

val principal: ReportingPrincipal = principalFromCurrentContext

when (principal) {
is MeasurementConsumerPrincipal -> {
if (request.parent != principal.resourceKey.toName()) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot get Metrics for another MeasurementConsumer."
}
}
}
}

grpcRequire(request.namesList.isNotEmpty()) { "No metric name is provided." }
grpcRequire(request.namesList.size <= MAX_BATCH_SIZE) {
"At most $MAX_BATCH_SIZE metrics can be supported in a batch."
}

val externalMetricIds: List<Long> =
request.namesList.map { metricName ->
val metricKey =
grpcRequireNotNull(MetricKey.fromName(metricName)) {
"Metric name is either unspecified or invalid."
}
apiIdToExternalId(metricKey.metricId)
}

val internalMetrics: List<InternalMetric> =
batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds)

// Only syncs pending measurements which can only be in metrics that are still running.
val toBeSyncedInternalMeasurements: List<InternalMeasurement> =
internalMetrics
.filter { internalMetric -> internalMetric.state == Metric.State.RUNNING }
.flatMap { internalMetric -> internalMetric.weightedMeasurementsList }
.map { weightedMeasurement -> weightedMeasurement.measurement }
.filter { internalMeasurement ->
internalMeasurement.state == InternalMeasurement.State.PENDING
}

val anyMeasurementUpdated: Boolean =
measurementSupplier.syncInternalMeasurements(
toBeSyncedInternalMeasurements,
principal.config.apiKey,
principal,
)

return batchGetMetricsResponse {
metrics +=
/**
* TODO(@riemanli): a potential improvement can be done by only getting the metrics whose
* measurements are updated. Re-evaluate when a load-test is ready after deployment.
*/
if (anyMeasurementUpdated) {
batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds)
.map { it.toMetric() }
} else {
internalMetrics.map { it.toMetric() }
}
}
}
override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse {
val listMetricsPageToken: ListMetricsPageToken = request.toListMetricsPageToken()

Expand Down Expand Up @@ -914,8 +1088,13 @@ class MetricsService(
principal,
)

// If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise,
// use the original list.
/**
* If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise,
* use the original list.
*
* TODO(@riemanli): a potential improvement can be done by only getting the metrics whose
* measurements are updated. Re-evaluate when a load-test is ready after deployment.
*/
val internalMetrics: List<InternalMetric> =
if (anyMeasurementUpdated) {
batchGetInternalMetrics(
Expand All @@ -935,7 +1114,7 @@ class MetricsService(
}
}

/** Gets a batch of [InternalMetric]. */
/** Gets a batch of [InternalMetric]s. */
private suspend fun batchGetInternalMetrics(
cmmsMeasurementConsumerId: String,
externalMetricIds: List<Long>,
Expand All @@ -952,6 +1131,23 @@ class MetricsService(
}
}

/** Gets an [InternalMetric]. */
private suspend fun getInternalMetric(
cmmsMeasurementConsumerId: String,
externalMetricId: Long,
): InternalMetric {
return try {
batchGetInternalMetrics(cmmsMeasurementConsumerId, listOf(externalMetricId)).first()
} catch (e: StatusException) {
val metricName =
MetricKey(cmmsMeasurementConsumerId, externalIdToApiId(externalMetricId)).toName()
throw Exception(
"Unable to get the metric with name = [${metricName}] from the reporting database.",
e
)
}
}

override suspend fun createMetric(request: CreateMetricRequest): Metric {
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"Parent is either unspecified or invalid."
Expand Down
Loading