From 684948d04042141c3e99db6ee7aa2f92ff9c9d49 Mon Sep 17 00:00:00 2001 From: jojijac0b Date: Wed, 13 Nov 2024 15:17:09 -0500 Subject: [PATCH] fix: Ensure population is persisted in ModelRelease upon creation (#1914) There is a bug where population is not added to the model release upon creation. This change fixes that bug by 1) including the population field when a new model release is created, 2) populating the externalDataProviderId(population data provider) and externalPopulationId fields when converting from external to internal model release, and 3) populating the population field when converting from internal to external model release. --------- Co-authored-by: jojijac0b --- .../gcloud/spanner/writers/CreateModelRelease.kt | 3 ++- .../service/api/v2alpha/ModelReleasesService.kt | 7 ++++++- .../kingdom/service/api/v2alpha/ProtoConversions.kt | 13 ++++++++++++- .../service/api/v2alpha/ModelReleasesServiceTest.kt | 11 +++++++++++ 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt index ba7b5f7df39..3528a94a380 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt @@ -29,6 +29,7 @@ import org.wfanet.measurement.gcloud.spanner.statement import org.wfanet.measurement.internal.kingdom.ModelRelease import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.ModelSuiteNotFoundException +import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.PopulationNotFoundException class CreateModelRelease(private val modelRelease: ModelRelease) : SpannerWriter() { @@ -45,7 +46,7 @@ class CreateModelRelease(private val modelRelease: ModelRelease) : val externalPopulationId = ExternalId(modelRelease.externalPopulationId) val populationData: Struct = readPopulationData(externalDataProviderId, externalPopulationId) - ?: throw ModelSuiteNotFoundException(externalDataProviderId, externalPopulationId) + ?: throw PopulationNotFoundException(externalDataProviderId, externalPopulationId) val internalModelReleaseId = idGenerator.generateInternalId() val externalModelReleaseId = idGenerator.generateExternalId() diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt index 4ce3c34323c..f0d65432db9 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt @@ -33,6 +33,7 @@ import org.wfanet.measurement.api.v2alpha.ModelRelease import org.wfanet.measurement.api.v2alpha.ModelReleaseKey import org.wfanet.measurement.api.v2alpha.ModelReleasesGrpcKt.ModelReleasesCoroutineImplBase as ModelReleasesCoroutineService import org.wfanet.measurement.api.v2alpha.ModelSuiteKey +import org.wfanet.measurement.api.v2alpha.PopulationKey import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.listModelReleasesPageToken import org.wfanet.measurement.api.v2alpha.listModelReleasesResponse @@ -62,6 +63,10 @@ class ModelReleasesService(private val internalClient: ModelReleasesCoroutineStu grpcRequireNotNull(ModelSuiteKey.fromName(request.parent)) { "Parent is either unspecified or invalid" } + val populationKey = + grpcRequireNotNull(PopulationKey.fromName(request.modelRelease.population)) { + "Population is either unspecified or invalid" + } when (val principal: MeasurementPrincipal = principalFromCurrentContext) { is ModelProviderPrincipal -> { @@ -78,7 +83,7 @@ class ModelReleasesService(private val internalClient: ModelReleasesCoroutineStu } } - val createModelReleaseRequest = request.modelRelease.toInternal(parentKey) + val createModelReleaseRequest = request.modelRelease.toInternal(parentKey, populationKey) return try { internalClient.createModelRelease(createModelReleaseRequest).toModelRelease() } catch (e: StatusException) { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt index 9f2b5a92a7e..f413bc1ddcc 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt @@ -554,10 +554,15 @@ fun ModelSuite.toInternal(modelProviderKey: ModelProviderKey): InternalModelSuit /** Converts a public [ModelRelease] to an internal [InternalModelRelease] */ @Suppress("UnusedReceiverParameter") // -fun ModelRelease.toInternal(modelSuiteKey: ModelSuiteKey): InternalModelRelease { +fun ModelRelease.toInternal( + modelSuiteKey: ModelSuiteKey, + populationKey: PopulationKey, +): InternalModelRelease { return internalModelRelease { externalModelProviderId = apiIdToExternalId(modelSuiteKey.modelProviderId) externalModelSuiteId = apiIdToExternalId(modelSuiteKey.modelSuiteId) + externalDataProviderId = apiIdToExternalId(populationKey.dataProviderId) + externalPopulationId = apiIdToExternalId(populationKey.populationId) } } @@ -574,6 +579,12 @@ fun InternalModelRelease.toModelRelease(): ModelRelease { ) .toName() createTime = source.createTime + population = + PopulationKey( + externalIdToApiId(source.externalDataProviderId), + externalIdToApiId(source.externalPopulationId), + ) + .toName() } } diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt index 99a88a369b7..2dee1c13b9d 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt @@ -33,12 +33,14 @@ import org.junit.runners.JUnit4 import org.mockito.kotlin.any import org.mockito.kotlin.stub import org.mockito.kotlin.verify +import org.wfanet.measurement.api.v2alpha.DataProviderKey import org.wfanet.measurement.api.v2alpha.ListModelReleasesPageTokenKt.previousPageEnd import org.wfanet.measurement.api.v2alpha.ListModelReleasesRequest import org.wfanet.measurement.api.v2alpha.ModelProviderKey import org.wfanet.measurement.api.v2alpha.ModelRelease import org.wfanet.measurement.api.v2alpha.ModelReleaseKey import org.wfanet.measurement.api.v2alpha.ModelSuiteKey +import org.wfanet.measurement.api.v2alpha.PopulationKey import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.createModelReleaseRequest import org.wfanet.measurement.api.v2alpha.getModelReleaseRequest @@ -86,6 +88,8 @@ private const val MODEL_RELEASE_NAME = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAA private const val MODEL_RELEASE_NAME_2 = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAAAJs" private const val MODEL_RELEASE_NAME_3 = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAAAKs" private const val MODEL_RELEASE_NAME_4 = "$MODEL_SUITE_NAME_2/modelReleases/AAAAAAAAAHs" +private val POPULATION_NAME = "$DATA_PROVIDER_NAME/populations/AAAAAAAAAHs" + private val EXTERNAL_MODEL_PROVIDER_ID = apiIdToExternalId(ModelProviderKey.fromName(MODEL_PROVIDER_NAME)!!.modelProviderId) private val EXTERNAL_MODEL_SUITE_ID = @@ -96,6 +100,10 @@ private val EXTERNAL_MODEL_RELEASE_ID_2 = apiIdToExternalId(ModelReleaseKey.fromName(MODEL_RELEASE_NAME_2)!!.modelReleaseId) private val EXTERNAL_MODEL_RELEASE_ID_3 = apiIdToExternalId(ModelReleaseKey.fromName(MODEL_RELEASE_NAME_3)!!.modelReleaseId) +private val EXTERNAL_DATA_PROVIDER_ID = + apiIdToExternalId(DataProviderKey.fromName(DATA_PROVIDER_NAME)!!.dataProviderId) +private val EXTERNAL_POPULATION_ID = + apiIdToExternalId(PopulationKey.fromName(POPULATION_NAME)!!.populationId) private val CREATE_TIME: Timestamp = Instant.ofEpochSecond(123).toProtoTime() @@ -104,11 +112,14 @@ private val INTERNAL_MODEL_RELEASE: InternalModelRelease = internalModelRelease externalModelSuiteId = EXTERNAL_MODEL_SUITE_ID externalModelReleaseId = EXTERNAL_MODEL_RELEASE_ID createTime = CREATE_TIME + externalDataProviderId = EXTERNAL_DATA_PROVIDER_ID + externalPopulationId = EXTERNAL_POPULATION_ID } private val MODEL_RELEASE: ModelRelease = modelRelease { name = MODEL_RELEASE_NAME createTime = CREATE_TIME + population = POPULATION_NAME } @RunWith(JUnit4::class)