Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getReport. #1029

Merged
merged 2 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,21 @@ import org.wfanet.measurement.internal.reporting.v2.createReportRequest as inter
import org.wfanet.measurement.internal.reporting.v2.getReportRequest as internalGetReportRequest
import org.wfanet.measurement.internal.reporting.v2.report as internalReport
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.CreateMetricRequest
import org.wfanet.measurement.reporting.v2alpha.CreateReportRequest
import org.wfanet.measurement.reporting.v2alpha.GetReportRequest
import org.wfanet.measurement.reporting.v2alpha.Metric
import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineStub
import org.wfanet.measurement.reporting.v2alpha.Report
import org.wfanet.measurement.reporting.v2alpha.ReportKt
import org.wfanet.measurement.reporting.v2alpha.ReportsGrpcKt.ReportsCoroutineImplBase
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.report

private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000
private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100

private typealias InternalReportingMetricEntries =
Map<Long, InternalReport.ReportingMetricCalculationSpec>
Expand All @@ -71,6 +75,90 @@ class ReportsService(
val internalTimeRange: InternalTimeRange,
)

override suspend fun getReport(request: GetReportRequest): Report {
val reportKey =
grpcRequireNotNull(ReportKey.fromName(request.name)) {
"Report name is either unspecified or invalid"
}

val principal: ReportingPrincipal = principalFromCurrentContext
when (principal) {
is MeasurementConsumerPrincipal -> {
if (reportKey.cmmsMeasurementConsumerId != principal.resourceKey.measurementConsumerId) {
failGrpc(Status.PERMISSION_DENIED) {
"Cannot get Report belonging to other MeasurementConsumers."
}
}
}
}

val internalReport =
try {
internalReportsStub.getReport(
internalGetReportRequest {
cmmsMeasurementConsumerId = reportKey.cmmsMeasurementConsumerId
externalReportId = apiIdToExternalId(reportKey.reportId)
}
)
} catch (e: StatusException) {
throw Exception("Unable to get the report from the reporting database.", e)
}

// Create metrics.
val metricNames: List<String> =
internalReport.reportingMetricEntriesMap.flatMap { (_, reportingMetricCalculationSpec) ->
reportingMetricCalculationSpec.metricCalculationSpecsList.flatMap { metricCalculationSpec ->
metricCalculationSpec.reportingMetricsList.map { reportingMetric ->
MetricKey(
principal.resourceKey.measurementConsumerId,
externalIdToApiId(reportingMetric.externalMetricId)
)
.toName()
}
}
}
val metrics: List<Metric> =
batchGetMetrics(principal.resourceKey.toName(), principal.config.apiKey, metricNames)

// Convert the internal report to public and return.
return convertInternalReportToPublic(internalReport, metrics)
}

private suspend fun batchGetMetrics(
parent: String,
apiAuthenticationKey: String,
metricNames: List<String>,
): List<Metric> {
val batchGetMetricsRequests = mutableListOf<BatchGetMetricsRequest>()

while (batchGetMetricsRequests.size * MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS < metricNames.size) {
val fromIndex = batchGetMetricsRequests.size * MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS
val toIndex = min(fromIndex + MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS, metricNames.size)

batchGetMetricsRequests += batchGetMetricsRequest {
this.parent = parent
names += metricNames.slice(fromIndex until toIndex)
}
}

return batchGetMetricsRequests.flatMap { batchGetMetricsRequest ->
try {
metricsStub
.withAuthenticationKey(apiAuthenticationKey)
.batchGetMetrics(batchGetMetricsRequest)
.metricsList
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.INVALID_ARGUMENT -> Status.INVALID_ARGUMENT.withDescription(e.message)
Status.Code.PERMISSION_DENIED -> Status.PERMISSION_DENIED.withDescription(e.message)
else -> Status.UNKNOWN.withDescription("Unable to create metrics.")
}
.withCause(e)
.asRuntimeException()
}
}
}

override suspend fun createReport(request: CreateReportRequest): Report {
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"Parent is either unspecified or invalid."
Expand Down Expand Up @@ -252,10 +340,10 @@ class ReportsService(
.metricsList
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.NOT_FOUND ->
Status.NOT_FOUND.withDescription("Reporting set used in the metric not found.")
Status.Code.FAILED_PRECONDITION ->
Status.FAILED_PRECONDITION.withDescription("Measurement Consumer not found.")
Status.Code.PERMISSION_DENIED -> Status.PERMISSION_DENIED.withDescription(e.message)
Status.Code.INVALID_ARGUMENT -> Status.INVALID_ARGUMENT.withDescription(e.message)
Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription(e.message)
Status.Code.FAILED_PRECONDITION -> Status.FAILED_PRECONDITION.withDescription(e.message)
else -> Status.UNKNOWN.withDescription("Unable to create metrics.")
}
.withCause(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule
import org.wfanet.measurement.common.grpc.testing.mockService
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.apiIdToExternalId
import org.wfanet.measurement.common.identity.externalIdToApiId
import org.wfanet.measurement.common.testing.verifyProtoArgument
import org.wfanet.measurement.common.toProtoDuration
import org.wfanet.measurement.common.toProtoTime
Expand All @@ -67,6 +68,7 @@ import org.wfanet.measurement.internal.reporting.v2.report as internalReport
import org.wfanet.measurement.internal.reporting.v2.timeInterval as internalTimeInterval
import org.wfanet.measurement.internal.reporting.v2.timeIntervals as internalTimeIntervals
import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.Metric
import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.reachResult
import org.wfanet.measurement.reporting.v2alpha.MetricSpec
Expand All @@ -79,9 +81,12 @@ import org.wfanet.measurement.reporting.v2alpha.ReportingSetKt
import org.wfanet.measurement.reporting.v2alpha.TimeInterval
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest
import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse
import org.wfanet.measurement.reporting.v2alpha.copy
import org.wfanet.measurement.reporting.v2alpha.createMetricRequest
import org.wfanet.measurement.reporting.v2alpha.createReportRequest
import org.wfanet.measurement.reporting.v2alpha.getReportRequest
import org.wfanet.measurement.reporting.v2alpha.metric
import org.wfanet.measurement.reporting.v2alpha.metricResult
import org.wfanet.measurement.reporting.v2alpha.metricSpec
Expand Down Expand Up @@ -119,6 +124,7 @@ private const val MAXIMUM_WATCH_DURATION_PER_USER = 4000
private const val DIFFERENTIAL_PRIVACY_DELTA = 1e-12

private const val MAX_BATCH_SIZE_FOR_BATCH_CREATE_METRICS = 1000
private const val MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS = 100

@RunWith(JUnit4::class)
class ReportsServiceTest {
Expand Down Expand Up @@ -2085,6 +2091,232 @@ class ReportsServiceTest {
assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
}

@Test
fun `getReport returns the report with SUCCEEDED when all metrics are SUCCEEDED`() = runBlocking {
whenever(
metricsMock.batchGetMetrics(
eq(
batchGetMetricsRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
names += SUCCEEDED_REACH_METRIC.name
}
)
)
)
.thenReturn(batchGetMetricsResponse { metrics += SUCCEEDED_REACH_METRIC })

val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(SUCCEEDED_REACH_REPORT)
}

@Test
fun `getReport returns the report with FAILED when any metric FAILED`() = runBlocking {
val failedReachMetric = RUNNING_REACH_METRIC.copy { state = Metric.State.FAILED }

whenever(
metricsMock.batchGetMetrics(
eq(
batchGetMetricsRequest {
parent = MEASUREMENT_CONSUMER_KEYS.first().toName()
names += failedReachMetric.name
}
)
)
)
.thenReturn(batchGetMetricsResponse { metrics += failedReachMetric })

val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(PENDING_REACH_REPORT.copy { state = Report.State.FAILED })
}

@Test
fun `getReport returns the report with RUNNING when metric is pending`(): Unit = runBlocking {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val report =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

assertThat(report).isEqualTo(PENDING_REACH_REPORT)
}

@Test
fun `getReport returns the report with RUNNING when there are more than max batch size metrics`():
Unit = runBlocking {
val startSec = 10L
val incrementSec = 1L
val intervalCount = MAX_BATCH_SIZE_FOR_BATCH_GET_METRICS + 1

val endTimesList: List<Long> =
(startSec + incrementSec until startSec + incrementSec + intervalCount).toList()
val internalTimeIntervals: List<InternalTimeInterval> =
endTimesList.map { end ->
internalTimeInterval {
startTime = timestamp { seconds = startSec }
endTime = timestamp { seconds = end }
}
}

val reportingCreateMetricRequests =
internalTimeIntervals.map { timeInterval ->
buildInitialReportingMetric(
PRIMITIVE_REPORTING_SETS.first().externalId,
timeInterval,
INTERNAL_REACH_METRIC_SPEC,
listOf()
)
}

val internalPendingReport = internalReport {
cmmsMeasurementConsumerId = MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId

this.periodicTimeInterval = internalPeriodicTimeInterval {
startTime = timestamp { seconds = startSec }
increment = duration { seconds = incrementSec }
this.intervalCount = intervalCount
}

externalReportId = 330L
createTime = Instant.now().toProtoTime()

val updatedReportingCreateMetricRequests =
reportingCreateMetricRequests.mapIndexed { requestId, request ->
request.copy {
this.createMetricRequestId = requestId.toString()
externalMetricId = EXTERNAL_METRIC_ID_BASE + requestId
}
}

reportingMetricEntries.putAll(
buildInternalReportingMetricEntryWithOneMetricCalculationSpec(
PRIMITIVE_REPORTING_SETS.first().externalId,
updatedReportingCreateMetricRequests,
DISPLAY_NAME,
listOf(),
true
)
)
}

whenever(
internalReportsMock.getReport(
eq(
internalGetReportRequest {
cmmsMeasurementConsumerId = internalPendingReport.cmmsMeasurementConsumerId
externalReportId = internalPendingReport.externalReportId
}
)
)
)
.thenReturn(internalPendingReport)

whenever(metricsMock.batchGetMetrics(any()))
.thenReturn(
batchGetMetricsResponse {
metrics +=
endTimesList.mapIndexed { index, end ->
metric {
name =
MetricKey(
MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId,
ExternalId(EXTERNAL_METRIC_ID_BASE + index).apiId.value
)
.toName()
reportingSet = PRIMITIVE_REPORTING_SETS.first().name
timeInterval = timeInterval {
startTime = timestamp { seconds = startSec }
endTime = timestamp { seconds = end }
}
metricSpec = REACH_METRIC_SPEC
state = Metric.State.RUNNING
createTime = Instant.now().toProtoTime()
}
}
}
)

val request = getReportRequest { name = internalPendingReport.resourceName }

withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}

val batchGetMetricsCaptor: KArgumentCaptor<BatchGetMetricsRequest> = argumentCaptor()
verifyBlocking(metricsMock, times(2)) { batchGetMetrics(batchGetMetricsCaptor.capture()) }
}

@Test
fun `getReport throws INVALID_ARGUMENT when Report name is invalid`() {
val request = getReportRequest { name = "INVALID_REPORT_NAME" }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
}

@Test
fun `getReport throws PERMISSION_DENIED when Report name is not accessible`() {
val inaccessibleReportName =
ReportKey(MEASUREMENT_CONSUMER_KEYS.last().measurementConsumerId, externalIdToApiId(330L))
.toName()
val request = getReportRequest { name = inaccessibleReportName }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `getReport throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val exception =
assertFailsWith<StatusRuntimeException> {
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.last().toName(), CONFIG) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED)
}

@Test
fun `getReport throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() {
val request = getReportRequest { name = PENDING_REACH_REPORT.name }

val exception =
assertFailsWith<StatusRuntimeException> {
withDataProviderPrincipal(DataProviderKey(ExternalId(550L).apiId.value).toName()) {
runBlocking { service.getReport(request) }
}
}

assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED)
}

companion object {
private fun buildInitialReportingMetric(
externalReportingSetId: Long,
Expand Down