Skip to content

Commit

Permalink
Fix unit tests. Add one unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
riemanli committed Jun 2, 2023
1 parent 9b484f2 commit d1f9ace
Showing 1 changed file with 213 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ import org.wfanet.measurement.reporting.v2alpha.Report
import org.wfanet.measurement.reporting.v2alpha.ReportKt
import org.wfanet.measurement.reporting.v2alpha.ReportingSet
import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt
import org.wfanet.measurement.reporting.v2alpha.TimeInterval
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.copy
Expand Down Expand Up @@ -663,6 +664,212 @@ class ReportsServiceTest {
)
}

@Test
fun `createReport returns report when multiple timeIntervals, groupings, and metricSpecs`() =
runBlocking {
val displayName = DISPLAY_NAME
val targetReportingSet = PRIMITIVE_REPORTING_SETS.first()

// Metric specs
val metricSpecs = listOf(REACH_METRIC_SPEC, FREQUENCY_HISTOGRAM_METRIC_SPEC)
val internalMetricSpecs =
listOf(INTERNAL_REACH_METRIC_SPEC, INTERNAL_FREQUENCY_HISTOGRAM_METRIC_SPEC)

// Time intervals
val timeIntervalsList =
listOf(
timeInterval {
startTime = START_TIME
endTime = END_TIME
},
timeInterval {
startTime = END_TIME
endTime = END_INSTANT.plus(Duration.ofDays(1)).toProtoTime()
}
)
val internalTimeIntervals =
listOf(
internalTimeInterval {
startTime = START_TIME
endTime = END_TIME
},
internalTimeInterval {
startTime = END_TIME
endTime = END_INSTANT.plus(Duration.ofDays(1)).toProtoTime()
}
)

// Groupings
val predicates1 = listOf("gender == MALE", "gender == FEMALE")
val predicates2 = listOf("age == 18_34", "age == 55_PLUS")
val groupings =
listOf(
ReportKt.grouping { predicates += predicates1 },
ReportKt.grouping { predicates += predicates2 }
)
val internalGroupings =
listOf(
InternalReportKt.MetricCalculationSpecKt.grouping { predicates += predicates1 },
InternalReportKt.MetricCalculationSpecKt.grouping { predicates += predicates2 }
)
val groupingsCartesianProduct: List<List<String>> =
predicates1.flatMap { filter1 -> predicates2.map { filter2 -> listOf(filter1, filter2) } }

// Metric configs for internal and public
data class MetricConfig(
val reportingSet: String,
val metricSpec: MetricSpec,
val timeInterval: TimeInterval,
val filters: List<String>
)
val metricConfigs =
timeIntervalsList.flatMap { timeInterval ->
metricSpecs.flatMap { metricSpec ->
groupingsCartesianProduct.map { predicateGroup ->
MetricConfig(targetReportingSet.name, metricSpec, timeInterval, predicateGroup)
}
}
}

data class ReportingMetricConfig(
val externalReportingSetId: Long,
val metricSpec: InternalMetricSpec,
val timeInterval: InternalTimeInterval,
val filters: List<String>
)
val reportingMetricConfigs =
internalTimeIntervals.flatMap { timeInterval ->
internalMetricSpecs.flatMap { metricSpec ->
groupingsCartesianProduct.map { predicateGroup ->
ReportingMetricConfig(
targetReportingSet.externalId,
metricSpec,
timeInterval,
predicateGroup
)
}
}
}

val initialReportingMetrics: List<InternalReport.ReportingMetric> =
reportingMetricConfigs.map { reportingMetricConfig ->
buildInitialReportingMetric(
reportingMetricConfig.externalReportingSetId,
reportingMetricConfig.timeInterval,
reportingMetricConfig.metricSpec,
reportingMetricConfig.filters
)
}

val (internalRequestingReport, internalInitialReport, internalPendingReport) =
buildInternalReports(
MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId,
internalTimeIntervals,
targetReportingSet.externalId,
initialReportingMetrics,
internalGroupings,
)

whenever(
internalReportsMock.createReport(
eq(internalCreateReportRequest { report = internalRequestingReport })
)
)
.thenReturn(internalInitialReport)

whenever(
internalReportsMock.getReport(
eq(
internalGetReportRequest {
cmmsMeasurementConsumerId = internalInitialReport.cmmsMeasurementConsumerId
externalReportId = internalInitialReport.externalReportId
}
)
)
)
.thenReturn(internalPendingReport)

val requestingMetrics: List<Metric> =
metricConfigs.map { metricConfig ->
metric {
reportingSet = metricConfig.reportingSet
timeInterval = metricConfig.timeInterval
metricSpec = metricConfig.metricSpec
filters += metricConfig.filters
}
}

whenever(metricsMock.batchCreateMetrics(any()))
.thenReturn(
batchCreateMetricsResponse {
metrics +=
requestingMetrics.mapIndexed { index, metric ->
metric.copy {
name =
MetricKey(
MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId,
ExternalId(EXTERNAL_METRIC_ID_BASE + index).apiId.value
)
.toName()
state = Metric.State.RUNNING
createTime = Instant.now().toProtoTime()
}
}
}
)

val requestingReport = report {
reportingMetricEntries +=
ReportKt.reportingMetricEntry {
key = targetReportingSet.name
value =
ReportKt.reportingMetricCalculationSpec {
metricCalculationSpecs +=
ReportKt.metricCalculationSpec {
this.displayName = displayName
this.metricSpecs += metricSpecs
this.groupings += groupings
cumulative = false
}
}
}
timeIntervals = timeIntervals { timeIntervals += timeIntervalsList }
}

val request = createReportRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
report = requestingReport
}
val result =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.createReport(request) }
}

verifyProtoArgument(metricsMock, MetricsGrpcKt.MetricsCoroutineImplBase::batchCreateMetrics)
.isEqualTo(
batchCreateMetricsRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
requests +=
requestingMetrics.mapIndexed { requestId, metric ->
createMetricRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
this.metric = metric
this.requestId = requestId.toString()
}
}
}
)

assertThat(result)
.isEqualTo(
requestingReport.copy {
name = internalPendingReport.resourceName
state = Report.State.RUNNING
createTime = internalPendingReport.createTime
}
)
}

@Test
fun `createReport returns report with 2 metrics generated when there are 2 reporting sets`() =
runBlocking {
Expand Down Expand Up @@ -1225,27 +1432,6 @@ class ReportsServiceTest {
val metricSpecWithoutVidSamplingInterval = REACH_METRIC_SPEC.copy { clearVidSamplingInterval() }
val requestId = "0"

val (internalRequestingReport, internalInitialReport, internalPendingReport) = INTERNAL_REPORTS

whenever(
internalReportsMock.createReport(
eq(internalCreateReportRequest { report = internalRequestingReport })
)
)
.thenReturn(internalInitialReport)

whenever(
internalReportsMock.getReport(
eq(
internalGetReportRequest {
cmmsMeasurementConsumerId = internalInitialReport.cmmsMeasurementConsumerId
externalReportId = internalInitialReport.externalReportId
}
)
)
)
.thenReturn(internalPendingReport)

val request = createReportRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
report =
Expand Down Expand Up @@ -1325,26 +1511,6 @@ class ReportsServiceTest {
assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `createReport throws PERMISSION_DENIED when report doesn't belong to caller`() {
val request = createReportRequest {
parent = MEASUREMENT_CONSUMER_KEYS.last().toName()
report =
PENDING_REACH_REPORT.copy {
clearName()
clearCreateTime()
clearState()
}
}
val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.createReport(request) }
}
}
assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `createReport throws UNAUTHENTICATED when the caller is not MeasurementConsumer`() {
val request = createReportRequest {
Expand Down Expand Up @@ -1429,6 +1595,12 @@ class ReportsServiceTest {
clearName()
clearCreateTime()
clearState()
timeIntervals = timeIntervals {
timeIntervals += timeInterval {
startTime = START_TIME
endTime = END_TIME
}
}
reportingMetricEntries.clear()
reportingMetricEntries +=
ReportKt.reportingMetricEntry {
Expand Down

0 comments on commit d1f9ace

Please sign in to comment.