Skip to content

Commit

Permalink
feat: Implement internal Access API services (#1932)
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjayVas authored Dec 10, 2024
1 parent a121af7 commit 2dc095b
Show file tree
Hide file tree
Showing 26 changed files with 3,842 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/main/kotlin/org/wfanet/measurement/access/common/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library")

package(default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/access:__subpackages__",
"//src/test/kotlin/org/wfanet/measurement/access:__subpackages__",
])

kt_jvm_library(
name = "tls_client_principal_mapping",
srcs = ["TlsClientPrincipalMapping.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/common/api:resource_ids",
"//src/main/proto/wfa/measurement/config:authority_key_to_principal_map_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/protobuf",
"@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.access.common

import com.google.protobuf.ByteString
import org.wfanet.measurement.common.api.ResourceIds
import org.wfanet.measurement.config.AuthorityKeyToPrincipalMap

/** Mapping for TLS client principals. */
class TlsClientPrincipalMapping(config: AuthorityKeyToPrincipalMap) {
data class TlsClient(
/** ID of the Principal resource. */
val principalResourceId: String,
/** Name of the resource protected by the Policy. */
val protectedResourceName: String,
/** Authority key identifier (AKID) key ID of the certificate. */
val authorityKeyIdentifier: ByteString,
)

private val clientsByPrincipalResourceId: Map<String, TlsClient>
private val clientsByAuthorityKeyIdentifier: Map<ByteString, TlsClient>

init {
val clients =
config.entriesList.map {
val protectedResourceName = it.principalResourceName
val principalResourceId = protectedResourceName.replace("/", "-").takeLast(63)
check(ResourceIds.RFC_1034_REGEX.matches(principalResourceId)) {
"Invalid character in protected resource name $protectedResourceName"
}
TlsClient(principalResourceId, protectedResourceName, it.authorityKeyIdentifier)
}

clientsByPrincipalResourceId = clients.associateBy(TlsClient::principalResourceId)
clientsByAuthorityKeyIdentifier = clients.associateBy(TlsClient::authorityKeyIdentifier)
}

/** Returns the [TlsClient] for the specified [principalResourceId], or `null` if not found. */
fun getByPrincipalResourceId(principalResourceId: String): TlsClient? =
clientsByPrincipalResourceId[principalResourceId]

/** Returns the [TlsClient] for the specified [authorityKeyIdentifier], or `null` if not found. */
fun getByAuthorityKeyIdentifier(authorityKeyIdentifier: ByteString): TlsClient? =
clientsByAuthorityKeyIdentifier[authorityKeyIdentifier]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library")

package(
default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/access/deploy:__subpackages__",
"//src/test/kotlin/org/wfanet/measurement/access/deploy:__subpackages__",
],
)

kt_jvm_library(
name = "spanner_principals_service",
srcs = ["SpannerPrincipalsService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping",
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator",
"//src/main/proto/wfa/measurement/internal/access:principals_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)

kt_jvm_library(
name = "spanner_permissions_service",
srcs = ["SpannerPermissionsService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping",
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/proto/wfa/measurement/internal/access:permissions_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)

kt_jvm_library(
name = "spanner_roles_service",
srcs = ["SpannerRolesService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:permission_mapping",
"//src/main/kotlin/org/wfanet/measurement/common/api:etags",
"//src/main/proto/wfa/measurement/internal/access:roles_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)

kt_jvm_library(
name = "spanner_policies_service",
srcs = ["SpannerPoliciesService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping",
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator",
"//src/main/kotlin/org/wfanet/measurement/common/api:etags",
"//src/main/proto/wfa/measurement/internal/access:policies_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.access.deploy.gcloud.spanner

import io.grpc.Status
import kotlin.math.abs
import org.wfanet.measurement.access.common.TlsClientPrincipalMapping
import org.wfanet.measurement.access.deploy.gcloud.spanner.db.checkPermissions
import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPrincipalIdByResourceId
import org.wfanet.measurement.access.service.internal.InvalidFieldValueException
import org.wfanet.measurement.access.service.internal.PermissionMapping
import org.wfanet.measurement.access.service.internal.PermissionNotFoundException
import org.wfanet.measurement.access.service.internal.PrincipalNotFoundException
import org.wfanet.measurement.access.service.internal.RequiredFieldNotSetException
import org.wfanet.measurement.access.service.internal.toPermission
import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.internal.access.CheckPermissionsRequest
import org.wfanet.measurement.internal.access.CheckPermissionsResponse
import org.wfanet.measurement.internal.access.GetPermissionRequest
import org.wfanet.measurement.internal.access.ListPermissionsPageTokenKt
import org.wfanet.measurement.internal.access.ListPermissionsRequest
import org.wfanet.measurement.internal.access.ListPermissionsResponse
import org.wfanet.measurement.internal.access.Permission
import org.wfanet.measurement.internal.access.PermissionsGrpcKt
import org.wfanet.measurement.internal.access.checkPermissionsResponse
import org.wfanet.measurement.internal.access.listPermissionsPageToken
import org.wfanet.measurement.internal.access.listPermissionsResponse

class SpannerPermissionsService(
private val databaseClient: AsyncDatabaseClient,
private val permissionMapping: PermissionMapping,
private val tlsClientMapping: TlsClientPrincipalMapping,
) : PermissionsGrpcKt.PermissionsCoroutineImplBase() {
override suspend fun getPermission(request: GetPermissionRequest): Permission {
if (request.permissionResourceId.isEmpty()) {
throw RequiredFieldNotSetException("permission_resource_id")
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
val mappingPermission: PermissionMapping.Permission =
permissionMapping.getPermissionByResourceId(request.permissionResourceId)
?: throw PermissionNotFoundException(request.permissionResourceId)
.asStatusRuntimeException(Status.Code.NOT_FOUND)

return mappingPermission.toPermission()
}

override suspend fun listPermissions(request: ListPermissionsRequest): ListPermissionsResponse {
if (request.pageSize < 0) {
throw InvalidFieldValueException("page_size") { fieldName ->
"$fieldName must be non-negative"
}
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
val pageSize: Int =
if (request.pageSize == 0) DEFAULT_PAGE_SIZE else request.pageSize.coerceAtMost(MAX_PAGE_SIZE)

val mappingPermissions: List<PermissionMapping.Permission> = permissionMapping.permissions
val fromIndex: Int =
if (request.hasPageToken()) {
val searchResult: Int =
mappingPermissions.binarySearchBy(request.pageToken.after.permissionResourceId) {
it.permissionResourceId
}
abs(searchResult + 1)
} else {
0
}
val toIndex: Int = (fromIndex + pageSize).coerceAtMost(mappingPermissions.size)
val permissions: List<Permission> =
mappingPermissions.subList(fromIndex, toIndex).map { it.toPermission() }

return listPermissionsResponse {
this.permissions += permissions
if (toIndex < mappingPermissions.size) {
nextPageToken = listPermissionsPageToken {
after =
ListPermissionsPageTokenKt.after {
permissionResourceId = permissions.last().permissionResourceId
}
}
}
}
}

override suspend fun checkPermissions(
request: CheckPermissionsRequest
): CheckPermissionsResponse {
if (request.principalResourceId.isEmpty()) {
throw RequiredFieldNotSetException("principal_resource_id")
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
if (request.permissionResourceIdsList.isEmpty()) {
throw RequiredFieldNotSetException("permission_resource_ids")
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
val permissionIds =
request.permissionResourceIdsList.map {
val mappingPermission =
permissionMapping.getPermissionByResourceId(it)
?: throw PermissionNotFoundException(it)
.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION)
mappingPermission.permissionId
}

val tlsClient: TlsClientPrincipalMapping.TlsClient? =
tlsClientMapping.getByPrincipalResourceId(request.principalResourceId)
if (tlsClient != null) {
val permissionResourceIds: List<String> =
if (request.protectedResourceName == tlsClient.protectedResourceName) {
request.permissionResourceIdsList
} else {
emptyList()
}
return checkPermissionsResponse { this.permissionResourceIds += permissionResourceIds }
}

return try {
val grantedPermissionIds: List<Long> =
databaseClient.readOnlyTransaction().use { txn ->
val principalId: Long = txn.getPrincipalIdByResourceId(request.principalResourceId)
txn.checkPermissions(request.protectedResourceName, principalId, permissionIds)
}
checkPermissionsResponse {
permissionResourceIds +=
grantedPermissionIds.map {
checkNotNull(permissionMapping.getPermissionById(it)).permissionResourceId
}
}
} catch (e: PrincipalNotFoundException) {
throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION)
}
}

companion object {
private const val MAX_PAGE_SIZE = 100
private const val DEFAULT_PAGE_SIZE = 50
}
}
Loading

0 comments on commit 2dc095b

Please sign in to comment.