Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement internal Access API services #1932

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/main/kotlin/org/wfanet/measurement/access/common/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library")

package(default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/access:__subpackages__",
"//src/test/kotlin/org/wfanet/measurement/access:__subpackages__",
])

kt_jvm_library(
name = "tls_client_principal_mapping",
srcs = ["TlsClientPrincipalMapping.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/common/api:resource_ids",
"//src/main/proto/wfa/measurement/config:authority_key_to_principal_map_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/protobuf",
"@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.access.common

import com.google.protobuf.ByteString
import org.wfanet.measurement.common.api.ResourceIds
import org.wfanet.measurement.config.AuthorityKeyToPrincipalMap

/** Mapping for TLS client principals. */
class TlsClientPrincipalMapping(config: AuthorityKeyToPrincipalMap) {
data class TlsClient(
/** ID of the Principal resource. */
val principalResourceId: String,
/** Name of the resource protected by the Policy. */
val protectedResourceName: String,
/** Authority key identifier (AKID) key ID of the certificate. */
val authorityKeyIdentifier: ByteString,
)

private val clientsByPrincipalResourceId: Map<String, TlsClient>
private val clientsByAuthorityKeyIdentifier: Map<ByteString, TlsClient>

init {
val clients =
config.entriesList.map {
val protectedResourceName = it.principalResourceName
val principalResourceId = protectedResourceName.replace("/", "-").takeLast(63)
check(ResourceIds.RFC_1034_REGEX.matches(principalResourceId)) {
"Invalid character in protected resource name $protectedResourceName"
}
TlsClient(principalResourceId, protectedResourceName, it.authorityKeyIdentifier)
}

clientsByPrincipalResourceId = clients.associateBy(TlsClient::principalResourceId)
clientsByAuthorityKeyIdentifier = clients.associateBy(TlsClient::authorityKeyIdentifier)
}

/** Returns the [TlsClient] for the specified [principalResourceId], or `null` if not found. */
fun getByPrincipalResourceId(principalResourceId: String): TlsClient? =
clientsByPrincipalResourceId[principalResourceId]

/** Returns the [TlsClient] for the specified [authorityKeyIdentifier], or `null` if not found. */
fun getByAuthorityKeyIdentifier(authorityKeyIdentifier: ByteString): TlsClient? =
clientsByAuthorityKeyIdentifier[authorityKeyIdentifier]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library")

package(
default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/access/deploy:__subpackages__",
"//src/test/kotlin/org/wfanet/measurement/access/deploy:__subpackages__",
],
)

kt_jvm_library(
name = "spanner_principals_service",
srcs = ["SpannerPrincipalsService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping",
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator",
"//src/main/proto/wfa/measurement/internal/access:principals_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)

kt_jvm_library(
name = "spanner_permissions_service",
srcs = ["SpannerPermissionsService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping",
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/proto/wfa/measurement/internal/access:permissions_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)

kt_jvm_library(
name = "spanner_roles_service",
srcs = ["SpannerRolesService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:permission_mapping",
"//src/main/kotlin/org/wfanet/measurement/common/api:etags",
"//src/main/proto/wfa/measurement/internal/access:roles_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)

kt_jvm_library(
name = "spanner_policies_service",
srcs = ["SpannerPoliciesService.kt"],
deps = [
"//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping",
"//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db",
"//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator",
"//src/main/kotlin/org/wfanet/measurement/common/api:etags",
"//src/main/proto/wfa/measurement/internal/access:policies_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.access.deploy.gcloud.spanner

import io.grpc.Status
import kotlin.math.abs
import org.wfanet.measurement.access.common.TlsClientPrincipalMapping
import org.wfanet.measurement.access.deploy.gcloud.spanner.db.checkPermissions
import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPrincipalIdByResourceId
import org.wfanet.measurement.access.service.internal.InvalidFieldValueException
import org.wfanet.measurement.access.service.internal.PermissionMapping
import org.wfanet.measurement.access.service.internal.PermissionNotFoundException
import org.wfanet.measurement.access.service.internal.PrincipalNotFoundException
import org.wfanet.measurement.access.service.internal.RequiredFieldNotSetException
import org.wfanet.measurement.access.service.internal.toPermission
import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient
import org.wfanet.measurement.internal.access.CheckPermissionsRequest
import org.wfanet.measurement.internal.access.CheckPermissionsResponse
import org.wfanet.measurement.internal.access.GetPermissionRequest
import org.wfanet.measurement.internal.access.ListPermissionsPageTokenKt
import org.wfanet.measurement.internal.access.ListPermissionsRequest
import org.wfanet.measurement.internal.access.ListPermissionsResponse
import org.wfanet.measurement.internal.access.Permission
import org.wfanet.measurement.internal.access.PermissionsGrpcKt
import org.wfanet.measurement.internal.access.checkPermissionsResponse
import org.wfanet.measurement.internal.access.listPermissionsPageToken
import org.wfanet.measurement.internal.access.listPermissionsResponse

class SpannerPermissionsService(
private val databaseClient: AsyncDatabaseClient,
private val permissionMapping: PermissionMapping,
private val tlsClientMapping: TlsClientPrincipalMapping,
) : PermissionsGrpcKt.PermissionsCoroutineImplBase() {
override suspend fun getPermission(request: GetPermissionRequest): Permission {
if (request.permissionResourceId.isEmpty()) {
throw RequiredFieldNotSetException("permission_resource_id")
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
val mappingPermission: PermissionMapping.Permission =
permissionMapping.getPermissionByResourceId(request.permissionResourceId)
?: throw PermissionNotFoundException(request.permissionResourceId)
.asStatusRuntimeException(Status.Code.NOT_FOUND)

return mappingPermission.toPermission()
}

override suspend fun listPermissions(request: ListPermissionsRequest): ListPermissionsResponse {
if (request.pageSize < 0) {
throw InvalidFieldValueException("page_size") { fieldName ->
"$fieldName must be non-negative"
}
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
val pageSize: Int =
if (request.pageSize == 0) DEFAULT_PAGE_SIZE else request.pageSize.coerceAtMost(MAX_PAGE_SIZE)

val mappingPermissions: List<PermissionMapping.Permission> = permissionMapping.permissions
val fromIndex: Int =
if (request.hasPageToken()) {
val searchResult: Int =
mappingPermissions.binarySearchBy(request.pageToken.after.permissionResourceId) {
it.permissionResourceId
}
abs(searchResult + 1)
} else {
0
}
val toIndex: Int = (fromIndex + pageSize).coerceAtMost(mappingPermissions.size)
val permissions: List<Permission> =
mappingPermissions.subList(fromIndex, toIndex).map { it.toPermission() }

return listPermissionsResponse {
this.permissions += permissions
if (toIndex < mappingPermissions.size) {
nextPageToken = listPermissionsPageToken {
after =
ListPermissionsPageTokenKt.after {
permissionResourceId = permissions.last().permissionResourceId
}
}
}
}
}

override suspend fun checkPermissions(
request: CheckPermissionsRequest
): CheckPermissionsResponse {
if (request.principalResourceId.isEmpty()) {
throw RequiredFieldNotSetException("principal_resource_id")
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
if (request.permissionResourceIdsList.isEmpty()) {
throw RequiredFieldNotSetException("permission_resource_ids")
.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
val permissionIds =
request.permissionResourceIdsList.map {
val mappingPermission =
permissionMapping.getPermissionByResourceId(it)
?: throw PermissionNotFoundException(it)
.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION)
mappingPermission.permissionId
}

val tlsClient: TlsClientPrincipalMapping.TlsClient? =
tlsClientMapping.getByPrincipalResourceId(request.principalResourceId)
if (tlsClient != null) {
val permissionResourceIds: List<String> =
if (request.protectedResourceName == tlsClient.protectedResourceName) {
request.permissionResourceIdsList
} else {
emptyList()
}
return checkPermissionsResponse { this.permissionResourceIds += permissionResourceIds }
}

return try {
val grantedPermissionIds: List<Long> =
databaseClient.readOnlyTransaction().use { txn ->
val principalId: Long = txn.getPrincipalIdByResourceId(request.principalResourceId)
txn.checkPermissions(request.protectedResourceName, principalId, permissionIds)
}
checkPermissionsResponse {
permissionResourceIds +=
grantedPermissionIds.map {
checkNotNull(permissionMapping.getPermissionById(it)).permissionResourceId
}
}
} catch (e: PrincipalNotFoundException) {
throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION)
}
}

companion object {
private const val MAX_PAGE_SIZE = 100
private const val DEFAULT_PAGE_SIZE = 50
}
}
Loading