From a70396a438f88c884bd010bfffa5b450bc9ff719 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Thu, 22 Jun 2023 12:42:06 -0700 Subject: [PATCH 1/4] Update query in readModelRolloutData to return the ModelRollout for a given ModelLine with the latest RolloutPeriodStartTime in CreateModelRollout.kt. Update CreateModelRollout.kt to throw ModelRolloutInvalidArgsException if the RolloutPeriodStartTime of the ModelRollout to be created precedes that of the last ModelRollout. Create tests in ModelRolloutServiceTest.kt --- .../spanner/writers/CreateModelRollout.kt | 28 +++-- .../testing/ModelRolloutsServiceTest.kt | 110 +++++++++++++++++- 2 files changed, 126 insertions(+), 12 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt index 31f5a72e561..592799b23d9 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt @@ -19,14 +19,16 @@ package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers import com.google.cloud.spanner.Statement import com.google.cloud.spanner.Struct import com.google.cloud.spanner.Value -import com.google.protobuf.Timestamp import com.google.protobuf.util.Timestamps import java.time.Clock +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.singleOrNull import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.InternalId +import org.wfanet.measurement.common.toInstant import org.wfanet.measurement.common.toProtoTime import org.wfanet.measurement.gcloud.common.toGcloudTimestamp +import org.wfanet.measurement.gcloud.common.toInstant import org.wfanet.measurement.gcloud.spanner.bind import org.wfanet.measurement.gcloud.spanner.bufferInsertMutation import org.wfanet.measurement.gcloud.spanner.set @@ -52,7 +54,6 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo "RolloutPeriodStartTime must be in the future." } } - if ( Timestamps.compare(modelRollout.rolloutPeriodStartTime, modelRollout.rolloutPeriodEndTime) > 0 ) { @@ -71,12 +72,25 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo ExternalId(modelRollout.externalModelProviderId), ExternalId(modelRollout.externalModelSuiteId), ExternalId(modelRollout.externalModelLineId), - modelRollout.rolloutPeriodStartTime ) } else { null } + val previousModelRolloutStartTime = + previousModelRolloutData?.getTimestamp("RolloutPeriodStartTime") + if ( + previousModelRolloutStartTime != null && + previousModelRolloutStartTime.toInstant() > modelRollout.rolloutPeriodStartTime.toInstant() + ) { + throw ModelRolloutInvalidArgsException( + ExternalId(modelRollout.externalModelProviderId), + ExternalId(modelRollout.externalModelSuiteId), + ExternalId(modelRollout.externalModelLineId) + ) { + "RolloutPeriodStartTime cannot precede that of previous ModelRollout." + } + } val modelLineData = readModelLineData( ExternalId(modelRollout.externalModelProviderId), @@ -170,13 +184,13 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo externalModelProviderId: ExternalId, externalModelSuiteId: ExternalId, externalModelLineId: ExternalId, - rolloutStartTime: Timestamp ): Struct? { val sql = """ SELECT ModelRollouts.ModelRolloutId, - ModelRollouts.ExternalModelRolloutId + ModelRollouts.ExternalModelRolloutId, + ModelRollouts.RolloutPeriodStartTime FROM ModelRollouts JOIN ModelLines USING (ModelProviderId, ModelSuiteId, ModelLineId) JOIN ModelSuites @@ -185,7 +199,6 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo WHERE ModelProviders.ExternalModelProviderId = @externalModelProviderId AND ModelSuites.ExternalModelSuiteId = @externalModelSuiteId AND ModelLines.ExternalModelLineId = @externalModelLineId - AND ModelRollouts.RolloutPeriodStartTime < @rolloutPeriodStartTime ORDER BY ModelRollouts.RolloutPeriodStartTime DESC """ .trimIndent() @@ -195,10 +208,9 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo bind("externalModelProviderId" to externalModelProviderId.value) bind("externalModelSuiteId" to externalModelSuiteId.value) bind("externalModelLineId" to externalModelLineId.value) - bind("rolloutPeriodStartTime" to rolloutStartTime.toGcloudTimestamp()) } - return transactionContext.executeQuery(statement).singleOrNull() + return transactionContext.executeQuery(statement).firstOrNull() } override fun ResultScope.buildResult(): ModelRollout { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt index a2436cfd32d..f92ce70dafc 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt @@ -826,11 +826,11 @@ abstract class ModelRolloutsServiceTest { } val modelRollout1 = modelRolloutsService.createModelRollout( - modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(150L).toProtoTime() } + modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(50L).toProtoTime() } ) val modelRollout2 = modelRolloutsService.createModelRollout( - modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(50L).toProtoTime() } + modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(150L).toProtoTime() } ) val modelRollouts: List = @@ -848,7 +848,7 @@ abstract class ModelRolloutsServiceTest { .toList() assertThat(modelRollouts).hasSize(1) - assertThat(modelRollouts).contains(modelRollout2) + assertThat(modelRollouts).contains(modelRollout1) val modelRollouts2: List = modelRolloutsService @@ -871,7 +871,7 @@ abstract class ModelRolloutsServiceTest { .toList() assertThat(modelRollouts2).hasSize(1) - assertThat(modelRollouts2).contains(modelRollout1) + assertThat(modelRollouts2).contains(modelRollout2) } @Test @@ -1046,4 +1046,106 @@ abstract class ModelRolloutsServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) assertThat(exception).hasMessageThat().contains("Missing RolloutPeriod fields") } + + @Test + fun `createModelRollout succeeds with multiple model releases`() = runBlocking { + val modelLine = + population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService) + val modelRelease = + population.createModelRelease( + modelSuite { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + }, + modelReleasesService + ) + + val modelRollout = modelRollout { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + externalModelLineId = modelLine.externalModelLineId + rolloutPeriodStartTime = Instant.now().plusSeconds(100L).toProtoTime() + rolloutPeriodEndTime = Instant.now().plusSeconds(100L).toProtoTime() + externalModelReleaseId = modelRelease.externalModelReleaseId + } + modelRolloutsService.createModelRollout(modelRollout) + + val modelRollout2 = modelRollout { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + externalModelLineId = modelLine.externalModelLineId + rolloutPeriodStartTime = Instant.now().plusSeconds(200L).toProtoTime() + rolloutPeriodEndTime = Instant.now().plusSeconds(300L).toProtoTime() + externalModelReleaseId = modelRelease.externalModelReleaseId + } + val createdModelRollout2 = modelRolloutsService.createModelRollout(modelRollout2) + + val modelRollout3 = modelRollout { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + externalModelLineId = modelLine.externalModelLineId + rolloutPeriodStartTime = Instant.now().plusSeconds(900L).toProtoTime() + rolloutPeriodEndTime = Instant.now().plusSeconds(1100L).toProtoTime() + externalModelReleaseId = modelRelease.externalModelReleaseId + } + val createdModelRollout3 = modelRolloutsService.createModelRollout(modelRollout3) + + assertThat(createdModelRollout3.externalPreviousModelRolloutId) + .isEqualTo(createdModelRollout2.externalModelRolloutId) + } + + @Test + fun `createModelRollout fails when new model rollout start time precedes that of previous model rollout`() = + runBlocking { + val modelLine = + population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService) + val modelRelease = + population.createModelRelease( + modelSuite { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + }, + modelReleasesService + ) + + val modelRollout = modelRollout { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + externalModelLineId = modelLine.externalModelLineId + rolloutPeriodStartTime = Instant.now().plusSeconds(100L).toProtoTime() + rolloutPeriodEndTime = Instant.now().plusSeconds(100L).toProtoTime() + externalModelReleaseId = modelRelease.externalModelReleaseId + } + modelRolloutsService.createModelRollout(modelRollout) + + val modelRollout2 = modelRollout { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + externalModelLineId = modelLine.externalModelLineId + rolloutPeriodStartTime = Instant.now().plusSeconds(300L).toProtoTime() + rolloutPeriodEndTime = Instant.now().plusSeconds(400L).toProtoTime() + externalModelReleaseId = modelRelease.externalModelReleaseId + } + + modelRolloutsService.createModelRollout(modelRollout2) + + val modelRollout3 = modelRollout { + externalModelProviderId = modelLine.externalModelProviderId + externalModelSuiteId = modelLine.externalModelSuiteId + externalModelLineId = modelLine.externalModelLineId + rolloutPeriodStartTime = Instant.now().plusSeconds(200L).toProtoTime() + rolloutPeriodEndTime = Instant.now().plusSeconds(300L).toProtoTime() + externalModelReleaseId = modelRelease.externalModelReleaseId + } + + val exception = + assertFailsWith { + modelRolloutsService.createModelRollout(modelRollout3) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception) + .hasMessageThat() + .contains("RolloutPeriodStartTime cannot precede that of previous ModelRollout.") + } } From 37edadf08bdb7009577350370d81be881beecaef Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Fri, 23 Jun 2023 12:42:30 -0700 Subject: [PATCH 2/4] Rename readModelRolloutData to readLatestModelRolloutData to more closely reflect what is done in the function. Create data type to be returned by readLatestModelRolloutData. Add comments to explain why the error is thrown. Add comments to API definition to surface this limitation. --- .../spanner/writers/CreateModelRollout.kt | 46 ++++++++++++------- .../kingdom/model_rollouts_service.proto | 3 ++ 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt index 592799b23d9..7839003ed06 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt @@ -16,6 +16,7 @@ package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers +import com.google.cloud.Timestamp import com.google.cloud.spanner.Statement import com.google.cloud.spanner.Struct import com.google.cloud.spanner.Value @@ -43,6 +44,11 @@ import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.ModelRelease class CreateModelRollout(private val modelRollout: ModelRollout, private val clock: Clock) : SpannerWriter() { + private data class LatestModelRolloutResult( + val modelRolloutId: Long, + val externalModelRolloutId: Long, + val rolloutPeriodStartTime: Timestamp? + ) override suspend fun TransactionScope.runTransaction(): ModelRollout { val now = clock.instant().toProtoTime() if (Timestamps.compare(now, modelRollout.rolloutPeriodStartTime) >= 0) { @@ -66,9 +72,9 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo } } - val previousModelRolloutData = + val latestModelRolloutData = if (modelRollout.rolloutPeriodStartTime != modelRollout.rolloutPeriodEndTime) { - readModelRolloutData( + readLatestModelRolloutData( ExternalId(modelRollout.externalModelProviderId), ExternalId(modelRollout.externalModelSuiteId), ExternalId(modelRollout.externalModelLineId), @@ -77,11 +83,13 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo null } - val previousModelRolloutStartTime = - previousModelRolloutData?.getTimestamp("RolloutPeriodStartTime") + val latestModelRolloutStartTime = latestModelRolloutData?.rolloutPeriodStartTime + + // New ModelRollout cannot have a RolloutStartTime that precedes the RolloutStartTime of the + // current PreviousModelRollout if ( - previousModelRolloutStartTime != null && - previousModelRolloutStartTime.toInstant() > modelRollout.rolloutPeriodStartTime.toInstant() + latestModelRolloutStartTime != null && + latestModelRolloutStartTime.toInstant() > modelRollout.rolloutPeriodStartTime.toInstant() ) { throw ModelRolloutInvalidArgsException( ExternalId(modelRollout.externalModelProviderId), @@ -128,10 +136,8 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo set("ExternalModelRolloutId" to externalModelRolloutId) set("RolloutPeriodStartTime" to modelRollout.rolloutPeriodStartTime.toGcloudTimestamp()) set("RolloutPeriodEndTime" to modelRollout.rolloutPeriodEndTime.toGcloudTimestamp()) - if (previousModelRolloutData != null) { - set( - "PreviousModelRolloutId" to InternalId(previousModelRolloutData.getLong("ModelRolloutId")) - ) + if (latestModelRolloutData != null) { + set("PreviousModelRolloutId" to InternalId(latestModelRolloutData.modelRolloutId)) } set("ModelReleaseId" to modelReleaseResult.modelReleaseId) set("CreateTime" to Value.COMMIT_TIMESTAMP) @@ -140,9 +146,8 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo return modelRollout.copy { this.externalModelRolloutId = externalModelRolloutId.value - if (previousModelRolloutData != null) { - this.externalPreviousModelRolloutId = - previousModelRolloutData.getLong("ExternalModelRolloutId") + if (latestModelRolloutData != null) { + this.externalPreviousModelRolloutId = latestModelRolloutData.externalModelRolloutId } } } @@ -180,11 +185,12 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo return transactionContext.executeQuery(statement).singleOrNull() } - private suspend fun TransactionScope.readModelRolloutData( + // Reads the ModelRollout for a given ModelLine with the latest RolloutPeriodStartTime. + private suspend fun TransactionScope.readLatestModelRolloutData( externalModelProviderId: ExternalId, externalModelSuiteId: ExternalId, externalModelLineId: ExternalId, - ): Struct? { + ): LatestModelRolloutResult? { val sql = """ SELECT @@ -210,7 +216,15 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo bind("externalModelLineId" to externalModelLineId.value) } - return transactionContext.executeQuery(statement).firstOrNull() + val result = transactionContext.executeQuery(statement).firstOrNull() + + return if (result == null) null + else + LatestModelRolloutResult( + result.getLong("ModelRolloutId"), + result.getLong("ExternalModelRolloutId"), + result.getTimestamp("RolloutPeriodStartTime") + ) } override fun ResultScope.buildResult(): ModelRollout { diff --git a/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto b/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto index 5a1b31f56f1..c0aba6c0f66 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto @@ -24,7 +24,10 @@ option java_multiple_files = true; // Internal service for persistence of ModelRollout entities. service ModelRollouts { + // rollout_period_start_time of given `ModelRollout` must be later than that + // of the current previous `ModelRollout` if it is set. rpc CreateModelRollout(ModelRollout) returns (ModelRollout); + // Streams `ModelRollout`s. rpc StreamModelRollouts(StreamModelRolloutsRequest) returns (stream ModelRollout); From c5a8c664cb8e81b216dc4db6943a064d7c6a994d Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Mon, 26 Jun 2023 08:47:20 -0700 Subject: [PATCH 3/4] Update name of test to reflect that it is testing createModelRollout with multiple model rollouts. Remove comment in API definition(will create another PR in public API). --- .../service/internal/testing/ModelRolloutsServiceTest.kt | 2 +- .../measurement/internal/kingdom/model_rollouts_service.proto | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt index f92ce70dafc..f164d80544f 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt @@ -1048,7 +1048,7 @@ abstract class ModelRolloutsServiceTest { } @Test - fun `createModelRollout succeeds with multiple model releases`() = runBlocking { + fun `createModelRollout succeeds with multiple model rollouts`() = runBlocking { val modelLine = population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService) val modelRelease = diff --git a/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto b/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto index c0aba6c0f66..2be6905564a 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/model_rollouts_service.proto @@ -24,8 +24,6 @@ option java_multiple_files = true; // Internal service for persistence of ModelRollout entities. service ModelRollouts { - // rollout_period_start_time of given `ModelRollout` must be later than that - // of the current previous `ModelRollout` if it is set. rpc CreateModelRollout(ModelRollout) returns (ModelRollout); // Streams `ModelRollout`s. From 9570fd33b604610669c62e8968d8454088cb73b0 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Mon, 26 Jun 2023 12:56:07 -0700 Subject: [PATCH 4/4] Remove logic that sets an instant(rollout start time === rollout end time) ModelRollout's previous rollout to null. This will be treated as regular ModelRollout. Update test to reflect this change. --- .../gcloud/spanner/writers/CreateModelRollout.kt | 14 +++++--------- .../internal/testing/ModelRolloutsServiceTest.kt | 5 +++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt index 7839003ed06..feddeabc964 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRollout.kt @@ -73,15 +73,11 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo } val latestModelRolloutData = - if (modelRollout.rolloutPeriodStartTime != modelRollout.rolloutPeriodEndTime) { - readLatestModelRolloutData( - ExternalId(modelRollout.externalModelProviderId), - ExternalId(modelRollout.externalModelSuiteId), - ExternalId(modelRollout.externalModelLineId), - ) - } else { - null - } + readLatestModelRolloutData( + ExternalId(modelRollout.externalModelProviderId), + ExternalId(modelRollout.externalModelSuiteId), + ExternalId(modelRollout.externalModelLineId), + ) val latestModelRolloutStartTime = latestModelRolloutData?.rolloutPeriodStartTime diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt index f164d80544f..a00529fe93c 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/ModelRolloutsServiceTest.kt @@ -199,7 +199,7 @@ abstract class ModelRolloutsServiceTest { } @Test - fun `createModelRollout correctly leave previous model rollout unset when rollout start time is equal to rollout end time`() = + fun `createModelRollout correctly sets previous model rollout when rollout start time is equal to rollout end time`() = runBlocking { val modelLine = population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService) @@ -234,7 +234,8 @@ abstract class ModelRolloutsServiceTest { val createdModelRollout2 = modelRolloutsService.createModelRollout(modelRollout2) - assertThat(createdModelRollout2.externalPreviousModelRolloutId).isEqualTo(0L) + assertThat(createdModelRollout2.externalPreviousModelRolloutId) + .isEqualTo(createdModelRollout.externalModelRolloutId) } @Test