diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt index ba7b5f7df39..3528a94a380 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateModelRelease.kt @@ -29,6 +29,7 @@ import org.wfanet.measurement.gcloud.spanner.statement import org.wfanet.measurement.internal.kingdom.ModelRelease import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.ModelSuiteNotFoundException +import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.PopulationNotFoundException class CreateModelRelease(private val modelRelease: ModelRelease) : SpannerWriter() { @@ -45,7 +46,7 @@ class CreateModelRelease(private val modelRelease: ModelRelease) : val externalPopulationId = ExternalId(modelRelease.externalPopulationId) val populationData: Struct = readPopulationData(externalDataProviderId, externalPopulationId) - ?: throw ModelSuiteNotFoundException(externalDataProviderId, externalPopulationId) + ?: throw PopulationNotFoundException(externalDataProviderId, externalPopulationId) val internalModelReleaseId = idGenerator.generateInternalId() val externalModelReleaseId = idGenerator.generateExternalId() diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt index 4ce3c34323c..f0d65432db9 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesService.kt @@ -33,6 +33,7 @@ import org.wfanet.measurement.api.v2alpha.ModelRelease import org.wfanet.measurement.api.v2alpha.ModelReleaseKey import org.wfanet.measurement.api.v2alpha.ModelReleasesGrpcKt.ModelReleasesCoroutineImplBase as ModelReleasesCoroutineService import org.wfanet.measurement.api.v2alpha.ModelSuiteKey +import org.wfanet.measurement.api.v2alpha.PopulationKey import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.listModelReleasesPageToken import org.wfanet.measurement.api.v2alpha.listModelReleasesResponse @@ -62,6 +63,10 @@ class ModelReleasesService(private val internalClient: ModelReleasesCoroutineStu grpcRequireNotNull(ModelSuiteKey.fromName(request.parent)) { "Parent is either unspecified or invalid" } + val populationKey = + grpcRequireNotNull(PopulationKey.fromName(request.modelRelease.population)) { + "Population is either unspecified or invalid" + } when (val principal: MeasurementPrincipal = principalFromCurrentContext) { is ModelProviderPrincipal -> { @@ -78,7 +83,7 @@ class ModelReleasesService(private val internalClient: ModelReleasesCoroutineStu } } - val createModelReleaseRequest = request.modelRelease.toInternal(parentKey) + val createModelReleaseRequest = request.modelRelease.toInternal(parentKey, populationKey) return try { internalClient.createModelRelease(createModelReleaseRequest).toModelRelease() } catch (e: StatusException) { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt index 9f2b5a92a7e..f413bc1ddcc 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt @@ -554,10 +554,15 @@ fun ModelSuite.toInternal(modelProviderKey: ModelProviderKey): InternalModelSuit /** Converts a public [ModelRelease] to an internal [InternalModelRelease] */ @Suppress("UnusedReceiverParameter") // -fun ModelRelease.toInternal(modelSuiteKey: ModelSuiteKey): InternalModelRelease { +fun ModelRelease.toInternal( + modelSuiteKey: ModelSuiteKey, + populationKey: PopulationKey, +): InternalModelRelease { return internalModelRelease { externalModelProviderId = apiIdToExternalId(modelSuiteKey.modelProviderId) externalModelSuiteId = apiIdToExternalId(modelSuiteKey.modelSuiteId) + externalDataProviderId = apiIdToExternalId(populationKey.dataProviderId) + externalPopulationId = apiIdToExternalId(populationKey.populationId) } } @@ -574,6 +579,12 @@ fun InternalModelRelease.toModelRelease(): ModelRelease { ) .toName() createTime = source.createTime + population = + PopulationKey( + externalIdToApiId(source.externalDataProviderId), + externalIdToApiId(source.externalPopulationId), + ) + .toName() } } diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt index 99a88a369b7..2dee1c13b9d 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ModelReleasesServiceTest.kt @@ -33,12 +33,14 @@ import org.junit.runners.JUnit4 import org.mockito.kotlin.any import org.mockito.kotlin.stub import org.mockito.kotlin.verify +import org.wfanet.measurement.api.v2alpha.DataProviderKey import org.wfanet.measurement.api.v2alpha.ListModelReleasesPageTokenKt.previousPageEnd import org.wfanet.measurement.api.v2alpha.ListModelReleasesRequest import org.wfanet.measurement.api.v2alpha.ModelProviderKey import org.wfanet.measurement.api.v2alpha.ModelRelease import org.wfanet.measurement.api.v2alpha.ModelReleaseKey import org.wfanet.measurement.api.v2alpha.ModelSuiteKey +import org.wfanet.measurement.api.v2alpha.PopulationKey import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.createModelReleaseRequest import org.wfanet.measurement.api.v2alpha.getModelReleaseRequest @@ -86,6 +88,8 @@ private const val MODEL_RELEASE_NAME = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAA private const val MODEL_RELEASE_NAME_2 = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAAAJs" private const val MODEL_RELEASE_NAME_3 = "$MODEL_SUITE_NAME/modelReleases/AAAAAAAAAKs" private const val MODEL_RELEASE_NAME_4 = "$MODEL_SUITE_NAME_2/modelReleases/AAAAAAAAAHs" +private val POPULATION_NAME = "$DATA_PROVIDER_NAME/populations/AAAAAAAAAHs" + private val EXTERNAL_MODEL_PROVIDER_ID = apiIdToExternalId(ModelProviderKey.fromName(MODEL_PROVIDER_NAME)!!.modelProviderId) private val EXTERNAL_MODEL_SUITE_ID = @@ -96,6 +100,10 @@ private val EXTERNAL_MODEL_RELEASE_ID_2 = apiIdToExternalId(ModelReleaseKey.fromName(MODEL_RELEASE_NAME_2)!!.modelReleaseId) private val EXTERNAL_MODEL_RELEASE_ID_3 = apiIdToExternalId(ModelReleaseKey.fromName(MODEL_RELEASE_NAME_3)!!.modelReleaseId) +private val EXTERNAL_DATA_PROVIDER_ID = + apiIdToExternalId(DataProviderKey.fromName(DATA_PROVIDER_NAME)!!.dataProviderId) +private val EXTERNAL_POPULATION_ID = + apiIdToExternalId(PopulationKey.fromName(POPULATION_NAME)!!.populationId) private val CREATE_TIME: Timestamp = Instant.ofEpochSecond(123).toProtoTime() @@ -104,11 +112,14 @@ private val INTERNAL_MODEL_RELEASE: InternalModelRelease = internalModelRelease externalModelSuiteId = EXTERNAL_MODEL_SUITE_ID externalModelReleaseId = EXTERNAL_MODEL_RELEASE_ID createTime = CREATE_TIME + externalDataProviderId = EXTERNAL_DATA_PROVIDER_ID + externalPopulationId = EXTERNAL_POPULATION_ID } private val MODEL_RELEASE: ModelRelease = modelRelease { name = MODEL_RELEASE_NAME createTime = CREATE_TIME + population = POPULATION_NAME } @RunWith(JUnit4::class)