Skip to content

Commit

Permalink
v5.x-control-max-tree-depth: Add maxTreeDepth check for TypeSerialize…
Browse files Browse the repository at this point in the history
…r, update tests.
  • Loading branch information
jozanek committed Jun 7, 2022
1 parent 8bee45d commit 129fdb0
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,95 +23,100 @@ object TypeSerializer {
embeddableIdToType(code)
}

def serialize(tpe: SType, w: SigmaByteWriter): Unit = tpe match {
case p: SEmbeddable => w.put(p.typeCode)
case SString => w.put(SString.typeCode)
case SAny => w.put(SAny.typeCode)
case SUnit => w.put(SUnit.typeCode)
case SBox => w.put(SBox.typeCode)
case SAvlTree => w.put(SAvlTree.typeCode)
case SContext => w.put(SContext.typeCode)
case SGlobal => w.put(SGlobal.typeCode)
case SHeader => w.put(SHeader.typeCode)
case SPreHeader => w.put(SPreHeader.typeCode)
case c: SCollectionType[a] => c.elemType match {
case p: SEmbeddable =>
val code = p.embedIn(CollectionTypeCode)
w.put(code)
case cn: SCollectionType[a] => cn.elemType match {
def serialize(tpe: SType, w: SigmaByteWriter): Unit = {
val depth = w.level
w.level = depth + 1
tpe match {
case p: SEmbeddable => w.put(p.typeCode)
case SString => w.put(SString.typeCode)
case SAny => w.put(SAny.typeCode)
case SUnit => w.put(SUnit.typeCode)
case SBox => w.put(SBox.typeCode)
case SAvlTree => w.put(SAvlTree.typeCode)
case SContext => w.put(SContext.typeCode)
case SGlobal => w.put(SGlobal.typeCode)
case SHeader => w.put(SHeader.typeCode)
case SPreHeader => w.put(SPreHeader.typeCode)
case c: SCollectionType[a] => c.elemType match {
case p: SEmbeddable =>
val code = p.embedIn(NestedCollectionTypeCode)
val code = p.embedIn(CollectionTypeCode)
w.put(code)
case _ =>
case cn: SCollectionType[a] => cn.elemType match {
case p: SEmbeddable =>
val code = p.embedIn(NestedCollectionTypeCode)
w.put(code)
case _ =>
w.put(CollectionTypeCode)
serialize(cn, w)
}
case t =>
w.put(CollectionTypeCode)
serialize(cn, w)
serialize(t, w)
}
case t =>
w.put(CollectionTypeCode)
serialize(t, w)
}
case o: SOption[a] => o.elemType match {
case p: SEmbeddable =>
val code = p.embedIn(SOption.OptionTypeCode)
w.put(code)
case c: SCollectionType[a] => c.elemType match {
case o: SOption[a] => o.elemType match {
case p: SEmbeddable =>
val code = p.embedIn(SOption.OptionCollectionTypeCode)
val code = p.embedIn(SOption.OptionTypeCode)
w.put(code)
case _ =>
case c: SCollectionType[a] => c.elemType match {
case p: SEmbeddable =>
val code = p.embedIn(SOption.OptionCollectionTypeCode)
w.put(code)
case _ =>
w.put(SOption.OptionTypeCode)
serialize(c, w)
}
case t =>
w.put(SOption.OptionTypeCode)
serialize(c, w)
serialize(t, w)
}
case t =>
w.put(SOption.OptionTypeCode)
serialize(t, w)
}
case _ @ STuple(Seq(t1, t2)) => (t1, t2) match {
case (p: SEmbeddable, _) =>
if (p == t2) {
// Symmetric pair of primitive types (`(Int, Int)`, `(Byte,Byte)`, etc.)
val code = p.embedIn(STuple.PairSymmetricTypeCode)
w.put(code)
} else {
// Pair of types where first is primitive (`(_, Int)`)
val code = p.embedIn(STuple.Pair1TypeCode)
case _ @ STuple(Seq(t1, t2)) => (t1, t2) match {
case (p: SEmbeddable, _) =>
if (p == t2) {
// Symmetric pair of primitive types (`(Int, Int)`, `(Byte,Byte)`, etc.)
val code = p.embedIn(STuple.PairSymmetricTypeCode)
w.put(code)
} else {
// Pair of types where first is primitive (`(_, Int)`)
val code = p.embedIn(STuple.Pair1TypeCode)
w.put(code)
serialize(t2, w)
}
case (_, p: SEmbeddable) =>
// Pair of types where second is primitive (`(Int, _)`)
val code = p.embedIn(STuple.Pair2TypeCode)
w.put(code)
serialize(t1, w)
case _ =>
// Pair of non-primitive types (`((Int, Byte), (Boolean,Box))`, etc.)
w.put(STuple.Pair1TypeCode)
serialize(t1, w)
serialize(t2, w)
}
case (_, p: SEmbeddable) =>
// Pair of types where second is primitive (`(Int, _)`)
val code = p.embedIn(STuple.Pair2TypeCode)
w.put(code)
serialize(t1, w)
case _ =>
// Pair of non-primitive types (`((Int, Byte), (Boolean,Box))`, etc.)
w.put(STuple.Pair1TypeCode)
serialize(t1, w)
serialize(t2, w)
}
case STuple(items) if items.length < 2 =>
sys.error(s"Invalid Tuple type with less than 2 items $items")
case tup: STuple => tup.items.length match {
case 3 =>
// Triple of types
w.put(STuple.TripleTypeCode)
for (i <- tup.items)
serialize(i, w)
case 4 =>
// Quadruple of types
w.put(STuple.QuadrupleTypeCode)
for (i <- tup.items)
serialize(i, w)
case _ =>
// `Tuple` type with more than 4 items `(Int, Byte, Box, Boolean, Int)`
serializeTuple(tup, w)
}
case typeIdent: STypeVar => {
w.put(typeIdent.typeCode)
val bytes = typeIdent.name.getBytes(StandardCharsets.UTF_8)
w.putUByte(bytes.length)
.putBytes(bytes)
}
case STuple(items) if items.length < 2 =>
sys.error(s"Invalid Tuple type with less than 2 items $items")
case tup: STuple => tup.items.length match {
case 3 =>
// Triple of types
w.put(STuple.TripleTypeCode)
for (i <- tup.items)
serialize(i, w)
case 4 =>
// Quadruple of types
w.put(STuple.QuadrupleTypeCode)
for (i <- tup.items)
serialize(i, w)
case _ =>
// `Tuple` type with more than 4 items `(Int, Byte, Box, Boolean, Int)`
serializeTuple(tup, w)
}
case typeIdent: STypeVar => {
w.put(typeIdent.typeCode)
val bytes = typeIdent.name.getBytes(StandardCharsets.UTF_8)
w.putUByte(bytes.length)
.putBytes(bytes)
}
}
w.level = depth
}

def deserialize(r: SigmaByteReader): SType = deserialize(r, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.ergoplatform.{ErgoBoxCandidate, Inputs, Outputs}
import org.scalacheck.Gen
import scalan.util.BenchmarkUtil
import scorex.util.ByteArrayBuilder
import scorex.util.serialization.{Reader, VLQByteBufferReader, VLQByteBufferWriter}
import scorex.util.serialization.{Reader, VLQByteBufferReader, VLQByteBufferWriter, Writer}
import sigmastate.Values.{BlockValue, GetVarInt, IntConstant, SValue, SigmaBoolean, SigmaPropValue, Tuple, ValDef, ValUse}
import sigmastate._
import sigmastate.eval.Extensions._
Expand Down Expand Up @@ -46,8 +46,8 @@ class DeserializationResilience extends SerializationSpecification
private def writer(maxTreeDepth: Int): SigmaByteWriter = {
val b = new ByteArrayBuilder()
val writer = new VLQByteBufferWriter(b)
val r = new SigmaByteWriter(writer, None, maxTreeDepth)
r
val w = new SigmaByteWriter(writer, None, maxTreeDepth)
w
}

@tailrec
Expand Down Expand Up @@ -196,7 +196,66 @@ class DeserializationResilience extends SerializationSpecification
val depth = stackTrace.count { se =>
(se.getClassName == ValueSerializer.getClass.getName && se.getMethodName == "deserialize") ||
(se.getClassName == DataSerializer.getClass.getName && se.getMethodName == "deserialize") ||
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "parse")
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "parse") ||
(se.getClassName == TypeSerializer.getClass.getName && se.getMethodName == "deserialize")
}
callDepthsBuilder += depth
}
}
(levels, callDepthsBuilder.result())
}

private def traceWriterCallDepth(expr: SValue): (IndexedSeq[Int], IndexedSeq[Int]) = {
class LoggingSigmaByteWriter(w: Writer) extends SigmaByteWriter(w, None, SigmaSerializer.MaxTreeDepth) {
val levels: mutable.ArrayBuilder[Int] = mutable.ArrayBuilder.make[Int]()
override def level_=(v: Int): Unit = {
if (v >= super.level) {
// going deeper (depth is increasing), save new depth to account added depth level by the caller
levels += v
} else {
// going up (depth is decreasing), save previous depth to account added depth level for the caller
levels += super.level
}
super.level_=(v)
}
}

class ProbeException extends Exception

class ThrowingSigmaByteWriter(w: Writer, levels: IndexedSeq[Int], throwOnNthLevelCall: Int)
extends SigmaByteWriter(w, None, SigmaSerializer.MaxTreeDepth)
{
private var levelCall: Int = 0
override def level_=(v: Int): Unit = {
if (throwOnNthLevelCall == levelCall) throw new ProbeException()
levelCall += 1
super.level_=(v)
}
}

val loggingW = new LoggingSigmaByteWriter(new VLQByteBufferWriter(new ByteArrayBuilder()))
val bytes = ValueSerializer.serialize(expr, loggingW)
val levels = loggingW.levels.result()
levels.nonEmpty shouldBe true

val callDepthsBuilder = mutable.ArrayBuilder.make[Int]()
levels.zipWithIndex.foreach { case (_, levelIndex) =>
val throwingW = new ThrowingSigmaByteWriter(
new VLQByteBufferWriter(new ByteArrayBuilder()),
levels,
throwOnNthLevelCall = levelIndex
)
try {
val _ = ValueSerializer.serialize(expr, throwingW)
} catch {
case e: Exception =>
e.isInstanceOf[ProbeException] shouldBe true
val stackTrace = e.getStackTrace
val depth = stackTrace.count { se =>
(se.getClassName == ValueSerializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == DataSerializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == TypeSerializer.getClass.getName && se.getMethodName == "serialize")
}
callDepthsBuilder += depth
}
Expand Down Expand Up @@ -227,6 +286,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("writer.level is updated in ValueSerializer.serialize") {
val expr = SizeOf(Outputs)
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("max recursive call depth is checked in reader.level for ValueSerializer calls") {
val expr = SizeOf(Outputs)
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -246,6 +312,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("writer.level is updated in DataSerializer.serialize") {
val expr = IntConstant(1)
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 2, 2, 1)
}

property("max recursive call depth is checked in reader.level for DataSerializer calls") {
val expr = IntConstant(1)
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -265,6 +338,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 4, 4, 3, 2, 1)
}

property("writer.level is updated in SigmaBoolean.serializer.serialize") {
val expr = CAND(Seq(proveDlogGen.sample.get, proveDHTGen.sample.get))
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 2, 3, 4, 4, 4, 4, 3, 2, 1)
}

property("max recursive call depth is checked in reader.level for SigmaBoolean.serializer calls") {
val expr = CAND(Seq(proveDlogGen.sample.get, proveDHTGen.sample.get))
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -284,6 +364,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 3, 3, 4, 4, 3, 2, 2, 3, 3, 2, 1)
}

property("writer.level is updated in TypeSerializer") {
val expr = Tuple(Tuple(IntConstant(1), IntConstant(1)), IntConstant(1))
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 4, 4, 3, 3, 4, 4, 4, 4, 3, 2, 2, 3, 3, 3, 3, 2, 1)
}

property("max recursive call depth is checked in reader.level for TypeSerializer") {
val expr = Tuple(Tuple(IntConstant(1), IntConstant(1)), IntConstant(1))
an[DeserializeCallDepthExceeded] should be thrownBy
Expand Down

0 comments on commit 129fdb0

Please sign in to comment.