Skip to content

Commit

Permalink
Add unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
riemanli committed Jun 5, 2023
1 parent 7cb62f5 commit 2a9246d
Showing 1 changed file with 232 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule
import org.wfanet.measurement.common.grpc.testing.mockService
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.apiIdToExternalId
import org.wfanet.measurement.common.identity.externalIdToApiId
import org.wfanet.measurement.common.testing.verifyProtoArgument
import org.wfanet.measurement.common.toProtoDuration
import org.wfanet.measurement.common.toProtoTime
Expand All @@ -67,6 +68,7 @@ import org.wfanet.measurement.internal.reporting.v2.report as internalReport
import org.wfanet.measurement.internal.reporting.v2.timeInterval as internalTimeInterval
import org.wfanet.measurement.internal.reporting.v2.timeIntervals as internalTimeIntervals
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.Metric
import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.reachResult
import org.wfanet.measurement.reporting.v2alpha.MetricSpec
Expand All @@ -79,9 +81,12 @@ import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt
import org.wfanet.measurement.reporting.v2alpha.TimeInterval
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.copy
import org.wfanet.measurement.reporting.v2alpha.createMetricRequest
import org.wfanet.measurement.reporting.v2alpha.createReportRequest
import org.wfanet.measurement.reporting.v2alpha.getReportRequest
import org.wfanet.measurement.reporting.v2alpha.metric
import org.wfanet.measurement.reporting.v2alpha.metricResult
import org.wfanet.measurement.reporting.v2alpha.metricSpec
Expand Down Expand Up @@ -119,6 +124,7 @@ private const val MAXIMUM_WATCH_DURATION_PER_USER = 4000
private const val DIFFERENTIAL_PRIVACY_DELTA = 1e-12

private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000
private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100

@RunWith(JUnit4::class)
class ReportsServiceTest {
Expand Down Expand Up @@ -2085,6 +2091,232 @@ class ReportsServiceTest {
assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
}

@Test
fun `getReport returns the report with SUCCEEDED when all metrics are SUCCEEDED`() = runBlocking {
whenever(
metricsMock.batchGetMetrics(
eq(
batchGetMetricsRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
names += SUCCEEDED_REACH_METRIC.name
}
)
)
)
.thenReturn(batchGetMetricsResponse { metrics += SUCCEEDED_REACH_METRIC })

val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(SUCCEEDED_REACH_REPORT)
}

@Test
fun `getReport returns the report with FAILED when any metric FAILED`() = runBlocking {
val failedReachMetric = RUNNING_REACH_METRIC.copy { state = Metric.State.FAILED }

whenever(
metricsMock.batchGetMetrics(
eq(
batchGetMetricsRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
names += failedReachMetric.name
}
)
)
)
.thenReturn(batchGetMetricsResponse { metrics += failedReachMetric })

val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(PENDING_REACH_REPORT.copy { state = Report.State.FAILED })
}

@Test
fun `getReport returns the report with RUNNING when metric is pending`(): Unit = runBlocking {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(PENDING_REACH_REPORT)
}

@Test
fun `getReport returns the report with RUNNING when there are more than max batch size metrics`():
Unit = runBlocking {
val startSec = 10L
val incrementSec = 1L
val intervalCount = MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS + 1

val endTimesList: List<Long> =
(startSec + incrementSec until startSec + incrementSec + intervalCount).toList()
val internalTimeIntervals: List<InternalTimeInterval> =
endTimesList.map { end ->
internalTimeInterval {
startTime = timestamp { seconds = startSec }
endTime = timestamp { seconds = end }
}
}

val reportingCreateMetricRequests =
internalTimeIntervals.map { timeInterval ->
buildInitialReportingMetric(
PRIMITIVE_REPORTING_SETS.first().externalId,
timeInterval,
INTERNAL_REACH_METRIC_SPEC,
listOf()
)
}

val internalPendingReport = internalReport {
cmmsMeasurementConsumerId = MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId

this.periodicTimeInterval = internalPeriodicTimeInterval {
startTime = timestamp { seconds = startSec }
increment = duration { seconds = incrementSec }
this.intervalCount = intervalCount
}

externalReportId = 330L
createTime = Instant.now().toProtoTime()

val updatedReportingCreateMetricRequests =
reportingCreateMetricRequests.mapIndexed { requestId, request ->
request.copy {
this.createMetricRequestId = requestId.toString()
externalMetricId = EXTERNAL_METRIC_ID_BASE + requestId
}
}

reportingMetricEntries.putAll(
buildInternalReportingMetricEntryWithOneMetricCalculationSpec(
PRIMITIVE_REPORTING_SETS.first().externalId,
updatedReportingCreateMetricRequests,
DISPLAY_NAME,
listOf(),
true
)
)
}

whenever(
internalReportsMock.getReport(
eq(
internalGetReportRequest {
cmmsMeasurementConsumerId = internalPendingReport.cmmsMeasurementConsumerId
externalReportId = internalPendingReport.externalReportId
}
)
)
)
.thenReturn(internalPendingReport)

whenever(metricsMock.batchGetMetrics(any()))
.thenReturn(
batchGetMetricsResponse {
metrics +=
endTimesList.mapIndexed { index, end ->
metric {
name =
MetricKey(
MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId,
ExternalId(EXTERNAL_METRIC_ID_BASE + index).apiId.value
)
.toName()
reportingSet = PRIMITIVE_REPORTING_SETS.first().name
timeInterval = timeInterval {
startTime = timestamp { seconds = startSec }
endTime = timestamp { seconds = end }
}
metricSpec = REACH_METRIC_SPEC
state = Metric.State.RUNNING
createTime = Instant.now().toProtoTime()
}
}
}
)

val request = getReportRequest { name = internalPendingReport.resourceName }

withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

val batchGetMetricsCaptor: KArgumentCaptor<BatchGetMetricsRequest> = argumentCaptor()
verifyBlocking(metricsMock, times(2)) { batchGetMetrics(batchGetMetricsCaptor.capture()) }
}

@Test
fun `getReport throws INVALID_ARGUMENT when Report name is invalid`() {
val request = getReportRequest { name = "INVALID_REPORT_NAME" }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
}

@Test
fun `getReport throws PERMISSION_DENIED when Report name is not accessible`() {
val inaccessibleReportName =
ReportKey(MEASUREMENT_CONSUMER_KEYS.last().measurementConsumerId, externalIdToApiId(330L))
.toName()
val request = getReportRequest { name = inaccessibleReportName }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `getReport throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.last().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `getReport throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val exception =
assertFailsWith<StatusRuntimeException> {
withDataProviderPrincipal(DataProviderKey(ExternalId(550L).apiId.value).toName()) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED)
}

companion object {
private fun buildInitialReportingMetric(
externalReportingSetId: Long,
Expand Down

0 comments on commit 2a9246d

Please sign in to comment.