From bb3172a249bb1e666aee23c109644c97880e567b Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Thu, 1 Jun 2023 00:55:00 +0000 Subject: [PATCH] Add unit test. --- .../service/api/v2alpha/ReportsServiceTest.kt | 232 ++++++++++++++++++ 1 file changed, 232 insertions(+) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt index c6efad1e1de..164552d0651 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt @@ -45,6 +45,7 @@ import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.apiIdToExternalId +import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.common.testing.verifyProtoArgument import org.wfanet.measurement.common.toProtoDuration import org.wfanet.measurement.common.toProtoTime @@ -67,6 +68,7 @@ import org.wfanet.measurement.internal.reporting.v2.report as internalReport import org.wfanet.measurement.internal.reporting.v2.timeInterval as internalTimeInterval import org.wfanet.measurement.internal.reporting.v2.timeIntervals as internalTimeIntervals import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest import org.wfanet.measurement.reporting.v2alpha.Metric import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.reachResult import org.wfanet.measurement.reporting.v2alpha.MetricSpec @@ -78,9 +80,12 @@ import org.wfanet.measurement.reporting.v2alpha.ReportingSet import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse import org.wfanet.measurement.reporting.v2alpha.copy import org.wfanet.measurement.reporting.v2alpha.createMetricRequest import org.wfanet.measurement.reporting.v2alpha.createReportRequest +import org.wfanet.measurement.reporting.v2alpha.getReportRequest import org.wfanet.measurement.reporting.v2alpha.metric import org.wfanet.measurement.reporting.v2alpha.metricResult import org.wfanet.measurement.reporting.v2alpha.metricSpec @@ -118,6 +123,7 @@ private const val MAXIMUM_WATCH_DURATION_PER_USER = 4000 private const val DIFFERENTIAL_PRIVACY_DELTA = 1e-12 private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000 +private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100 @RunWith(JUnit4::class) class ReportsServiceTest { @@ -1913,6 +1919,232 @@ class ReportsServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) } + @Test + fun `getReport returns the report with SUCCEEDED when all metrics are SUCCEEDED`() = runBlocking { + whenever( + metricsMock.batchGetMetrics( + eq( + batchGetMetricsRequest { + parent = MEASUREMENT_CONSUMER_KEYS.first().toName() + names += SUCCEEDED_REACH_METRIC.name + } + ) + ) + ) + .thenReturn(batchGetMetricsResponse { metrics += SUCCEEDED_REACH_METRIC }) + + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val report = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + assertThat(report).isEqualTo(SUCCEEDED_REACH_REPORT) + } + + @Test + fun `getReport returns the report with FAILED when any metric FAILED`() = runBlocking { + val failedReachMetric = RUNNING_REACH_METRIC.copy { state = Metric.State.FAILED } + + whenever( + metricsMock.batchGetMetrics( + eq( + batchGetMetricsRequest { + parent = MEASUREMENT_CONSUMER_KEYS.first().toName() + names += failedReachMetric.name + } + ) + ) + ) + .thenReturn(batchGetMetricsResponse { metrics += failedReachMetric }) + + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val report = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + assertThat(report).isEqualTo(PENDING_REACH_REPORT.copy { state = Report.State.FAILED }) + } + + @Test + fun `getReport returns the report with RUNNING when metric is pending`(): Unit = runBlocking { + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val report = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + assertThat(report).isEqualTo(PENDING_REACH_REPORT) + } + + @Test + fun `getReport returns the report with RUNNING when there are more than max batch size metrics`(): + Unit = runBlocking { + val startSec = 10L + val incrementSec = 1L + val intervalCount = MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS + 1 + + val endTimesList: List = + (startSec + incrementSec until startSec + incrementSec + intervalCount).toList() + val internalTimeIntervals: List = + endTimesList.map { end -> + internalTimeInterval { + startTime = timestamp { seconds = startSec } + endTime = timestamp { seconds = end } + } + } + + val reportingCreateMetricRequests = + internalTimeIntervals.map { timeInterval -> + buildInitialReportingMetric( + PRIMITIVE_REPORTING_SETS.first().externalId, + timeInterval, + INTERNAL_REACH_METRIC_SPEC, + listOf() + ) + } + + val internalPendingReport = internalReport { + cmmsMeasurementConsumerId = MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId + + this.periodicTimeInterval = internalPeriodicTimeInterval { + startTime = timestamp { seconds = startSec } + increment = duration { seconds = incrementSec } + this.intervalCount = intervalCount + } + + externalReportId = 330L + createTime = Instant.now().toProtoTime() + + val updatedReportingCreateMetricRequests = + reportingCreateMetricRequests.mapIndexed { requestId, request -> + request.copy { + this.requestId = requestId.toString() + externalMetricId = EXTERNAL_METRIC_ID_BASE + requestId + } + } + + reportingMetricEntries.putAll( + buildInternalReportingMetricEntryWithOneMetricCalculationSpec( + PRIMITIVE_REPORTING_SETS.first().externalId, + updatedReportingCreateMetricRequests, + DISPLAY_NAME, + listOf(), + true + ) + ) + } + + whenever( + internalReportsMock.getReport( + eq( + internalGetReportRequest { + cmmsMeasurementConsumerId = internalPendingReport.cmmsMeasurementConsumerId + externalReportId = internalPendingReport.externalReportId + } + ) + ) + ) + .thenReturn(internalPendingReport) + + whenever(metricsMock.batchGetMetrics(any())) + .thenReturn( + batchGetMetricsResponse { + metrics += + endTimesList.mapIndexed { index, end -> + metric { + name = + MetricKey( + MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId, + ExternalId(EXTERNAL_METRIC_ID_BASE + index).apiId.value + ) + .toName() + reportingSet = PRIMITIVE_REPORTING_SETS.first().name + timeInterval = timeInterval { + startTime = timestamp { seconds = startSec } + endTime = timestamp { seconds = end } + } + metricSpec = REACH_METRIC_SPEC + state = Metric.State.RUNNING + createTime = Instant.now().toProtoTime() + } + } + } + ) + + val request = getReportRequest { name = internalPendingReport.resourceName } + + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + + val batchGetMetricsCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(metricsMock, times(2)) { batchGetMetrics(batchGetMetricsCaptor.capture()) } + } + + @Test + fun `getReport throws INVALID_ARGUMENT when Report name is invalid`() { + val request = getReportRequest { name = "INVALID_REPORT_NAME" } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + } + + @Test + fun `getReport throws PERMISSION_DENIED when Report name is not accessible`() { + val inaccessibleReportName = + ReportKey(MEASUREMENT_CONSUMER_KEYS.last().measurementConsumerId, externalIdToApiId(330L)) + .toName() + val request = getReportRequest { name = inaccessibleReportName } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) + } + + @Test + fun `getReport throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() { + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.last().toName(), CONFIG) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) + } + + @Test + fun `getReport throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() { + val request = getReportRequest { name = PENDING_REACH_REPORT.name } + + val exception = + assertFailsWith { + withDataProviderPrincipal(DataProviderKey(ExternalId(550L).apiId.value).toName()) { + runBlocking { service.getReport(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + } + companion object { private fun buildInitialReportingMetric( externalReportingSetId: Long,