Skip to content

Commit

Permalink
Fix that only allows creation of ModelRollouts with RolloutPeriodStar…
Browse files Browse the repository at this point in the history
…tTime after that of previous ModelRollout (#1076)

Update query in readModelRolloutData to return the ModelRollout for a
given ModelLine with the latest RolloutPeriodStartTime in
CreateModelRollout.kt. Update CreateModelRollout.kt to throw
ModelRolloutInvalidArgsException if the RolloutPeriodStartTime of the
ModelRollout to be created precedes that of the last ModelRollout.
Create tests in ModelRolloutServiceTest.kt. Update 'createModelRollout
succeeds with multiple model releases' test to adhere to new
RolloutPeriodStartTime limitations.
  • Loading branch information
jojijac0b authored Jul 7, 2023
1 parent 8ffb126 commit ed63c3a
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@

package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers

import com.google.cloud.Timestamp
import com.google.cloud.spanner.Statement
import com.google.cloud.spanner.Struct
import com.google.cloud.spanner.Value
import com.google.protobuf.Timestamp
import com.google.protobuf.util.Timestamps
import java.time.Clock
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.flow.singleOrNull
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.common.toInstant
import org.wfanet.measurement.common.toProtoTime
import org.wfanet.measurement.gcloud.common.toGcloudTimestamp
import org.wfanet.measurement.gcloud.common.toInstant
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.bufferInsertMutation
import org.wfanet.measurement.gcloud.spanner.set
Expand All @@ -41,6 +44,11 @@ import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.ModelRelease
class CreateModelRollout(private val modelRollout: ModelRollout, private val clock: Clock) :
SpannerWriter<ModelRollout, ModelRollout>() {

private data class LatestModelRolloutResult(
val modelRolloutId: Long,
val externalModelRolloutId: Long,
val rolloutPeriodStartTime: Timestamp?
)
override suspend fun TransactionScope.runTransaction(): ModelRollout {
val now = clock.instant().toProtoTime()
if (Timestamps.compare(now, modelRollout.rolloutPeriodStartTime) >= 0) {
Expand All @@ -52,7 +60,6 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
"RolloutPeriodStartTime must be in the future."
}
}

if (
Timestamps.compare(modelRollout.rolloutPeriodStartTime, modelRollout.rolloutPeriodEndTime) > 0
) {
Expand All @@ -65,18 +72,29 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
}
}

val previousModelRolloutData =
if (modelRollout.rolloutPeriodStartTime != modelRollout.rolloutPeriodEndTime) {
readModelRolloutData(
ExternalId(modelRollout.externalModelProviderId),
ExternalId(modelRollout.externalModelSuiteId),
ExternalId(modelRollout.externalModelLineId),
modelRollout.rolloutPeriodStartTime
)
} else {
null
}
val latestModelRolloutData =
readLatestModelRolloutData(
ExternalId(modelRollout.externalModelProviderId),
ExternalId(modelRollout.externalModelSuiteId),
ExternalId(modelRollout.externalModelLineId),
)

val latestModelRolloutStartTime = latestModelRolloutData?.rolloutPeriodStartTime

// New ModelRollout cannot have a RolloutStartTime that precedes the RolloutStartTime of the
// current PreviousModelRollout
if (
latestModelRolloutStartTime != null &&
latestModelRolloutStartTime.toInstant() > modelRollout.rolloutPeriodStartTime.toInstant()
) {
throw ModelRolloutInvalidArgsException(
ExternalId(modelRollout.externalModelProviderId),
ExternalId(modelRollout.externalModelSuiteId),
ExternalId(modelRollout.externalModelLineId)
) {
"RolloutPeriodStartTime cannot precede that of previous ModelRollout."
}
}
val modelLineData =
readModelLineData(
ExternalId(modelRollout.externalModelProviderId),
Expand Down Expand Up @@ -114,10 +132,8 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
set("ExternalModelRolloutId" to externalModelRolloutId)
set("RolloutPeriodStartTime" to modelRollout.rolloutPeriodStartTime.toGcloudTimestamp())
set("RolloutPeriodEndTime" to modelRollout.rolloutPeriodEndTime.toGcloudTimestamp())
if (previousModelRolloutData != null) {
set(
"PreviousModelRolloutId" to InternalId(previousModelRolloutData.getLong("ModelRolloutId"))
)
if (latestModelRolloutData != null) {
set("PreviousModelRolloutId" to InternalId(latestModelRolloutData.modelRolloutId))
}
set("ModelReleaseId" to modelReleaseResult.modelReleaseId)
set("CreateTime" to Value.COMMIT_TIMESTAMP)
Expand All @@ -126,9 +142,8 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo

return modelRollout.copy {
this.externalModelRolloutId = externalModelRolloutId.value
if (previousModelRolloutData != null) {
this.externalPreviousModelRolloutId =
previousModelRolloutData.getLong("ExternalModelRolloutId")
if (latestModelRolloutData != null) {
this.externalPreviousModelRolloutId = latestModelRolloutData.externalModelRolloutId
}
}
}
Expand Down Expand Up @@ -166,17 +181,18 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
return transactionContext.executeQuery(statement).singleOrNull()
}

private suspend fun TransactionScope.readModelRolloutData(
// Reads the ModelRollout for a given ModelLine with the latest RolloutPeriodStartTime.
private suspend fun TransactionScope.readLatestModelRolloutData(
externalModelProviderId: ExternalId,
externalModelSuiteId: ExternalId,
externalModelLineId: ExternalId,
rolloutStartTime: Timestamp
): Struct? {
): LatestModelRolloutResult? {
val sql =
"""
SELECT
ModelRollouts.ModelRolloutId,
ModelRollouts.ExternalModelRolloutId
ModelRollouts.ExternalModelRolloutId,
ModelRollouts.RolloutPeriodStartTime
FROM ModelRollouts JOIN ModelLines
USING (ModelProviderId, ModelSuiteId, ModelLineId)
JOIN ModelSuites
Expand All @@ -185,7 +201,6 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
WHERE ModelProviders.ExternalModelProviderId = @externalModelProviderId AND
ModelSuites.ExternalModelSuiteId = @externalModelSuiteId AND
ModelLines.ExternalModelLineId = @externalModelLineId
AND ModelRollouts.RolloutPeriodStartTime < @rolloutPeriodStartTime
ORDER BY ModelRollouts.RolloutPeriodStartTime DESC
"""
.trimIndent()
Expand All @@ -195,10 +210,17 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
bind("externalModelProviderId" to externalModelProviderId.value)
bind("externalModelSuiteId" to externalModelSuiteId.value)
bind("externalModelLineId" to externalModelLineId.value)
bind("rolloutPeriodStartTime" to rolloutStartTime.toGcloudTimestamp())
}

return transactionContext.executeQuery(statement).singleOrNull()
val result = transactionContext.executeQuery(statement).firstOrNull()

return if (result == null) null
else
LatestModelRolloutResult(
result.getLong("ModelRolloutId"),
result.getLong("ExternalModelRolloutId"),
result.getTimestamp("RolloutPeriodStartTime")
)
}

override fun ResultScope<ModelRollout>.buildResult(): ModelRollout {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
}

@Test
fun `createModelRollout correctly leave previous model rollout unset when rollout start time is equal to rollout end time`() =
fun `createModelRollout correctly sets previous model rollout when rollout start time is equal to rollout end time`() =
runBlocking {
val modelLine =
population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService)
Expand Down Expand Up @@ -234,7 +234,8 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {

val createdModelRollout2 = modelRolloutsService.createModelRollout(modelRollout2)

assertThat(createdModelRollout2.externalPreviousModelRolloutId).isEqualTo(0L)
assertThat(createdModelRollout2.externalPreviousModelRolloutId)
.isEqualTo(createdModelRollout.externalModelRolloutId)
}

@Test
Expand Down Expand Up @@ -826,11 +827,11 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
}
val modelRollout1 =
modelRolloutsService.createModelRollout(
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(150L).toProtoTime() }
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(50L).toProtoTime() }
)
val modelRollout2 =
modelRolloutsService.createModelRollout(
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(50L).toProtoTime() }
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(150L).toProtoTime() }
)

val modelRollouts: List<ModelRollout> =
Expand All @@ -848,7 +849,7 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
.toList()

assertThat(modelRollouts).hasSize(1)
assertThat(modelRollouts).contains(modelRollout2)
assertThat(modelRollouts).contains(modelRollout1)

val modelRollouts2: List<ModelRollout> =
modelRolloutsService
Expand All @@ -871,7 +872,7 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
.toList()

assertThat(modelRollouts2).hasSize(1)
assertThat(modelRollouts2).contains(modelRollout1)
assertThat(modelRollouts2).contains(modelRollout2)
}

@Test
Expand Down Expand Up @@ -1046,4 +1047,106 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
assertThat(exception).hasMessageThat().contains("Missing RolloutPeriod fields")
}

@Test
fun `createModelRollout succeeds with multiple model rollouts`() = runBlocking {
val modelLine =
population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService)
val modelRelease =
population.createModelRelease(
modelSuite {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
},
modelReleasesService
)

val modelRollout = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(100L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(100L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
modelRolloutsService.createModelRollout(modelRollout)

val modelRollout2 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(200L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(300L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
val createdModelRollout2 = modelRolloutsService.createModelRollout(modelRollout2)

val modelRollout3 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(900L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(1100L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
val createdModelRollout3 = modelRolloutsService.createModelRollout(modelRollout3)

assertThat(createdModelRollout3.externalPreviousModelRolloutId)
.isEqualTo(createdModelRollout2.externalModelRolloutId)
}

@Test
fun `createModelRollout fails when new model rollout start time precedes that of previous model rollout`() =
runBlocking {
val modelLine =
population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService)
val modelRelease =
population.createModelRelease(
modelSuite {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
},
modelReleasesService
)

val modelRollout = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(100L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(100L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
modelRolloutsService.createModelRollout(modelRollout)

val modelRollout2 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(300L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(400L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}

modelRolloutsService.createModelRollout(modelRollout2)

val modelRollout3 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(200L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(300L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}

val exception =
assertFailsWith<StatusRuntimeException> {
modelRolloutsService.createModelRollout(modelRollout3)
}

assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
assertThat(exception)
.hasMessageThat()
.contains("RolloutPeriodStartTime cannot precede that of previous ModelRollout.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ option java_multiple_files = true;
// Internal service for persistence of ModelRollout entities.
service ModelRollouts {
rpc CreateModelRollout(ModelRollout) returns (ModelRollout);

// Streams `ModelRollout`s.
rpc StreamModelRollouts(StreamModelRolloutsRequest)
returns (stream ModelRollout);
Expand Down

0 comments on commit ed63c3a

Please sign in to comment.