Skip to content

Commit

Permalink
[DPP-628] Self service error codes in time service
Browse files Browse the repository at this point in the history
CHANGELOG_BEGIN
CHANGELOG_END
  • Loading branch information
pbatko-da committed Oct 19, 2021
1 parent 9b00a1a commit 8e0ab13
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import com.daml.ledger.api.validation.CommandsValidator.{Submitters, effectiveSu
import com.daml.lf.command._
import com.daml.lf.data._
import com.daml.lf.value.{Value => Lf}
import com.daml.platform.server.api.validation.ErrorFactories._
import com.daml.platform.server.api.validation.ErrorFactories.Default._
import com.daml.platform.server.api.validation.FieldValidations.{requirePresence, _}
import io.grpc.StatusRuntimeException
import scalaz.syntax.tag._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import com.daml.ledger.api.v1.command_completion_service.{
}
import com.daml.platform.server.api.validation.FieldValidations._
import io.grpc.StatusRuntimeException
import com.daml.platform.server.api.validation.ErrorFactories._
import com.daml.platform.server.api.validation.ErrorFactories.Default._

class CompletionServiceRequestValidator(ledgerId: LedgerId, partyNameChecker: PartyNameChecker) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.daml.error.ContextualizedErrorLogger
import com.daml.ledger.api.domain
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset.LedgerBoundary
import com.daml.platform.server.api.validation.ErrorFactories.{
import com.daml.platform.server.api.validation.ErrorFactories.Default.{
invalidArgument,
missingField,
offsetAfterLedgerEnd,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.daml.ledger.api.validation

import com.daml.error.ContextualizedErrorLogger
import com.daml.lf.data.Ref.Party
import com.daml.platform.server.api.validation.ErrorFactories.invalidArgument
import com.daml.platform.server.api.validation.ErrorFactories.Default.invalidArgument
import com.daml.platform.server.api.validation.FieldValidations.requireParties
import io.grpc.StatusRuntimeException

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ object TransactionFilterValidator {
contextualizedErrorLogger: ContextualizedErrorLogger
): Either[StatusRuntimeException, domain.TransactionFilter] = {
if (txFilter.filtersByParty.isEmpty) {
Left(ErrorFactories.invalidArgument(None)("filtersByParty cannot be empty"))
Left(ErrorFactories.Default.invalidArgument(None)("filtersByParty cannot be empty"))
} else {
val convertedFilters =
txFilter.filtersByParty.toList.traverse { case (k, v) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import com.daml.ledger.api.v1.transaction_service.{
GetTransactionByIdRequest,
GetTransactionsRequest,
}
import com.daml.platform.server.api.validation.ErrorFactories._
import com.daml.platform.server.api.validation.ErrorFactories.Default._
import com.daml.platform.server.api.validation.FieldValidations._
import io.grpc.StatusRuntimeException

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.daml.ledger.api.domain
import com.daml.ledger.api.v1.value.Value.Sum
import com.daml.ledger.api.v1.{value => api}
import com.daml.lf.value.{Value => Lf}
import com.daml.platform.server.api.validation.ErrorFactories._
import com.daml.platform.server.api.validation.ErrorFactories.Default._
import com.daml.platform.server.api.validation.FieldValidations.{requirePresence, _}
import io.grpc.StatusRuntimeException

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ class ErrorFactories private (errorCodesVersionSwitcher: ErrorCodesVersionSwitch
* TODO error codes: Remove default implementation once all Ledger API services
* output versioned error codes.
*/
object ErrorFactories extends ErrorFactories(new ErrorCodesVersionSwitcher(false)) {
object ErrorFactories {

val Default: ErrorFactories = apply(new ErrorCodesVersionSwitcher(false))

def apply(errorCodesVersionSwitcher: ErrorCodesVersionSwitcher): ErrorFactories =
new ErrorFactories(errorCodesVersionSwitcher)

Expand All @@ -298,4 +301,5 @@ object ErrorFactories extends ErrorFactories(new ErrorCodesVersionSwitcher(false
statusBuilder.addDetails(definiteAnswers(definiteAnswer))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class FieldValidations private (errorFactories: ErrorFactories) {
/** Default implementation exposing field validations with the legacy error factories.
* TODO error codes: Remove default implementation once all consumers output versioned error codes.
*/
object FieldValidations extends FieldValidations(ErrorFactories) {
object FieldValidations extends FieldValidations(ErrorFactories.Default) {
def apply(errorFactories: ErrorFactories): FieldValidations =
new FieldValidations(errorFactories)
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class ErrorFactoriesSpec extends AnyWordSpec with Matchers with TableDrivenPrope
)

forEvery(testCases) { (definiteAnswer, expectedDetails) =>
val exception = aborted("my message", definiteAnswer)
val exception = ErrorFactories.Default.aborted("my message", definiteAnswer)
val status = StatusProto.fromThrowable(exception)
status.getCode shouldBe Code.ABORTED.value()
status.getMessage shouldBe "my message"
Expand Down Expand Up @@ -142,7 +142,7 @@ class ErrorFactoriesSpec extends AnyWordSpec with Matchers with TableDrivenPrope
}

"fail on creating a ledgerIdMismatch error due to a wrong definite answer" in {
an[IllegalArgumentException] should be thrownBy ledgerIdMismatch(
an[IllegalArgumentException] should be thrownBy ErrorFactories.Default.ledgerIdMismatch(
LedgerId("expected"),
LedgerId("received"),
definiteAnswer = Some(true),
Expand Down Expand Up @@ -238,7 +238,7 @@ class ErrorFactoriesSpec extends AnyWordSpec with Matchers with TableDrivenPrope

"should create an ApiException without the stack trace" in {
val status = Status.newBuilder().setCode(Code.INTERNAL.value()).build()
val exception = grpcError(status)
val exception = ErrorFactories.Default.grpcError(status)
exception.getStackTrace shouldBe Array.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ private[daml] object ApiServices {

val apiTimeServiceOpt =
optTimeServiceBackend.map(tsb =>
new TimeServiceAuthorization(ApiTimeService.create(ledgerId, tsb), authorizer)
new TimeServiceAuthorization(
ApiTimeService.create(ledgerId, tsb, errorsVersionsSwitcher),
authorizer,
)
)
val writeServiceBackedApiServices =
intitializeWriteServiceBackedApiServices(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.api.util.TimestampConversion._
import com.daml.error.{ContextualizedErrorLogger, DamlContextualizedErrorLogger}
import com.daml.error.{
ContextualizedErrorLogger,
DamlContextualizedErrorLogger,
ErrorCodesVersionSwitcher,
}
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.ledger.api.domain.LedgerId
import com.daml.ledger.api.v1.testing.time_service.TimeServiceGrpc.TimeService
Expand All @@ -17,8 +21,7 @@ import com.daml.platform.akkastreams.dispatcher.SignalDispatcher
import com.daml.platform.api.grpc.GrpcApiService
import com.daml.platform.apiserver.TimeServiceBackend
import com.daml.platform.server.api.ValidationLogger
import com.daml.platform.server.api.validation.ErrorFactories
import com.daml.platform.server.api.validation.FieldValidations._
import com.daml.platform.server.api.validation.{ErrorFactories, FieldValidations}
import com.google.protobuf.empty.Empty
import io.grpc.{ServerServiceDefinition, StatusRuntimeException}
import scalaz.syntax.tag._
Expand All @@ -29,6 +32,7 @@ import scala.concurrent.{ExecutionContext, Future}
private[apiserver] final class ApiTimeService private (
val ledgerId: LedgerId,
backend: TimeServiceBackend,
errorCodesVersionSwitcher: ErrorCodesVersionSwitcher,
)(implicit
protected val mat: Materializer,
protected val esf: ExecutionSequencerFactory,
Expand All @@ -45,27 +49,32 @@ private[apiserver] final class ApiTimeService private (
s"${getClass.getSimpleName} initialized with ledger ID ${ledgerId.unwrap}, start time ${backend.getCurrentTime}"
)

private val errorFactories = ErrorFactories(errorCodesVersionSwitcher)
private val fieldValidation = FieldValidations(errorFactories)

private val dispatcher = SignalDispatcher[Instant]()

override protected def getTimeSource(request: GetTimeRequest): Source[GetTimeResponse, NotUsed] =
matchLedgerId(ledgerId)(LedgerId(request.ledgerId)).fold(
t => Source.failed(ValidationLogger.logFailureWithContext(request, t)),
{ ledgerId =>
logger.info(s"Received request for time with ledger ID $ledgerId")
dispatcher
.subscribe()
.map(_ => backend.getCurrentTime)
.scan[Option[Instant]](Some(backend.getCurrentTime)) {
case (Some(previousTime), currentTime) if previousTime == currentTime => None
case (_, currentTime) => Some(currentTime)
}
.mapConcat {
case None => Nil
case Some(t) => List(GetTimeResponse(Some(fromInstant(t))))
}
.via(logger.logErrorsOnStream)
},
)
fieldValidation
.matchLedgerId(ledgerId)(LedgerId(request.ledgerId))
.fold(
t => Source.failed(ValidationLogger.logFailureWithContext(request, t)),
{ ledgerId =>
logger.info(s"Received request for time with ledger ID $ledgerId")
dispatcher
.subscribe()
.map(_ => backend.getCurrentTime)
.scan[Option[Instant]](Some(backend.getCurrentTime)) {
case (Some(previousTime), currentTime) if previousTime == currentTime => None
case (_, currentTime) => Some(currentTime)
}
.mapConcat {
case None => Nil
case Some(t) => List(GetTimeResponse(Some(fromInstant(t))))
}
.via(logger.logErrorsOnStream)
},
)

@SuppressWarnings(Array("org.wartremover.warts.JavaSerializable"))
override def setTime(request: SetTimeRequest): Future[Empty] = {
Expand All @@ -80,23 +89,25 @@ private[apiserver] final class ApiTimeService private (
if (success) Right(requestedTime)
else
Left(
ErrorFactories.invalidArgument(None)(
errorFactories.invalidArgument(None)(
s"current_time mismatch. Provided: $expectedTime. Actual: ${backend.getCurrentTime}"
)
)
)
}

val result = for {
_ <- matchLedgerId(ledgerId)(LedgerId(request.ledgerId))
expectedTime <- requirePresence(request.currentTime, "current_time").map(toInstant)
requestedTime <- requirePresence(request.newTime, "new_time").map(toInstant)
_ <- fieldValidation.matchLedgerId(ledgerId)(LedgerId(request.ledgerId))
expectedTime <- fieldValidation
.requirePresence(request.currentTime, "current_time")
.map(toInstant)
requestedTime <- fieldValidation.requirePresence(request.newTime, "new_time").map(toInstant)
_ <- {
if (!requestedTime.isBefore(expectedTime))
Right(())
else
Left(
ErrorFactories.invalidArgument(None)(
errorFactories.invalidArgument(None)(
s"new_time [$requestedTime] is before current_time [$expectedTime]. Setting time backwards is not allowed."
)
)
Expand Down Expand Up @@ -131,11 +142,15 @@ private[apiserver] final class ApiTimeService private (
}

private[apiserver] object ApiTimeService {
def create(ledgerId: LedgerId, backend: TimeServiceBackend)(implicit
def create(
ledgerId: LedgerId,
backend: TimeServiceBackend,
errorCodesVersionSwitcher: ErrorCodesVersionSwitcher,
)(implicit
mat: Materializer,
esf: ExecutionSequencerFactory,
executionContext: ExecutionContext,
loggingContext: LoggingContext,
): TimeService with GrpcApiService =
new ApiTimeService(ledgerId, backend)
new ApiTimeService(ledgerId, backend, errorCodesVersionSwitcher)
}

0 comments on commit 8e0ab13

Please sign in to comment.