Skip to content

Commit

Permalink
Add getReport (#1029)
Browse files Browse the repository at this point in the history
  • Loading branch information
riemanli authored and ple13 committed Aug 16, 2024
1 parent f769ea7 commit 972c43d
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,21 @@ import org.wfanet.measurement.internal.reporting.v2.createReportRequest as inter
import org.wfanet.measurement.internal.reporting.v2.getReportRequest as internalGetReportRequest
import org.wfanet.measurement.internal.reporting.v2.report as internalReport
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.CreateMetricRequest
import org.wfanet.measurement.reporting.v2alpha.CreateReportRequest
import org.wfanet.measurement.reporting.v2alpha.GetReportRequest
import org.wfanet.measurement.reporting.v2alpha.Metric
import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineStub
import org.wfanet.measurement.reporting.v2alpha.Report
import org.wfanet.measurement.reporting.v2alpha.ReportKt
import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt.ReportsCoroutineImplBase
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.report

private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000
private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100

private typealias InternalReportingMetricEntries =
Map<Long, InternalReport.ReportingMetricCalculationSpec>
Expand All @@ -71,6 +75,90 @@ class ReportsService(
val internalTimeRange: InternalTimeRange,
)

override suspend fun getReport(request: GetReportRequest): Report {
val reportKey =
grpcRequireNotNull(ReportKey.fromName(request.name)) {
"Report name is either unspecified or invalid"
}

val principal: ReportingPrincipal = principalFromCurrentContext
when (principal) {
is MeasurementConsumerPrincipal -> {
if (reportKey.cmmsMeasurementConsumerId != principal.resourceKey.measurementConsumerId) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot get Report belonging to other MeasurementConsumers."
}
}
}
}

val internalReport =
try {
internalReportsStub.getReport(
internalGetReportRequest {
cmmsMeasurementConsumerId = reportKey.cmmsMeasurementConsumerId
externalReportId = apiIdToExternalId(reportKey.reportId)
}
)
} catch (e: StatusException) {
throw Exception("Unable to get the report from the reporting database.", e)
}

// Create metrics.
val metricNames: List<String> =
internalReport.reportingMetricEntriesMap.flatMap { (_, reportingMetricCalculationSpec) ->
reportingMetricCalculationSpec.metricCalculationSpecsList.flatMap { metricCalculationSpec ->
metricCalculationSpec.reportingMetricsList.map { reportingMetric ->
MetricKey(
principal.resourceKey.measurementConsumerId,
externalIdToApiId(reportingMetric.externalMetricId)
)
.toName()
}
}
}
val metrics: List<Metric> =
batchGetMetrics(principal.resourceKey.toName(), principal.config.apiKey, metricNames)

// Convert the internal report to public and return.
return convertInternalReportToPublic(internalReport, metrics)
}

private suspend fun batchGetMetrics(
parent: String,
apiAuthenticationKey: String,
metricNames: List<String>,
): List<Metric> {
val batchGetMetricsRequests = mutableListOf<BatchGetMetricsRequest>()

while (batchGetMetricsRequests.size * MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS < metricNames.size) {
val fromIndex = batchGetMetricsRequests.size * MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS
val toIndex = min(fromIndex + MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS, metricNames.size)

batchGetMetricsRequests += batchGetMetricsRequest {
this.parent = parent
names += metricNames.slice(fromIndex until toIndex)
}
}

return batchGetMetricsRequests.flatMap { batchGetMetricsRequest ->
try {
metricsStub
.withAuthenticationKey(apiAuthenticationKey)
.batchGetMetrics(batchGetMetricsRequest)
.metricsList
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.INVALID_ARGUMENT -> Status.INVALID_ARGUMENT.withDescription(e.message)
Status.Code.PERMISSION_DENIED -> Status.PERMISSION_DENIED.withDescription(e.message)
else -> Status.UNKNOWN.withDescription("Unable to create metrics.")
}
.withCause(e)
.asRuntimeException()
}
}
}

override suspend fun createReport(request: CreateReportRequest): Report {
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"Parent is either unspecified or invalid."
Expand Down Expand Up @@ -252,10 +340,10 @@ class ReportsService(
.metricsList
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("Reporting set used in the metric not found.")
Status.Code.FAILED_PRECONDITION ->
Status.FAILED_PRECONDITION.withDescription("Measurement Consumer not found.")
Status.Code.PERMISSION_DENIED -> Status.PERMISSION_DENIED.withDescription(e.message)
Status.Code.INVALID_ARGUMENT -> Status.INVALID_ARGUMENT.withDescription(e.message)
Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription(e.message)
Status.Code.FAILED_PRECONDITION -> Status.FAILED_PRECONDITION.withDescription(e.message)
else -> Status.UNKNOWN.withDescription("Unable to create metrics.")
}
.withCause(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule
import org.wfanet.measurement.common.grpc.testing.mockService
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.apiIdToExternalId
import org.wfanet.measurement.common.identity.externalIdToApiId
import org.wfanet.measurement.common.testing.verifyProtoArgument
import org.wfanet.measurement.common.toProtoDuration
import org.wfanet.measurement.common.toProtoTime
Expand All @@ -67,6 +68,7 @@ import org.wfanet.measurement.internal.reporting.v2.report as internalReport
import org.wfanet.measurement.internal.reporting.v2.timeInterval as internalTimeInterval
import org.wfanet.measurement.internal.reporting.v2.timeIntervals as internalTimeIntervals
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.Metric
import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.reachResult
import org.wfanet.measurement.reporting.v2alpha.MetricSpec
Expand All @@ -79,9 +81,12 @@ import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt
import org.wfanet.measurement.reporting.v2alpha.TimeInterval
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.copy
import org.wfanet.measurement.reporting.v2alpha.createMetricRequest
import org.wfanet.measurement.reporting.v2alpha.createReportRequest
import org.wfanet.measurement.reporting.v2alpha.getReportRequest
import org.wfanet.measurement.reporting.v2alpha.metric
import org.wfanet.measurement.reporting.v2alpha.metricResult
import org.wfanet.measurement.reporting.v2alpha.metricSpec
Expand Down Expand Up @@ -119,6 +124,7 @@ private const val MAXIMUM_WATCH_DURATION_PER_USER = 4000
private const val DIFFERENTIAL_PRIVACY_DELTA = 1e-12

private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000
private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100

@RunWith(JUnit4::class)
class ReportsServiceTest {
Expand Down Expand Up @@ -2085,6 +2091,232 @@ class ReportsServiceTest {
assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
}

@Test
fun `getReport returns the report with SUCCEEDED when all metrics are SUCCEEDED`() = runBlocking {
whenever(
metricsMock.batchGetMetrics(
eq(
batchGetMetricsRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
names += SUCCEEDED_REACH_METRIC.name
}
)
)
)
.thenReturn(batchGetMetricsResponse { metrics += SUCCEEDED_REACH_METRIC })

val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(SUCCEEDED_REACH_REPORT)
}

@Test
fun `getReport returns the report with FAILED when any metric FAILED`() = runBlocking {
val failedReachMetric = RUNNING_REACH_METRIC.copy { state = Metric.State.FAILED }

whenever(
metricsMock.batchGetMetrics(
eq(
batchGetMetricsRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
names += failedReachMetric.name
}
)
)
)
.thenReturn(batchGetMetricsResponse { metrics += failedReachMetric })

val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(PENDING_REACH_REPORT.copy { state = Report.State.FAILED })
}

@Test
fun `getReport returns the report with RUNNING when metric is pending`(): Unit = runBlocking {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(PENDING_REACH_REPORT)
}

@Test
fun `getReport returns the report with RUNNING when there are more than max batch size metrics`():
Unit = runBlocking {
val startSec = 10L
val incrementSec = 1L
val intervalCount = MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS + 1

val endTimesList: List<Long> =
(startSec + incrementSec until startSec + incrementSec + intervalCount).toList()
val internalTimeIntervals: List<InternalTimeInterval> =
endTimesList.map { end ->
internalTimeInterval {
startTime = timestamp { seconds = startSec }
endTime = timestamp { seconds = end }
}
}

val reportingCreateMetricRequests =
internalTimeIntervals.map { timeInterval ->
buildInitialReportingMetric(
PRIMITIVE_REPORTING_SETS.first().externalId,
timeInterval,
INTERNAL_REACH_METRIC_SPEC,
listOf()
)
}

val internalPendingReport = internalReport {
cmmsMeasurementConsumerId = MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId

this.periodicTimeInterval = internalPeriodicTimeInterval {
startTime = timestamp { seconds = startSec }
increment = duration { seconds = incrementSec }
this.intervalCount = intervalCount
}

externalReportId = 330L
createTime = Instant.now().toProtoTime()

val updatedReportingCreateMetricRequests =
reportingCreateMetricRequests.mapIndexed { requestId, request ->
request.copy {
this.createMetricRequestId = requestId.toString()
externalMetricId = EXTERNAL_METRIC_ID_BASE + requestId
}
}

reportingMetricEntries.putAll(
buildInternalReportingMetricEntryWithOneMetricCalculationSpec(
PRIMITIVE_REPORTING_SETS.first().externalId,
updatedReportingCreateMetricRequests,
DISPLAY_NAME,
listOf(),
true
)
)
}

whenever(
internalReportsMock.getReport(
eq(
internalGetReportRequest {
cmmsMeasurementConsumerId = internalPendingReport.cmmsMeasurementConsumerId
externalReportId = internalPendingReport.externalReportId
}
)
)
)
.thenReturn(internalPendingReport)

whenever(metricsMock.batchGetMetrics(any()))
.thenReturn(
batchGetMetricsResponse {
metrics +=
endTimesList.mapIndexed { index, end ->
metric {
name =
MetricKey(
MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId,
ExternalId(EXTERNAL_METRIC_ID_BASE + index).apiId.value
)
.toName()
reportingSet = PRIMITIVE_REPORTING_SETS.first().name
timeInterval = timeInterval {
startTime = timestamp { seconds = startSec }
endTime = timestamp { seconds = end }
}
metricSpec = REACH_METRIC_SPEC
state = Metric.State.RUNNING
createTime = Instant.now().toProtoTime()
}
}
}
)

val request = getReportRequest { name = internalPendingReport.resourceName }

withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

val batchGetMetricsCaptor: KArgumentCaptor<BatchGetMetricsRequest> = argumentCaptor()
verifyBlocking(metricsMock, times(2)) { batchGetMetrics(batchGetMetricsCaptor.capture()) }
}

@Test
fun `getReport throws INVALID_ARGUMENT when Report name is invalid`() {
val request = getReportRequest { name = "INVALID_REPORT_NAME" }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
}

@Test
fun `getReport throws PERMISSION_DENIED when Report name is not accessible`() {
val inaccessibleReportName =
ReportKey(MEASUREMENT_CONSUMER_KEYS.last().measurementConsumerId, externalIdToApiId(330L))
.toName()
val request = getReportRequest { name = inaccessibleReportName }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `getReport throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.last().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `getReport throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val exception =
assertFailsWith<StatusRuntimeException> {
withDataProviderPrincipal(DataProviderKey(ExternalId(550L).apiId.value).toName()) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED)
}

companion object {
private fun buildInitialReportingMetric(
externalReportingSetId: Long,
Expand Down

0 comments on commit 972c43d

Please sign in to comment.