Skip to content

Commit

Permalink
Add exception infrastructure to Scala impl of DAML-LF AST
Browse files Browse the repository at this point in the history
This advances states of #8020.

CHANGELOG_BEGIN
CHANGELOG_END
  • Loading branch information
hurryabit authored and remyhaemmerle-da committed Dec 16, 2020
1 parent 78a468d commit 3ce8659
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
private def decodeModuleWithName(lfModule: PLF.Module, moduleName: ModuleName) = {
val defs = mutable.ArrayBuffer[(DottedName, Definition)]()
val templates = mutable.ArrayBuffer[(DottedName, Template)]()
val exceptions = mutable.ArrayBuffer[(DottedName, Unit)]()

if (versionIsOlderThan(LV.Features.typeSynonyms)) {
assertEmpty(lfModule.getSynonymsList, "Module.synonyms")
Expand Down Expand Up @@ -265,7 +266,17 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
templates += ((defName, decodeTemplate(defn)))
}

Module(moduleName, defs, templates, decodeFeatureFlags(lfModule.getFlags))
if (versionIsOlderThan(LV.Features.exceptions)) {
assertEmpty(lfModule.getExceptionsList, "Module.exceptions")
} else if (!onlySerializableDataDefs) {
lfModule.getExceptionsList.asScala
.foreach { defn =>
val defName = getInternedDottedName(defn.getNameInternedDname)
exceptions += ((defName, ()))
}
}

Module(moduleName, defs, templates, exceptions, decodeFeatureFlags(lfModule.getFlags))
}

// -----------------------------------------------------------------------
Expand Down Expand Up @@ -987,13 +998,21 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
ETypeRep(decodeType(lfExpr.getTypeRep))

case PLF.Expr.SumCase.TO_ANY_EXCEPTION =>
throw ParseError("Expr.TO_ANY_EXCEPTION") // TODO https://github.com/digital-asset/daml/issues/8020
assertSince(LV.Features.exceptions, "Expr.to_any_exception")
val makeAnyException = lfExpr.getMakeAnyException
EMakeAnyException(
typ = decodeType(makeAnyException.getType),
message = decodeExpr(makeAnyException.getMessage, definition),
value = decodeExpr(makeAnyException.getExpr, definition),
)

case PLF.Expr.SumCase.FROM_ANY_EXCEPTION =>
throw ParseError("Expr.FROM_ANY_EXCEPTION") // TODO https://github.com/digital-asset/daml/issues/8020

case PLF.Expr.SumCase.THROW =>
throw ParseError("Expr.THROW") // TODO https://github.com/digital-asset/daml/issues/8020
assertSince(LV.Features.exceptions, "Expr.from_any_exception")
val fromAnyException = lfExpr.getFromAnyException
EFromAnyException(
typ = decodeType(fromAnyException.getType),
value = decodeExpr(fromAnyException.getExpr, definition),
)

case PLF.Expr.SumCase.SUM_NOT_SET =>
throw ParseError("Expr.SUM_NOT_SET")
Expand Down Expand Up @@ -1160,7 +1179,14 @@ private[archive] class DecodeV1(minor: LV.Minor) extends Decode.OfPackage[PLF.Pa
UpdateEmbedExpr(decodeType(embedExpr.getType), decodeExpr(embedExpr.getBody, definition))

case PLF.Update.SumCase.TRY_CATCH =>
throw ParseError("Update.TRY_CATCH") // TODO #8020
assertSince(LV.Features.exceptions, "Update.try_catch")
val tryCatch = lfUpdate.getTryCatch
UpdateTryCatch(
typ = decodeType(tryCatch.getReturnType),
body = decodeExpr(tryCatch.getTryExpr, definition),
binder = toName(internedStrings(tryCatch.getVarInternedStr)),
handler = decodeExpr(tryCatch.getCatchExpr, definition),
)

case PLF.Update.SumCase.SUM_NOT_SET =>
throw ParseError("Update.SUM_NOT_SET")
Expand Down Expand Up @@ -1376,11 +1402,10 @@ private[lf] object DecodeV1 {
BuiltinTypeInfo(NUMERIC, BTNumeric, minVersion = numeric),
BuiltinTypeInfo(ANY, BTAny, minVersion = anyType),
BuiltinTypeInfo(TYPE_REP, BTTypeRep, minVersion = typeRep),
// FIXME: https://github.com/digital-asset/daml/issues/8020
// BuiltinTypeInfo(ANY_EXCEPTION, ???, minVersion = exceptions),
// BuiltinTypeInfo(GENERAL_ERROR, ???, minVersion = exceptions),
// BuiltinTypeInfo(ARITHMETIC_ERROR, ???, minVersion = exceptions),
// BuiltinTypeInfo(CONTRACT_ERROR, ???, minVersion = exceptions)
BuiltinTypeInfo(ANY_EXCEPTION, BTAnyException, minVersion = exceptions),
BuiltinTypeInfo(GENERAL_ERROR, BTGeneralError, minVersion = exceptions),
BuiltinTypeInfo(ARITHMETIC_ERROR, BTArithmeticError, minVersion = exceptions),
BuiltinTypeInfo(CONTRACT_ERROR, BTContractError, minVersion = exceptions)
)
}

Expand Down Expand Up @@ -1700,13 +1725,17 @@ private[lf] object DecodeV1 {
BuiltinFunctionInfo(EQUAL_CONTRACT_ID, BEqualContractId, maxVersion = Some(genMap)),
BuiltinFunctionInfo(TRACE, BTrace),
BuiltinFunctionInfo(COERCE_CONTRACT_ID, BCoerceContractId),
BuiltinFunctionInfo(MAKE_GENERAL_ERROR, BTextToUpper, minVersion = exceptions), // TODO #8020
BuiltinFunctionInfo(MAKE_ARITHMETIC_ERROR, BTextToUpper, minVersion = exceptions), // TODO #8020
BuiltinFunctionInfo(MAKE_CONTRACT_ERROR, BTextToUpper, minVersion = exceptions), // TODO #8020
BuiltinFunctionInfo(ANY_EXCEPTION_MESSAGE, BTextToUpper, minVersion = exceptions), // TODO #8020
BuiltinFunctionInfo(GENERAL_ERROR_MESSAGE, BTextToUpper, minVersion = exceptions), // TODO #8020
BuiltinFunctionInfo(ARITHMETIC_ERROR_MESSAGE, BTextToUpper, minVersion = exceptions), // TODO #8020
BuiltinFunctionInfo(CONTRACT_ERROR_MESSAGE, BTextToUpper, minVersion = exceptions), // TODO #8020
BuiltinFunctionInfo(THROW, BThrow, minVersion = exceptions),
BuiltinFunctionInfo(MAKE_GENERAL_ERROR, BMakeGeneralError, minVersion = exceptions),
BuiltinFunctionInfo(MAKE_ARITHMETIC_ERROR, BMakeArithmeticError, minVersion = exceptions),
BuiltinFunctionInfo(MAKE_CONTRACT_ERROR, BMakeContractError, minVersion = exceptions),
BuiltinFunctionInfo(ANY_EXCEPTION_MESSAGE, BAnyExceptionMessage, minVersion = exceptions),
BuiltinFunctionInfo(GENERAL_ERROR_MESSAGE, BGeneralErrorMessage, minVersion = exceptions),
BuiltinFunctionInfo(
ARITHMETIC_ERROR_MESSAGE,
BArithmeticErrorMessage,
minVersion = exceptions),
BuiltinFunctionInfo(CONTRACT_ERROR_MESSAGE, BContractErrorMessage, minVersion = exceptions),
BuiltinFunctionInfo(TEXT_TO_UPPER, BTextToUpper, minVersion = unstable),
BuiltinFunctionInfo(TEXT_TO_LOWER, BTextToLower, minVersion = unstable),
BuiltinFunctionInfo(TEXT_SLICE, BTextSlice, minVersion = unstable),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ object InterfaceReader {
case Ast.BTOptional => \/-((1, PrimType.Optional))
case Ast.BTTextMap => \/-((1, PrimType.TextMap))
case Ast.BTGenMap => \/-((2, PrimType.GenMap))
case Ast.BTAnyException | Ast.BTGeneralError | Ast.BTArithmeticError | Ast.BTContractError =>
// TODO #8020 Add exception types to the interface reader
unserializableDataType(
ctx,
"Exception types are still under implementation, see issue #8020"
)
case Ast.BTNumeric =>
unserializableDataType(
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ object Ast {
/** Unique textual representation of template Id **/
final case class ETypeRep(typ: Type) extends Expr

/** Construct an AnyException from its message and payload */
final case class EMakeAnyException(typ: Type, message: Expr, value: Expr) extends Expr

/** Extract the payload from an AnyException if it matches the given exception type */
final case class EFromAnyException(typ: Type, value: Expr) extends Expr

//
// Kinds
//
Expand Down Expand Up @@ -292,6 +298,10 @@ object Ast {
case object BTArrow extends BuiltinType
case object BTAny extends BuiltinType
case object BTTypeRep extends BuiltinType
case object BTAnyException extends BuiltinType
case object BTGeneralError extends BuiltinType
case object BTArithmeticError extends BuiltinType
case object BTContractError extends BuiltinType

//
// Primitive literals
Expand Down Expand Up @@ -414,6 +424,16 @@ object Ast {

final case object BCoerceContractId extends BuiltinFunction(1) // : ∀a b. ContractId a -> ContractId b

// Exceptions
final case object BThrow extends BuiltinFunction(1) // : ∀a. AnyException -> a
final case object BMakeGeneralError extends BuiltinFunction(1) // Text -> GeneralError
final case object BMakeArithmeticError extends BuiltinFunction(1) // Text -> ArithmeticError
final case object BMakeContractError extends BuiltinFunction(1) // Text -> ContractError
final case object BAnyExceptionMessage extends BuiltinFunction(1) // AnyException -> Text
final case object BGeneralErrorMessage extends BuiltinFunction(1) // GeneralError -> Text
final case object BArithmeticErrorMessage extends BuiltinFunction(1) // ArithmeticError -> Text
final case object BContractErrorMessage extends BuiltinFunction(1) // ContractError -> Text

// Unstable Text Primitives
final case object BTextToUpper extends BuiltinFunction(1) // Text → Text
final case object BTextToLower extends BuiltinFunction(1) // : Text → Text
Expand Down Expand Up @@ -452,6 +472,12 @@ object Ast {
final case class UpdateFetchByKey(rbk: RetrieveByKey) extends Update
final case class UpdateLookupByKey(rbk: RetrieveByKey) extends Update
final case class UpdateEmbedExpr(typ: Type, body: Expr) extends Update
final case class UpdateTryCatch(
typ: Type,
body: Expr,
binder: ExprVarName,
handler: Expr,
) extends Update

//
// Scenario expressions
Expand Down Expand Up @@ -701,6 +727,7 @@ object Ast {
name: ModuleName,
definitions: Map[DottedName, GenDefinition[E]],
templates: Map[DottedName, GenTemplate[E]],
exceptions: Map[DottedName, Unit],
featureFlags: FeatureFlags
) extends NoCopy

Expand All @@ -713,6 +740,7 @@ object Ast {
name: ModuleName,
definitions: Iterable[(DottedName, GenDefinition[E])],
templates: Iterable[(DottedName, GenTemplate[E])],
exceptions: Iterable[(DottedName, Unit)],
featureFlags: FeatureFlags
): GenModule[E] = {

Expand All @@ -724,13 +752,18 @@ object Ast {
throw PackageError(s"Collision on template name ${templName.toString}")
}

new GenModule(name, definitions.toMap, templates.toMap, featureFlags)
findDuplicate(exceptions).foreach { exnName =>
throw PackageError(s"Collision on exception name ${exnName.toString}")
}

new GenModule(name, definitions.toMap, templates.toMap, exceptions.toMap, featureFlags)
}

def unapply(arg: GenModule[E]): Option[(
ModuleName,
Map[DottedName, GenDefinition[E]],
Map[DottedName, GenTemplate[E]],
Map[DottedName, Unit],
FeatureFlags)] =
GenModule.unapply(arg)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ object Util {
val TDecimalScale = TNat(Decimal.scale)
val TDecimal = TNumeric(TDecimalScale)

val TAnyException = TBuiltin(BTAnyException)
val TGeneralError = TBuiltin(BTGeneralError)
val TArithmeticError = TBuiltin(BTArithmeticError)
val TContractError = TBuiltin(BTContractError)

val EUnit = EPrimCon(PCUnit)
val ETrue = EPrimCon(PCTrue)
val EFalse = EPrimCon(PCFalse)
Expand Down Expand Up @@ -164,7 +169,7 @@ object Util {

private[this] def toSignature(module: Module): ModuleSignature =
module match {
case Module(name, definitions, templates, featureFlags) =>
case Module(name, definitions, templates, exceptions, featureFlags) =>
ModuleSignature(
name = name,
definitions = definitions.transform {
Expand All @@ -173,6 +178,7 @@ object Util {
case (_, typeSyn: DTypeSyn) => typeSyn
},
templates = templates.transform((_, template) => toSignature(template)),
exceptions = exceptions,
featureFlags = featureFlags,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ private[validation] object Serializability {
checkType(targ)
case TBuiltin(builtinType) =>
builtinType match {
case BTInt64 | BTText | BTTimestamp | BTDate | BTParty | BTBool | BTUnit =>
case BTInt64 | BTText | BTTimestamp | BTDate | BTParty | BTBool | BTUnit |
BTAnyException | BTGeneralError | BTArithmeticError | BTContractError => ()
case BTNumeric =>
unserializable(URNumeric)
case BTList =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ private[validation] object Typing {
}

private def kindOfBuiltin(bType: BuiltinType): Kind = bType match {
case BTInt64 | BTText | BTTimestamp | BTParty | BTBool | BTDate | BTUnit | BTAny | BTTypeRep =>
case BTInt64 | BTText | BTTimestamp | BTParty | BTBool | BTDate | BTUnit | BTAny | BTTypeRep |
BTAnyException | BTGeneralError | BTArithmeticError | BTContractError =>
KStar
case BTNumeric => KArrow(KNat, KStar)
case BTList | BTUpdate | BTScenario | BTContractId | BTOptional | BTTextMap =>
Expand Down Expand Up @@ -211,6 +212,15 @@ private[validation] object Typing {
TForall(
alpha.name -> KStar,
TForall(beta.name -> KStar, TContractId(alpha) ->: TContractId(beta))),
// Exception functions
BThrow -> TForall(alpha.name -> KStar, TAnyException ->: alpha),
BMakeGeneralError -> (TText ->: TGeneralError),
BMakeArithmeticError -> (TText ->: TArithmeticError),
BMakeContractError -> (TText ->: TContractError),
BAnyExceptionMessage -> (TAnyException ->: TText),
BGeneralErrorMessage -> (TGeneralError ->: TText),
BArithmeticErrorMessage -> (TArithmeticError ->: TText),
BContractErrorMessage -> (TContractError ->: TText),
// Unstable text functions
BTextToUpper -> (TText ->: TText),
BTextToLower -> (TText ->: TText),
Expand Down Expand Up @@ -266,6 +276,15 @@ private[validation] object Typing {
throw EExpectedTemplatableType(env.ctx, tyConName)
}
}
mod.exceptions.foreach {
case (exnName, ()) =>
val tyConName = TypeConName(pkgId, QualifiedName(mod.name, exnName))
val ctx = ContextDefException(tyConName)
world.lookupDataType(ctx, tyConName) match {
case DDataType(_, ImmArray(), DataRecord(_)) => ()
case _ => throw EExpectedExceptionableType(ctx, tyConName)
}
}
}

case class Env(
Expand Down Expand Up @@ -854,6 +873,12 @@ private[validation] object Typing {
case UpdateLookupByKey(retrieveByKey) =>
checkByKey(retrieveByKey.templateId, retrieveByKey.key)
TUpdate(TOptional(TContractId(TTyCon(retrieveByKey.templateId))))
case UpdateTryCatch(typ, body, binder, handler) =>
checkType(typ, KStar)
val updTyp = TUpdate(typ)
checkExpr(body, updTyp)
introExprVar(binder, TAnyException).checkExpr(handler, TOptional(updTyp))
updTyp
}

private def typeOfCommit(typ: Type, party: Expr, update: Expr): Type = {
Expand Down Expand Up @@ -907,6 +932,14 @@ private[validation] object Typing {
checkType(typ, KStar)
}

private def checkExceptionType(typ: Type): Unit = {
typ match {
case TGeneralError | TArithmeticError | TContractError => ()
case TTyCon(tyCon) if lookupException(ctx, tyCon) => ()
case _ => throw EExpectedExceptionType(ctx, typ)
}
}

def typeOf(expr: Expr): Type = {
val typ0 = typeOf_(expr)
expandTypeSynonyms(typ0)
Expand Down Expand Up @@ -984,6 +1017,15 @@ private[validation] object Typing {
case ETypeRep(typ) =>
checkAnyType(typ)
TTypeRep
case EMakeAnyException(typ, message, value) =>
checkExceptionType(typ)
checkExpr(message, TText)
checkExpr(value, typ)
TAnyException
case EFromAnyException(typ, value) =>
checkExceptionType(typ)
checkExpr(value, TAnyException)
TOptional(typ)
}

def checkExpr(expr: Expr, typ0: Type): Type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ final case class ContextDefDataType(tycon: TypeConName) extends Context {
final case class ContextTemplate(tycon: TypeConName) extends Context {
def pretty: String = s"data type ${tycon.qualifiedName}"
}
final case class ContextDefException(tycon: TypeConName) extends Context {
def pretty: String = s"exception type ${tycon.qualifiedName}"
}
final case class ContextDefValue(ref: ValueRef) extends Context {
def pretty: String = s"value type ${ref.qualifiedName}"
}
Expand Down Expand Up @@ -294,6 +297,10 @@ final case class EExpectedAnyType(context: Context, typ: Type) extends Validatio
protected def prettyInternal: String =
s"expected a type containing neither type variables nor quantifiers, but found: ${typ.pretty}"
}
final case class EExpectedExceptionType(context: Context, typ: Type) extends ValidationError {
protected def prettyInternal: String =
s"expected an exception type, but found: ${typ.pretty}"
}
final case class EExpectedHigherKind(context: Context, kind: Kind) extends ValidationError {
protected def prettyInternal: String = s"expected higher kinded type, but found: ${kind.pretty}"
}
Expand Down Expand Up @@ -353,6 +360,11 @@ final case class EExpectedTemplatableType(context: Context, conName: TypeConName
protected def prettyInternal: String =
s"expected monomorphic record type in template definition, but found: ${conName.qualifiedName}"
}
final case class EExpectedExceptionableType(context: Context, conName: TypeConName)
extends ValidationError {
protected def prettyInternal: String =
s"expected monomorphic record type in exception definition, but found: ${conName.qualifiedName}"
}
final case class EImportCycle(context: Context, modName: List[ModuleName]) extends ValidationError {
protected def prettyInternal: String = s"cycle in module dependency ${modName.mkString(" -> ")}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,6 @@ private[validation] class World(packages: PartialFunction[PackageId, Ast.GenPack
throw EUnknownDefinition(ctx, LEValue(name))
}

def lookupException(ctx: => Context, name: TypeConName): Boolean =
lookupModule(ctx, name.packageId, name.qualifiedName.module).exceptions.contains(name.qualifiedName.name)
}
Loading

0 comments on commit 3ce8659

Please sign in to comment.