Skip to content

Commit

Permalink
fix: Ensure population is persisted in ModelRelease upon creation (#1914
Browse files Browse the repository at this point in the history
)

There is a bug where population is not added to the model release upon
creation. This change fixes that bug by 1) including the population
field when a new model release is created, 2) populating the
externalDataProviderId(population data provider) and
externalPopulationId fields when converting from external to internal
model release, and 3) populating the population field when converting
from internal to external model release.

---------

Co-authored-by: jojijac0b <[email protected]>
  • Loading branch information
jojijac0b and jojijac0b authored Nov 13, 2024
1 parent a1d5e05 commit 684948d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.wfanet.measurement.gcloud.spanner.statement
import org.wfanet.measurement.internal.kingdom.ModelRelease
import org.wfanet.measurement.internal.kingdom.copy
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.ModelSuiteNotFoundException
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.PopulationNotFoundException

class CreateModelRelease(private val modelRelease: ModelRelease) :
SpannerWriter<ModelRelease, ModelRelease>() {
Expand All @@ -45,7 +46,7 @@ class CreateModelRelease(private val modelRelease: ModelRelease) :
val externalPopulationId = ExternalId(modelRelease.externalPopulationId)
val populationData: Struct =
readPopulationData(externalDataProviderId, externalPopulationId)
?: throw ModelSuiteNotFoundException(externalDataProviderId, externalPopulationId)
?: throw PopulationNotFoundException(externalDataProviderId, externalPopulationId)

val internalModelReleaseId = idGenerator.generateInternalId()
val externalModelReleaseId = idGenerator.generateExternalId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.wfanet.measurement.api.v2alpha.ModelRelease
import org.wfanet.measurement.api.v2alpha.ModelReleaseKey
import org.wfanet.measurement.api.v2alpha.ModelReleasesGrpcKt.ModelReleasesCoroutineImplBase as ModelReleasesCoroutineService
import org.wfanet.measurement.api.v2alpha.ModelSuiteKey
import org.wfanet.measurement.api.v2alpha.PopulationKey
import org.wfanet.measurement.api.v2alpha.copy
import org.wfanet.measurement.api.v2alpha.listModelReleasesPageToken
import org.wfanet.measurement.api.v2alpha.listModelReleasesResponse
Expand Down Expand Up @@ -62,6 +63,10 @@ class ModelReleasesService(private val internalClient: ModelReleasesCoroutineStu
grpcRequireNotNull(ModelSuiteKey.fromName(request.parent)) {
"Parent is either unspecified or invalid"
}
val populationKey =
grpcRequireNotNull(PopulationKey.fromName(request.modelRelease.population)) {
"Population is either unspecified or invalid"
}

when (val principal: MeasurementPrincipal = principalFromCurrentContext) {
is ModelProviderPrincipal -> {
Expand All @@ -78,7 +83,7 @@ class ModelReleasesService(private val internalClient: ModelReleasesCoroutineStu
}
}

val createModelReleaseRequest = request.modelRelease.toInternal(parentKey)
val createModelReleaseRequest = request.modelRelease.toInternal(parentKey, populationKey)
return try {
internalClient.createModelRelease(createModelReleaseRequest).toModelRelease()
} catch (e: StatusException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,15 @@ fun ModelSuite.toInternal(modelProviderKey: ModelProviderKey): InternalModelSuit

/** Converts a public [ModelRelease] to an internal [InternalModelRelease] */
@Suppress("UnusedReceiverParameter") //
fun ModelRelease.toInternal(modelSuiteKey: ModelSuiteKey): InternalModelRelease {
fun ModelRelease.toInternal(
modelSuiteKey: ModelSuiteKey,
populationKey: PopulationKey,
): InternalModelRelease {
return internalModelRelease {
externalModelProviderId = apiIdToExternalId(modelSuiteKey.modelProviderId)
externalModelSuiteId = apiIdToExternalId(modelSuiteKey.modelSuiteId)
externalDataProviderId = apiIdToExternalId(populationKey.dataProviderId)
externalPopulationId = apiIdToExternalId(populationKey.populationId)
}
}

Expand All @@ -574,6 +579,12 @@ fun InternalModelRelease.toModelRelease(): ModelRelease {
)
.toName()
createTime = source.createTime
population =
PopulationKey(
externalIdToApiId(source.externalDataProviderId),
externalIdToApiId(source.externalPopulationId),
)
.toName()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import org.junit.runners.JUnit4
import org.mockito.kotlin.any
import org.mockito.kotlin.stub
import org.mockito.kotlin.verify
import org.wfanet.measurement.api.v2alpha.DataProviderKey
import org.wfanet.measurement.api.v2alpha.ListModelReleasesPageTokenKt.previousPageEnd
import org.wfanet.measurement.api.v2alpha.ListModelReleasesRequest
import org.wfanet.measurement.api.v2alpha.ModelProviderKey
import org.wfanet.measurement.api.v2alpha.ModelRelease
import org.wfanet.measurement.api.v2alpha.ModelReleaseKey
import org.wfanet.measurement.api.v2alpha.ModelSuiteKey
import org.wfanet.measurement.api.v2alpha.PopulationKey
import org.wfanet.measurement.api.v2alpha.copy
import org.wfanet.measurement.api.v2alpha.createModelReleaseRequest
import org.wfanet.measurement.api.v2alpha.getModelReleaseRequest
Expand Down Expand Up @@ -86,6 +88,8 @@ private const val MODEL_RELEASE_NAME = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAA
private const val MODEL_RELEASE_NAME_2 = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAAAJs"
private const val MODEL_RELEASE_NAME_3 = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAAAKs"
private const val MODEL_RELEASE_NAME_4 = "$MODEL_SUITE_NAME_2/modelReleases/AAAAAAAAAHs"
private val POPULATION_NAME = "$DATA_PROVIDER_NAME/populations/AAAAAAAAAHs"

private val EXTERNAL_MODEL_PROVIDER_ID =
apiIdToExternalId(ModelProviderKey.fromName(MODEL_PROVIDER_NAME)!!.modelProviderId)
private val EXTERNAL_MODEL_SUITE_ID =
Expand All @@ -96,6 +100,10 @@ private val EXTERNAL_MODEL_RELEASE_ID_2 =
apiIdToExternalId(ModelReleaseKey.fromName(MODEL_RELEASE_NAME_2)!!.modelReleaseId)
private val EXTERNAL_MODEL_RELEASE_ID_3 =
apiIdToExternalId(ModelReleaseKey.fromName(MODEL_RELEASE_NAME_3)!!.modelReleaseId)
private val EXTERNAL_DATA_PROVIDER_ID =
apiIdToExternalId(DataProviderKey.fromName(DATA_PROVIDER_NAME)!!.dataProviderId)
private val EXTERNAL_POPULATION_ID =
apiIdToExternalId(PopulationKey.fromName(POPULATION_NAME)!!.populationId)

private val CREATE_TIME: Timestamp = Instant.ofEpochSecond(123).toProtoTime()

Expand All @@ -104,11 +112,14 @@ private val INTERNAL_MODEL_RELEASE: InternalModelRelease = internalModelRelease
externalModelSuiteId = EXTERNAL_MODEL_SUITE_ID
externalModelReleaseId = EXTERNAL_MODEL_RELEASE_ID
createTime = CREATE_TIME
externalDataProviderId = EXTERNAL_DATA_PROVIDER_ID
externalPopulationId = EXTERNAL_POPULATION_ID
}

private val MODEL_RELEASE: ModelRelease = modelRelease {
name = MODEL_RELEASE_NAME
createTime = CREATE_TIME
population = POPULATION_NAME
}

@RunWith(JUnit4::class)
Expand Down

0 comments on commit 684948d

Please sign in to comment.