From a59b9e8bb627949237c94104068087bd4e865768 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Mon, 24 Apr 2023 19:20:04 +0000 Subject: [PATCH] Add more info to different error status. --- .../service/api/v2alpha/MetricsService.kt | 93 +++++++++++++++---- .../service/api/v2alpha/MetricsServiceTest.kt | 17 ++-- 2 files changed, 84 insertions(+), 26 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 4f9bf3bb3b8..3aa7ee4d953 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -298,7 +298,21 @@ class MetricsService( .withAuthenticationKey(principal.config.apiKey) .createMeasurement(createMeasurementRequest) } catch (e: StatusException) { - throw Exception("Unable to create a CMMS measurement.", e) + throw when (e.status.code) { + Status.Code.INVALID_ARGUMENT -> + Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Cannot create a CMMS Measurement for another MeasurementConsumer." + ) + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription("Failed precondition.") + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${measurementConsumer.name} is not found.") + else -> Status.UNKNOWN.withDescription("Unable to create a CMMS measurement.") + } + .withCause(e) + .asRuntimeException() } } @@ -409,8 +423,8 @@ class MetricsService( } catch (e: StatusException) { throw when (e.status.code) { Status.Code.NOT_FOUND -> - Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found") - else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName") + Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found.") + else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName.") } .withCause(e) .asRuntimeException() @@ -422,7 +436,16 @@ class MetricsService( .withAuthenticationKey(apiAuthenticationKey) .getCertificate(getCertificateRequest { name = dataProvider.certificate }) } catch (e: StatusException) { - throw Exception("Unable to retrieve Certificate ${dataProvider.certificate}", e) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${dataProvider.certificate} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve Certificate ${dataProvider.certificate}." + ) + } + .withCause(e) + .asRuntimeException() } if ( certificate.revocationState != Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED @@ -539,10 +562,16 @@ class MetricsService( getMeasurementConsumerRequest { name = principal.resourceKey.toName() } ) } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the measurement consumer " + "[${principal.resourceKey.toName()}].", - e - ) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}]." + ) + } + .withCause(e) + .asRuntimeException() } } @@ -584,11 +613,19 @@ class MetricsService( .getCertificate(getCertificateRequest { name = principal.config.signingCertificateName }) .x509Der } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the signing certificate for the measurement consumer " + - "[$principal.config.signingCertificateName].", - e - ) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription( + "${principal.config.signingCertificateName} not found." + ) + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the signing certificate " + + "[${principal.config.signingCertificateName}] for the measurement consumer." + ) + } + .withCause(e) + .asRuntimeException() } } @@ -724,7 +761,20 @@ class MetricsService( .withAuthenticationKey(apiAuthenticationKey) .getMeasurement(getMeasurementRequest { name = measurementResourceName }) } catch (e: StatusException) { - throw Exception("Unable to retrieve the measurement [$measurementResourceName].", e) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("$measurementResourceName not found.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Doesn't have permission to get $measurementResourceName." + ) + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the measurement [$measurementResourceName]." + ) + } + .withCause(e) + .asRuntimeException() } } } @@ -770,10 +820,17 @@ class MetricsService( .withAuthenticationKey(apiAuthenticationKey) .getCertificate(getCertificateRequest { name = measurementResultPair.certificate }) } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the certificate [${measurementResultPair.certificate}].", - e - ) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${measurementResultPair.certificate} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the certificate " + + "[${measurementResultPair.certificate}] for the measurement consumer." + ) + } + .withCause(e) + .asRuntimeException() } val signedResult = diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index 89b06c42caf..ff5c4b01747 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -102,6 +102,7 @@ import org.wfanet.measurement.common.crypto.subjectKeyIdentifier import org.wfanet.measurement.common.crypto.testing.loadSigningKey import org.wfanet.measurement.common.crypto.tink.loadPrivateKey import org.wfanet.measurement.common.getRuntimePath +import org.wfanet.measurement.common.grpc.grpcStatusCode import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.ExternalId @@ -2737,7 +2738,7 @@ class MetricsServiceTest { } @Test - fun `createMetric throws exception when the CMMs createMeasurement throws exception`() = + fun `createMetric throws exception when the CMMs createMeasurement throws INVALID_ARGUMENT`() = runBlocking { whenever(measurementsMock.createMeasurement(any())) .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) @@ -2753,8 +2754,9 @@ class MetricsServiceTest { runBlocking { service.createMetric(request) } } } - val expectedExceptionDescription = "Unable to create a CMMS measurement." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + assertThat(exception.grpcStatusCode()).isEqualTo(Status.Code.INVALID_ARGUMENT) + val expectedExceptionDescription = "Required field unspecified or invalid." + assertThat(exception.message).contains(expectedExceptionDescription) } @Test @@ -2780,9 +2782,9 @@ class MetricsServiceTest { } @Test - fun `createMetric throws exception when getMeasurementConsumer throws exception`() = runBlocking { + fun `createMetric throws exception when getMeasurementConsumer throws NOT_FOUND`() = runBlocking { whenever(measurementConsumersMock.getMeasurementConsumer(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) + .thenThrow(StatusRuntimeException(Status.NOT_FOUND)) val request = createMetricRequest { parent = MEASUREMENT_CONSUMERS.values.first().name @@ -2795,9 +2797,8 @@ class MetricsServiceTest { runBlocking { service.createMetric(request) } } } - val expectedExceptionDescription = - "Unable to retrieve the measurement consumer [${MEASUREMENT_CONSUMERS.values.first().name}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + assertThat(exception.grpcStatusCode()).isEqualTo(Status.Code.NOT_FOUND) + assertThat(exception.message).contains(MEASUREMENT_CONSUMERS.values.first().name) } @Test