diff --git a/src/main/kotlin/org/wfanet/measurement/access/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/access/common/BUILD.bazel new file mode 100644 index 00000000000..e760550192e --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/common/BUILD.bazel @@ -0,0 +1,17 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package(default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/access:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/access:__subpackages__", +]) + +kt_jvm_library( + name = "tls_client_principal_mapping", + srcs = ["TlsClientPrincipalMapping.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/common/api:resource_ids", + "//src/main/proto/wfa/measurement/config:authority_key_to_principal_map_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/access/common/TlsClientPrincipalMapping.kt b/src/main/kotlin/org/wfanet/measurement/access/common/TlsClientPrincipalMapping.kt new file mode 100644 index 00000000000..8b04ea3a753 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/common/TlsClientPrincipalMapping.kt @@ -0,0 +1,59 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.common + +import com.google.protobuf.ByteString +import org.wfanet.measurement.common.api.ResourceIds +import org.wfanet.measurement.config.AuthorityKeyToPrincipalMap + +/** Mapping for TLS client principals. */ +class TlsClientPrincipalMapping(config: AuthorityKeyToPrincipalMap) { + data class TlsClient( + /** ID of the Principal resource. */ + val principalResourceId: String, + /** Name of the resource protected by the Policy. */ + val protectedResourceName: String, + /** Authority key identifier (AKID) key ID of the certificate. */ + val authorityKeyIdentifier: ByteString, + ) + + private val clientsByPrincipalResourceId: Map + private val clientsByAuthorityKeyIdentifier: Map + + init { + val clients = + config.entriesList.map { + val protectedResourceName = it.principalResourceName + val principalResourceId = protectedResourceName.replace("/", "-").takeLast(63) + check(ResourceIds.RFC_1034_REGEX.matches(principalResourceId)) { + "Invalid character in protected resource name $protectedResourceName" + } + TlsClient(principalResourceId, protectedResourceName, it.authorityKeyIdentifier) + } + + clientsByPrincipalResourceId = clients.associateBy(TlsClient::principalResourceId) + clientsByAuthorityKeyIdentifier = clients.associateBy(TlsClient::authorityKeyIdentifier) + } + + /** Returns the [TlsClient] for the specified [principalResourceId], or `null` if not found. */ + fun getByPrincipalResourceId(principalResourceId: String): TlsClient? = + clientsByPrincipalResourceId[principalResourceId] + + /** Returns the [TlsClient] for the specified [authorityKeyIdentifier], or `null` if not found. */ + fun getByAuthorityKeyIdentifier(authorityKeyIdentifier: ByteString): TlsClient? = + clientsByAuthorityKeyIdentifier[authorityKeyIdentifier] +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/BUILD.bazel new file mode 100644 index 00000000000..8edc9bd4a50 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/BUILD.bazel @@ -0,0 +1,58 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package( + default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/access/deploy:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/access/deploy:__subpackages__", + ], +) + +kt_jvm_library( + name = "spanner_principals_service", + srcs = ["SpannerPrincipalsService.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator", + "//src/main/proto/wfa/measurement/internal/access:principals_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner", + ], +) + +kt_jvm_library( + name = "spanner_permissions_service", + srcs = ["SpannerPermissionsService.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db", + "//src/main/proto/wfa/measurement/internal/access:permissions_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner", + ], +) + +kt_jvm_library( + name = "spanner_roles_service", + srcs = ["SpannerRolesService.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:permission_mapping", + "//src/main/kotlin/org/wfanet/measurement/common/api:etags", + "//src/main/proto/wfa/measurement/internal/access:roles_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner", + ], +) + +kt_jvm_library( + name = "spanner_policies_service", + srcs = ["SpannerPoliciesService.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator", + "//src/main/kotlin/org/wfanet/measurement/common/api:etags", + "//src/main/proto/wfa/measurement/internal/access:policies_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPermissionsService.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPermissionsService.kt new file mode 100644 index 00000000000..52ce73665c1 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPermissionsService.kt @@ -0,0 +1,152 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import io.grpc.Status +import kotlin.math.abs +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.checkPermissions +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPrincipalIdByResourceId +import org.wfanet.measurement.access.service.internal.InvalidFieldValueException +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.access.service.internal.PermissionNotFoundException +import org.wfanet.measurement.access.service.internal.PrincipalNotFoundException +import org.wfanet.measurement.access.service.internal.RequiredFieldNotSetException +import org.wfanet.measurement.access.service.internal.toPermission +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.internal.access.CheckPermissionsRequest +import org.wfanet.measurement.internal.access.CheckPermissionsResponse +import org.wfanet.measurement.internal.access.GetPermissionRequest +import org.wfanet.measurement.internal.access.ListPermissionsPageTokenKt +import org.wfanet.measurement.internal.access.ListPermissionsRequest +import org.wfanet.measurement.internal.access.ListPermissionsResponse +import org.wfanet.measurement.internal.access.Permission +import org.wfanet.measurement.internal.access.PermissionsGrpcKt +import org.wfanet.measurement.internal.access.checkPermissionsResponse +import org.wfanet.measurement.internal.access.listPermissionsPageToken +import org.wfanet.measurement.internal.access.listPermissionsResponse + +class SpannerPermissionsService( + private val databaseClient: AsyncDatabaseClient, + private val permissionMapping: PermissionMapping, + private val tlsClientMapping: TlsClientPrincipalMapping, +) : PermissionsGrpcKt.PermissionsCoroutineImplBase() { + override suspend fun getPermission(request: GetPermissionRequest): Permission { + if (request.permissionResourceId.isEmpty()) { + throw RequiredFieldNotSetException("permission_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + val mappingPermission: PermissionMapping.Permission = + permissionMapping.getPermissionByResourceId(request.permissionResourceId) + ?: throw PermissionNotFoundException(request.permissionResourceId) + .asStatusRuntimeException(Status.Code.NOT_FOUND) + + return mappingPermission.toPermission() + } + + override suspend fun listPermissions(request: ListPermissionsRequest): ListPermissionsResponse { + if (request.pageSize < 0) { + throw InvalidFieldValueException("page_size") { fieldName -> + "$fieldName must be non-negative" + } + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + val pageSize: Int = + if (request.pageSize == 0) DEFAULT_PAGE_SIZE else request.pageSize.coerceAtMost(MAX_PAGE_SIZE) + + val mappingPermissions: List = permissionMapping.permissions + val fromIndex: Int = + if (request.hasPageToken()) { + val searchResult: Int = + mappingPermissions.binarySearchBy(request.pageToken.after.permissionResourceId) { + it.permissionResourceId + } + abs(searchResult + 1) + } else { + 0 + } + val toIndex: Int = (fromIndex + pageSize).coerceAtMost(mappingPermissions.size) + val permissions: List = + mappingPermissions.subList(fromIndex, toIndex).map { it.toPermission() } + + return listPermissionsResponse { + this.permissions += permissions + if (toIndex < mappingPermissions.size) { + nextPageToken = listPermissionsPageToken { + after = + ListPermissionsPageTokenKt.after { + permissionResourceId = permissions.last().permissionResourceId + } + } + } + } + } + + override suspend fun checkPermissions( + request: CheckPermissionsRequest + ): CheckPermissionsResponse { + if (request.principalResourceId.isEmpty()) { + throw RequiredFieldNotSetException("principal_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + if (request.permissionResourceIdsList.isEmpty()) { + throw RequiredFieldNotSetException("permission_resource_ids") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + val permissionIds = + request.permissionResourceIdsList.map { + val mappingPermission = + permissionMapping.getPermissionByResourceId(it) + ?: throw PermissionNotFoundException(it) + .asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + mappingPermission.permissionId + } + + val tlsClient: TlsClientPrincipalMapping.TlsClient? = + tlsClientMapping.getByPrincipalResourceId(request.principalResourceId) + if (tlsClient != null) { + val permissionResourceIds: List = + if (request.protectedResourceName == tlsClient.protectedResourceName) { + request.permissionResourceIdsList + } else { + emptyList() + } + return checkPermissionsResponse { this.permissionResourceIds += permissionResourceIds } + } + + return try { + val grantedPermissionIds: List = + databaseClient.readOnlyTransaction().use { txn -> + val principalId: Long = txn.getPrincipalIdByResourceId(request.principalResourceId) + txn.checkPermissions(request.protectedResourceName, principalId, permissionIds) + } + checkPermissionsResponse { + permissionResourceIds += + grantedPermissionIds.map { + checkNotNull(permissionMapping.getPermissionById(it)).permissionResourceId + } + } + } catch (e: PrincipalNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } + } + + companion object { + private const val MAX_PAGE_SIZE = 100 + private const val DEFAULT_PAGE_SIZE = 50 + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPoliciesService.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPoliciesService.kt new file mode 100644 index 00000000000..5001c9c998e --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPoliciesService.kt @@ -0,0 +1,329 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import com.google.cloud.spanner.ErrorCode +import com.google.cloud.spanner.SpannerException +import com.google.protobuf.Timestamp +import io.grpc.Status +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.deletePolicyBinding +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPolicyByProtectedResourceName +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPolicyByResourceId +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPrincipalIdsByResourceIds +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getRoleIdByResourceId +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getRoleIdsByResourceIds +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.insertPolicy +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.insertPolicyBinding +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.policyExists +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.updatePolicy +import org.wfanet.measurement.access.service.internal.EtagMismatchException +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PolicyAlreadyExistsException +import org.wfanet.measurement.access.service.internal.PolicyBindingMembershipAlreadyExistsException +import org.wfanet.measurement.access.service.internal.PolicyBindingMembershipNotFoundException +import org.wfanet.measurement.access.service.internal.PolicyNotFoundException +import org.wfanet.measurement.access.service.internal.PolicyNotFoundForProtectedResourceException +import org.wfanet.measurement.access.service.internal.PrincipalNotFoundException +import org.wfanet.measurement.access.service.internal.PrincipalTypeNotSupportedException +import org.wfanet.measurement.access.service.internal.RequiredFieldNotSetException +import org.wfanet.measurement.access.service.internal.RoleNotFoundException +import org.wfanet.measurement.access.service.internal.generateNewId +import org.wfanet.measurement.common.api.ETags +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.internal.access.AddPolicyBindingMembersRequest +import org.wfanet.measurement.internal.access.GetPolicyRequest +import org.wfanet.measurement.internal.access.LookupPolicyRequest +import org.wfanet.measurement.internal.access.PoliciesGrpcKt +import org.wfanet.measurement.internal.access.Policy +import org.wfanet.measurement.internal.access.PolicyKt +import org.wfanet.measurement.internal.access.Principal +import org.wfanet.measurement.internal.access.RemovePolicyBindingMembersRequest +import org.wfanet.measurement.internal.access.copy +import org.wfanet.measurement.internal.access.policy + +class SpannerPoliciesService( + private val databaseClient: AsyncDatabaseClient, + private val tlsClientMapping: TlsClientPrincipalMapping, + private val idGenerator: IdGenerator = IdGenerator.Default, +) : PoliciesGrpcKt.PoliciesCoroutineImplBase() { + override suspend fun getPolicy(request: GetPolicyRequest): Policy { + if (request.policyResourceId.isEmpty()) { + throw RequiredFieldNotSetException("policy_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + + return try { + databaseClient.singleUse().use { txn -> + txn.getPolicyByResourceId(request.policyResourceId).policy + } + } catch (e: PolicyNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + } + + override suspend fun lookupPolicy(request: LookupPolicyRequest): Policy { + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf oneof case enums cannot be null. + return when (request.lookupKeyCase) { + LookupPolicyRequest.LookupKeyCase.PROTECTED_RESOURCE_NAME -> { + try { + databaseClient.singleUse().use { txn -> + txn.getPolicyByProtectedResourceName(request.protectedResourceName).policy + } + } catch (e: PolicyNotFoundForProtectedResourceException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + } + LookupPolicyRequest.LookupKeyCase.LOOKUPKEY_NOT_SET -> + throw RequiredFieldNotSetException("lookup_key") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + } + + override suspend fun createPolicy(request: Policy): Policy { + if (request.policyResourceId.isEmpty()) { + throw RequiredFieldNotSetException("policy_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + val bindings: Map> = + request.bindingsMap.mapValues { (_, value) -> + value.memberPrincipalResourceIdsList.toSet().also { + if (it.isEmpty()) + throw RequiredFieldNotSetException("member_principal_resource_ids") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + } + val principalResourceIds = buildSet { bindings.values.forEach { addAll(it) } } + try { + checkPrincipalTypes(principalResourceIds) + } catch (e: PrincipalTypeNotSupportedException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } + + val transactionRunner = databaseClient.readWriteTransaction() + return try { + transactionRunner.run { txn -> + val policyId = idGenerator.generateNewId { id -> txn.policyExists(id) } + txn.insertPolicy(policyId, request.policyResourceId, request.protectedResourceName) + val roleIdByResourceId: Map = txn.getRoleIdsByResourceIds(bindings.keys) + val principalIdByResourceId: Map = + txn.getPrincipalIdsByResourceIds(principalResourceIds) + for ((roleResourceId, memberPrincipalResourceIds) in bindings) { + val roleId = roleIdByResourceId.getValue(roleResourceId) + for (principalResourceId in memberPrincipalResourceIds) { + val principalId = principalIdByResourceId.getValue(principalResourceId) + txn.insertPolicyBinding(policyId, roleId, principalId) + } + } + } + val commitTimestamp: Timestamp = transactionRunner.getCommitTimestamp().toProto() + + policy { + policyResourceId = request.policyResourceId + protectedResourceName = request.protectedResourceName + createTime = commitTimestamp + updateTime = commitTimestamp + etag = ETags.computeETag(updateTime.toInstant()) + + for ((roleResourceId, memberPrincipalResourceIds) in bindings) { + this.bindings[roleResourceId] = + PolicyKt.members { this.memberPrincipalResourceIds += memberPrincipalResourceIds } + } + } + } catch (e: SpannerException) { + if (e.errorCode == ErrorCode.ALREADY_EXISTS) { + throw PolicyAlreadyExistsException(e).asStatusRuntimeException(Status.Code.ALREADY_EXISTS) + } else { + throw e + } + } catch (e: RoleNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } catch (e: PrincipalNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } + } + + override suspend fun addPolicyBindingMembers(request: AddPolicyBindingMembersRequest): Policy { + if (request.policyResourceId.isEmpty()) { + throw RequiredFieldNotSetException("policy_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + if (request.roleResourceId.isEmpty()) { + throw RequiredFieldNotSetException("role_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + if (request.memberPrincipalResourceIdsList.isEmpty()) { + throw RequiredFieldNotSetException("member_principal_resource_ids") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + val memberPrincipalResourceIds: Set = request.memberPrincipalResourceIdsList.toSet() + try { + checkPrincipalTypes(memberPrincipalResourceIds) + } catch (e: PrincipalTypeNotSupportedException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } + + val transactionRunner = databaseClient.readWriteTransaction() + return try { + val policy = + transactionRunner.run { txn -> + val (policyId, policy) = txn.getPolicyByResourceId(request.policyResourceId) + if (request.etag.isNotEmpty()) { + EtagMismatchException.check(request.etag, policy.etag) + } + val existingPrincipalResourceIds = + policy.bindingsMap + .getOrDefault(request.roleResourceId, Policy.Members.getDefaultInstance()) + .memberPrincipalResourceIdsList + .toSet() + for (principalResourceId in memberPrincipalResourceIds) { + if (principalResourceId in existingPrincipalResourceIds) { + throw PolicyBindingMembershipAlreadyExistsException( + request.policyResourceId, + request.roleResourceId, + principalResourceId, + ) + } + } + + val roleId = txn.getRoleIdByResourceId(request.roleResourceId) + val principalIdByResourceId = txn.getPrincipalIdsByResourceIds(memberPrincipalResourceIds) + for (principalId in principalIdByResourceId.values) { + txn.insertPolicyBinding(policyId, roleId, principalId) + } + txn.updatePolicy(policyId) + policy + } + val commitTimestamp: Timestamp = transactionRunner.getCommitTimestamp().toProto() + + policy.copy { + updateTime = commitTimestamp + etag = ETags.computeETag(updateTime.toInstant()) + bindings[request.roleResourceId] = + bindings.getOrDefault(request.roleResourceId, Policy.Members.getDefaultInstance()).copy { + this.memberPrincipalResourceIds += memberPrincipalResourceIds + } + } + } catch (e: PolicyNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } catch (e: RoleNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } catch (e: PrincipalNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } catch (e: PolicyBindingMembershipAlreadyExistsException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } catch (e: EtagMismatchException) { + throw e.asStatusRuntimeException(Status.Code.ABORTED) + } + } + + override suspend fun removePolicyBindingMembers( + request: RemovePolicyBindingMembersRequest + ): Policy { + if (request.policyResourceId.isEmpty()) { + throw RequiredFieldNotSetException("policy_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + if (request.roleResourceId.isEmpty()) { + throw RequiredFieldNotSetException("role_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + if (request.memberPrincipalResourceIdsList.isEmpty()) { + throw RequiredFieldNotSetException("member_principal_resource_ids") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + val memberPrincipalResourceIds: Set = request.memberPrincipalResourceIdsList.toSet() + + val transactionRunner = databaseClient.readWriteTransaction() + return try { + val policy: Policy = + transactionRunner.run { txn -> + val (policyId, policy) = txn.getPolicyByResourceId(request.policyResourceId) + if (request.etag.isNotEmpty()) { + EtagMismatchException.check(request.etag, policy.etag) + } + val existingPrincipalResourceIds = + policy.bindingsMap + .getOrDefault(request.roleResourceId, Policy.Members.getDefaultInstance()) + .memberPrincipalResourceIdsList + .toSet() + for (principalResourceId in memberPrincipalResourceIds) { + if (principalResourceId !in existingPrincipalResourceIds) { + throw PolicyBindingMembershipNotFoundException( + request.policyResourceId, + request.roleResourceId, + principalResourceId, + ) + } + } + + val roleId = txn.getRoleIdByResourceId(request.roleResourceId) + val principalIdByResourceId = txn.getPrincipalIdsByResourceIds(memberPrincipalResourceIds) + for (principalId in principalIdByResourceId.values) { + txn.deletePolicyBinding(policyId, roleId, principalId) + } + txn.updatePolicy(policyId) + policy + } + val commitTimestamp: Timestamp = transactionRunner.getCommitTimestamp().toProto() + + val updatedMemberPrincipalResourceIds: List = + policy.bindingsMap.getValue(request.roleResourceId).memberPrincipalResourceIdsList.filter { + it !in memberPrincipalResourceIds + } + policy.copy { + updateTime = commitTimestamp + etag = ETags.computeETag(updateTime.toInstant()) + if (updatedMemberPrincipalResourceIds.isEmpty()) { + bindings.remove(request.roleResourceId) + } else { + bindings[request.roleResourceId] = + PolicyKt.members { + this.memberPrincipalResourceIds += updatedMemberPrincipalResourceIds + } + } + } + } catch (e: PolicyNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } catch (e: RoleNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.INTERNAL) + } catch (e: PrincipalNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.INTERNAL) + } catch (e: PolicyBindingMembershipNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } catch (e: EtagMismatchException) { + throw e.asStatusRuntimeException(Status.Code.ABORTED) + } + } + + /** + * Checks that each [Principal] represented by [principalResourceIds] has a supported type. + * + * @throws PrincipalTypeNotSupportedException + */ + private fun checkPrincipalTypes(principalResourceIds: Iterable) { + for (principalResourceId in principalResourceIds) { + if (tlsClientMapping.getByPrincipalResourceId(principalResourceId) != null) { + throw PrincipalTypeNotSupportedException( + principalResourceId, + Principal.IdentityCase.TLS_CLIENT, + ) + } + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPrincipalsService.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPrincipalsService.kt new file mode 100644 index 00000000000..6e57a8e85a9 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPrincipalsService.kt @@ -0,0 +1,166 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import com.google.cloud.spanner.ErrorCode +import com.google.cloud.spanner.SpannerException +import com.google.protobuf.ByteString +import com.google.protobuf.Empty +import com.google.protobuf.Timestamp +import io.grpc.Status +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.deletePrincipal +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPrincipalByResourceId +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPrincipalByUserKey +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getPrincipalIdByResourceId +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.insertPrincipal +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.insertUserPrincipal +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.principalExists +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PrincipalAlreadyExistsException +import org.wfanet.measurement.access.service.internal.PrincipalNotFoundException +import org.wfanet.measurement.access.service.internal.PrincipalNotFoundForTlsClientException +import org.wfanet.measurement.access.service.internal.PrincipalNotFoundForUserException +import org.wfanet.measurement.access.service.internal.PrincipalTypeNotSupportedException +import org.wfanet.measurement.access.service.internal.RequiredFieldNotSetException +import org.wfanet.measurement.access.service.internal.generateNewId +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.internal.access.CreateUserPrincipalRequest +import org.wfanet.measurement.internal.access.DeletePrincipalRequest +import org.wfanet.measurement.internal.access.GetPrincipalRequest +import org.wfanet.measurement.internal.access.LookupPrincipalRequest +import org.wfanet.measurement.internal.access.Principal +import org.wfanet.measurement.internal.access.PrincipalKt.tlsClient +import org.wfanet.measurement.internal.access.PrincipalsGrpcKt +import org.wfanet.measurement.internal.access.principal + +class SpannerPrincipalsService( + private val databaseClient: AsyncDatabaseClient, + private val tlsClientMapping: TlsClientPrincipalMapping, + private val idGenerator: IdGenerator = IdGenerator.Default, +) : PrincipalsGrpcKt.PrincipalsCoroutineImplBase() { + override suspend fun getPrincipal(request: GetPrincipalRequest): Principal { + val tlsClient = tlsClientMapping.getByPrincipalResourceId(request.principalResourceId) + if (tlsClient != null) { + return tlsClient.toPrincipal() + } + + try { + databaseClient.singleUse().use { txn -> + return txn.getPrincipalByResourceId(request.principalResourceId).principal + } + } catch (e: PrincipalNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + } + + override suspend fun createUserPrincipal(request: CreateUserPrincipalRequest): Principal { + val runner: AsyncDatabaseClient.TransactionRunner = databaseClient.readWriteTransaction() + try { + runner.run { txn -> + val principalId: Long = idGenerator.generateNewId { id -> txn.principalExists(id) } + txn.insertPrincipal(principalId, request.principalResourceId) + txn.insertUserPrincipal(principalId, request.user.issuer, request.user.subject) + } + val commitTimestamp: Timestamp = runner.getCommitTimestamp().toProto() + return principal { + principalResourceId = request.principalResourceId + user = request.user + createTime = commitTimestamp + updateTime = commitTimestamp + } + } catch (e: SpannerException) { + if (e.errorCode == ErrorCode.ALREADY_EXISTS) { + throw PrincipalAlreadyExistsException(e) + .asStatusRuntimeException(Status.Code.ALREADY_EXISTS) + } else { + throw e + } + } + } + + override suspend fun deletePrincipal(request: DeletePrincipalRequest): Empty { + if (request.principalResourceId.isEmpty()) { + throw RequiredFieldNotSetException("principal_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + if (tlsClientMapping.getByPrincipalResourceId(request.principalResourceId) != null) { + throw PrincipalTypeNotSupportedException( + request.principalResourceId, + Principal.IdentityCase.TLS_CLIENT, + ) + .asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } + + try { + databaseClient.readWriteTransaction().run { txn -> + val principalId: Long = txn.getPrincipalIdByResourceId(request.principalResourceId) + txn.deletePrincipal(principalId) + } + } catch (e: PrincipalNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + + return Empty.getDefaultInstance() + } + + override suspend fun lookupPrincipal(request: LookupPrincipalRequest): Principal { + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf case enums cannot be null. + return when (request.lookupKeyCase) { + LookupPrincipalRequest.LookupKeyCase.TLS_CLIENT -> { + val authorityKeyIdentifier: ByteString = request.tlsClient.authorityKeyIdentifier + try { + lookupTlsClientPrincipal(authorityKeyIdentifier) + } catch (e: PrincipalNotFoundForTlsClientException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + } + LookupPrincipalRequest.LookupKeyCase.USER -> { + try { + return databaseClient.singleUse().use { txn -> + txn.getPrincipalByUserKey(request.user.issuer, request.user.subject).principal + } + } catch (e: PrincipalNotFoundForUserException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + } + LookupPrincipalRequest.LookupKeyCase.LOOKUPKEY_NOT_SET -> + throw RequiredFieldNotSetException("lookup_key") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + } + + /** + * Looks up a TLS client [Principal] by [authorityKeyIdentifier]. + * + * @throws PrincipalNotFoundForTlsClientException + */ + private fun lookupTlsClientPrincipal(authorityKeyIdentifier: ByteString): Principal { + val tlsClient: TlsClientPrincipalMapping.TlsClient = + tlsClientMapping.getByAuthorityKeyIdentifier(authorityKeyIdentifier) + ?: throw PrincipalNotFoundForTlsClientException(authorityKeyIdentifier) + return tlsClient.toPrincipal() + } +} + +private fun TlsClientPrincipalMapping.TlsClient.toPrincipal(): Principal { + val source = this + return principal { + principalResourceId = source.principalResourceId + tlsClient = tlsClient { authorityKeyIdentifier = source.authorityKeyIdentifier } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerRolesService.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerRolesService.kt new file mode 100644 index 00000000000..7efcd248c87 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerRolesService.kt @@ -0,0 +1,320 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import com.google.cloud.spanner.ErrorCode +import com.google.cloud.spanner.SpannerException +import com.google.protobuf.Empty +import com.google.protobuf.Timestamp +import io.grpc.Status +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collectIndexed +import kotlinx.coroutines.flow.map +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.RoleResult +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.deleteRole +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.deleteRolePermission +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.deleteRoleResourceType +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getRoleByResourceId +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.getRoleIdByResourceId +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.insertRole +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.insertRolePermission +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.insertRoleResourceType +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.readRoles +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.roleExists +import org.wfanet.measurement.access.deploy.gcloud.spanner.db.updateRole +import org.wfanet.measurement.access.service.internal.EtagMismatchException +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.InvalidFieldValueException +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.access.service.internal.PermissionNotFoundException +import org.wfanet.measurement.access.service.internal.PermissionNotFoundForRoleException +import org.wfanet.measurement.access.service.internal.RequiredFieldNotSetException +import org.wfanet.measurement.access.service.internal.ResourceTypeNotFoundInPermissionException +import org.wfanet.measurement.access.service.internal.RoleAlreadyExistsException +import org.wfanet.measurement.access.service.internal.RoleNotFoundException +import org.wfanet.measurement.access.service.internal.generateNewId +import org.wfanet.measurement.common.api.ETags +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.internal.access.DeleteRoleRequest +import org.wfanet.measurement.internal.access.GetRoleRequest +import org.wfanet.measurement.internal.access.ListRolesPageTokenKt +import org.wfanet.measurement.internal.access.ListRolesRequest +import org.wfanet.measurement.internal.access.ListRolesResponse +import org.wfanet.measurement.internal.access.Role +import org.wfanet.measurement.internal.access.RolesGrpcKt +import org.wfanet.measurement.internal.access.copy +import org.wfanet.measurement.internal.access.listRolesPageToken +import org.wfanet.measurement.internal.access.listRolesResponse +import org.wfanet.measurement.internal.access.role + +class SpannerRolesService( + private val databaseClient: AsyncDatabaseClient, + private val permissionMapping: PermissionMapping, + private val idGenerator: IdGenerator = IdGenerator.Default, +) : RolesGrpcKt.RolesCoroutineImplBase() { + override suspend fun getRole(request: GetRoleRequest): Role { + if (request.roleResourceId.isEmpty()) { + throw RequiredFieldNotSetException("role_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + + val roleResult: RoleResult = + try { + databaseClient.singleUse().use { txn -> + txn.getRoleByResourceId(permissionMapping, request.roleResourceId) + } + } catch (e: PermissionNotFoundForRoleException) { + // This means that an expected Permission is missing from the mapping. + throw e.asStatusRuntimeException(Status.Code.INTERNAL) + } catch (e: RoleNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + + return roleResult.role + } + + override suspend fun listRoles(request: ListRolesRequest): ListRolesResponse { + if (request.pageSize < 0) { + throw InvalidFieldValueException("max_page_size") { fieldName -> + "$fieldName must be non-negative" + } + } + val pageSize = + if (request.pageSize == 0) { + DEFAULT_PAGE_SIZE + } else { + request.pageSize.coerceAtMost(MAX_PAGE_SIZE) + } + val after = if (request.hasPageToken()) request.pageToken.after else null + + return try { + databaseClient.singleUse().use { txn -> + val roles: Flow = + txn.readRoles(permissionMapping, pageSize + 1, after).map { it.role } + listRolesResponse { + roles.collectIndexed { index, role -> + if (index == pageSize) { + nextPageToken = listRolesPageToken { + this.after = + ListRolesPageTokenKt.after { + roleResourceId = this@listRolesResponse.roles.last().roleResourceId + } + } + } else { + this.roles += role + } + } + } + } + } catch (e: PermissionNotFoundForRoleException) { + throw e.asStatusRuntimeException(Status.Code.INTERNAL) + } + } + + override suspend fun createRole(request: Role): Role { + try { + validateRequestRole(request) + } catch (e: RequiredFieldNotSetException) { + throw e.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + + val permissionIds = + request.permissionResourceIdsList.map { permissionResourceId -> + try { + getPermissionByResourceId(permissionResourceId) + .also { it.checkResourceTypes(request.resourceTypesList) } + .permissionId + } catch (e: PermissionNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } catch (e: ResourceTypeNotFoundInPermissionException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } + } + + val transactionRunner = databaseClient.readWriteTransaction() + val commitTimestamp: Timestamp = + try { + transactionRunner.run { txn -> + val roleId = idGenerator.generateNewId { id -> txn.roleExists(id) } + txn.insertRole(roleId, request.roleResourceId) + for (permissionId in permissionIds) { + txn.insertRolePermission(roleId, permissionId) + } + for (resourceType in request.resourceTypesList) { + txn.insertRoleResourceType(roleId, resourceType) + } + } + transactionRunner.getCommitTimestamp().toProto() + } catch (e: SpannerException) { + if (e.errorCode == ErrorCode.ALREADY_EXISTS) { + throw RoleAlreadyExistsException(e).asStatusRuntimeException(Status.Code.ALREADY_EXISTS) + } else { + throw e + } + } + + return role { + roleResourceId = request.roleResourceId + createTime = commitTimestamp + updateTime = commitTimestamp + permissionResourceIds += request.permissionResourceIdsList + resourceTypes += request.resourceTypesList + etag = ETags.computeETag(updateTime.toInstant()) + } + } + + override suspend fun updateRole(request: Role): Role { + try { + validateRequestRole(request) + } catch (e: RequiredFieldNotSetException) { + throw e.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + if (request.etag.isEmpty()) { + throw RequiredFieldNotSetException("etag") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + + val transactionRunner: AsyncDatabaseClient.TransactionRunner = + databaseClient.readWriteTransaction() + transactionRunner.run { txn -> + val (roleId: Long, role: Role) = + try { + txn.getRoleByResourceId(permissionMapping, request.roleResourceId) + } catch (e: RoleNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } catch (e: PermissionNotFoundForRoleException) { + throw e.asStatusRuntimeException(Status.Code.INTERNAL) + } + try { + EtagMismatchException.check(request.etag, role.etag) + } catch (e: EtagMismatchException) { + throw e.asStatusRuntimeException(Status.Code.ABORTED) + } + + txn.updateRole(roleId) + + val requestResourceTypes = request.resourceTypesList.toSet() + val existingResourceTypes = role.resourceTypesList.toSet() + for (resourceType in requestResourceTypes.subtract(existingResourceTypes)) { + txn.insertRoleResourceType(roleId, resourceType) + } + for (resourceType in existingResourceTypes.subtract(requestResourceTypes)) { + txn.deleteRoleResourceType(roleId, resourceType) + } + + val requestPermissionResourceIds = request.permissionResourceIdsList.toSet() + val existingPermissionResourceIds = role.permissionResourceIdsList.toSet() + for (permissionResourceId in + requestPermissionResourceIds.union(existingPermissionResourceIds)) { + val permissionId = + try { + getPermissionByResourceId(permissionResourceId) + .also { it.checkResourceTypes(requestResourceTypes) } + .permissionId + } catch (e: PermissionNotFoundException) { + val statusCode = + if (permissionResourceId in existingPermissionResourceIds) { + Status.Code.INTERNAL // Existing permission should be found. + } else { + Status.Code.FAILED_PRECONDITION + } + throw e.asStatusRuntimeException(statusCode) + } catch (e: ResourceTypeNotFoundInPermissionException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } + + if (permissionResourceId !in requestPermissionResourceIds) { + txn.deleteRolePermission(roleId, permissionId) + } else if (permissionResourceId !in existingPermissionResourceIds) { + txn.insertRolePermission(roleId, permissionId) + } + } + } + val commitTimestamp: Timestamp = transactionRunner.getCommitTimestamp().toProto() + + return request.copy { + updateTime = commitTimestamp + etag = ETags.computeETag(updateTime.toInstant()) + } + } + + override suspend fun deleteRole(request: DeleteRoleRequest): Empty { + if (request.roleResourceId.isEmpty()) { + throw RequiredFieldNotSetException("role_resource_id") + .asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + + try { + databaseClient.readWriteTransaction().run { txn -> + val roleId = txn.getRoleIdByResourceId(request.roleResourceId) + txn.deleteRole(roleId) + } + } catch (e: RoleNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } + + return Empty.getDefaultInstance() + } + + /** + * Returns the [PermissionMapping.Permission] with the specified [permissionResourceId]. + * + * @throws PermissionNotFoundException + */ + private fun getPermissionByResourceId( + permissionResourceId: String + ): PermissionMapping.Permission { + return permissionMapping.getPermissionByResourceId(permissionResourceId) + ?: throw PermissionNotFoundException(permissionResourceId) + } + + /** + * Checks whether the specified request role is valid. + * + * @throws RequiredFieldNotSetException + */ + private fun validateRequestRole(role: Role) { + if (role.roleResourceId.isEmpty()) { + throw RequiredFieldNotSetException("role_resource_id") + } + if (role.permissionResourceIdsList.isEmpty()) { + throw RequiredFieldNotSetException("permission_resource_ids") + } + if (role.resourceTypesList.isEmpty()) { + throw RequiredFieldNotSetException("resource_types") + } + } + + /** + * Checks whether this permission has all the resource types in [resourceTypes]. + * + * @throws ResourceTypeNotFoundInPermissionException + */ + private fun PermissionMapping.Permission.checkResourceTypes(resourceTypes: Iterable) { + for (resourceType in resourceTypes) { + if (!protectedResourceTypes.contains(resourceType)) { + throw ResourceTypeNotFoundInPermissionException(resourceType, permissionResourceId) + } + } + } + + companion object { + private const val MAX_PAGE_SIZE = 100 + private const val DEFAULT_PAGE_SIZE = 50 + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/BUILD.bazel new file mode 100644 index 00000000000..97ade177451 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/BUILD.bazel @@ -0,0 +1,24 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package( + default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:__pkg__", + ], +) + +kt_jvm_library( + name = "db", + srcs = glob(["*.kt"]), + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:permission_mapping", + "//src/main/kotlin/org/wfanet/measurement/common:flows", + "//src/main/kotlin/org/wfanet/measurement/common/api:etags", + "//src/main/proto/wfa/measurement/internal/access:policy_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/access:principal_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/access:role_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/access:roles_service_kt_jvm_proto", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Permissions.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Permissions.kt new file mode 100644 index 00000000000..b5d85a10f73 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Permissions.kt @@ -0,0 +1,53 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner.db + +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.gcloud.spanner.statement + +/** Returns the subset of [permissionIds] that the principal has on the protected resource. */ +suspend fun AsyncDatabaseClient.ReadContext.checkPermissions( + protectedResourceName: String, + principalId: Long, + permissionIds: Iterable, +): List { + val sql = + """ + SELECT + PermissionId + FROM + PolicyBindings + JOIN Policies USING (PolicyId) + JOIN RolePermissions USING (RoleId) + WHERE + PrincipalId = @principalId + AND ProtectedResourceName = @protectedResourceName + AND PermissionId IN UNNEST(@permissionIds) + """ + .trimIndent() + return executeQuery( + statement(sql) { + bind("principalId").to(principalId) + bind("protectedResourceName").to(protectedResourceName) + bind("permissionIds").toInt64Array(permissionIds) + } + ) + .map { it.getLong("PermissionId") } + .toList() +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Policies.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Policies.kt new file mode 100644 index 00000000000..b21d36ee590 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Policies.kt @@ -0,0 +1,183 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner.db + +import com.google.cloud.spanner.Key +import com.google.cloud.spanner.Mutation +import com.google.cloud.spanner.Struct +import com.google.cloud.spanner.Value +import kotlinx.coroutines.flow.toList +import org.wfanet.measurement.access.service.internal.PolicyNotFoundException +import org.wfanet.measurement.access.service.internal.PolicyNotFoundForProtectedResourceException +import org.wfanet.measurement.common.api.ETags +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.gcloud.spanner.bufferInsertMutation +import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation +import org.wfanet.measurement.gcloud.spanner.getNullableString +import org.wfanet.measurement.gcloud.spanner.statement +import org.wfanet.measurement.internal.access.Policy +import org.wfanet.measurement.internal.access.PolicyKt.members +import org.wfanet.measurement.internal.access.policy + +data class PolicyResult(val policyId: Long, val policy: Policy) + +/** Returns whether the [Policy] with the specified [policyId] exists. */ +suspend fun AsyncDatabaseClient.ReadContext.policyExists(policyId: Long): Boolean { + return readRow("Policies", Key.of(policyId), listOf("PolicyId")) != null +} + +/** + * Reads a [Policy] by its [policyResourceId]. + * + * @throws PolicyNotFoundException + */ +suspend fun AsyncDatabaseClient.ReadContext.getPolicyByResourceId( + policyResourceId: String +): PolicyResult { + val sql = buildString { + appendLine(Policies.BASE_SQL) + appendLine("WHERE PolicyResourceId = @policyResourceId") + } + val rows: List = + executeQuery(statement(sql) { bind("policyResourceId").to(policyResourceId) }).toList() + + if (rows.isEmpty()) { + throw PolicyNotFoundException(policyResourceId) + } + + return Policies.buildPolicyResult(rows) +} + +/** + * Reads a [Policy] by its [protectedResourceName]. + * + * @throws PolicyNotFoundForProtectedResourceException + */ +suspend fun AsyncDatabaseClient.ReadContext.getPolicyByProtectedResourceName( + protectedResourceName: String +): PolicyResult { + val sql = buildString { + appendLine(Policies.BASE_SQL) + appendLine("WHERE ProtectedResourceName = @protectedResourceName") + } + val rows: List = + executeQuery(statement(sql) { bind("protectedResourceName").to(protectedResourceName) }) + .toList() + + if (rows.isEmpty()) { + throw PolicyNotFoundForProtectedResourceException(protectedResourceName) + } + + return Policies.buildPolicyResult(rows) +} + +/** Buffers an insert mutation for the Policies table. */ +fun AsyncDatabaseClient.TransactionContext.insertPolicy( + policyId: Long, + policyResourceId: String, + protectedResourceName: String, +) { + bufferInsertMutation("Policies") { + set("PolicyId").to(policyId) + set("PolicyResourceId").to(policyResourceId) + set("ProtectedResourceName").to(protectedResourceName) + set("CreateTime").to(Value.COMMIT_TIMESTAMP) + set("UpdateTime").to(Value.COMMIT_TIMESTAMP) + } +} + +/** Buffers an update mutation for the Policies table. */ +fun AsyncDatabaseClient.TransactionContext.updatePolicy(policyId: Long) { + bufferUpdateMutation("Policies") { + set("PolicyId").to(policyId) + set("UpdateTime").to(Value.COMMIT_TIMESTAMP) + } +} + +/** Buffers an insert mutation for the PolicyBindings table. */ +fun AsyncDatabaseClient.TransactionContext.insertPolicyBinding( + policyId: Long, + roleId: Long, + principalId: Long, +) { + bufferInsertMutation("PolicyBindings") { + set("PolicyId").to(policyId) + set("RoleId").to(roleId) + set("PrincipalId").to(principalId) + } +} + +/** Buffers a delete mutation for the PolicyBindings table. */ +fun AsyncDatabaseClient.TransactionContext.deletePolicyBinding( + policyId: Long, + roleId: Long, + principalId: Long, +) { + buffer(Mutation.delete("PolicyBindings", Key.of(policyId, roleId, principalId))) +} + +private object Policies { + val BASE_SQL = + """ + SELECT + Policies.*, + RoleResourceId, + PrincipalResourceId + FROM + Policies + LEFT JOIN ( + PolicyBindings + JOIN Roles USING (RoleId) + LEFT JOIN Principals USING (PrincipalId) + ) USING (PolicyId) + """ + .trimIndent() + + fun buildPolicyResult(rows: Iterable): PolicyResult { + val firstRow = rows.first() + val membersByRole: Map> = buildMembersByRole(rows) + return PolicyResult( + firstRow.getLong("PolicyId"), + policy { + policyResourceId = firstRow.getString("PolicyResourceId") + protectedResourceName = firstRow.getString("ProtectedResourceName") + createTime = firstRow.getTimestamp("CreateTime").toProto() + updateTime = firstRow.getTimestamp("UpdateTime").toProto() + etag = ETags.computeETag(updateTime.toInstant()) + + for ((roleResourceId, memberPrincipalResourceIds) in membersByRole) { + bindings[roleResourceId] = members { + this.memberPrincipalResourceIds += memberPrincipalResourceIds + } + } + }, + ) + } + + private fun buildMembersByRole(rows: Iterable): Map> { + return mutableMapOf>().apply { + for (row in rows) { + val roleResourceId: String = row.getNullableString("RoleResourceId") ?: continue + val principalResourceId: String = row.getNullableString("PrincipalResourceId") ?: continue + val memberPrincipalResourceIds: MutableList = + getOrPut(roleResourceId, ::mutableListOf) + memberPrincipalResourceIds.add(principalResourceId) + } + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Principals.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Principals.kt new file mode 100644 index 00000000000..353aa78ddbe --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Principals.kt @@ -0,0 +1,215 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner.db + +import com.google.cloud.spanner.Key +import com.google.cloud.spanner.KeySet +import com.google.cloud.spanner.Mutation +import com.google.cloud.spanner.Struct +import com.google.cloud.spanner.Value +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import org.wfanet.measurement.access.service.internal.PrincipalNotFoundException +import org.wfanet.measurement.access.service.internal.PrincipalNotFoundForUserException +import org.wfanet.measurement.common.singleOrNullIfEmpty +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.gcloud.spanner.bufferInsertMutation +import org.wfanet.measurement.gcloud.spanner.statement +import org.wfanet.measurement.internal.access.Principal +import org.wfanet.measurement.internal.access.PrincipalKt.oAuthUser +import org.wfanet.measurement.internal.access.principal + +data class PrincipalResult(val principalId: Long, val principal: Principal) + +/** Returns whether the Principal with the specified [principalId] exists. */ +suspend fun AsyncDatabaseClient.ReadContext.principalExists(principalId: Long): Boolean { + return readRow("Principals", Key.of(principalId), listOf("PrincipalId")) != null +} + +/** + * Reads the principal ID by its resource ID. + * + * @throws PrincipalNotFoundException + */ +suspend fun AsyncDatabaseClient.ReadContext.getPrincipalIdByResourceId( + principalResourceId: String +): Long { + val struct = + readRowUsingIndex( + "Principals", + "PrincipalsByResourceId", + Key.of(principalResourceId), + "PrincipalId", + ) ?: throw PrincipalNotFoundException(principalResourceId) + return struct.getLong("PrincipalId") +} + +/** + * Reads the [Principal] by its resource ID. + * + * @throws PrincipalNotFoundException + */ +suspend fun AsyncDatabaseClient.ReadContext.getPrincipalByResourceId( + principalResourceId: String +): PrincipalResult { + val sql = + """ + SELECT + Principals.*, + Subject, + Issuer + FROM + Principals + JOIN UserPrincipals USING (PrincipalId) + WHERE + PrincipalResourceId = @principalResourceId + """ + .trimIndent() + val struct = + executeQuery(statement(sql) { bind("principalResourceId").to(principalResourceId) }) + .singleOrNullIfEmpty() ?: throw PrincipalNotFoundException(principalResourceId) + + return PrincipalResult( + struct.getLong("PrincipalId"), + principal { + this.principalResourceId = principalResourceId + createTime = struct.getTimestamp("CreateTime").toProto() + updateTime = struct.getTimestamp("UpdateTime").toProto() + user = oAuthUser { + issuer = struct.getString("Issuer") + subject = struct.getString("Subject") + } + }, + ) +} + +/** + * Reads the principal IDs for the specified resource IDs. + * + * @throws PrincipalNotFoundException + */ +suspend fun AsyncDatabaseClient.ReadContext.getPrincipalIdsByResourceIds( + principalResourceIds: Collection +): Map { + if (principalResourceIds.isEmpty()) { + return emptyMap() + } + + val keys = + KeySet.newBuilder() + .apply { + for (principalResourceId in principalResourceIds) { + addKey(Key.of(principalResourceId)) + } + } + .build() + + val principalIdByResourceId: Map = + readUsingIndex( + "Principals", + "PrincipalsByResourceId", + keys, + listOf("PrincipalResourceId", "PrincipalId"), + ) + .map { it.getString(0) to it.getLong(1) } + .toList() + .toMap() + + if (principalIdByResourceId.size != principalResourceIds.size) { + for (principalResourceId in principalResourceIds) { + if (!principalIdByResourceId.containsKey(principalResourceId)) { + throw PrincipalNotFoundException(principalResourceId) + } + } + } + + return principalIdByResourceId +} + +/** + * Reads the [Principal] by its user lookup key. + * + * @throws PrincipalNotFoundForUserException + */ +suspend fun AsyncDatabaseClient.ReadContext.getPrincipalByUserKey( + issuer: String, + subject: String, +): PrincipalResult { + val sql = + """ + SELECT + Principals.* + FROM + Principals + JOIN UserPrincipals USING (PrincipalId) + WHERE + Issuer = @issuer + AND Subject = @subject + """ + .trimIndent() + val row: Struct = + executeQuery( + statement(sql) { + bind("issuer").to(issuer) + bind("subject").to(subject) + } + ) + .singleOrNullIfEmpty() ?: throw PrincipalNotFoundForUserException(issuer, subject) + return PrincipalResult( + row.getLong("PrincipalId"), + principal { + principalResourceId = row.getString("PrincipalResourceId") + user = oAuthUser { + this.issuer = issuer + this.subject = subject + } + createTime = row.getTimestamp("CreateTime").toProto() + updateTime = row.getTimestamp("UpdateTime").toProto() + }, + ) +} + +/** Buffers an insert mutation to the Principals table. */ +fun AsyncDatabaseClient.TransactionContext.insertPrincipal( + principalId: Long, + principalResourceId: String, +) { + bufferInsertMutation("Principals") { + set("PrincipalId").to(principalId) + set("PrincipalResourceId").to(principalResourceId) + set("CreateTime").to(Value.COMMIT_TIMESTAMP) + set("UpdateTime").to(Value.COMMIT_TIMESTAMP) + } +} + +/** Buffers an insert mutation to the UserPrincipals table. */ +fun AsyncDatabaseClient.TransactionContext.insertUserPrincipal( + principalId: Long, + issuer: String, + subject: String, +) { + bufferInsertMutation("UserPrincipals") { + set("PrincipalId").to(principalId) + set("Issuer").to(issuer) + set("Subject").to(subject) + } +} + +/** Buffers a delete mutation to the Principals table. */ +fun AsyncDatabaseClient.TransactionContext.deletePrincipal(principalId: Long) { + buffer(Mutation.delete("Principals", Key.of(principalId))) +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Roles.kt b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Roles.kt new file mode 100644 index 00000000000..e0e6a94ff58 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/db/Roles.kt @@ -0,0 +1,267 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner.db + +import com.google.cloud.spanner.Key +import com.google.cloud.spanner.KeySet +import com.google.cloud.spanner.Mutation +import com.google.cloud.spanner.Struct +import com.google.cloud.spanner.Value +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.access.service.internal.PermissionNotFoundForRoleException +import org.wfanet.measurement.access.service.internal.RoleNotFoundException +import org.wfanet.measurement.common.api.ETags +import org.wfanet.measurement.common.singleOrNullIfEmpty +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.gcloud.spanner.bufferInsertMutation +import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation +import org.wfanet.measurement.gcloud.spanner.statement +import org.wfanet.measurement.internal.access.ListRolesPageToken +import org.wfanet.measurement.internal.access.Role +import org.wfanet.measurement.internal.access.role + +data class RoleResult(val roleId: Long, val role: Role) + +/** + * Reads a [Role] by its resource ID. + * + * @throws RoleNotFoundException + * @throws PermissionNotFoundForRoleException + */ +suspend fun AsyncDatabaseClient.ReadContext.getRoleByResourceId( + permissionMapping: PermissionMapping, + roleResourceId: String, +): RoleResult { + val sql = + """ + SELECT + RoleId, + CreateTime, + UpdateTime, + ARRAY( + SELECT ResourceType FROM RoleResourceTypes WHERE RoleResourceTypes.RoleId = Roles.RoleId + ) AS ResourceTypes, + ARRAY( + SELECT PermissionId FROM RolePermissions WHERE RolePermissions.RoleId = Roles.RoleId + ) AS PermissionIds + FROM + Roles + WHERE + RoleResourceId = @roleResourceId + """ + .trimIndent() + val row: Struct = + executeQuery(statement(sql) { bind("roleResourceId").to(roleResourceId) }).singleOrNullIfEmpty() + ?: throw RoleNotFoundException(roleResourceId) + + val permissionResourceIds = + row.getLongList("PermissionIds").map { + val mappingPermission: PermissionMapping.Permission = + permissionMapping.getPermissionById(it) + ?: throw PermissionNotFoundForRoleException(roleResourceId) + mappingPermission.permissionResourceId + } + + return RoleResult( + row.getLong("RoleId"), + role { + this.roleResourceId = roleResourceId + createTime = row.getTimestamp("CreateTime").toProto() + updateTime = row.getTimestamp("UpdateTime").toProto() + resourceTypes += row.getStringList("ResourceTypes") + this.permissionResourceIds += permissionResourceIds + etag = ETags.computeETag(updateTime.toInstant()) + }, + ) +} + +/** + * Reads the ID of the [Role] with the specified resource ID. + * + * @throws RoleNotFoundException + */ +suspend fun AsyncDatabaseClient.ReadContext.getRoleIdByResourceId(roleResourceId: String): Long { + val row = + readRowUsingIndex("Roles", "RolesByResourceId", Key.of(roleResourceId), "RoleId") + ?: throw RoleNotFoundException(roleResourceId) + return row.getLong("RoleId") +} + +/** + * Reads the role IDs for the specified resource IDs. + * + * @throws RoleNotFoundException if no role ID is found for a specified resource ID + */ +suspend fun AsyncDatabaseClient.ReadContext.getRoleIdsByResourceIds( + roleResourceIds: Collection +): Map { + if (roleResourceIds.isEmpty()) { + return emptyMap() + } + + val keys = + KeySet.newBuilder() + .apply { + for (roleResourceId in roleResourceIds) { + addKey(Key.of(roleResourceId)) + } + } + .build() + + val roleIdByResourceId: Map = + readUsingIndex("Roles", "RolesByResourceId", keys, listOf("RoleResourceId", "RoleId")) + .map { it.getString(0) to it.getLong(1) } + .toList() + .toMap() + + if (roleIdByResourceId.size != roleResourceIds.size) { + for (roleResourceId in roleResourceIds) { + if (!roleIdByResourceId.containsKey(roleResourceId)) { + throw RoleNotFoundException(roleResourceId) + } + } + } + + return roleIdByResourceId +} + +/** Returns whether a [Role] with the specified [roleId] exists. */ +suspend fun AsyncDatabaseClient.ReadContext.roleExists(roleId: Long): Boolean { + return readRow("Roles", Key.of(roleId), listOf("RoleId")) != null +} + +/** + * Reads [Role]s ordered by resource ID. + * + * @throws PermissionNotFoundForRoleException + */ +fun AsyncDatabaseClient.ReadContext.readRoles( + permissionMapping: PermissionMapping, + limit: Int, + after: ListRolesPageToken.After? = null, +): Flow { + val sql = buildString { + appendLine( + """ + SELECT + Roles.*, + ARRAY( + SELECT ResourceType FROM RoleResourceTypes WHERE RoleResourceTypes.RoleId = Roles.RoleId + ) AS ResourceTypes, + ARRAY( + SELECT PermissionId FROM RolePermissions WHERE RolePermissions.RoleId = Roles.RoleId + ) AS PermissionIds + FROM + Roles + """ + .trimIndent() + ) + if (after != null) { + appendLine("WHERE RoleResourceId > @afterRoleResourceId") + } + appendLine("ORDER BY RoleResourceId") + appendLine("LIMIT @limit") + } + val query = + statement(sql) { + if (after != null) { + bind("afterRoleResourceId").to(after.roleResourceId) + } + bind("limit").to(limit.toLong()) + } + + return executeQuery(query).map { row -> + val roleResourceId = row.getString("RoleResourceId") + val permissionResourceIds = + row.getLongList("PermissionIds").map { + val mappingPermission: PermissionMapping.Permission = + permissionMapping.getPermissionById(it) + ?: throw PermissionNotFoundForRoleException(roleResourceId) + mappingPermission.permissionResourceId + } + RoleResult( + row.getLong("RoleId"), + role { + this.roleResourceId = roleResourceId + createTime = row.getTimestamp("CreateTime").toProto() + updateTime = row.getTimestamp("UpdateTime").toProto() + resourceTypes += row.getStringList("ResourceTypes") + this.permissionResourceIds += permissionResourceIds + etag = ETags.computeETag(updateTime.toInstant()) + }, + ) + } +} + +/** Buffers an insert mutation for the Roles table. */ +fun AsyncDatabaseClient.TransactionContext.insertRole(roleId: Long, roleResourceId: String) { + bufferInsertMutation("Roles") { + set("RoleId").to(roleId) + set("RoleResourceId").to(roleResourceId) + set("CreateTime").to(Value.COMMIT_TIMESTAMP) + set("UpdateTime").to(Value.COMMIT_TIMESTAMP) + } +} + +/** Buffers an update mutation for the Roles table. */ +fun AsyncDatabaseClient.TransactionContext.updateRole(roleId: Long) { + bufferUpdateMutation("Roles") { + set("RoleId").to(roleId) + set("UpdateTime").to(Value.COMMIT_TIMESTAMP) + } +} + +/** Buffers a delete mutation for the Roles table. */ +fun AsyncDatabaseClient.TransactionContext.deleteRole(roleId: Long) { + buffer(Mutation.delete("Roles", Key.of(roleId))) +} + +/** Buffers an insert mutation for the RoleResourceTypes table. */ +fun AsyncDatabaseClient.TransactionContext.insertRoleResourceType( + roleId: Long, + resourceType: String, +) { + bufferInsertMutation("RoleResourceTypes") { + set("RoleId").to(roleId) + set("ResourceType").to(resourceType) + } +} + +/** Buffers a delete mutation for the RoleResourceTypes table. */ +fun AsyncDatabaseClient.TransactionContext.deleteRoleResourceType( + roleId: Long, + resourceType: String, +) { + buffer(Mutation.delete("RoleResourceTypes", Key.of(roleId, resourceType))) +} + +/** Buffers an insert mutation for the RolePermissions table. */ +fun AsyncDatabaseClient.TransactionContext.insertRolePermission(roleId: Long, permissionId: Long) { + bufferInsertMutation("RolePermissions") { + set("RoleId").to(roleId) + set("PermissionId").to(permissionId) + } +} + +/** Buffers a delete mutation for the RolePermissions table. */ +fun AsyncDatabaseClient.TransactionContext.deleteRolePermission(roleId: Long, permissionId: Long) { + buffer(Mutation.delete("RolePermissions", Key.of(roleId, permissionId))) +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/access/service/internal/BUILD.bazel index 0a1e8bb6134..c74e29102c5 100644 --- a/src/main/kotlin/org/wfanet/measurement/access/service/internal/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/BUILD.bazel @@ -18,3 +18,18 @@ kt_jvm_library( "@wfa_common_jvm//imports/java/io/grpc:api", ], ) + +kt_jvm_library( + name = "id_generator", + srcs = ["IdGenerator.kt"], +) + +kt_jvm_library( + name = "permission_mapping", + srcs = ["PermissionMapping.kt"], + deps = [ + "//src/main/proto/wfa/measurement/config/access:permissions_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/access:permission_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/common:guava", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/IdGenerator.kt b/src/main/kotlin/org/wfanet/measurement/access/service/internal/IdGenerator.kt new file mode 100644 index 00000000000..686f2da578c --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/IdGenerator.kt @@ -0,0 +1,46 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.service.internal + +import kotlin.random.Random + +fun interface IdGenerator { + /** Returns a non-zero ID. */ + fun generateId(): Long + + companion object { + val Default: IdGenerator = RandomIdGenerator() + } +} + +class RandomIdGenerator(private val random: Random = Random.Default) : IdGenerator { + override fun generateId(): Long { + var nextId = 0L + while (nextId == 0L) { + nextId = random.nextLong() + } + return nextId + } +} + +inline fun IdGenerator.generateNewId(idExists: (id: Long) -> Boolean): Long { + var id = generateId() + while (idExists(id)) { + id = generateId() + } + return id +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/PermissionMapping.kt b/src/main/kotlin/org/wfanet/measurement/access/service/internal/PermissionMapping.kt new file mode 100644 index 00000000000..8163f57bd3e --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/PermissionMapping.kt @@ -0,0 +1,86 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.service.internal + +import com.google.common.hash.Hashing +import org.wfanet.measurement.config.access.PermissionsConfig +import org.wfanet.measurement.internal.access.Permission +import org.wfanet.measurement.internal.access.permission + +class PermissionMapping(config: PermissionsConfig) { + data class Permission( + val permissionId: Long, + val permissionResourceId: String, + val protectedResourceTypes: Set, + ) + + /** Permissions sorted by resource ID. */ + val permissions: List = + buildList { + for ((permissionResourceId, configPermission) in config.permissionsMap) { + check(PERMISSION_RESOURCE_ID_REGEX.matches(permissionResourceId)) { + "Invalid permission resource ID $permissionResourceId" + } + val permissionId = fingerprint(permissionResourceId) + add( + Permission( + permissionId, + permissionResourceId, + configPermission.protectedResourceTypesList.toSet(), + ) + ) + } + } + .sortedBy { it.permissionResourceId } + + private val permissionsById: Map = + buildMap(permissions.size) { + for (permission in permissions) { + val existingPermission = get(permission.permissionId) + if (existingPermission != null) { + error( + "Fingerprinting collision between permissions " + + "${existingPermission.permissionResourceId} and ${permission.permissionResourceId}" + ) + } + put(permission.permissionId, permission) + } + } + + private val permissionsByResourceId: Map = + permissions.associateBy { it.permissionResourceId } + + fun getPermissionById(permissionId: Long) = permissionsById[permissionId] + + fun getPermissionByResourceId(permissionResourceId: String) = + permissionsByResourceId[permissionResourceId] + + companion object { + private fun fingerprint(input: String): Long = + Hashing.farmHashFingerprint64().hashString(input, Charsets.UTF_8).asLong() + + private val PERMISSION_RESOURCE_ID_REGEX = Regex("^[a-zA-Z]([a-zA-Z0-9.-]{0,61}[a-zA-Z0-9])?$") + } +} + +fun PermissionMapping.Permission.toPermission(): Permission { + val source = this + return permission { + permissionResourceId = source.permissionResourceId + resourceTypes += source.protectedResourceTypes + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/BUILD.bazel new file mode 100644 index 00000000000..b7ffa2263e8 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/BUILD.bazel @@ -0,0 +1,99 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package( + default_testonly = True, + default_visibility = [ + "//src/test/kotlin/org/wfanet/measurement/access/deploy:__subpackages__", + ], +) + +kt_jvm_library( + name = "test_config", + srcs = ["TestConfig.kt"], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:permission_mapping", + "//src/main/proto/wfa/measurement/config/access:permissions_config_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + ], +) + +kt_jvm_library( + name = "principals_service_test", + srcs = ["PrincipalsServiceTest.kt"], + deps = [ + ":test_config", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator", + "//src/main/kotlin/org/wfanet/measurement/common/grpc:error_info", + "//src/main/proto/wfa/measurement/internal/access:principals_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", + ], +) + +kt_jvm_library( + name = "permissions_service_test", + srcs = ["PermissionsServiceTest.kt"], + deps = [ + ":test_config", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator", + "//src/main/kotlin/org/wfanet/measurement/common/grpc:error_info", + "//src/main/proto/wfa/measurement/internal/access:permissions_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/internal/access:policies_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/internal/access:principals_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/internal/access:roles_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + ], +) + +kt_jvm_library( + name = "roles_service_test", + srcs = ["RolesServiceTest.kt"], + deps = [ + ":test_config", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:permission_mapping", + "//src/main/kotlin/org/wfanet/measurement/common/grpc:error_info", + "//src/main/proto/wfa/measurement/internal/access:roles_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", + ], +) + +kt_jvm_library( + name = "policies_service_test", + srcs = ["PoliciesServiceTest.kt"], + deps = [ + ":test_config", + "//src/main/kotlin/org/wfanet/measurement/access/common:tls_client_principal_mapping", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:errors", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:id_generator", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal:permission_mapping", + "//src/main/kotlin/org/wfanet/measurement/common/grpc:error_info", + "//src/main/proto/wfa/measurement/internal/access:policies_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/internal/access:principals_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/internal/access:roles_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PermissionsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PermissionsServiceTest.kt new file mode 100644 index 00000000000..7eb7fd2d030 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PermissionsServiceTest.kt @@ -0,0 +1,264 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.service.internal.testing + +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth.assertThat +import com.google.protobuf.Descriptors +import com.google.rpc.errorInfo +import io.grpc.Status +import io.grpc.StatusRuntimeException +import kotlin.test.assertFailsWith +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.service.internal.Errors +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.access.service.internal.toPermission +import org.wfanet.measurement.common.grpc.errorInfo +import org.wfanet.measurement.internal.access.CheckPermissionsResponse +import org.wfanet.measurement.internal.access.ListPermissionsPageTokenKt +import org.wfanet.measurement.internal.access.ListPermissionsRequest +import org.wfanet.measurement.internal.access.ListPermissionsResponse +import org.wfanet.measurement.internal.access.Permission +import org.wfanet.measurement.internal.access.PermissionsGrpcKt +import org.wfanet.measurement.internal.access.PoliciesGrpcKt +import org.wfanet.measurement.internal.access.PolicyKt +import org.wfanet.measurement.internal.access.PrincipalKt.oAuthUser +import org.wfanet.measurement.internal.access.PrincipalsGrpcKt +import org.wfanet.measurement.internal.access.RolesGrpcKt +import org.wfanet.measurement.internal.access.checkPermissionsRequest +import org.wfanet.measurement.internal.access.checkPermissionsResponse +import org.wfanet.measurement.internal.access.createUserPrincipalRequest +import org.wfanet.measurement.internal.access.getPermissionRequest +import org.wfanet.measurement.internal.access.listPermissionsPageToken +import org.wfanet.measurement.internal.access.listPermissionsRequest +import org.wfanet.measurement.internal.access.listPermissionsResponse +import org.wfanet.measurement.internal.access.policy +import org.wfanet.measurement.internal.access.role + +@RunWith(JUnit4::class) +abstract class PermissionsServiceTest { + protected data class Services( + /** Service under test. */ + val service: PermissionsGrpcKt.PermissionsCoroutineImplBase, + val principalsService: PrincipalsGrpcKt.PrincipalsCoroutineImplBase, + val rolesServices: RolesGrpcKt.RolesCoroutineImplBase, + val policiesService: PoliciesGrpcKt.PoliciesCoroutineImplBase, + ) + + protected abstract fun initServices( + permissionMapping: PermissionMapping, + tlsClientMapping: TlsClientPrincipalMapping, + idGenerator: IdGenerator, + ): Services + + private fun initServices(idGenerator: IdGenerator = IdGenerator.Default) = + initServices(TestConfig.PERMISSION_MAPPING, TestConfig.TLS_CLIENT_MAPPING, idGenerator) + + @Test + fun `getPermission returns Permission`() = runBlocking { + val service = initServices().service + + val response: Permission = + service.getPermission( + getPermissionRequest { permissionResourceId = PERMISSIONS.first().permissionResourceId } + ) + + assertThat(response) + .ignoringRepeatedFieldOrderOfFieldDescriptors(RESOURCE_TYPES_FIELD) + .isEqualTo(PERMISSIONS.first()) + } + + @Test + fun `getPermission throws NOT_FOUND when Permission not found`() = runBlocking { + val service = initServices().service + + val request = getPermissionRequest { permissionResourceId = "not-found" } + val exception = assertFailsWith { service.getPermission(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.PERMISSION_NOT_FOUND.name + metadata[Errors.Metadata.PERMISSION_RESOURCE_ID.key] = request.permissionResourceId + } + ) + } + + @Test + fun `listPermissions returns Permissions ordered by resource ID`() { + runBlocking { + val service = initServices().service + + val response: ListPermissionsResponse = + service.listPermissions(ListPermissionsRequest.getDefaultInstance()) + + assertThat(response) + .ignoringRepeatedFieldOrderOfFieldDescriptors(RESOURCE_TYPES_FIELD) + .isEqualTo(listPermissionsResponse { permissions += PERMISSIONS }) + } + } + + @Test + fun `listPermissions returns Permissions when page size is specified`() = runBlocking { + val service = initServices().service + + val response: ListPermissionsResponse = + service.listPermissions(listPermissionsRequest { pageSize = PERMISSIONS.size }) + + assertThat(response) + .ignoringRepeatedFieldOrderOfFieldDescriptors(RESOURCE_TYPES_FIELD) + .isEqualTo(listPermissionsResponse { permissions += PERMISSIONS }) + } + + @Test + fun `listPermissions returns next page token when there are more results`() = runBlocking { + val service = initServices().service + + val response: ListPermissionsResponse = + service.listPermissions(listPermissionsRequest { pageSize = 2 }) + + assertThat(response) + .ignoringRepeatedFieldOrderOfFieldDescriptors(RESOURCE_TYPES_FIELD) + .isEqualTo( + listPermissionsResponse { + permissions += PERMISSIONS.take(2) + nextPageToken = listPermissionsPageToken { + after = + ListPermissionsPageTokenKt.after { + permissionResourceId = + this@listPermissionsResponse.permissions.last().permissionResourceId + } + } + } + ) + } + + @Test + fun `listPermissions returns Permissions after page token`() = runBlocking { + val service = initServices().service + + val response: ListPermissionsResponse = + service.listPermissions( + listPermissionsRequest { + pageToken = listPermissionsPageToken { + after = + ListPermissionsPageTokenKt.after { + permissionResourceId = PERMISSIONS[1].permissionResourceId + } + } + } + ) + + assertThat(response) + .ignoringRepeatedFieldOrderOfFieldDescriptors(RESOURCE_TYPES_FIELD) + .isEqualTo(listPermissionsResponse { permissions += PERMISSIONS.drop(2) }) + } + + @Test + fun `checkPermissions returns all requested permissions for TLS client principal`() { + val service = initServices().service + val request = checkPermissionsRequest { + protectedResourceName = TestConfig.TLS_CLIENT_PROTECTED_RESOURCE_NAME + principalResourceId = TestConfig.TLS_CLIENT_PRINCIPAL_RESOURCE_ID + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_CREATE + } + + val response: CheckPermissionsResponse = runBlocking { service.checkPermissions(request) } + + assertThat(response) + .isEqualTo( + checkPermissionsResponse { permissionResourceIds += request.permissionResourceIdsList } + ) + } + + @Test + fun `checkPermissions returns no permissions for TLS client principal with wrong protected resource`() { + val service = initServices().service + val request = checkPermissionsRequest { + protectedResourceName = "shelves/404" + principalResourceId = TestConfig.TLS_CLIENT_PRINCIPAL_RESOURCE_ID + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_CREATE + } + + val response: CheckPermissionsResponse = runBlocking { service.checkPermissions(request) } + + assertThat(response).isEqualTo(CheckPermissionsResponse.getDefaultInstance()) + } + + @Test + fun `checkPermissions returns permissions for user Principal`(): Unit = runBlocking { + val (service, principalsService, rolesService, policiesService) = initServices() + val principal = + principalsService.createUserPrincipal( + createUserPrincipalRequest { + principalResourceId = "user-1" + user = oAuthUser { + issuer = "example-issuer" + subject = "user@example.com" + } + } + ) + val role = + rolesService.createRole( + role { + roleResourceId = "shelfBookReader" + resourceTypes += TestConfig.ResourceType.SHELF + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_LIST + } + ) + val policy = + policiesService.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = + PolicyKt.members { memberPrincipalResourceIds += principal.principalResourceId } + } + ) + + val response = + service.checkPermissions( + checkPermissionsRequest { + protectedResourceName = policy.protectedResourceName + principalResourceId = principal.principalResourceId + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_LIST + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_CREATE + } + ) + + assertThat(response.permissionResourceIdsList) + .containsExactly(TestConfig.PermissionResourceId.BOOKS_LIST) + } + + companion object { + private val RESOURCE_TYPES_FIELD: Descriptors.FieldDescriptor = + Permission.getDescriptor().findFieldByNumber(Permission.RESOURCE_TYPES_FIELD_NUMBER) + + private val PERMISSIONS: List = + TestConfig.PERMISSION_MAPPING.permissions.map { it.toPermission() } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PoliciesServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PoliciesServiceTest.kt new file mode 100644 index 00000000000..6786ada0ba4 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PoliciesServiceTest.kt @@ -0,0 +1,444 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.service.internal.testing + +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth.assertThat +import com.google.rpc.errorInfo +import io.grpc.Status +import io.grpc.StatusRuntimeException +import java.time.Instant +import kotlin.test.assertFailsWith +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.service.internal.Errors +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.common.grpc.errorInfo +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.internal.access.PoliciesGrpcKt +import org.wfanet.measurement.internal.access.Policy +import org.wfanet.measurement.internal.access.PolicyKt +import org.wfanet.measurement.internal.access.PrincipalKt.oAuthUser +import org.wfanet.measurement.internal.access.PrincipalsGrpcKt +import org.wfanet.measurement.internal.access.RolesGrpcKt +import org.wfanet.measurement.internal.access.addPolicyBindingMembersRequest +import org.wfanet.measurement.internal.access.copy +import org.wfanet.measurement.internal.access.createUserPrincipalRequest +import org.wfanet.measurement.internal.access.getPolicyRequest +import org.wfanet.measurement.internal.access.lookupPolicyRequest +import org.wfanet.measurement.internal.access.policy +import org.wfanet.measurement.internal.access.removePolicyBindingMembersRequest +import org.wfanet.measurement.internal.access.role + +@RunWith(JUnit4::class) +abstract class PoliciesServiceTest { + protected data class Services( + /** Service under test. */ + val service: PoliciesGrpcKt.PoliciesCoroutineImplBase, + val principalsService: PrincipalsGrpcKt.PrincipalsCoroutineImplBase, + val rolesServices: RolesGrpcKt.RolesCoroutineImplBase, + ) + + protected abstract fun initServices( + permissionMapping: PermissionMapping, + tlsClientMapping: TlsClientPrincipalMapping, + idGenerator: IdGenerator, + ): Services + + private fun initServices(idGenerator: IdGenerator = IdGenerator.Default) = + initServices(TestConfig.PERMISSION_MAPPING, TestConfig.TLS_CLIENT_MAPPING, idGenerator) + + @Test + fun `createPolicy returns created Policy`() = runBlocking { + val (service, principalsService, rolesService) = initServices() + val startTime = Instant.now() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + + val request = policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = + PolicyKt.members { memberPrincipalResourceIds += principal.principalResourceId } + } + val response: Policy = service.createPolicy(request) + + assertThat(response) + .ignoringRepeatedFieldOrder() + .ignoringFields( + Policy.CREATE_TIME_FIELD_NUMBER, + Policy.UPDATE_TIME_FIELD_NUMBER, + Policy.ETAG_FIELD_NUMBER, + ) + .isEqualTo(request) + assertThat(response.createTime.toInstant()).isGreaterThan(startTime) + assertThat(response.updateTime).isEqualTo(response.createTime) + assertThat(response.etag).isNotEmpty() + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo( + service.getPolicy(getPolicyRequest { policyResourceId = request.policyResourceId }) + ) + } + + @Test + fun `createPolicy throws INVALID_ARGUMENT if binding has no members`() = runBlocking { + val (service, _, rolesService) = initServices() + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + + val exception = + assertFailsWith { + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = Policy.Members.getDefaultInstance() + } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.REQUIRED_FIELD_NOT_SET.name + metadata[Errors.Metadata.FIELD_NAME.key] = "member_principal_resource_ids" + } + ) + } + + @Test + fun `lookupPolicy returns Policy`() = runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = + PolicyKt.members { memberPrincipalResourceIds += principal.principalResourceId } + } + ) + + val response: Policy = + service.lookupPolicy( + lookupPolicyRequest { protectedResourceName = policy.protectedResourceName } + ) + + assertThat(response).isEqualTo(policy) + } + + @Test + fun `addPolicyBindingMembers adds members to Policy binding`() = runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + } + ) + + val response = + service.addPolicyBindingMembers( + addPolicyBindingMembersRequest { + policyResourceId = policy.policyResourceId + roleResourceId = role.roleResourceId + memberPrincipalResourceIds += principal.principalResourceId + } + ) + + assertThat(response) + .ignoringRepeatedFieldOrder() + .ignoringFields(Policy.UPDATE_TIME_FIELD_NUMBER, Policy.ETAG_FIELD_NUMBER) + .isEqualTo( + policy.copy { + bindings[role.roleResourceId] = + PolicyKt.members { memberPrincipalResourceIds += principal.principalResourceId } + } + ) + assertThat(response.updateTime.toInstant()).isGreaterThan(policy.updateTime.toInstant()) + assertThat(response.etag).isNotEqualTo(policy.etag) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo(service.getPolicy(getPolicyRequest { policyResourceId = policy.policyResourceId })) + } + + @Test + fun `addPolicyBindingMembers throws FAILED_PRECONDITION if binding already has member`() = + runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = + PolicyKt.members { memberPrincipalResourceIds += principal.principalResourceId } + } + ) + + val exception = + assertFailsWith { + service.addPolicyBindingMembers( + addPolicyBindingMembersRequest { + policyResourceId = policy.policyResourceId + roleResourceId = role.roleResourceId + memberPrincipalResourceIds += principal.principalResourceId + } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.POLICY_BINDING_MEMBERSHIP_ALREADY_EXISTS.name + metadata[Errors.Metadata.POLICY_RESOURCE_ID.key] = policy.policyResourceId + metadata[Errors.Metadata.ROLE_RESOURCE_ID.key] = role.roleResourceId + metadata[Errors.Metadata.PRINCIPAL_RESOURCE_ID.key] = principal.principalResourceId + } + ) + } + + @Test + fun `addPolicyBindingMembers throws ABORTED on etag mismatch`() = runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + } + ) + + val request = addPolicyBindingMembersRequest { + policyResourceId = policy.policyResourceId + roleResourceId = role.roleResourceId + memberPrincipalResourceIds += principal.principalResourceId + etag = "invalid" + } + val exception = + assertFailsWith { service.addPolicyBindingMembers(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.ABORTED) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.ETAG_MISMATCH.name + metadata[Errors.Metadata.ETAG.key] = policy.etag + metadata[Errors.Metadata.REQUEST_ETAG.key] = request.etag + } + ) + } + + @Test + fun `removePolicyBindingMembers removes members from Policy binding`() = runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val principal2 = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_2_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = + PolicyKt.members { + memberPrincipalResourceIds += principal.principalResourceId + memberPrincipalResourceIds += principal2.principalResourceId + } + } + ) + + val response: Policy = + service.removePolicyBindingMembers( + removePolicyBindingMembersRequest { + policyResourceId = policy.policyResourceId + roleResourceId = role.roleResourceId + memberPrincipalResourceIds += principal2.principalResourceId + } + ) + + assertThat(response) + .ignoringRepeatedFieldOrder() + .ignoringFields(Policy.UPDATE_TIME_FIELD_NUMBER, Policy.ETAG_FIELD_NUMBER) + .isEqualTo( + policy.copy { + bindings[role.roleResourceId] = + PolicyKt.members { memberPrincipalResourceIds += principal.principalResourceId } + } + ) + assertThat(response.updateTime.toInstant()).isGreaterThan(policy.updateTime.toInstant()) + assertThat(response.etag).isNotEqualTo(policy.etag) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo(service.getPolicy(getPolicyRequest { policyResourceId = policy.policyResourceId })) + } + + @Test + fun `removePolicyBindingMembers removes all members from Policy binding`() = runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val principal2 = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_2_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = + PolicyKt.members { + memberPrincipalResourceIds += principal.principalResourceId + memberPrincipalResourceIds += principal2.principalResourceId + } + } + ) + + val response: Policy = + service.removePolicyBindingMembers( + removePolicyBindingMembersRequest { + policyResourceId = policy.policyResourceId + roleResourceId = role.roleResourceId + memberPrincipalResourceIds += principal.principalResourceId + memberPrincipalResourceIds += principal2.principalResourceId + } + ) + + assertThat(response) + .ignoringRepeatedFieldOrder() + .ignoringFields(Policy.UPDATE_TIME_FIELD_NUMBER, Policy.ETAG_FIELD_NUMBER) + .isEqualTo(policy.copy { bindings.clear() }) + assertThat(response.updateTime.toInstant()).isGreaterThan(policy.updateTime.toInstant()) + assertThat(response.etag).isNotEqualTo(policy.etag) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo(service.getPolicy(getPolicyRequest { policyResourceId = policy.policyResourceId })) + } + + @Test + fun `removePolicyBindingMembers throws FAILED_PRECONDITION if binding does not have member`() = + runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + } + ) + + val exception = + assertFailsWith { + service.removePolicyBindingMembers( + removePolicyBindingMembersRequest { + policyResourceId = policy.policyResourceId + roleResourceId = role.roleResourceId + memberPrincipalResourceIds += principal.principalResourceId + } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.POLICY_BINDING_MEMBERSHIP_NOT_FOUND.name + metadata[Errors.Metadata.POLICY_RESOURCE_ID.key] = policy.policyResourceId + metadata[Errors.Metadata.ROLE_RESOURCE_ID.key] = role.roleResourceId + metadata[Errors.Metadata.PRINCIPAL_RESOURCE_ID.key] = principal.principalResourceId + } + ) + } + + @Test + fun `removePolicyBindingMembers throws ABORTED on etag mistmatch`() = runBlocking { + val (service, principalsService, rolesService) = initServices() + val principal = principalsService.createUserPrincipal(CREATE_USER_PRINCIPAL_REQUEST) + val role = rolesService.createRole(CREATE_ROLE_REQUEST) + val policy: Policy = + service.createPolicy( + policy { + policyResourceId = "fantasy-shelf-policy" + protectedResourceName = "shelves/fantasy" + bindings[role.roleResourceId] = + PolicyKt.members { memberPrincipalResourceIds += principal.principalResourceId } + } + ) + + val request = removePolicyBindingMembersRequest { + policyResourceId = policy.policyResourceId + roleResourceId = role.roleResourceId + memberPrincipalResourceIds += principal.principalResourceId + etag = "invalid" + } + val exception = + assertFailsWith { service.removePolicyBindingMembers(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.ABORTED) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.ETAG_MISMATCH.name + metadata[Errors.Metadata.ETAG.key] = policy.etag + metadata[Errors.Metadata.REQUEST_ETAG.key] = request.etag + } + ) + } + + companion object { + private val CREATE_USER_PRINCIPAL_REQUEST = createUserPrincipalRequest { + principalResourceId = "user-1" + user = oAuthUser { + issuer = "example-issuer" + subject = "user@example.com" + } + } + + private val CREATE_USER_PRINCIPAL_2_REQUEST = + CREATE_USER_PRINCIPAL_REQUEST.copy { + principalResourceId = "user-2" + user = user.copy { subject = "user2@example.com" } + } + + private val CREATE_ROLE_REQUEST = role { + roleResourceId = "shelfBookReader" + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_LIST + resourceTypes += TestConfig.ResourceType.SHELF + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PrincipalsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PrincipalsServiceTest.kt new file mode 100644 index 00000000000..5da328f3a38 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/PrincipalsServiceTest.kt @@ -0,0 +1,259 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.service.internal.testing + +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth.assertThat +import com.google.rpc.errorInfo +import io.grpc.Status +import io.grpc.StatusRuntimeException +import java.time.Instant +import kotlin.test.assertFailsWith +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.mockito.kotlin.mock +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.service.internal.Errors +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.common.grpc.errorInfo +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.internal.access.Principal +import org.wfanet.measurement.internal.access.PrincipalKt.oAuthUser +import org.wfanet.measurement.internal.access.PrincipalKt.tlsClient +import org.wfanet.measurement.internal.access.PrincipalsGrpcKt +import org.wfanet.measurement.internal.access.copy +import org.wfanet.measurement.internal.access.createUserPrincipalRequest +import org.wfanet.measurement.internal.access.deletePrincipalRequest +import org.wfanet.measurement.internal.access.getPrincipalRequest +import org.wfanet.measurement.internal.access.lookupPrincipalRequest +import org.wfanet.measurement.internal.access.principal + +@RunWith(JUnit4::class) +abstract class PrincipalsServiceTest { + /** Initializes the service under test. */ + abstract fun initService( + tlsClientMapping: TlsClientPrincipalMapping, + idGenerator: IdGenerator, + ): PrincipalsGrpcKt.PrincipalsCoroutineImplBase + + private fun initService(idGenerator: IdGenerator = IdGenerator.Default) = + initService(TestConfig.TLS_CLIENT_MAPPING, idGenerator) + + @Test + fun `getPrincipal returns TLS client principal`() { + val service = initService() + val request = getPrincipalRequest { + principalResourceId = TestConfig.TLS_CLIENT_PRINCIPAL_RESOURCE_ID + } + + val response: Principal = runBlocking { service.getPrincipal(request) } + + assertThat(response) + .isEqualTo( + principal { + principalResourceId = TestConfig.TLS_CLIENT_PRINCIPAL_RESOURCE_ID + tlsClient = tlsClient { authorityKeyIdentifier = TestConfig.AUTHORITY_KEY_IDENTIFIER } + } + ) + } + + @Test + fun `createUserPrincipal returns created principal`() = runBlocking { + val service = initService() + val request = createUserPrincipalRequest { + principalResourceId = "user-1" + user = oAuthUser { + issuer = "example-issuer" + subject = "user@example.com" + } + } + + val principal: Principal = service.createUserPrincipal(request) + + assertThat(principal) + .ignoringFields(Principal.CREATE_TIME_FIELD_NUMBER, Principal.UPDATE_TIME_FIELD_NUMBER) + .isEqualTo( + principal { + principalResourceId = request.principalResourceId + user = request.user + } + ) + assertThat(principal.createTime.toInstant()).isGreaterThan(Instant.now().minusSeconds(10)) + assertThat(principal.createTime).isEqualTo(principal.updateTime) + assertThat(principal) + .isEqualTo( + service.getPrincipal( + getPrincipalRequest { principalResourceId = request.principalResourceId } + ) + ) + } + + @Test + fun `createUserPrincipal retries ID generation if ID already in use`(): Unit = runBlocking { + val principalId1 = 1234L + val principalId2 = 2345L + val idGeneratorMock = + mock { on { generateId() }.thenReturn(principalId1, principalId1, principalId2) } + val service = initService(idGeneratorMock) + service.createUserPrincipal( + createUserPrincipalRequest { + principalResourceId = "user-1" + user = oAuthUser { + issuer = "example-issuer" + subject = "user@example.com" + } + } + ) + + service.createUserPrincipal( + createUserPrincipalRequest { + principalResourceId = "user-2" + user = oAuthUser { + issuer = "example-issuer" + subject = "user2@example.com" + } + } + ) + + verify(idGeneratorMock, times(3)).generateId() + } + + @Test + fun `createUserPrincipal throws ALREADY_EXISTS if Principal with resource ID already exists`() = + runBlocking { + val service = initService() + val request = createUserPrincipalRequest { + principalResourceId = "user-1" + user = oAuthUser { + issuer = "example-issuer" + subject = "user@example.com" + } + } + service.createUserPrincipal(request) + + val exception = + assertFailsWith { + service.createUserPrincipal( + request.copy { user = user.copy { subject = "user2@example.com" } } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.ALREADY_EXISTS) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.PRINCIPAL_ALREADY_EXISTS.name + } + ) + } + + @Test + fun `lookupPrincipal returns TLS client principal`() { + val service = initService() + val request = lookupPrincipalRequest { + tlsClient = tlsClient { authorityKeyIdentifier = TestConfig.AUTHORITY_KEY_IDENTIFIER } + } + + val response: Principal = runBlocking { service.lookupPrincipal(request) } + + assertThat(response) + .isEqualTo( + principal { + principalResourceId = TestConfig.TLS_CLIENT_PRINCIPAL_RESOURCE_ID + tlsClient = request.tlsClient + } + ) + } + + @Test + fun `lookupPrincipal returns user principal`() = runBlocking { + val service = initService() + val principal = + service.createUserPrincipal( + createUserPrincipalRequest { + principalResourceId = "user-1" + user = oAuthUser { + issuer = "example-issuer" + subject = "user@example.com" + } + } + ) + val request = lookupPrincipalRequest { user = principal.user } + + val response: Principal = runBlocking { service.lookupPrincipal(request) } + + assertThat(response).isEqualTo(principal) + } + + @Test + fun `deletePrincipal deletes user Principal`() = runBlocking { + val service = initService() + val principal = + service.createUserPrincipal( + createUserPrincipalRequest { + principalResourceId = "user-1" + user = oAuthUser { + issuer = "example-issuer" + subject = "user@example.com" + } + } + ) + + service.deletePrincipal( + deletePrincipalRequest { principalResourceId = principal.principalResourceId } + ) + + val exception = + assertFailsWith { + service.getPrincipal( + getPrincipalRequest { principalResourceId = principal.principalResourceId } + ) + } + assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) + } + + @Test + fun `deletePrincipal throws FAILED_PRECONDITION for TLS client Principal`() = runBlocking { + val service = initService() + + val exception = + assertFailsWith { + service.deletePrincipal( + deletePrincipalRequest { + principalResourceId = TestConfig.TLS_CLIENT_PRINCIPAL_RESOURCE_ID + } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.PRINCIPAL_TYPE_NOT_SUPPORTED.name + metadata[Errors.Metadata.PRINCIPAL_TYPE.key] = "TLS_CLIENT" + metadata[Errors.Metadata.PRINCIPAL_RESOURCE_ID.key] = + TestConfig.TLS_CLIENT_PRINCIPAL_RESOURCE_ID + } + ) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/RolesServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/RolesServiceTest.kt new file mode 100644 index 00000000000..5f1a489b1af --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/RolesServiceTest.kt @@ -0,0 +1,457 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.service.internal.testing + +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth.assertThat +import com.google.rpc.errorInfo +import io.grpc.Status +import io.grpc.StatusRuntimeException +import java.time.Instant +import kotlin.test.assertFailsWith +import kotlinx.coroutines.runBlocking +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.mockito.kotlin.mock +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.wfanet.measurement.access.service.internal.Errors +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.common.grpc.errorInfo +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.internal.access.ListRolesPageTokenKt +import org.wfanet.measurement.internal.access.ListRolesRequest +import org.wfanet.measurement.internal.access.ListRolesResponse +import org.wfanet.measurement.internal.access.Role +import org.wfanet.measurement.internal.access.RolesGrpcKt +import org.wfanet.measurement.internal.access.copy +import org.wfanet.measurement.internal.access.deleteRoleRequest +import org.wfanet.measurement.internal.access.getRoleRequest +import org.wfanet.measurement.internal.access.listRolesPageToken +import org.wfanet.measurement.internal.access.listRolesRequest +import org.wfanet.measurement.internal.access.listRolesResponse +import org.wfanet.measurement.internal.access.role + +@RunWith(JUnit4::class) +abstract class RolesServiceTest { + /** Initializes the service under test. */ + abstract fun initService( + permissionMapping: PermissionMapping, + idGenerator: IdGenerator, + ): RolesGrpcKt.RolesCoroutineImplBase + + private fun initService(idGenerator: IdGenerator = IdGenerator.Default) = + initService(TestConfig.PERMISSION_MAPPING, idGenerator) + + @Test + fun `getRole throws NOT_FOUND when Role not found`() = runBlocking { + val service = initService() + val request = getRoleRequest { roleResourceId = "not-found" } + + val exception = assertFailsWith { service.getRole(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.ROLE_NOT_FOUND.name + metadata[Errors.Metadata.ROLE_RESOURCE_ID.key] = request.roleResourceId + } + ) + } + + @Test + fun `createRole returns created Role`() = runBlocking { + val service = initService() + val request = role { + roleResourceId = "shelfBookReader" + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_LIST + resourceTypes += TestConfig.ResourceType.SHELF + } + + val response: Role = service.createRole(request) + + assertThat(response) + .ignoringFields( + Role.CREATE_TIME_FIELD_NUMBER, + Role.UPDATE_TIME_FIELD_NUMBER, + Role.ETAG_FIELD_NUMBER, + ) + .ignoringRepeatedFieldOrder() + .isEqualTo(request) + assertThat(response.createTime.toInstant()).isGreaterThan(Instant.now().minusSeconds(10)) + assertThat(response.updateTime).isEqualTo(response.createTime) + assertThat(response.etag).isNotEmpty() + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo(service.getRole(getRoleRequest { roleResourceId = request.roleResourceId })) + } + + @Test + fun `createRole throws ALREADY_EXISTS if Role with resource ID already exists`() = runBlocking { + val service = initService() + val request = role { + roleResourceId = "bookReader" + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + resourceTypes += TestConfig.ResourceType.BOOK + } + service.createRole(request) + + val exception = assertFailsWith { service.createRole(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.ALREADY_EXISTS) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.ROLE_ALREADY_EXISTS.name + } + ) + } + + @Test + fun `createRole retries ID generation if ID already in use`(): Unit = runBlocking { + val principalId1 = 1234L + val principalId2 = 2345L + val idGeneratorMock = + mock { on { generateId() }.thenReturn(principalId1, principalId1, principalId2) } + val service = initService(idGeneratorMock) + service.createRole( + role { + roleResourceId = "shelfBookReader" + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_LIST + resourceTypes += TestConfig.ResourceType.SHELF + } + ) + + service.createRole( + role { + roleResourceId = "bookReader" + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + resourceTypes += TestConfig.ResourceType.BOOK + } + ) + + verify(idGeneratorMock, times(3)).generateId() + } + + @Test + fun `createRole throws FAILED_PRECONDITION if resource type not found in Permission`() = + runBlocking { + val service = initService() + val request = role { + roleResourceId = "bookWriter" + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_CREATE + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + resourceTypes += TestConfig.ResourceType.BOOK + resourceTypes += TestConfig.ResourceType.SHELF + } + + val exception = assertFailsWith { service.createRole(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.RESOURCE_TYPE_NOT_FOUND_IN_PERMISSION.name + metadata[Errors.Metadata.RESOURCE_TYPE.key] = TestConfig.ResourceType.BOOK + metadata[Errors.Metadata.PERMISSION_RESOURCE_ID.key] = + TestConfig.PermissionResourceId.BOOKS_CREATE + } + ) + } + + @Test + fun `updateRole returns updated Role`() = runBlocking { + val service = initService() + val role = + service.createRole( + role { + roleResourceId = "bookUser" + resourceTypes += TestConfig.ResourceType.BOOK + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_DELETE + } + ) + + val request = + role.copy { + resourceTypes.clear() + resourceTypes += TestConfig.ResourceType.SHELF + + permissionResourceIds.clear() + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_LIST + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_CREATE + } + val response: Role = service.updateRole(request) + + assertThat(response) + .ignoringRepeatedFieldOrder() + .ignoringFields(Role.UPDATE_TIME_FIELD_NUMBER, Role.ETAG_FIELD_NUMBER) + .isEqualTo(request) + assertThat(response.updateTime.toInstant()).isGreaterThan(role.updateTime.toInstant()) + assertThat(response.etag).isNotEqualTo(role.etag) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo(service.getRole(getRoleRequest { roleResourceId = role.roleResourceId })) + } + + @Test + fun `updateRole throws INVALID_ARGUMENT if etag not set`() = runBlocking { + val service = initService() + val role = + service.createRole( + role { + roleResourceId = "bookUser" + resourceTypes += TestConfig.ResourceType.BOOK + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_DELETE + } + ) + + val exception = + assertFailsWith { service.updateRole(role.copy { clearEtag() }) } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.REQUIRED_FIELD_NOT_SET.name + metadata[Errors.Metadata.FIELD_NAME.key] = "etag" + } + ) + } + + @Test + fun `updateRole throws ABORTED if etag does not match`() = runBlocking { + val service = initService() + val role = + service.createRole( + role { + roleResourceId = "bookUser" + resourceTypes += TestConfig.ResourceType.BOOK + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + } + ) + + val request = role.copy { etag = "W/\"foo\"" } + val exception = assertFailsWith { service.updateRole(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.ABORTED) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.ETAG_MISMATCH.name + metadata[Errors.Metadata.ETAG.key] = role.etag + metadata[Errors.Metadata.REQUEST_ETAG.key] = request.etag + } + ) + } + + @Test + fun `updateRole throws FAILED_PRECONDITION if resource type not found in new Permission`() = + runBlocking { + val service = initService() + val role = + service.createRole( + role { + roleResourceId = "bookUser" + resourceTypes += TestConfig.ResourceType.BOOK + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_DELETE + } + ) + + val exception = + assertFailsWith { + service.updateRole( + role.copy { permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_CREATE } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.RESOURCE_TYPE_NOT_FOUND_IN_PERMISSION.name + metadata[Errors.Metadata.RESOURCE_TYPE.key] = TestConfig.ResourceType.BOOK + metadata[Errors.Metadata.PERMISSION_RESOURCE_ID.key] = + TestConfig.PermissionResourceId.BOOKS_CREATE + } + ) + } + + @Test + fun `updateRole throws FAILED_PRECONDITION if resource type not found in existing Permission`() = + runBlocking { + val service = initService() + val role = + service.createRole( + role { + roleResourceId = "bookAdmin" + resourceTypes += TestConfig.ResourceType.SHELF + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_CREATE + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_DELETE + } + ) + + val exception = + assertFailsWith { + service.updateRole(role.copy { resourceTypes += TestConfig.ResourceType.BOOK }) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.RESOURCE_TYPE_NOT_FOUND_IN_PERMISSION.name + metadata[Errors.Metadata.RESOURCE_TYPE.key] = TestConfig.ResourceType.BOOK + metadata[Errors.Metadata.PERMISSION_RESOURCE_ID.key] = + TestConfig.PermissionResourceId.BOOKS_CREATE + } + ) + } + + @Test + fun `deleteRole deletes Role`() = runBlocking { + val service = initService() + val role = + service.createRole( + role { + roleResourceId = "bookReader" + resourceTypes += TestConfig.ResourceType.BOOK + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + } + ) + + service.deleteRole(deleteRoleRequest { roleResourceId = role.roleResourceId }) + + val exception = + assertFailsWith { + service.getRole(getRoleRequest { roleResourceId = role.roleResourceId }) + } + assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) + } + + @Test + fun `deleteRole throws NOT_FOUND when role not found`() = runBlocking { + val service = initService() + val request = deleteRoleRequest { roleResourceId = "not-found" } + + val exception = assertFailsWith { service.deleteRole(request) } + + assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) + assertThat(exception.errorInfo) + .isEqualTo( + errorInfo { + domain = Errors.DOMAIN + reason = Errors.Reason.ROLE_NOT_FOUND.name + metadata[Errors.Metadata.ROLE_RESOURCE_ID.key] = request.roleResourceId + } + ) + } + + @Test + fun `listRoles returns roles ordered by resource ID`() = runBlocking { + val service = initService() + val roles: List = createRoles(service, 10) + + val response: ListRolesResponse = service.listRoles(ListRolesRequest.getDefaultInstance()) + + assertThat(response).isEqualTo(listRolesResponse { this.roles += roles }) + } + + @Test + fun `listRoles returns roles when page size is specified`() = runBlocking { + val service = initService() + val roles: List = createRoles(service, 10) + + val response: ListRolesResponse = service.listRoles(listRolesRequest { pageSize = 10 }) + + assertThat(response).isEqualTo(listRolesResponse { this.roles += roles }) + } + + @Test + fun `listRoles returns next page token when there are more results`() = runBlocking { + val service = initService() + val roles: List = createRoles(service, 10) + + val request = listRolesRequest { pageSize = 5 } + val response: ListRolesResponse = service.listRoles(request) + + assertThat(response) + .isEqualTo( + listRolesResponse { + this.roles += roles.take(request.pageSize) + nextPageToken = listRolesPageToken { + after = ListRolesPageTokenKt.after { roleResourceId = "role-0000000005" } + } + } + ) + } + + @Test + fun `listRoles returns results after page token`() = runBlocking { + val service = initService() + val roles: List = createRoles(service, 10) + + val request = listRolesRequest { + pageSize = 2 + pageToken = listRolesPageToken { + after = ListRolesPageTokenKt.after { roleResourceId = "role-0000000005" } + } + } + val response: ListRolesResponse = service.listRoles(request) + + assertThat(response) + .isEqualTo( + listRolesResponse { + this.roles += roles.subList(5, 7) + nextPageToken = listRolesPageToken { + after = ListRolesPageTokenKt.after { roleResourceId = "role-0000000007" } + } + } + ) + } + + private suspend fun createRoles( + service: RolesGrpcKt.RolesCoroutineImplBase, + count: Int, + ): List { + return (1..count).map { + service.createRole( + role { + roleResourceId = String.format("role-%010d", it) + resourceTypes += TestConfig.ResourceType.BOOK + permissionResourceIds += TestConfig.PermissionResourceId.BOOKS_GET + } + ) + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/TestConfig.kt b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/TestConfig.kt new file mode 100644 index 00000000000..79339176de0 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/access/service/internal/testing/TestConfig.kt @@ -0,0 +1,99 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.service.internal.testing + +import com.google.protobuf.ByteString +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.common.byteStringOf +import org.wfanet.measurement.config.AuthorityKeyToPrincipalMapKt +import org.wfanet.measurement.config.access.PermissionsConfigKt.permission +import org.wfanet.measurement.config.access.permissionsConfig +import org.wfanet.measurement.config.authorityKeyToPrincipalMap + +object TestConfig { + val AUTHORITY_KEY_IDENTIFIER: ByteString = + byteStringOf( + 0x7C, + 0xE6, + 0x3F, + 0xEA, + 0x65, + 0xED, + 0x71, + 0x3D, + 0x9E, + 0x59, + 0x79, + 0xA0, + 0xC8, + 0x08, + 0xC9, + 0x57, + 0xAA, + 0xC6, + 0xB1, + 0x6A, + ) + + const val TLS_CLIENT_PROTECTED_RESOURCE_NAME = "shelves/fantasy" + const val TLS_CLIENT_PRINCIPAL_RESOURCE_ID = "shelves-fantasy" + + val TLS_CLIENT_MAPPING = + TlsClientPrincipalMapping( + authorityKeyToPrincipalMap { + entries += + AuthorityKeyToPrincipalMapKt.entry { + authorityKeyIdentifier = AUTHORITY_KEY_IDENTIFIER + principalResourceName = TLS_CLIENT_PROTECTED_RESOURCE_NAME + } + } + ) + + object PermissionResourceId { + const val BOOKS_GET = "books.get" + const val BOOKS_LIST = "books.list" + const val BOOKS_CREATE = "books.create" + const val BOOKS_DELETE = "books.delete" + } + + object ResourceType { + private const val DOMAIN = "library.googleapis.com" + const val BOOK = "$DOMAIN/Book" + const val SHELF = "$DOMAIN/Shelf" + } + + val PERMISSION_MAPPING = + PermissionMapping( + permissionsConfig { + permissions[PermissionResourceId.BOOKS_GET] = permission { + protectedResourceTypes += ResourceType.SHELF + protectedResourceTypes += ResourceType.BOOK + } + permissions[PermissionResourceId.BOOKS_DELETE] = permission { + protectedResourceTypes += ResourceType.SHELF + protectedResourceTypes += ResourceType.BOOK + } + permissions[PermissionResourceId.BOOKS_LIST] = permission { + protectedResourceTypes += ResourceType.SHELF + } + permissions[PermissionResourceId.BOOKS_CREATE] = permission { + protectedResourceTypes += ResourceType.SHELF + } + } + ) +} diff --git a/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/BUILD.bazel index bbe08b24e41..6474ad778fc 100644 --- a/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/BUILD.bazel @@ -13,3 +13,56 @@ spanner_emulator_test( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner/testing", ], ) + +spanner_emulator_test( + name = "SpannerPrincipalsServiceTest", + srcs = ["SpannerPrincipalsServiceTest.kt"], + test_class = "org.wfanet.measurement.access.deploy.gcloud.spanner.SpannerPrincipalsServiceTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_principals_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/testing:schemata", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal/testing:principals_service_test", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner/testing", + ], +) + +spanner_emulator_test( + name = "SpannerPermissionsServiceTest", + srcs = ["SpannerPermissionsServiceTest.kt"], + test_class = "org.wfanet.measurement.access.deploy.gcloud.spanner.SpannerPermissionsServiceTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_permissions_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_policies_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_principals_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_roles_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/testing:schemata", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal/testing:permissions_service_test", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner/testing", + ], +) + +spanner_emulator_test( + name = "SpannerRolesServiceTest", + srcs = ["SpannerRolesServiceTest.kt"], + test_class = "org.wfanet.measurement.access.deploy.gcloud.spanner.SpannerRolesServiceTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_roles_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/testing:schemata", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal/testing:roles_service_test", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner/testing", + ], +) + +spanner_emulator_test( + name = "SpannerPoliciesServiceTest", + srcs = ["SpannerPoliciesServiceTest.kt"], + test_class = "org.wfanet.measurement.access.deploy.gcloud.spanner.SpannerPoliciesServiceTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_policies_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_principals_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner:spanner_roles_service", + "//src/main/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/testing:schemata", + "//src/main/kotlin/org/wfanet/measurement/access/service/internal/testing:policies_service_test", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/gcloud/spanner/testing", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPermissionsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPermissionsServiceTest.kt new file mode 100644 index 00000000000..35b5446b334 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPermissionsServiceTest.kt @@ -0,0 +1,51 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import org.junit.ClassRule +import org.junit.Rule +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.deploy.gcloud.spanner.testing.Schemata +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.access.service.internal.testing.PermissionsServiceTest +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorDatabaseRule +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorRule + +class SpannerPermissionsServiceTest : PermissionsServiceTest() { + @get:Rule + val spannerDatabase = SpannerEmulatorDatabaseRule(spannerEmulator, Schemata.ACCESS_CHANGELOG_PATH) + + override fun initServices( + permissionMapping: PermissionMapping, + tlsClientMapping: TlsClientPrincipalMapping, + idGenerator: IdGenerator, + ): Services { + val databaseClient: AsyncDatabaseClient = spannerDatabase.databaseClient + return Services( + SpannerPermissionsService(databaseClient, permissionMapping, tlsClientMapping), + SpannerPrincipalsService(databaseClient, tlsClientMapping, idGenerator), + SpannerRolesService(databaseClient, permissionMapping, idGenerator), + SpannerPoliciesService(databaseClient, tlsClientMapping, idGenerator), + ) + } + + companion object { + @get:ClassRule @JvmStatic val spannerEmulator = SpannerEmulatorRule() + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPoliciesServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPoliciesServiceTest.kt new file mode 100644 index 00000000000..db8ee2d8f1e --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPoliciesServiceTest.kt @@ -0,0 +1,50 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import org.junit.ClassRule +import org.junit.Rule +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.deploy.gcloud.spanner.testing.Schemata +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.access.service.internal.testing.PoliciesServiceTest +import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorDatabaseRule +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorRule + +class SpannerPoliciesServiceTest : PoliciesServiceTest() { + @get:Rule + val spannerDatabase = SpannerEmulatorDatabaseRule(spannerEmulator, Schemata.ACCESS_CHANGELOG_PATH) + + override fun initServices( + permissionMapping: PermissionMapping, + tlsClientMapping: TlsClientPrincipalMapping, + idGenerator: IdGenerator, + ): Services { + val databaseClient: AsyncDatabaseClient = spannerDatabase.databaseClient + return Services( + SpannerPoliciesService(databaseClient, tlsClientMapping, idGenerator), + SpannerPrincipalsService(databaseClient, tlsClientMapping, idGenerator), + SpannerRolesService(databaseClient, permissionMapping, idGenerator), + ) + } + + companion object { + @get:ClassRule @JvmStatic val spannerEmulator = SpannerEmulatorRule() + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPrincipalsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPrincipalsServiceTest.kt new file mode 100644 index 00000000000..7896815e0d1 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerPrincipalsServiceTest.kt @@ -0,0 +1,38 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import org.junit.ClassRule +import org.junit.Rule +import org.wfanet.measurement.access.common.TlsClientPrincipalMapping +import org.wfanet.measurement.access.deploy.gcloud.spanner.testing.Schemata +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.testing.PrincipalsServiceTest +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorDatabaseRule +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorRule + +class SpannerPrincipalsServiceTest : PrincipalsServiceTest() { + @get:Rule + val spannerDatabase = SpannerEmulatorDatabaseRule(spannerEmulator, Schemata.ACCESS_CHANGELOG_PATH) + + override fun initService(tlsClientMapping: TlsClientPrincipalMapping, idGenerator: IdGenerator) = + SpannerPrincipalsService(spannerDatabase.databaseClient, tlsClientMapping, idGenerator) + + companion object { + @get:ClassRule @JvmStatic val spannerEmulator = SpannerEmulatorRule() + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerRolesServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerRolesServiceTest.kt new file mode 100644 index 00000000000..c2e14fd0d50 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/access/deploy/gcloud/spanner/SpannerRolesServiceTest.kt @@ -0,0 +1,38 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.access.deploy.gcloud.spanner + +import org.junit.ClassRule +import org.junit.Rule +import org.wfanet.measurement.access.deploy.gcloud.spanner.testing.Schemata +import org.wfanet.measurement.access.service.internal.IdGenerator +import org.wfanet.measurement.access.service.internal.PermissionMapping +import org.wfanet.measurement.access.service.internal.testing.RolesServiceTest +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorDatabaseRule +import org.wfanet.measurement.gcloud.spanner.testing.SpannerEmulatorRule + +class SpannerRolesServiceTest : RolesServiceTest() { + @get:Rule + val spannerDatabase = SpannerEmulatorDatabaseRule(spannerEmulator, Schemata.ACCESS_CHANGELOG_PATH) + + override fun initService(permissionMapping: PermissionMapping, idGenerator: IdGenerator) = + SpannerRolesService(spannerDatabase.databaseClient, permissionMapping, idGenerator) + + companion object { + @get:ClassRule @JvmStatic val spannerEmulator = SpannerEmulatorRule() + } +}