From 23211c77941a5d4b16f6f6aca8b38450764ab35d Mon Sep 17 00:00:00 2001 From: Craig Wright Date: Fri, 20 Sep 2024 14:22:18 -0600 Subject: [PATCH] feat: Add incrementBy to FrequencyVectorBuilder (#1817) --- .../v2alpha/FrequencyVectorBuilder.kt | 25 ++++++++++++++++--- .../v2alpha/FrequencyVectorBuilderTest.kt | 13 +++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilder.kt b/src/main/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilder.kt index ec7c78bf4d8..9c6a0ec8841 100644 --- a/src/main/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilder.kt +++ b/src/main/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilder.kt @@ -172,6 +172,17 @@ class FrequencyVectorBuilder( * supported by this builder */ fun increment(globalIndex: Int) { + incrementBy(globalIndex, 1) + } + + /** + * Increment the frequency vector for the VID at globalIndex by amount. + * + * See [increment] for additional information. + */ + fun incrementBy(globalIndex: Int, amount: Int) { + require(amount > 0) { "amount must be > 0 got ${amount}" } + if (!(globalIndex in primaryRange || globalIndex in wrappedRange)) { if (strict) { require(globalIndex in primaryRange || globalIndex in wrappedRange) { @@ -188,15 +199,23 @@ class FrequencyVectorBuilder( } else { primaryRange.count() + globalIndex } - frequencyData[localIndex] = minOf(frequencyData[localIndex] + 1, maxFrequency) + frequencyData[localIndex] = minOf(frequencyData[localIndex] + amount, maxFrequency) } /** * Add each globalIndex in the input Collection to the [FrequencyVector] according to the criteria - * described by [addVid] + * described by [increment] */ fun incrementAll(globalIndexes: Collection) { - globalIndexes.map { increment(it) } + globalIndexes.map { incrementBy(it, 1) } + } + + /** + * Add each globalIndex in the input Collection to the [FrequencyVector] according to the criteria + * described by [incrementBy] + */ + fun incrementAllBy(globalIndexes: Collection, amount: Int) { + globalIndexes.map { incrementBy(it, amount) } } /** diff --git a/src/test/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilderTest.kt b/src/test/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilderTest.kt index acc9497dd72..fa1771eec1a 100644 --- a/src/test/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilderTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha/FrequencyVectorBuilderTest.kt @@ -227,11 +227,18 @@ class FrequencyVectorBuilderTest { assertThat(builder.build()).isEqualTo(frequencyVector { data += listOf(0, 0, 0, 0, 0) }) } + @Test + fun `incrementBy fails when amount is 0`() { + val builder = + FrequencyVectorBuilder(SMALL_POPULATION_SPEC, PARTIAL_NON_WRAPPING_REACH_MEASUREMENT_SPEC) + assertFailsWith("expected exception") { builder.incrementBy(5, 0) } + } + @Test fun `build returns a frequency vector for frequency over full interval`() { val frequencyMeasurementSpec = measurementSpec { vidSamplingInterval = FULL_SAMPLING_INTERVAL - reachAndFrequency = reachAndFrequency { maximumFrequency = 2 } + reachAndFrequency = reachAndFrequency { maximumFrequency = 3 } } val frequencyVector = @@ -245,11 +252,11 @@ class FrequencyVectorBuilderTest { STARTING_VID + 8, STARTING_VID + 8, ) - .map { increment(SMALL_POPULATION_VID_INDEX_MAP[it]) } + .map { incrementBy(SMALL_POPULATION_VID_INDEX_MAP[it], 2) } } assertThat(frequencyVector) - .isEqualTo(frequencyVector { data += listOf(2, 1, 0, 0, 0, 0, 0, 0, 2, 0) }) + .isEqualTo(frequencyVector { data += listOf(3, 2, 0, 0, 0, 0, 0, 0, 3, 0) }) } @Test