Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LF: Make DarReader ZipEntries immulatble #10243

Merged
merged 6 commits into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,74 @@

package com.daml.lf.archive

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream, InputStream}
import java.util.zip.ZipInputStream

import com.daml.daml_lf_dev.DamlLf
import com.daml.lf.data.Bytes
import com.daml.lf.data.TryOps.Bracket.bracket
import com.daml.lf.data.TryOps.sequence

import scala.annotation.tailrec
import java.io.{File, FileInputStream, IOException, InputStream}
import java.util.zip.ZipInputStream
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}

class GenDarReader[A](
// The `Long` is the dalf size in bytes.
parseDalf: (Long, InputStream) => Try[A]
) {
class GenDarReader[A](parseDalf: Bytes => Try[A]) {

import GenDarReader._

/** Reads an archive from a File. */
def readArchiveFromFile(darFile: File): Try[Dar[A]] =
readArchive(darFile.getName, new ZipInputStream(new FileInputStream(darFile)))
bracket(Try(new ZipInputStream(new FileInputStream(darFile))))(zis => Try(zis.close))
.flatMap(readArchive(darFile.getName, _))

/** Reads an archive from a ZipInputStream. The stream will be closed by this function! */
def readArchive(
name: String,
darStream: ZipInputStream,
entrySizeThreshold: Int = EntrySizeThreshold,
): Try[Dar[A]] = {
): Try[Dar[A]] =
for {
entries <- bracket(Try(darStream))(zis => Try(zis.close())).flatMap(zis =>
loadZipEntries(name, zis, entrySizeThreshold)
)
names <- entries.readDalfNames(DarManifestReader.dalfNames): Try[Dar[String]]
main <- parseOne(entries.getInputStreamFor)(names.main): Try[A]
deps <- parseAll(entries.getInputStreamFor)(names.dependencies): Try[List[A]]
entries <- loadZipEntries(name, darStream, entrySizeThreshold)
names <- entries.readDalfNames
main <- parseOne(entries.get)(names.main)
deps <- parseAll(entries.get)(names.dependencies)
} yield Dar(main, deps)
}

// Fails if a zip bomb is detected
private def slurpWithCaution(
@throws[Error.ZipBomb]
@throws[IOException]
private[this] def slurpWithCaution(
name: String,
zip: ZipInputStream,
entrySizeThreshold: Int,
): Try[(Long, InputStream)] =
Try {
val output = new ByteArrayOutputStream()
val buffer = new Array[Byte](4096)
for (n <- Iterator.continually(zip.read(buffer)).takeWhile(_ >= 0) if n > 0) {
output.write(buffer, 0, n)
if (output.size >= entrySizeThreshold) throw Error.ZipBomb()
}
(output.size.toLong, new ByteArrayInputStream(output.toByteArray))
): (String, Bytes) = {
val buffSize = 4 * 1024 // 4k
val buffer = new Array[Byte](buffSize)
var output = Bytes.Empty
Iterator.continually(zip.read(buffer)).takeWhile(_ >= 0).foreach { size =>
output = output ++ Bytes.fromByteArray(buffer, 0, size)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concatenation is constant time.

if (output.length >= entrySizeThreshold) throw Error.ZipBomb()
}
name -> output
}

private def loadZipEntries(
private[this] def loadZipEntries(
name: String,
darStream: ZipInputStream,
entrySizeThreshold: Int,
): Try[ZipEntries] = {
@tailrec
def go(accT: Try[Map[String, (Long, InputStream)]]): Try[Map[String, (Long, InputStream)]] =
Option(darStream.getNextEntry) match {
case Some(entry) =>
go(
accT.flatMap { acc =>
bracket(slurpWithCaution(darStream, entrySizeThreshold))(_ =>
Try(darStream.closeEntry())
)
.map { sizedBytes =>
acc + (entry.getName -> sizedBytes)
}
}
)
case None => accT
}
): Try[ZipEntries] =
Try(
Iterator
.continually(darStream.getNextEntry)
.takeWhile(_ != null)
.map(entry => slurpWithCaution(entry.getName, darStream, entrySizeThreshold))
.toMap
).map(ZipEntries(name, _))

go(Success(Map.empty)).map(ZipEntries(name, _))
}

private def parseAll(getInputStreamFor: String => Try[(Long, InputStream)])(
names: List[String]
): Try[List[A]] =
sequence(names.map(parseOne(getInputStreamFor)))
private[this] def parseAll(getPayload: String => Try[Bytes])(names: List[String]): Try[List[A]] =
sequence(names.map(parseOne(getPayload)))

private def parseOne(getInputStreamFor: String => Try[(Long, InputStream)])(s: String): Try[A] =
bracket(getInputStreamFor(s))({ case (_, is) => Try(is.close()) }).flatMap({ case (size, is) =>
parseDalf(size, is)
})
private[this] def parseOne(getPayload: String => Try[Bytes])(s: String): Try[A] =
getPayload(s).flatMap(parseDalf)

}

Expand All @@ -98,53 +79,24 @@ object GenDarReader {
private val ManifestName = "META-INF/MANIFEST.MF"
private[archive] val EntrySizeThreshold = 1024 * 1024 * 1024 // 1 GB

private[archive] case class ZipEntries(name: String, entries: Map[String, (Long, InputStream)]) {
private[archive] case class ZipEntry(size: Long, getStream: () => InputStream)

def getInputStreamFor(entryName: String): Try[(Long, InputStream)] = {
private[archive] case class ZipEntries(name: String, entries: Map[String, Bytes]) {
private[GenDarReader] def get(entryName: String): Try[Bytes] = {
entries.get(entryName) match {
case Some((size, is)) => Success(size -> is)
case Some(is) => Success(is)
case None => Failure(Error.InvalidZipEntry(entryName, this))
}
}

def readDalfNames(
readDalfNamesFromManifest: InputStream => Try[Dar[String]]
): Try[Dar[String]] =
parseDalfNamesFromManifest(readDalfNamesFromManifest).recoverWith { case NonFatal(e1) =>
findLegacyDalfNames().recoverWith { case NonFatal(_) =>
Failure(Error.InvalidDar(this, e1))
}
}

private def parseDalfNamesFromManifest(
readDalfNamesFromManifest: InputStream => Try[Dar[String]]
): Try[Dar[String]] =
bracket(getInputStreamFor(ManifestName)) { case (_, is) => Try(is.close()) }
.flatMap { case (_, is) => readDalfNamesFromManifest(is) }

// There are three cases:
// 1. if it's only one .dalf, then that's the main one
// 2. if it's two .dalfs, where one of them has -prim in the name, the one without -prim is the main dalf.
// 3. parse error in all other cases
private def findLegacyDalfNames(): Try[Dar[String]] = {
val dalfs: List[String] = entries.keys.filter(isDalf).toList

dalfs.partition(isPrimDalf) match {
case (List(prim), List(main)) => Success(Dar(main, List(prim)))
case (List(prim), Nil) => Success(Dar(prim, List.empty))
case (Nil, List(main)) => Success(Dar(main, List.empty))
case _ => Failure(Error.InvalidLegacyDar(this))
}
}

private def isDalf(s: String): Boolean = s.toLowerCase.endsWith(".dalf")

private def isPrimDalf(s: String): Boolean = s.toLowerCase.contains("-prim") && isDalf(s)
remyhaemmerle-da marked this conversation as resolved.
Show resolved Hide resolved
private[GenDarReader] def readDalfNames: Try[Dar[String]] =
bracket(get(ManifestName).map(_.toInputStream))(is => Try(is.close()))
.flatMap(DarManifestReader.dalfNames)
.recoverWith { case NonFatal(e1) => Failure(Error.InvalidDar(this, e1)) }
}
}

object DarReader
extends GenDarReader[ArchivePayload]({ case (_, is) => Try(Reader.readArchive(is)) })
object DarReader extends GenDarReader[ArchivePayload](is => Try(Reader.readArchive(is)))

object RawDarReader
extends GenDarReader[DamlLf.Archive]({ case (_, is) => Try(DamlLf.Archive.parseFrom(is)) })
extends GenDarReader[DamlLf.Archive](is => Try(DamlLf.Archive.parseFrom(is.toByteString)))
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ package archive

import java.io.InputStream
import java.security.MessageDigest

import com.daml.lf.data.Ref.PackageId
import com.daml.lf.language.{LanguageMajorVersion, LanguageVersion}
import com.daml.daml_lf_dev.DamlLf
import com.daml.lf.data.Bytes
import com.google.protobuf.CodedInputStream

object Reader {
Expand All @@ -20,6 +20,9 @@ object Reader {
readArchive(DamlLf.Archive.parser().parseFrom(cos))
}

def readArchive(bytes: Bytes): ArchivePayload =
readArchive(bytes.toInputStream)

@throws[Error.Parsing]
def readArchive(lf: DamlLf.Archive): ArchivePayload = {
lf.getHashFunction match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ object Bytes {
new Bytes(value)

def fromByteArray(a: Array[Byte]): Bytes =
new Bytes(ByteString.copyFrom(a))
fromByteArray(a, 0, a.length)

def fromByteArray(a: Array[Byte], offset: Int, size: Int) =
new Bytes(ByteString.copyFrom(a, offset, size))

def fromByteBuffer(a: ByteBuffer): Bytes =
new Bytes(ByteString.copyFrom(a))
Expand Down