Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LF: Improve safety of proto message Serialization. #12686

Merged
merged 1 commit into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ package archive

import java.io.File

sealed abstract class Error(val msg: String) extends RuntimeException(msg)
sealed abstract class Error(val msg: String)
extends RuntimeException(msg)
with Product
with Serializable

object Error {

Expand Down Expand Up @@ -42,4 +45,6 @@ object Error {
extends Error(s"Unsupported file extension: ${file.getAbsolutePath}")

final case class Parsing(override val msg: String) extends Error(msg)

final case class Encoding(override val msg: String) extends Error(msg)
}
1 change: 1 addition & 0 deletions daml-lf/encoder/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ da_scala_library(
"//daml-lf/archive:daml_lf_archive_reader",
"//daml-lf/data",
"//daml-lf/language",
"//libs-scala/safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
],
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.lf.archive.testing
package com.daml.lf
package archive.testing

import java.security.MessageDigest
import com.daml.SafeProto

import java.security.MessageDigest
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.language.Ast.Package
import com.daml.lf.language.{LanguageMajorVersion, LanguageVersion}
Expand Down Expand Up @@ -35,7 +37,7 @@ object Encode {

final def encodeArchive(pkg: (PackageId, Package), version: LanguageVersion): PLF.Archive = {

val payload = encodePayloadOfVersion(pkg, version).toByteString
val payload = data.assertRight(SafeProto.toByteString(encodePayloadOfVersion(pkg, version)))
val hash = PackageId.assertFromString(
MessageDigest.getInstance("SHA-256").digest(payload.toByteArray).map("%02x" format _).mkString
)
Expand Down
1 change: 1 addition & 0 deletions daml-lf/kv-support/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ da_scala_library(
"//daml-lf/transaction",
"//daml-lf/transaction:transaction_proto_java",
"//daml-lf/transaction:value_proto_java",
"//libs-scala/safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ object ConversionError {
extends ConversionError(errorMessage)
final case class DecodeError(cause: ValueCoder.DecodeError)
extends ConversionError(cause.errorMessage)
final case class EncodeError(cause: ValueCoder.EncodeError)
extends ConversionError(cause.errorMessage)
final case class InternalError(override val errorMessage: String)
extends ConversionError(errorMessage)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.daml.lf.kv.archives

import com.daml.SafeProto
import com.daml.lf.archive.{ArchiveParser, Decode, Error => ArchiveError}
import com.daml.lf.data.Ref
import com.daml.lf.data.Ref.PackageId
Expand All @@ -21,12 +22,15 @@ object ArchiveConversions {

def parsePackageIdsAndRawArchives(
archives: List[com.daml.daml_lf_dev.DamlLf.Archive]
): Either[ArchiveError.Parsing, Map[Ref.PackageId, RawArchive]] =
): Either[ArchiveError, Map[Ref.PackageId, RawArchive]] =
archives.partitionMap { archive =>
Ref.PackageId.fromString(archive.getHash).map(_ -> RawArchive(archive.toByteString))
for {
pkgId <- Ref.PackageId.fromString(archive.getHash).left.map(ArchiveError.Parsing)
bytes <- SafeProto.toByteString(archive).left.map(ArchiveError.Encoding)
} yield pkgId -> RawArchive(bytes)
} match {
case (Nil, hashesAndRawArchives) => Right(hashesAndRawArchives.toMap)
case (errors, _) => Left(ArchiveError.Parsing(errors.head))
case (errors, _) => Left(errors.head)
}

def decodePackages(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.daml.lf.kv.contracts

import com.daml.SafeProto
import com.daml.lf.kv.ConversionError
import com.daml.lf.transaction.{TransactionCoder, TransactionOuterClass}
import com.daml.lf.value.{Value, ValueCoder}
Expand All @@ -14,9 +15,10 @@ object ContractConversions {
def encodeContractInstance(
coinst: Value.VersionedContractInstance
): Either[ValueCoder.EncodeError, RawContractInstance] =
TransactionCoder
.encodeContractInstance(ValueCoder.CidEncoder, coinst)
.map(contractInstance => RawContractInstance(contractInstance.toByteString))
for {
message <- TransactionCoder.encodeContractInstance(ValueCoder.CidEncoder, coinst)
bytes <- SafeProto.toByteString(message).left.map(ValueCoder.EncodeError(_))
} yield RawContractInstance(bytes)

def decodeContractInstance(
rawContractInstance: RawContractInstance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.daml.lf.kv.transactions

import com.daml.SafeProto
import com.daml.lf.data.{FrontStack, FrontStackCons, ImmArray}
import com.daml.lf.kv.ConversionError
import com.daml.lf.transaction.TransactionOuterClass.Node.NodeTypeCase
Expand All @@ -25,9 +26,11 @@ object TransactionConversions {
def encodeTransaction(
tx: VersionedTransaction
): Either[ValueCoder.EncodeError, RawTransaction] =
TransactionCoder
.encodeTransaction(TransactionCoder.NidEncoder, ValueCoder.CidEncoder, tx)
.map(transaction => RawTransaction(transaction.toByteString))
for {
msg <-
TransactionCoder.encodeTransaction(TransactionCoder.NidEncoder, ValueCoder.CidEncoder, tx)
bytes <- SafeProto.toByteString(msg).left.map(ValueCoder.EncodeError(_))
} yield RawTransaction(bytes)

def decodeTransaction(
rawTx: RawTransaction
Expand Down Expand Up @@ -63,7 +66,7 @@ object TransactionConversions {
def reconstructTransaction(
transactionVersion: String,
nodesWithIds: Seq[TransactionNodeIdWithNode],
): Either[ConversionError.ParseError, RawTransaction] = {
): Either[ConversionError, RawTransaction] = {
import scalaz.std.either._
import scalaz.std.list._
import scalaz.syntax.traverse._
Expand Down Expand Up @@ -94,7 +97,14 @@ object TransactionConversions {
}
.toList
.sequence_
.map(_ => RawTransaction(transactionBuilder.build.toByteString))
.flatMap(_ =>
SafeProto.toByteString(transactionBuilder.build()) match {
case Right(bytes) =>
Right(RawTransaction(bytes))
case Left(msg) =>
Left(ConversionError.EncodeError(ValueCoder.EncodeError(msg)))
}
)
}

/** Decodes and extracts outputs of a submitted transaction, that is the IDs and keys of contracts created or updated
Expand Down Expand Up @@ -210,7 +220,7 @@ object TransactionConversions {
}
}

goNodesToKeep(transaction.getRootsList.asScala.to(FrontStack), Set.empty).map {
goNodesToKeep(transaction.getRootsList.asScala.to(FrontStack), Set.empty).flatMap {
nodesToKeep =>
val filteredRoots = transaction.getRootsList.asScala.filter(nodesToKeep)

Expand Down Expand Up @@ -239,7 +249,14 @@ object TransactionConversions {
.addAllNodes(filteredNodes.asJavaCollection)
.setVersion(transaction.getVersion)
.build()
RawTransaction(newTransaction.toByteString)

SafeProto.toByteString(newTransaction) match {
case Right(bytes) =>
Right(RawTransaction(bytes))
case Left(msg) =>
// Should not happen as removing nodes should results into a smaller transaction.
Left(ConversionError.InternalError(msg))
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ object TransactionTraversal {
case Left(error) => Left(ConversionError.DecodeError(error))
case Right(nodeWitnesses) =>
val witnesses = parentWitnesses union nodeWitnesses
// Here node.toByteString is safe.
// Indeed node is a submessage of the transaction `rawTx` we got serialized
// as input of `traverseTransactionWithWitnesses` and successfully decoded, i.e.
// `rawTx` requires less than 2GB to be serialized, so does <node`.
remyhaemmerle-da marked this conversation as resolved.
Show resolved Hide resolved
// See com.daml.SafeProto for more details about issues with the toByteString method.
f(nodeId, RawTransaction.Node(node.toByteString), witnesses)
// Recurse into children (if any).
node.getNodeTypeCase match {
Expand All @@ -62,7 +67,7 @@ object TransactionTraversal {
}
}

private def informeesOfNode(
private[this] def informeesOfNode(
txVersion: TransactionVersion,
node: TransactionOuterClass.Node,
): Either[ValueCoder.DecodeError, Set[Ref.Party]] =
Expand Down
2 changes: 1 addition & 1 deletion libs-scala/safe-proto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ da_scala_library(
da_scala_test_suite(
name = "safe-protot-test",
srcs = glob(["src/test/scala/**/*.scala"]),
max_heap_size = "4g",
max_heap_size = "3g",
deps = [
":safe-proto",
"@maven//:com_google_protobuf_protobuf_java",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object SafeProto {
case e: RuntimeException
if e.isInstanceOf[NegativeArraySizeException] ||
e.getCause != null && e.getCause.isInstanceOf[CodedOutputStream.OutOfSpaceException] =>
Left(s"the ${message.getClass.getName} message is too big to be serialized")
Left(s"the ${message.getClass.getName} message is too large to be serialized")
}

def toByteString(message: AbstractMessageLite[_, _]): Either[String, ByteString] =
Expand Down