Skip to content

Commit

Permalink
v5.x-control-max-tree-depth: Add maxTreeDepth check for TypeSerialize…
Browse files Browse the repository at this point in the history
…r, update tests.
  • Loading branch information
jozanek committed Jun 7, 2022
1 parent 8bee45d commit 91853dd
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ object TypeSerializer {
embeddableIdToType(code)
}

def serialize(tpe: SType, w: SigmaByteWriter): Unit = tpe match {
def serialize(tpe: SType, w: SigmaByteWriter): Unit = {
val depth = w.level
w.level = depth + 1
tpe match {
case p: SEmbeddable => w.put(p.typeCode)
case SString => w.put(SString.typeCode)
case SAny => w.put(SAny.typeCode)
Expand Down Expand Up @@ -112,6 +115,8 @@ object TypeSerializer {
w.putUByte(bytes.length)
.putBytes(bytes)
}
}
w.level = depth
}

def deserialize(r: SigmaByteReader): SType = deserialize(r, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.ergoplatform.{ErgoBoxCandidate, Inputs, Outputs}
import org.scalacheck.Gen
import scalan.util.BenchmarkUtil
import scorex.util.ByteArrayBuilder
import scorex.util.serialization.{Reader, VLQByteBufferReader, VLQByteBufferWriter}
import scorex.util.serialization.{Reader, VLQByteBufferReader, VLQByteBufferWriter, Writer}
import sigmastate.Values.{BlockValue, GetVarInt, IntConstant, SValue, SigmaBoolean, SigmaPropValue, Tuple, ValDef, ValUse}
import sigmastate._
import sigmastate.eval.Extensions._
Expand Down Expand Up @@ -46,8 +46,8 @@ class DeserializationResilience extends SerializationSpecification
private def writer(maxTreeDepth: Int): SigmaByteWriter = {
val b = new ByteArrayBuilder()
val writer = new VLQByteBufferWriter(b)
val r = new SigmaByteWriter(writer, None, maxTreeDepth)
r
val w = new SigmaByteWriter(writer, None, maxTreeDepth)
w
}

@tailrec
Expand Down Expand Up @@ -196,7 +196,66 @@ class DeserializationResilience extends SerializationSpecification
val depth = stackTrace.count { se =>
(se.getClassName == ValueSerializer.getClass.getName && se.getMethodName == "deserialize") ||
(se.getClassName == DataSerializer.getClass.getName && se.getMethodName == "deserialize") ||
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "parse")
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "parse") ||
(se.getClassName == TypeSerializer.getClass.getName && se.getMethodName == "deserialize")
}
callDepthsBuilder += depth
}
}
(levels, callDepthsBuilder.result())
}

private def traceWriterCallDepth(expr: SValue): (IndexedSeq[Int], IndexedSeq[Int]) = {
class LoggingSigmaByteWriter(w: Writer) extends SigmaByteWriter(w, None, SigmaSerializer.MaxTreeDepth) {
val levels: mutable.ArrayBuilder[Int] = mutable.ArrayBuilder.make[Int]()
override def level_=(v: Int): Unit = {
if (v >= super.level) {
// going deeper (depth is increasing), save new depth to account added depth level by the caller
levels += v
} else {
// going up (depth is decreasing), save previous depth to account added depth level for the caller
levels += super.level
}
super.level_=(v)
}
}

class ProbeException extends Exception

class ThrowingSigmaByteWriter(w: Writer, levels: IndexedSeq[Int], throwOnNthLevelCall: Int)
extends SigmaByteWriter(w, None, SigmaSerializer.MaxTreeDepth)
{
private var levelCall: Int = 0
override def level_=(v: Int): Unit = {
if (throwOnNthLevelCall == levelCall) throw new ProbeException()
levelCall += 1
super.level_=(v)
}
}

val loggingW = new LoggingSigmaByteWriter(new VLQByteBufferWriter(new ByteArrayBuilder()))
val bytes = ValueSerializer.serialize(expr, loggingW)
val levels = loggingW.levels.result()
levels.nonEmpty shouldBe true

val callDepthsBuilder = mutable.ArrayBuilder.make[Int]()
levels.zipWithIndex.foreach { case (_, levelIndex) =>
val throwingW = new ThrowingSigmaByteWriter(
new VLQByteBufferWriter(new ByteArrayBuilder()),
levels,
throwOnNthLevelCall = levelIndex
)
try {
val _ = ValueSerializer.serialize(expr, throwingW)
} catch {
case e: Exception =>
e.isInstanceOf[ProbeException] shouldBe true
val stackTrace = e.getStackTrace
val depth = stackTrace.count { se =>
(se.getClassName == ValueSerializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == DataSerializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == TypeSerializer.getClass.getName && se.getMethodName == "serialize")
}
callDepthsBuilder += depth
}
Expand Down Expand Up @@ -227,6 +286,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("writer.level is updated in ValueSerializer.serialize") {
val expr = SizeOf(Outputs)
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("max recursive call depth is checked in reader.level for ValueSerializer calls") {
val expr = SizeOf(Outputs)
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -246,6 +312,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("writer.level is updated in DataSerializer.serialize") {
val expr = IntConstant(1)
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 2, 2, 1)
}

property("max recursive call depth is checked in reader.level for DataSerializer calls") {
val expr = IntConstant(1)
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -265,6 +338,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 4, 4, 3, 2, 1)
}

property("writer.level is updated in SigmaBoolean.serializer.serialize") {
val expr = CAND(Seq(proveDlogGen.sample.get, proveDHTGen.sample.get))
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 2, 3, 4, 4, 4, 4, 3, 2, 1)
}

property("max recursive call depth is checked in reader.level for SigmaBoolean.serializer calls") {
val expr = CAND(Seq(proveDlogGen.sample.get, proveDHTGen.sample.get))
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -284,6 +364,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 3, 3, 4, 4, 3, 2, 2, 3, 3, 2, 1)
}

property("writer.level is updated in TypeSerializer") {
val expr = Tuple(Tuple(IntConstant(1), IntConstant(1)), IntConstant(1))
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 4, 4, 3, 3, 4, 4, 4, 4, 3, 2, 2, 3, 3, 3, 3, 2, 1)
}

property("max recursive call depth is checked in reader.level for TypeSerializer") {
val expr = Tuple(Tuple(IntConstant(1), IntConstant(1)), IntConstant(1))
an[DeserializeCallDepthExceeded] should be thrownBy
Expand Down

0 comments on commit 91853dd

Please sign in to comment.