Skip to content

Commit

Permalink
Add internal BatchGetDataProviders.
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjayVas committed Mar 13, 2024
1 parent 0359a08 commit 40dab02
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ import org.wfanet.measurement.common.grpc.grpcRequireNotNull
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.IdGenerator
import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.internal.kingdom.BatchGetDataProvidersRequest
import org.wfanet.measurement.internal.kingdom.BatchGetDataProvidersResponse
import org.wfanet.measurement.internal.kingdom.DataProvider
import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.GetDataProviderRequest
import org.wfanet.measurement.internal.kingdom.ReplaceDataAvailabilityIntervalRequest
import org.wfanet.measurement.internal.kingdom.ReplaceDataProviderRequiredDuchiesRequest
import org.wfanet.measurement.internal.kingdom.batchGetDataProvidersResponse
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DataProviderNotFoundException
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.KingdomInternalException
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.DataProviderReader
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.CreateDataProvider
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.ReplaceDataAvailabilityInterval
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.ReplaceDataProviderRequiredDuchies

// TODO(@marcopremier): Add method to update data provider required duchies list.
class SpannerDataProvidersService(
private val idGenerator: IdGenerator,
private val client: AsyncDatabaseClient,
Expand All @@ -54,6 +57,26 @@ class SpannerDataProvidersService(
?.dataProvider ?: failGrpc(Status.NOT_FOUND) { "DataProvider not found" }
}

override suspend fun batchGetDataProviders(
request: BatchGetDataProvidersRequest
): BatchGetDataProvidersResponse {
val dataProviders =
try {
DataProviderReader()
.readByExternalDataProviderIds(
client.singleUse(),
request.externalDataProviderIdsList.map(::ExternalId),
)
.map(DataProviderReader.Result::dataProvider)
} catch (e: DataProviderNotFoundException) {
throw e.asStatusRuntimeException(Status.Code.NOT_FOUND)
} catch (e: KingdomInternalException) {
throw e.asStatusRuntimeException(Status.Code.INTERNAL)
}

return batchGetDataProvidersResponse { this.dataProviders += dataProviders }
}

override suspend fun replaceDataProviderRequiredDuchies(
request: ReplaceDataProviderRequiredDuchiesRequest
): DataProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers

import com.google.cloud.spanner.Key
import com.google.cloud.spanner.Struct
import kotlinx.coroutines.flow.toList
import org.wfanet.measurement.common.identity.ExternalId
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.common.singleOrNullIfEmpty
Expand All @@ -25,6 +26,7 @@ import org.wfanet.measurement.gcloud.spanner.getInternalId
import org.wfanet.measurement.gcloud.spanner.getProtoMessage
import org.wfanet.measurement.internal.kingdom.DataProvider
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DataProviderNotFoundException

class DataProviderReader : SpannerReader<DataProviderReader.Result>() {
data class Result(val dataProvider: DataProvider, val dataProviderId: Long)
Expand All @@ -46,8 +48,9 @@ class DataProviderReader : SpannerReader<DataProviderReader.Result>() {
SELECT AS STRUCT
DataProviderRequiredDuchies.DuchyId
FROM
DataProviders
JOIN DataProviderRequiredDuchies USING (DataProviderId)
DataProviderRequiredDuchies
WHERE
DataProviderRequiredDuchies.DataProviderId = DataProviders.DataProviderId
) AS DataProviderRequiredDuchies,
FROM DataProviders
JOIN DataProviderCertificates ON (
Expand All @@ -73,6 +76,31 @@ class DataProviderReader : SpannerReader<DataProviderReader.Result>() {
.singleOrNullIfEmpty()
}

/**
* Reads the [DataProvider]s by [externalDataProviderIds].
*
* @return list of [Result] in the same iteration order as [externalDataProviderIds]
* @throws DataProviderNotFoundException if no [DataProvider] is found for a specified external ID
*/
suspend fun readByExternalDataProviderIds(
readContext: AsyncDatabaseClient.ReadContext,
externalDataProviderIds: Iterable<ExternalId>,
): List<Result> {
val resultsByExternalId: Map<ExternalId, Result> =
fillStatementBuilder {
appendClause("WHERE ExternalDataProviderId IN UNNEST(@externalDataProviderIds)")
bind("externalDataProviderIds")
.toInt64Array(externalDataProviderIds.map(ExternalId::value))
}
.execute(readContext)
.toList()
.associateBy { ExternalId(it.dataProvider.externalDataProviderId) }

return externalDataProviderIds.map {
resultsByExternalId[it] ?: throw DataProviderNotFoundException(it)
}
}

private fun buildDataProvider(struct: Struct): DataProvider =
DataProvider.newBuilder()
.apply {
Expand Down
Loading

0 comments on commit 40dab02

Please sign in to comment.