Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix that only allows creation of ModelRollouts with RolloutPeriodStartTime after that of previous ModelRollout #1076

Merged
merged 5 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@

package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers

import com.google.cloud.Timestamp
import com.google.cloud.spanner.Statement
import com.google.cloud.spanner.Struct
import com.google.cloud.spanner.Value
import com.google.protobuf.Timestamp
import com.google.protobuf.util.Timestamps
import java.time.Clock
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.flow.singleOrNull
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.common.toInstant
import org.wfanet.measurement.common.toProtoTime
import org.wfanet.measurement.gcloud.common.toGcloudTimestamp
import org.wfanet.measurement.gcloud.common.toInstant
import org.wfanet.measurement.gcloud.spanner.bind
import org.wfanet.measurement.gcloud.spanner.bufferInsertMutation
import org.wfanet.measurement.gcloud.spanner.set
Expand All @@ -41,6 +44,11 @@ import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.ModelRelease
class CreateModelRollout(private val modelRollout: ModelRollout, private val clock: Clock) :
SpannerWriter<ModelRollout, ModelRollout>() {

private data class LatestModelRolloutResult(
val modelRolloutId: Long,
val externalModelRolloutId: Long,
val rolloutPeriodStartTime: Timestamp?
)
override suspend fun TransactionScope.runTransaction(): ModelRollout {
val now = clock.instant().toProtoTime()
if (Timestamps.compare(now, modelRollout.rolloutPeriodStartTime) >= 0) {
Expand All @@ -52,7 +60,6 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
"RolloutPeriodStartTime must be in the future."
}
}

if (
Timestamps.compare(modelRollout.rolloutPeriodStartTime, modelRollout.rolloutPeriodEndTime) > 0
) {
Expand All @@ -65,18 +72,29 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
}
}

val previousModelRolloutData =
if (modelRollout.rolloutPeriodStartTime != modelRollout.rolloutPeriodEndTime) {
readModelRolloutData(
ExternalId(modelRollout.externalModelProviderId),
ExternalId(modelRollout.externalModelSuiteId),
ExternalId(modelRollout.externalModelLineId),
modelRollout.rolloutPeriodStartTime
)
} else {
null
}
val latestModelRolloutData =
readLatestModelRolloutData(
ExternalId(modelRollout.externalModelProviderId),
ExternalId(modelRollout.externalModelSuiteId),
ExternalId(modelRollout.externalModelLineId),
)

val latestModelRolloutStartTime = latestModelRolloutData?.rolloutPeriodStartTime

// New ModelRollout cannot have a RolloutStartTime that precedes the RolloutStartTime of the
// current PreviousModelRollout
if (
latestModelRolloutStartTime != null &&
latestModelRolloutStartTime.toInstant() > modelRollout.rolloutPeriodStartTime.toInstant()
) {
throw ModelRolloutInvalidArgsException(
ExternalId(modelRollout.externalModelProviderId),
ExternalId(modelRollout.externalModelSuiteId),
ExternalId(modelRollout.externalModelLineId)
) {
"RolloutPeriodStartTime cannot precede that of previous ModelRollout."
}
}
val modelLineData =
readModelLineData(
ExternalId(modelRollout.externalModelProviderId),
Expand Down Expand Up @@ -114,10 +132,8 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
set("ExternalModelRolloutId" to externalModelRolloutId)
set("RolloutPeriodStartTime" to modelRollout.rolloutPeriodStartTime.toGcloudTimestamp())
set("RolloutPeriodEndTime" to modelRollout.rolloutPeriodEndTime.toGcloudTimestamp())
if (previousModelRolloutData != null) {
set(
"PreviousModelRolloutId" to InternalId(previousModelRolloutData.getLong("ModelRolloutId"))
)
if (latestModelRolloutData != null) {
set("PreviousModelRolloutId" to InternalId(latestModelRolloutData.modelRolloutId))
}
set("ModelReleaseId" to modelReleaseResult.modelReleaseId)
set("CreateTime" to Value.COMMIT_TIMESTAMP)
Expand All @@ -126,9 +142,8 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo

return modelRollout.copy {
this.externalModelRolloutId = externalModelRolloutId.value
if (previousModelRolloutData != null) {
this.externalPreviousModelRolloutId =
previousModelRolloutData.getLong("ExternalModelRolloutId")
if (latestModelRolloutData != null) {
this.externalPreviousModelRolloutId = latestModelRolloutData.externalModelRolloutId
}
}
}
Expand Down Expand Up @@ -166,17 +181,18 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
return transactionContext.executeQuery(statement).singleOrNull()
}

private suspend fun TransactionScope.readModelRolloutData(
// Reads the ModelRollout for a given ModelLine with the latest RolloutPeriodStartTime.
private suspend fun TransactionScope.readLatestModelRolloutData(
externalModelProviderId: ExternalId,
externalModelSuiteId: ExternalId,
externalModelLineId: ExternalId,
rolloutStartTime: Timestamp
): Struct? {
): LatestModelRolloutResult? {
val sql =
"""
SELECT
ModelRollouts.ModelRolloutId,
ModelRollouts.ExternalModelRolloutId
ModelRollouts.ExternalModelRolloutId,
ModelRollouts.RolloutPeriodStartTime
FROM ModelRollouts JOIN ModelLines
USING (ModelProviderId, ModelSuiteId, ModelLineId)
JOIN ModelSuites
Expand All @@ -185,7 +201,6 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
WHERE ModelProviders.ExternalModelProviderId = @externalModelProviderId AND
ModelSuites.ExternalModelSuiteId = @externalModelSuiteId AND
ModelLines.ExternalModelLineId = @externalModelLineId
AND ModelRollouts.RolloutPeriodStartTime < @rolloutPeriodStartTime
ORDER BY ModelRollouts.RolloutPeriodStartTime DESC
"""
.trimIndent()
Expand All @@ -195,10 +210,17 @@ class CreateModelRollout(private val modelRollout: ModelRollout, private val clo
bind("externalModelProviderId" to externalModelProviderId.value)
bind("externalModelSuiteId" to externalModelSuiteId.value)
bind("externalModelLineId" to externalModelLineId.value)
bind("rolloutPeriodStartTime" to rolloutStartTime.toGcloudTimestamp())
}

return transactionContext.executeQuery(statement).singleOrNull()
val result = transactionContext.executeQuery(statement).firstOrNull()

return if (result == null) null
else
LatestModelRolloutResult(
result.getLong("ModelRolloutId"),
result.getLong("ExternalModelRolloutId"),
result.getTimestamp("RolloutPeriodStartTime")
)
}

override fun ResultScope<ModelRollout>.buildResult(): ModelRollout {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
}

@Test
fun `createModelRollout correctly leave previous model rollout unset when rollout start time is equal to rollout end time`() =
fun `createModelRollout correctly sets previous model rollout when rollout start time is equal to rollout end time`() =
runBlocking {
val modelLine =
population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService)
Expand Down Expand Up @@ -234,7 +234,8 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {

val createdModelRollout2 = modelRolloutsService.createModelRollout(modelRollout2)

assertThat(createdModelRollout2.externalPreviousModelRolloutId).isEqualTo(0L)
assertThat(createdModelRollout2.externalPreviousModelRolloutId)
.isEqualTo(createdModelRollout.externalModelRolloutId)
}

@Test
Expand Down Expand Up @@ -826,11 +827,11 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
}
val modelRollout1 =
modelRolloutsService.createModelRollout(
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(150L).toProtoTime() }
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(50L).toProtoTime() }
)
val modelRollout2 =
modelRolloutsService.createModelRollout(
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(50L).toProtoTime() }
modelRollout.copy { rolloutPeriodStartTime = Instant.now().plusSeconds(150L).toProtoTime() }
)

val modelRollouts: List<ModelRollout> =
Expand All @@ -848,7 +849,7 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
.toList()

assertThat(modelRollouts).hasSize(1)
assertThat(modelRollouts).contains(modelRollout2)
assertThat(modelRollouts).contains(modelRollout1)

val modelRollouts2: List<ModelRollout> =
modelRolloutsService
Expand All @@ -871,7 +872,7 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
.toList()

assertThat(modelRollouts2).hasSize(1)
assertThat(modelRollouts2).contains(modelRollout1)
assertThat(modelRollouts2).contains(modelRollout2)
}

@Test
Expand Down Expand Up @@ -1046,4 +1047,106 @@ abstract class ModelRolloutsServiceTest<T : ModelRolloutsCoroutineImplBase> {
assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
assertThat(exception).hasMessageThat().contains("Missing RolloutPeriod fields")
}

@Test
fun `createModelRollout succeeds with multiple model rollouts`() = runBlocking {
val modelLine =
population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService)
val modelRelease =
population.createModelRelease(
modelSuite {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
},
modelReleasesService
)

val modelRollout = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(100L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(100L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
modelRolloutsService.createModelRollout(modelRollout)

val modelRollout2 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(200L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(300L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
val createdModelRollout2 = modelRolloutsService.createModelRollout(modelRollout2)

val modelRollout3 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(900L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(1100L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
val createdModelRollout3 = modelRolloutsService.createModelRollout(modelRollout3)

assertThat(createdModelRollout3.externalPreviousModelRolloutId)
.isEqualTo(createdModelRollout2.externalModelRolloutId)
}

@Test
fun `createModelRollout fails when new model rollout start time precedes that of previous model rollout`() =
runBlocking {
val modelLine =
population.createModelLine(modelProvidersService, modelSuitesService, modelLinesService)
val modelRelease =
population.createModelRelease(
modelSuite {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
},
modelReleasesService
)

val modelRollout = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(100L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(100L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}
modelRolloutsService.createModelRollout(modelRollout)

val modelRollout2 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(300L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(400L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}

modelRolloutsService.createModelRollout(modelRollout2)

val modelRollout3 = modelRollout {
externalModelProviderId = modelLine.externalModelProviderId
externalModelSuiteId = modelLine.externalModelSuiteId
externalModelLineId = modelLine.externalModelLineId
rolloutPeriodStartTime = Instant.now().plusSeconds(200L).toProtoTime()
rolloutPeriodEndTime = Instant.now().plusSeconds(300L).toProtoTime()
externalModelReleaseId = modelRelease.externalModelReleaseId
}

val exception =
assertFailsWith<StatusRuntimeException> {
modelRolloutsService.createModelRollout(modelRollout3)
}

assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT)
assertThat(exception)
.hasMessageThat()
.contains("RolloutPeriodStartTime cannot precede that of previous ModelRollout.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ option java_multiple_files = true;
// Internal service for persistence of ModelRollout entities.
service ModelRollouts {
rpc CreateModelRollout(ModelRollout) returns (ModelRollout);

// Streams `ModelRollout`s.
rpc StreamModelRollouts(StreamModelRolloutsRequest)
returns (stream ModelRollout);
Expand Down