diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel index d15497f13d6..2860048d384 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/BUILD.bazel @@ -204,6 +204,7 @@ kt_jvm_library( srcs = ["ReportsService.kt"], deps = [ ":metadata_principal_server_interceptor", + "//imports/java/org/projectnessie/cel", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", "//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store", diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index 2d7fc1b1f64..4cf0543ced7 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -24,6 +24,12 @@ import kotlin.math.min import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.toList +import org.projectnessie.cel.Env +import org.projectnessie.cel.EnvOption +import org.projectnessie.cel.checker.Decls +import org.projectnessie.cel.common.types.Err +import org.projectnessie.cel.common.types.pb.ProtoTypeRegistry +import org.projectnessie.cel.common.types.ref.Val import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.common.base64UrlDecode import org.wfanet.measurement.common.base64UrlEncode @@ -156,9 +162,12 @@ class ReportsService( return listReportsResponse { reports += - subResults.map { internalReport -> - convertInternalReportToPublic(internalReport, externalIdToMetricMap) - } + filterReports( + subResults.map { internalReport -> + convertInternalReportToPublic(internalReport, externalIdToMetricMap) + }, + request.filter + ) if (nextPageToken != null) { this.nextPageToken = nextPageToken.toByteString().base64UrlEncode() @@ -676,8 +685,75 @@ class ReportsService( return result } + private fun filterReports(reports: List, filter: String): List { + if (filter.isEmpty()) { + return reports + } + + val astAndIssues = + try { + ENV.compile(filter) + } catch (_: NullPointerException) { + // NullPointerException is thrown when an operator in the filter is not a CEL operator. + throw Status.INVALID_ARGUMENT.withDescription("filter is not a valid CEL expression") + .asRuntimeException() + } + if (astAndIssues.hasIssues()) { + throw Status.INVALID_ARGUMENT.withDescription( + "filter is not a valid CEL expression: ${astAndIssues.issues}" + ) + .asRuntimeException() + } + val program = ENV.program(astAndIssues.ast) + + return reports.filter { report -> + val variables: Map = + mutableMapOf().apply { + for (fieldDescriptor in report.descriptorForType.fields) { + put(fieldDescriptor.name, report.getField(fieldDescriptor)) + } + } + val result: Val = program.eval(variables).`val` + if (result is Err) { + throw result.toRuntimeException() + } + + if (result.value() !is Boolean) { + throw Status.INVALID_ARGUMENT.withDescription("filter does not evaluate to boolean") + .asRuntimeException() + } + + result.booleanValue() + } + } + companion object { private val RESOURCE_ID_REGEX = Regex("^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$") + private val ENV: Env = buildCelEnvironment() + + private fun buildCelEnvironment(): Env { + // Build CEL ProtoTypeRegistry. + val celTypeRegistry = ProtoTypeRegistry.newRegistry() + celTypeRegistry.registerMessage(Report.getDefaultInstance()) + + // Build CEL Env. + val reportDescriptor = Report.getDescriptor() + val env = + Env.newEnv( + EnvOption.container(reportDescriptor.fullName), + EnvOption.customTypeProvider(celTypeRegistry), + EnvOption.customTypeAdapter(celTypeRegistry), + EnvOption.declarations( + reportDescriptor.fields.map { + Decls.newVar( + it.name, + celTypeRegistry.findFieldType(reportDescriptor.fullName, it.name).type + ) + } + ) + ) + return env + } } } diff --git a/src/main/proto/wfa/measurement/reporting/v2alpha/reports_service.proto b/src/main/proto/wfa/measurement/reporting/v2alpha/reports_service.proto index 60b7b4907d8..5163ed8759b 100644 --- a/src/main/proto/wfa/measurement/reporting/v2alpha/reports_service.proto +++ b/src/main/proto/wfa/measurement/reporting/v2alpha/reports_service.proto @@ -86,6 +86,9 @@ message ListReportsRequest { // When paginating, all other parameters provided to `ListReports` must match // the call that provided the page token. string page_token = 3; + + // Result filter. Raw CEL expression that is applied to the message. + string filter = 4; } // Response message for `ListReports` method. diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt index ee6bbe2998f..c7f8e40d392 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsServiceTest.kt @@ -2880,6 +2880,36 @@ class ReportsServiceTest { assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) } + @Test + fun `listReports with a filter returns filtered results`() = runBlocking { + val pageSize = 2 + val request = listReportsRequest { + parent = MEASUREMENT_CONSUMER_KEYS.first().toName() + this.pageSize = pageSize + filter = "name != '${PENDING_WATCH_DURATION_REPORT.name}'" + } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_KEYS.first().toName(), CONFIG) { + runBlocking { service.listReports(request) } + } + + val expected = listReportsResponse { reports.add(PENDING_REACH_REPORT) } + + assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) + + verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::streamReports) + .isEqualTo( + streamReportsRequest { + limit = pageSize + 1 + this.filter = + StreamReportsRequestKt.filter { + cmmsMeasurementConsumerId = MEASUREMENT_CONSUMER_KEYS.first().measurementConsumerId + } + } + ) + } + @Test fun `listReports throws UNAUTHENTICATED when no principal is found`() { val request = listReportsRequest { parent = MEASUREMENT_CONSUMER_KEYS.first().toName() } @@ -3358,6 +3388,7 @@ class ReportsServiceTest { PRIMITIVE_REPORTING_SETS.first().resourceId, listOf(INITIAL_REACH_REPORTING_METRIC), listOf(), + "reach-" ) private val INITIAL_WATCH_DURATION_REPORTING_METRIC = @@ -3382,6 +3413,7 @@ class ReportsServiceTest { PRIMITIVE_REPORTING_SETS.first().resourceId, listOf(INITIAL_WATCH_DURATION_REPORTING_METRIC), listOf(), + "duration-", metricIdBaseLong = WATCH_DURATION_METRIC_ID_BASE_LONG )