diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGeneration.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGeneration.kt index 9999422bcb9..67f31374d0c 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGeneration.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGeneration.kt @@ -19,9 +19,15 @@ package org.wfanet.measurement.loadtest.dataprovider import com.google.protobuf.Descriptors.FieldDescriptor import com.google.protobuf.Message import java.time.ZoneOffset +import java.util.Random +import kotlin.math.max +import kotlin.math.min +import kotlin.random.Random as KotlinRandom +import kotlin.random.asKotlinRandom import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.FieldValue import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SimulatorSyntheticDataSpec import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpec +import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpec.FrequencySpec.VidRangeSpec import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticPopulationSpec import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticPopulationSpec.SubPopulation import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.VidRange @@ -46,14 +52,39 @@ object SyntheticDataGeneration { populationSpec: SyntheticPopulationSpec, syntheticEventGroupSpec: SyntheticEventGroupSpec, ): Sequence> { + var samplingRequired = false + val vidRangeSpecs = + syntheticEventGroupSpec.dateSpecsList + .flatMap { it.frequencySpecsList } + .flatMap { it.vidRangeSpecsList } + + for (vidRangeSpec in vidRangeSpecs) { + val vidRangeWidth = vidRangeSpec.vidRange.endExclusive - vidRangeSpec.vidRange.start + check(vidRangeWidth >= vidRangeSpec.sampleSize) { + "all vidRange widths should be larger than sampleSizes" + } + if (vidRangeSpec.sampleSize > 0) { + samplingRequired = true + } + } + + if (samplingRequired) { + check(syntheticEventGroupSpec.rngType == SyntheticEventGroupSpec.RngType.JAVA_UTIL_RANDOM) { + "Expecting JAVA_UTIL_RANDOM rng type, got ${syntheticEventGroupSpec.rngType}" + } + } + val subPopulations = populationSpec.subPopulationsList return sequence { for (dateSpec: SyntheticEventGroupSpec.DateSpec in syntheticEventGroupSpec.dateSpecsList) { val dateProgression = dateSpec.dateRange.toProgression() for (frequencySpec: SyntheticEventGroupSpec.FrequencySpec in dateSpec.frequencySpecsList) { - for (vidRangeSpec: SyntheticEventGroupSpec.FrequencySpec.VidRangeSpec in - frequencySpec.vidRangeSpecsList) { + + check(!frequencySpec.hasOverlaps()) { "The VID ranges should be non-overlapping." } + + for (vidRangeSpec: VidRangeSpec in frequencySpec.vidRangeSpecsList) { + val random = Random(vidRangeSpec.randomSeed).asKotlinRandom() val subPopulation: SubPopulation = vidRangeSpec.vidRange.findSubPopulation(subPopulations) ?: throw IllegalArgumentException() @@ -77,9 +108,10 @@ object SyntheticDataGeneration { @Suppress("UNCHECKED_CAST") // Safe per protobuf API. val message = builder.build() as T - for (vid in vidRangeSpec.vidRange.start until vidRangeSpec.vidRange.endExclusive) { - for (date in dateProgression) { - for (i in 0 until frequencySpec.frequency) { + for (date in dateProgression) { + for (i in 0 until frequencySpec.frequency) { + val sampledVids = sampleVids(vidRangeSpec, random) + for (vid in sampledVids) { yield(LabeledEvent(date.atStartOfDay().toInstant(ZoneOffset.UTC), vid, message)) } } @@ -90,6 +122,19 @@ object SyntheticDataGeneration { } } + /** + * Returns the sampled Vids from [vidRangeSpec]. Given the same [vidRangeSpec] and [randomSeed], + * returns the same vids. Returns all of the vids if sample size is 0. + */ + private fun sampleVids(vidRangeSpec: VidRangeSpec, random: KotlinRandom): Sequence { + val vidRangeSequence = + (vidRangeSpec.vidRange.start until vidRangeSpec.vidRange.endExclusive).asSequence() + if (vidRangeSpec.sampleSize == 0) { + return vidRangeSequence + } + return vidRangeSequence.shuffled(random).take(vidRangeSpec.sampleSize) + } + /** * Returns the [SubPopulation] from a list of [SubPopulation] that contains the [VidRange] in its * range. @@ -154,3 +199,15 @@ object SyntheticDataGeneration { private fun SyntheticEventGroupSpec.DateSpec.DateRange.toProgression(): LocalDateProgression { return start.toLocalDate()..endExclusive.toLocalDate().minusDays(1) } + +// Sort the ranges by their start. If there are any consecutive ranges where +// the previous has a larger end than the latter's start, then there is an overlap. +private fun SyntheticEventGroupSpec.FrequencySpec.hasOverlaps() = + vidRangeSpecsList + .map { it.vidRange } + .sortedBy { it.start } + .zipWithNext() + .any { (first, second) -> first.overlaps(second) } + +private fun VidRange.overlaps(other: VidRange) = + max(start, other.start) < min(endExclusive, other.endExclusive) diff --git a/src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing/simulator_synthetic_data_spec.proto b/src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing/simulator_synthetic_data_spec.proto index a878c488935..df90eb0fe87 100644 --- a/src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing/simulator_synthetic_data_spec.proto +++ b/src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing/simulator_synthetic_data_spec.proto @@ -110,6 +110,14 @@ message SyntheticEventGroupSpec { // A map of `non_population_fields` from `SyntheticPopulationSpec` to // their values. map non_population_field_values = 2; + + // Number of vids sampled uniformly without replacement from vid_range. + // If this is 0, no sampling is done and all the vids in range are taken. + int32 sample_size = 3; + + // Random seed to be fed into the random number generator to sample vids. + // Required if this VidRangeSpec specifies a sample_size. + int64 random_seed = 4; } // The VID ranges should be non-overlapping sub-ranges of SubPopulations. repeated VidRangeSpec vid_range_specs = 2; @@ -130,4 +138,16 @@ message SyntheticEventGroupSpec { } // `DateSpec`s should describe non-overlapping date ranges. repeated DateSpec date_specs = 2; + + // Type of random number generator to sample vids. + enum RngType { + // Default value used if the rng type is omitted. + RNG_TYPE_UNSPECIFIED = 0; + // Signals java.util.Random should be used for sampling. + JAVA_UTIL_RANDOM = 1; + } + + // Random Number Generator type for this `SyntheticEventGroupSpec`. + // Required if any VidRangeSpec specifies a sample_size. + RngType rng_type = 4; } diff --git a/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGenerationTest.kt b/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGenerationTest.kt index e516e848057..f4e85040fb9 100644 --- a/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGenerationTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/SyntheticDataGenerationTest.kt @@ -25,6 +25,7 @@ import kotlin.test.assertFailsWith import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpec import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpecKt import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticPopulationSpecKt import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.fieldValue @@ -249,6 +250,478 @@ class SyntheticDataGenerationTest { assertThat(labeledEvents).containsExactlyElementsIn(expectedTestEvents) } + @Test + fun `generateEvents returns a sequence of sampled events when sample size specified`() { + + val sampleSizeForFreqOne = 2 + val firstsampleSizeForFreqTwo = 5 + val secondSampleSizeForFreqTwo = 10 + + val population = syntheticPopulationSpec { + vidRange = vidRange { + start = 0L + endExclusive = 100L + } + + populationFields += "person.gender" + populationFields += "person.age_group" + + nonPopulationFields += "banner_ad.viewable" + nonPopulationFields += "video_ad.viewed_fraction" + + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 0L + endExclusive = 50L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.MALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 50L + endExclusive = 100L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.FEMALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + } + val eventGroupSpec = syntheticEventGroupSpec { + description = "event group 1" + rngType = SyntheticEventGroupSpec.RngType.JAVA_UTIL_RANDOM + + dateSpecs += + SyntheticEventGroupSpecKt.dateSpec { + dateRange = + SyntheticEventGroupSpecKt.DateSpecKt.dateRange { + start = date { + year = 2023 + month = 6 + day = 27 + } + endExclusive = date { + year = 2023 + month = 6 + day = 28 + } + } + + frequencySpecs += + SyntheticEventGroupSpecKt.frequencySpec { + frequency = 2 + + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + randomSeed = 42L + vidRange = vidRange { + start = 0L + endExclusive = 25L + } + + sampleSize = firstsampleSizeForFreqTwo + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = true } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.5 + } + } + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + randomSeed = 42L + vidRange = vidRange { + start = 25L + endExclusive = 50L + } + + sampleSize = secondSampleSizeForFreqTwo + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = false } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.7 + } + } + } + frequencySpecs += + SyntheticEventGroupSpecKt.frequencySpec { + frequency = 1 + + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + randomSeed = 42L + vidRange = vidRange { + start = 50L + endExclusive = 75L + } + + sampleSize = sampleSizeForFreqOne + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = true } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.8 + } + } + } + } + } + + val labeledEvents: List> = + SyntheticDataGeneration.generateEvents( + TestEvent.getDefaultInstance(), + population, + eventGroupSpec, + ) + .toList() + val expectedNumberOfEvents = + sampleSizeForFreqOne + 2 * (firstsampleSizeForFreqTwo + secondSampleSizeForFreqTwo) + assertThat(labeledEvents.size).isEqualTo(expectedNumberOfEvents) + } + + @Test + fun `generateEvents throws IllegalStateException for sample size larger than vidRange`() { + + val population = syntheticPopulationSpec { + vidRange = vidRange { + start = 0L + endExclusive = 100L + } + + populationFields += "person.gender" + populationFields += "person.age_group" + + nonPopulationFields += "banner_ad.viewable" + nonPopulationFields += "video_ad.viewed_fraction" + + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 0L + endExclusive = 50L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.MALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 50L + endExclusive = 100L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.FEMALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + } + val eventGroupSpec = syntheticEventGroupSpec { + description = "event group 1" + rngType = SyntheticEventGroupSpec.RngType.JAVA_UTIL_RANDOM + + dateSpecs += + SyntheticEventGroupSpecKt.dateSpec { + dateRange = + SyntheticEventGroupSpecKt.DateSpecKt.dateRange { + start = date { + year = 2023 + month = 6 + day = 27 + } + endExclusive = date { + year = 2023 + month = 6 + day = 28 + } + } + + frequencySpecs += + SyntheticEventGroupSpecKt.frequencySpec { + frequency = 2 + + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + randomSeed = 42L + vidRange = vidRange { + start = 0L + endExclusive = 25L + } + + sampleSize = 50 + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = true } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.5 + } + } + } + } + } + + assertFailsWith { + SyntheticDataGeneration.generateEvents( + TestEvent.getDefaultInstance(), + population, + eventGroupSpec, + ) + } + } + + @Test + fun `generateEvents throws IllegalStateException for RNG not specified when sampling enabled`() { + + val sampleSizeForFreqOne = 2 + val firstsampleSizeForFreqTwo = 5 + val secondSampleSizeForFreqTwo = 10 + + val population = syntheticPopulationSpec { + vidRange = vidRange { + start = 0L + endExclusive = 100L + } + + populationFields += "person.gender" + populationFields += "person.age_group" + + nonPopulationFields += "banner_ad.viewable" + nonPopulationFields += "video_ad.viewed_fraction" + + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 0L + endExclusive = 50L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.MALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 50L + endExclusive = 100L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.FEMALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + } + val eventGroupSpec = syntheticEventGroupSpec { + description = "event group 1" + + dateSpecs += + SyntheticEventGroupSpecKt.dateSpec { + dateRange = + SyntheticEventGroupSpecKt.DateSpecKt.dateRange { + start = date { + year = 2023 + month = 6 + day = 27 + } + endExclusive = date { + year = 2023 + month = 6 + day = 28 + } + } + + frequencySpecs += + SyntheticEventGroupSpecKt.frequencySpec { + frequency = 2 + + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + randomSeed = 42L + vidRange = vidRange { + start = 0L + endExclusive = 25L + } + + sampleSize = firstsampleSizeForFreqTwo + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = true } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.5 + } + } + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + randomSeed = 42L + vidRange = vidRange { + start = 25L + endExclusive = 50L + } + + sampleSize = secondSampleSizeForFreqTwo + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = false } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.7 + } + } + } + frequencySpecs += + SyntheticEventGroupSpecKt.frequencySpec { + frequency = 1 + + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + randomSeed = 42L + vidRange = vidRange { + start = 50L + endExclusive = 75L + } + + sampleSize = sampleSizeForFreqOne + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = true } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.8 + } + } + } + } + } + + assertFailsWith { + SyntheticDataGeneration.generateEvents( + TestEvent.getDefaultInstance(), + population, + eventGroupSpec, + ) + } + } + + fun `generateEvents throws IllegalArgumentException when vid ranges overlap`() { + val population = syntheticPopulationSpec { + vidRange = vidRange { + start = 0L + endExclusive = 100L + } + + populationFields += "person.gender" + populationFields += "person.age_group" + + nonPopulationFields += "banner_ad.viewable" + nonPopulationFields += "video_ad.viewed_fraction" + + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 0L + endExclusive = 50L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.MALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + subPopulations += + SyntheticPopulationSpecKt.subPopulation { + vidSubRange = vidRange { + start = 50L + endExclusive = 100L + } + + populationFieldsValues["person.gender"] = fieldValue { + enumValue = Person.Gender.FEMALE_VALUE + } + populationFieldsValues["person.age_group"] = fieldValue { + enumValue = Person.AgeGroup.YEARS_18_TO_34_VALUE + } + } + } + val eventGroupSpec = syntheticEventGroupSpec { + description = "event group 1" + + dateSpecs += + SyntheticEventGroupSpecKt.dateSpec { + dateRange = + SyntheticEventGroupSpecKt.DateSpecKt.dateRange { + start = date { + year = 2023 + month = 6 + day = 27 + } + endExclusive = date { + year = 2023 + month = 6 + day = 28 + } + } + + frequencySpecs += + SyntheticEventGroupSpecKt.frequencySpec { + frequency = 2 + + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + vidRange = vidRange { + start = 0L + endExclusive = 25L + } + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = true } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.5 + } + } + vidRangeSpecs += + SyntheticEventGroupSpecKt.FrequencySpecKt.vidRangeSpec { + vidRange = vidRange { + // 20 is in between 0 and 25, the previous range. + start = 20L + endExclusive = 50L + } + + nonPopulationFieldValues["banner_ad.viewable"] = fieldValue { boolValue = false } + nonPopulationFieldValues["video_ad.viewed_fraction"] = fieldValue { + doubleValue = 0.7 + } + } + } + } + } + + assertFailsWith { + SyntheticDataGeneration.generateEvents( + TestEvent.getDefaultInstance(), + population, + eventGroupSpec, + ) + .toList() + } + } + @Test fun `generateEvents returns messages with a Duration field`() { val populationSpec = syntheticPopulationSpec {