Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove non-template fields from our test templates. #1150

Merged
merged 1 commit into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ package(
],
)

kt_jvm_library(
name = "labeled_event",
srcs = ["LabeledEvent.kt"],
deps = ["@wfa_common_jvm//imports/java/com/google/protobuf"],
)

kt_jvm_library(
name = "event_query",
srcs = ["EventQuery.kt"],
deps = [
":labeled_event",
"//imports/java/org/projectnessie/cel",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/eventfiltration:event_filters",
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto",
Expand All @@ -30,6 +37,7 @@ kt_jvm_library(
srcs = ["InMemoryEventQuery.kt"],
deps = [
":event_query",
":labeled_event",
"//imports/java/org/projectnessie/cel",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/eventfiltration:event_filters",
"//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto",
Expand All @@ -42,6 +50,7 @@ kt_jvm_library(
srcs = ["CsvEventQuery.kt"],
deps = [
":in_memory_event_query",
":labeled_event",
"//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/opencsv",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
Expand All @@ -53,10 +62,12 @@ kt_jvm_library(
srcs = ["BigQueryEventQuery.kt"],
deps = [
":event_query",
":labeled_event",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/eventfiltration:event_filters",
"//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/cloud/bigquery",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/common",
],
)

Expand All @@ -65,6 +76,7 @@ kt_jvm_library(
srcs = ["SyntheticGeneratorEventQuery.kt"],
deps = [
":event_query",
":labeled_event",
":synthetic_data_generation",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/eventfiltration:event_filters",
"//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:simulator_synthetic_data_spec_kt_jvm_proto",
Expand All @@ -79,6 +91,7 @@ kt_jvm_library(
"SyntheticDataGeneration.kt",
],
deps = [
":labeled_event",
"//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:simulator_synthetic_data_spec_kt_jvm_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,19 @@ import org.wfanet.measurement.api.v2alpha.event_templates.testing.video
import org.wfanet.measurement.common.OpenEndTimeRange
import org.wfanet.measurement.common.toRange
import org.wfanet.measurement.eventdataprovider.eventfiltration.EventFilters
import org.wfanet.measurement.gcloud.common.toInstant

/** Fulfill the query by querying the specified BigQuery table. */
abstract class BigQueryEventQuery(
private val bigQuery: BigQuery,
private val datasetName: String,
private val tableName: String,
) : EventQuery {

) : EventQuery<TestEvent> {
protected abstract fun getPublisherId(eventGroup: EventGroup): Int

override fun getUserVirtualIds(eventGroupSpec: EventQuery.EventGroupSpec): Sequence<Long> {
override fun getLabeledEvents(
eventGroupSpec: EventQuery.EventGroupSpec
): Sequence<LabeledEvent<TestEvent>> {
val timeRange: OpenEndTimeRange = eventGroupSpec.spec.collectionInterval.toRange()
val queryConfig: QueryJobConfiguration =
buildQueryConfig(getPublisherId(eventGroupSpec.eventGroup), timeRange)
Expand All @@ -70,9 +72,8 @@ abstract class BigQueryEventQuery(
.getQueryResults()
.iterateAll()
.asSequence()
.map { it.toTestEvent() }
.filter { EventFilters.matches(it, program) }
.map { it.person.vid }
.map { it.toLabeledEvent() }
.filter { EventFilters.matches(it.message, program) }
}

/** Builds a query based on the parameters given. */
Expand Down Expand Up @@ -101,7 +102,7 @@ abstract class BigQueryEventQuery(
.build()
}

private fun FieldValueList.toTestEvent(): TestEvent {
private fun FieldValueList.toLabeledEvent(): LabeledEvent<TestEvent> {
val gender: Person.Gender? =
when (get("sex").stringValue) {
"M" -> Person.Gender.MALE
Expand All @@ -126,10 +127,8 @@ abstract class BigQueryEventQuery(
0L -> false
else -> true
}
return testEvent {
time = Timestamp.ofTimeMicroseconds(get("time").timestampValue).toProto()
val message = testEvent {
person = person {
vid = get("vid").longValue
if (gender != null) {
this.gender = gender
}
Expand All @@ -149,6 +148,11 @@ abstract class BigQueryEventQuery(
}
}
}
return LabeledEvent(
Timestamp.ofTimeMicroseconds(get("time").timestampValue).toInstant(),
get("vid").longValue,
message
)
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ import java.time.format.DateTimeFormatter
import java.util.Locale
import java.util.logging.Logger
import org.wfanet.measurement.api.v2alpha.event_templates.testing.Person
import org.wfanet.measurement.api.v2alpha.event_templates.testing.TestEvent
import org.wfanet.measurement.api.v2alpha.event_templates.testing.person
import org.wfanet.measurement.api.v2alpha.event_templates.testing.testEvent
import org.wfanet.measurement.api.v2alpha.event_templates.testing.video
import org.wfanet.measurement.common.toProtoTime

private const val EDP_ID_INDEX = 0
private const val GENDER_INDEX = 2
Expand All @@ -47,7 +45,7 @@ class CsvEventQuery(publisherId: Int, file: File) :
private val dateFormatter = DateTimeFormatter.ofPattern("dd/MM/yyyy", Locale.UK)

@Throws(IOException::class)
private fun readCsvFile(publisherId: Int, file: File): List<LabelledEvent> {
private fun readCsvFile(publisherId: Int, file: File): List<LabeledTestEvent> {
logger.info("Reading data from CSV file: $file...")

return file.reader().use { fileReader ->
Expand All @@ -56,16 +54,15 @@ class CsvEventQuery(publisherId: Int, file: File) :
.iterator()
.asSequence()
.filter { row -> row[EDP_ID_INDEX].toInt() == publisherId }
.map { row ->
val event: TestEvent = parseTestEvent(row)
LabelledEvent(event.person.vid, event)
}
.map { row -> parseLabeledEvent(row) }
.toList()
}
}

private fun parseTestEvent(row: Array<String>): TestEvent {
private fun parseLabeledEvent(row: Array<String>): LabeledTestEvent {
val vid = row[VID_INDEX].toLong()
val timestamp =
LocalDate.parse(row[DATE_INDEX], dateFormatter).atStartOfDay().toInstant(ZoneOffset.UTC)
val gender: Person.Gender? =
when (row[GENDER_INDEX]) {
"M" -> Person.Gender.MALE
Expand All @@ -91,14 +88,8 @@ class CsvEventQuery(publisherId: Int, file: File) :
1 -> true
else -> null
}
return testEvent {
time =
LocalDate.parse(row[DATE_INDEX], dateFormatter)
.atStartOfDay()
.toInstant(ZoneOffset.UTC)
.toProtoTime()
val message = testEvent {
person = person {
this.vid = vid
if (gender != null) {
this.gender = gender
}
Expand All @@ -120,6 +111,8 @@ class CsvEventQuery(publisherId: Int, file: File) :
}
}
}

return LabeledTestEvent(timestamp, vid, message)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class EdpSimulator(
private val eventGroupMetadataDescriptorsStub: EventGroupMetadataDescriptorsCoroutineStub,
private val requisitionsStub: RequisitionsCoroutineStub,
private val requisitionFulfillmentStub: RequisitionFulfillmentCoroutineStub,
private val eventQuery: EventQuery,
private val eventQuery: EventQuery<Message>,
private val throttler: Throttler,
private val privacyBudgetManager: PrivacyBudgetManager,
private val trustedCertificates: Map<ByteString, X509Certificate>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ abstract class EdpSimulatorRunner() : Runnable {
protected lateinit var flags: EdpSimulatorFlags
private set

protected fun run(eventQuery: EventQuery, eventGroupMetadata: Message) {
protected fun run(eventQuery: EventQuery<Message>, eventGroupMetadata: Message) {
val clientCerts =
SigningCerts.fromPemFiles(
certificateFile = flags.tlsFlags.certFile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package org.wfanet.measurement.loadtest.dataprovider

import com.google.protobuf.Descriptors
import com.google.protobuf.Message
import org.projectnessie.cel.Program
import org.projectnessie.cel.common.types.BoolT
import org.wfanet.measurement.api.v2alpha.EventGroup
Expand All @@ -23,7 +24,7 @@ import org.wfanet.measurement.api.v2alpha.RequisitionSpec.EventFilter
import org.wfanet.measurement.eventdataprovider.eventfiltration.EventFilters

/** A query to get the list of user virtual IDs for a particular requisition. */
interface EventQuery {
interface EventQuery<out T : Message> {
/**
* An [EventGroup] with the specification of events from it.
*
Expand All @@ -34,13 +35,18 @@ interface EventQuery {
val spec: RequisitionSpec.EventGroupEntry.Value
)

/** Returns a [Sequence] of [LabeledEvent]. */
fun getLabeledEvents(eventGroupSpec: EventGroupSpec): Sequence<LabeledEvent<out T>>

/**
* Returns a [Sequence] of virtual person IDs for matching events.
*
* Each element in the returned value represents a single event. As a result, the same VID may be
* returned multiple times.
*/
fun getUserVirtualIds(eventGroupSpec: EventGroupSpec): Sequence<Long>
fun getUserVirtualIds(eventGroupSpec: EventGroupSpec): Sequence<Long> {
return getLabeledEvents(eventGroupSpec).map { it.vid }
}

companion object {
private val TRUE_EVAL_RESULT = Program.newEvalResult(BoolT.True, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,25 @@ package org.wfanet.measurement.loadtest.dataprovider

import org.projectnessie.cel.Program
import org.wfanet.measurement.api.v2alpha.event_templates.testing.TestEvent
import org.wfanet.measurement.api.v2alpha.event_templates.testing.timeOrNull
import org.wfanet.measurement.common.OpenEndTimeRange
import org.wfanet.measurement.common.toInstant
import org.wfanet.measurement.common.toRange
import org.wfanet.measurement.eventdataprovider.eventfiltration.EventFilters

/** Fulfills the query with matching events using filters. */
open class InMemoryEventQuery(private val events: Iterable<LabelledEvent>) : EventQuery {
constructor(
events: Map<Long, List<TestEvent>>
) : this(events.flatMap { (vid, events) -> events.map { event -> LabelledEvent(vid, event) } })

data class LabelledEvent(val vid: Long, val event: TestEvent)
typealias LabeledTestEvent = LabeledEvent<TestEvent>

override fun getUserVirtualIds(eventGroupSpec: EventQuery.EventGroupSpec): Sequence<Long> {
/** Fulfills the query with matching events using filters. */
open class InMemoryEventQuery(
private val labeledEvents: List<LabeledTestEvent>,
) : EventQuery<TestEvent> {
override fun getLabeledEvents(
eventGroupSpec: EventQuery.EventGroupSpec
): Sequence<LabeledTestEvent> {
val timeRange: OpenEndTimeRange = eventGroupSpec.spec.collectionInterval.toRange()
val program: Program =
EventQuery.compileProgram(eventGroupSpec.spec.filter, TestEvent.getDescriptor())

return events
.asSequence()
.filter { (_, event) ->
checkNotNull(event.timeOrNull).toInstant() in timeRange &&
EventFilters.matches(event, program)
}
.map { it.vid }
return labeledEvents.asSequence().filter {
it.timestamp in timeRange && EventFilters.matches(it.message, program)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2023 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.loadtest.dataprovider

import com.google.protobuf.Message
import java.time.Instant

/** An event [message] with [timestamp] and [vid] labels. */
data class LabeledEvent<T : Message>(val timestamp: Instant, val vid: Long, val message: T)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.wfanet.measurement.loadtest.dataprovider

import com.google.protobuf.Message
import org.wfanet.anysketch.AnySketch
import org.wfanet.anysketch.Sketch
import org.wfanet.anysketch.SketchConfig
Expand All @@ -31,7 +32,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementSpec
import org.wfanet.measurement.loadtest.config.VidSampling

class SketchGenerator(
private val eventQuery: EventQuery,
private val eventQuery: EventQuery<Message>,
private val sketchConfig: SketchConfig,
private val vidSamplingInterval: MeasurementSpec.VidSamplingInterval
) {
Expand Down
Loading