diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index c8fe3ebe00f..87bd0a21667 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -38,17 +38,21 @@ import org.wfanet.measurement.internal.reporting.v2.createReportRequest as inter import org.wfanet.measurement.internal.reporting.v2.getReportRequest as internalGetReportRequest import org.wfanet.measurement.internal.reporting.v2.report as internalReport import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest import org.wfanet.measurement.reporting.v2alpha.CreateMetricRequest import org.wfanet.measurement.reporting.v2alpha.CreateReportRequest +import org.wfanet.measurement.reporting.v2alpha.GetReportRequest import org.wfanet.measurement.reporting.v2alpha.Metric import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineStub import org.wfanet.measurement.reporting.v2alpha.Report import org.wfanet.measurement.reporting.v2alpha.ReportKt import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt.ReportsCoroutineImplBase import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest import org.wfanet.measurement.reporting.v2alpha.report private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000 +private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100 private typealias InternalReportingMetricEntries = Map @@ -71,6 +75,90 @@ class ReportsService( val internalTimeRange: InternalTimeRange, ) + override suspend fun getReport(request: GetReportRequest): Report { + val reportKey = + grpcRequireNotNull(ReportKey.fromName(request.name)) { + "Report name is either unspecified or invalid" + } + + val principal: ReportingPrincipal = principalFromCurrentContext + when (principal) { + is MeasurementConsumerPrincipal -> { + if (reportKey.cmmsMeasurementConsumerId != principal.resourceKey.measurementConsumerId) { + failGrpc(Status.PERMISSION_DENIED) { + "Cannot get Report belonging to other MeasurementConsumers." + } + } + } + } + + val internalReport = + try { + internalReportsStub.getReport( + internalGetReportRequest { + cmmsMeasurementConsumerId = reportKey.cmmsMeasurementConsumerId + externalReportId = apiIdToExternalId(reportKey.reportId) + } + ) + } catch (e: StatusException) { + throw Exception("Unable to get the report from the reporting database.", e) + } + + // Create metrics. + val metricNames: List = + internalReport.reportingMetricEntriesMap.flatMap { (_, reportingMetricCalculationSpec) -> + reportingMetricCalculationSpec.metricCalculationSpecsList.flatMap { metricCalculationSpec -> + metricCalculationSpec.reportingMetricsList.map { reportingMetric -> + MetricKey( + principal.resourceKey.measurementConsumerId, + externalIdToApiId(reportingMetric.externalMetricId) + ) + .toName() + } + } + } + val metrics: List = + batchGetMetrics(principal.resourceKey.toName(), principal.config.apiKey, metricNames) + + // Convert the internal report to public and return. + return convertInternalReportToPublic(internalReport, metrics) + } + + private suspend fun batchGetMetrics( + parent: String, + apiAuthenticationKey: String, + metricNames: List, + ): List { + val batchGetMetricsRequests = mutableListOf() + + while (batchGetMetricsRequests.size * MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS < metricNames.size) { + val fromIndex = batchGetMetricsRequests.size * MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS + val toIndex = min(fromIndex + MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS, metricNames.size) + + batchGetMetricsRequests += batchGetMetricsRequest { + this.parent = parent + names += metricNames.slice(fromIndex until toIndex) + } + } + + return batchGetMetricsRequests.flatMap { batchGetMetricsRequest -> + try { + metricsStub + .withAuthenticationKey(apiAuthenticationKey) + .batchGetMetrics(batchGetMetricsRequest) + .metricsList + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.INVALID_ARGUMENT -> Status.INVALID_ARGUMENT.withDescription(e.message) + Status.Code.PERMISSION_DENIED -> Status.PERMISSION_DENIED.withDescription(e.message) + else -> Status.UNKNOWN.withDescription("Unable to create metrics.") + } + .withCause(e) + .asRuntimeException() + } + } + } + override suspend fun createReport(request: CreateReportRequest): Report { grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { "Parent is either unspecified or invalid." @@ -252,10 +340,10 @@ class ReportsService( .metricsList } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("Reporting set used in the metric not found.") - Status.Code.FAILED_PRECONDITION -> - Status.FAILED_PRECONDITION.withDescription("Measurement Consumer not found.") + Status.Code.PERMISSION_DENIED -> Status.PERMISSION_DENIED.withDescription(e.message) + Status.Code.INVALID_ARGUMENT -> Status.INVALID_ARGUMENT.withDescription(e.message) + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription(e.message) + Status.Code.FAILED_PRECONDITION -> Status.FAILED_PRECONDITION.withDescription(e.message) else -> Status.UNKNOWN.withDescription("Unable to create metrics.") } .withCause(e) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt index 338b8c21e39..cc79e538acb 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt @@ -45,6 +45,7 @@ import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.apiIdToExternalId +import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.common.testing.verifyProtoArgument import org.wfanet.measurement.common.toProtoDuration import org.wfanet.measurement.common.toProtoTime @@ -67,6 +68,7 @@ import org.wfanet.measurement.internal.reporting.v2.report as internalReport import org.wfanet.measurement.internal.reporting.v2.timeInterval as internalTimeInterval import org.wfanet.measurement.internal.reporting.v2.timeIntervals as internalTimeIntervals import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest import org.wfanet.measurement.reporting.v2alpha.Metric import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.reachResult import org.wfanet.measurement.reporting.v2alpha.MetricSpec @@ -79,9 +81,12 @@ import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt import org.wfanet.measurement.reporting.v2alpha.TimeInterval import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse import org.wfanet.measurement.reporting.v2alpha.copy import org.wfanet.measurement.reporting.v2alpha.createMetricRequest import org.wfanet.measurement.reporting.v2alpha.createReportRequest +import org.wfanet.measurement.reporting.v2alpha.getReportRequest import org.wfanet.measurement.reporting.v2alpha.metric import org.wfanet.measurement.reporting.v2alpha.metricResult import org.wfanet.measurement.reporting.v2alpha.metricSpec @@ -119,6 +124,7 @@ private const val MAXIMUM_WATCH_DURATION_PER_USER = 4000 private const val DIFFERENTIAL_PRIVACY_DELTA = 1e-12 private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000 +private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100 @RunWith(JUnit4::class) class ReportsServiceTest { @@ -2085,6 +2091,232 @@ class ReportsServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) } + @Test + fun `getReport returns the report with SUCCEEDED when all metrics are SUCCEEDED`() = runBlocking { + whenever( + metricsMock.batchGetMetrics( + eq( + batchGetMetricsRequest { + parent = MEASUREMENT_CONSUMER_KEYS.first().toName() + names += SUCCEEDED_REACH_METRIC.name + } + ) + ) + ) + .thenReturn(batchGetMetricsResponse { metrics += SUCCEEDED_REACH_METRIC }) + + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val report = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + assertThat(report).isEqualTo(SUCCEEDED_REACH_REPORT) + } + + @Test + fun `getReport returns the report with FAILED when any metric FAILED`() = runBlocking { + val failedReachMetric = RUNNING_REACH_METRIC.copy { state = Metric.State.FAILED } + + whenever( + metricsMock.batchGetMetrics( + eq( + batchGetMetricsRequest { + parent = MEASUREMENT_CONSUMER_KEYS.first().toName() + names += failedReachMetric.name + } + ) + ) + ) + .thenReturn(batchGetMetricsResponse { metrics += failedReachMetric }) + + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val report = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + assertThat(report).isEqualTo(PENDING_REACH_REPORT.copy { state = Report.State.FAILED }) + } + + @Test + fun `getReport returns the report with RUNNING when metric is pending`(): Unit = runBlocking { + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val report = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + assertThat(report).isEqualTo(PENDING_REACH_REPORT) + } + + @Test + fun `getReport returns the report with RUNNING when there are more than max batch size metrics`(): + Unit = runBlocking { + val startSec = 10L + val incrementSec = 1L + val intervalCount = MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS + 1 + + val endTimesList: List = + (startSec + incrementSec until startSec + incrementSec + intervalCount).toList() + val internalTimeIntervals: List = + endTimesList.map { end -> + internalTimeInterval { + startTime = timestamp { seconds = startSec } + endTime = timestamp { seconds = end } + } + } + + val reportingCreateMetricRequests = + internalTimeIntervals.map { timeInterval -> + buildInitialReportingMetric( + PRIMITIVE_REPORTING_SETS.first().externalId, + timeInterval, + INTERNAL_REACH_METRIC_SPEC, + listOf() + ) + } + + val internalPendingReport = internalReport { + cmmsMeasurementConsumerId = MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId + + this.periodicTimeInterval = internalPeriodicTimeInterval { + startTime = timestamp { seconds = startSec } + increment = duration { seconds = incrementSec } + this.intervalCount = intervalCount + } + + externalReportId = 330L + createTime = Instant.now().toProtoTime() + + val updatedReportingCreateMetricRequests = + reportingCreateMetricRequests.mapIndexed { requestId, request -> + request.copy { + this.createMetricRequestId = requestId.toString() + externalMetricId = EXTERNAL_METRIC_ID_BASE + requestId + } + } + + reportingMetricEntries.putAll( + buildInternalReportingMetricEntryWithOneMetricCalculationSpec( + PRIMITIVE_REPORTING_SETS.first().externalId, + updatedReportingCreateMetricRequests, + DISPLAY_NAME, + listOf(), + true + ) + ) + } + + whenever( + internalReportsMock.getReport( + eq( + internalGetReportRequest { + cmmsMeasurementConsumerId = internalPendingReport.cmmsMeasurementConsumerId + externalReportId = internalPendingReport.externalReportId + } + ) + ) + ) + .thenReturn(internalPendingReport) + + whenever(metricsMock.batchGetMetrics(any())) + .thenReturn( + batchGetMetricsResponse { + metrics += + endTimesList.mapIndexed { index, end -> + metric { + name = + MetricKey( + MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId, + ExternalId(EXTERNAL_METRIC_ID_BASE + index).apiId.value + ) + .toName() + reportingSet = PRIMITIVE_REPORTING_SETS.first().name + timeInterval = timeInterval { + startTime = timestamp { seconds = startSec } + endTime = timestamp { seconds = end } + } + metricSpec = REACH_METRIC_SPEC + state = Metric.State.RUNNING + createTime = Instant.now().toProtoTime() + } + } + } + ) + + val request = getReportRequest { name = internalPendingReport.resourceName } + + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + val batchGetMetricsCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(metricsMock, times(2)) { batchGetMetrics(batchGetMetricsCaptor.capture()) } + } + + @Test + fun `getReport throws INVALID_ARGUMENT when Report name is invalid`() { + val request = getReportRequest { name = "INVALID_REPORT_NAME" } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + } + + @Test + fun `getReport throws PERMISSION_DENIED when Report name is not accessible`() { + val inaccessibleReportName = + ReportKey(MEASUREMENT_CONSUMER_KEYS.last().measurementConsumerId, externalIdToApiId(330L)) + .toName() + val request = getReportRequest { name = inaccessibleReportName } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) + } + + @Test + fun `getReport throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() { + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.last().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) + } + + @Test + fun `getReport throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() { + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val exception = + assertFailsWith { + withDataProviderPrincipal(DataProviderKey(ExternalId(550L).apiId.value).toName()) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + } + companion object { private fun buildInitialReportingMetric( externalReportingSetId: Long,