Skip to content

Commit

Permalink
v5.x-control-max-tree-depth: Add maxTreeDepth check for TypeSerialize…
Browse files Browse the repository at this point in the history
…r, update tests.
  • Loading branch information
jozanek committed Jun 7, 2022
1 parent 8bee45d commit 595ef21
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,62 +27,62 @@ object DataSerializer {
val depth = w.level
w.level = depth + 1
tpe match {
case SUnit => // don't need to save anything
case SBoolean => w.putBoolean(v.asInstanceOf[Boolean])
case SByte => w.put(v.asInstanceOf[Byte])
case SShort => w.putShort(v.asInstanceOf[Short])
case SInt => w.putInt(v.asInstanceOf[Int])
case SLong => w.putLong(v.asInstanceOf[Long])
case SString =>
val bytes = v.asInstanceOf[String].getBytes(StandardCharsets.UTF_8)
w.putUInt(bytes.length)
w.putBytes(bytes)
case SBigInt =>
val data = SigmaDsl.toBigInteger(v.asInstanceOf[BigInt]).toByteArray
w.putUShort(data.length)
w.putBytes(data)
case SGroupElement =>
GroupElementSerializer.serialize(groupElementToECPoint(v.asInstanceOf[GroupElement]), w)
case SSigmaProp =>
val p = v.asInstanceOf[SigmaProp]
SigmaBoolean.serializer.serialize(sigmaPropToSigmaBoolean(p), w)
case SBox =>
val b = v.asInstanceOf[Box]
ErgoBox.sigmaSerializer.serialize(boxToErgoBox(b), w)
case SAvlTree =>
AvlTreeData.serializer.serialize(avlTreeToAvlTreeData(v.asInstanceOf[AvlTree]), w)
case tColl: SCollectionType[a] =>
val coll = v.asInstanceOf[tColl.WrappedType]
w.putUShort(coll.length)
tColl.elemType match {
case SBoolean =>
w.putBits(coll.asInstanceOf[Coll[Boolean]].toArray)
case SByte =>
w.putBytes(coll.asInstanceOf[Coll[Byte]].toArray)
case _ =>
val arr = coll.toArray
cfor(0)(_ < arr.length, _ + 1) { i =>
val x = arr(i)
serialize(x, tColl.elemType, w)
}
}
case SUnit => // don't need to save anything
case SBoolean => w.putBoolean(v.asInstanceOf[Boolean])
case SByte => w.put(v.asInstanceOf[Byte])
case SShort => w.putShort(v.asInstanceOf[Short])
case SInt => w.putInt(v.asInstanceOf[Int])
case SLong => w.putLong(v.asInstanceOf[Long])
case SString =>
val bytes = v.asInstanceOf[String].getBytes(StandardCharsets.UTF_8)
w.putUInt(bytes.length)
w.putBytes(bytes)
case SBigInt =>
val data = SigmaDsl.toBigInteger(v.asInstanceOf[BigInt]).toByteArray
w.putUShort(data.length)
w.putBytes(data)
case SGroupElement =>
GroupElementSerializer.serialize(groupElementToECPoint(v.asInstanceOf[GroupElement]), w)
case SSigmaProp =>
val p = v.asInstanceOf[SigmaProp]
SigmaBoolean.serializer.serialize(sigmaPropToSigmaBoolean(p), w)
case SBox =>
val b = v.asInstanceOf[Box]
ErgoBox.sigmaSerializer.serialize(boxToErgoBox(b), w)
case SAvlTree =>
AvlTreeData.serializer.serialize(avlTreeToAvlTreeData(v.asInstanceOf[AvlTree]), w)
case tColl: SCollectionType[a] =>
val coll = v.asInstanceOf[tColl.WrappedType]
w.putUShort(coll.length)
tColl.elemType match {
case SBoolean =>
w.putBits(coll.asInstanceOf[Coll[Boolean]].toArray)
case SByte =>
w.putBytes(coll.asInstanceOf[Coll[Byte]].toArray)
case _ =>
val arr = coll.toArray
cfor(0)(_ < arr.length, _ + 1) { i =>
val x = arr(i)
serialize(x, tColl.elemType, w)
}
}

case t: STuple =>
val arr = Evaluation.fromDslTuple(v, t).asInstanceOf[t.WrappedType]
val len = arr.length
assert(arr.length == t.items.length, s"Type $t doesn't correspond to value $arr")
if (len > 0xFFFF)
sys.error(s"Length of tuple ${arr.length} exceeds ${0xFFFF} limit.")
var i = 0
while (i < arr.length) {
serialize[SType](arr(i), t.items(i), w)
i += 1
}
case t: STuple =>
val arr = Evaluation.fromDslTuple(v, t).asInstanceOf[t.WrappedType]
val len = arr.length
assert(arr.length == t.items.length, s"Type $t doesn't correspond to value $arr")
if (len > 0xFFFF)
sys.error(s"Length of tuple ${arr.length} exceeds ${0xFFFF} limit.")
var i = 0
while (i < arr.length) {
serialize[SType](arr(i), t.items(i), w)
i += 1
}

// TODO v6.0 (3h): support Option[T] (see https://github.com/ScorexFoundation/sigmastate-interpreter/issues/659)
case _ =>
CheckSerializableTypeCode(tpe.typeCode)
throw new SerializerException(s"Don't know how to serialize ($v, $tpe)")
// TODO v6.0 (3h): support Option[T] (see https://github.com/ScorexFoundation/sigmastate-interpreter/issues/659)
case _ =>
CheckSerializableTypeCode(tpe.typeCode)
throw new SerializerException(s"Don't know how to serialize ($v, $tpe)")
}
w.level = depth
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ object TypeSerializer {
embeddableIdToType(code)
}

def serialize(tpe: SType, w: SigmaByteWriter): Unit = tpe match {
def serialize(tpe: SType, w: SigmaByteWriter): Unit = {
val depth = w.level
w.level = depth + 1
tpe match {
case p: SEmbeddable => w.put(p.typeCode)
case SString => w.put(SString.typeCode)
case SAny => w.put(SAny.typeCode)
Expand Down Expand Up @@ -112,6 +115,8 @@ object TypeSerializer {
w.putUByte(bytes.length)
.putBytes(bytes)
}
}
w.level = depth
}

def deserialize(r: SigmaByteReader): SType = deserialize(r, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,37 +351,37 @@ object ValueSerializer extends SigmaSerializerCompanion[Value[SType]] {
val depth = w.level
w.level = depth + 1
serializable(v) match {
case c: Constant[SType] =>
w.constantExtractionStore match {
case Some(constantStore) =>
val ph = constantStore.put(c)(DeserializationSigmaBuilder)
w.put(ph.opCode)
constantPlaceholderSerializer.serialize(ph, w)
case c: Constant[SType] =>
w.constantExtractionStore match {
case Some(constantStore) =>
val ph = constantStore.put(c)(DeserializationSigmaBuilder)
w.put(ph.opCode)
constantPlaceholderSerializer.serialize(ph, w)
case None =>
constantSerializer.serialize(c, w)
}
case _ =>
val opCode = v.opCode
// help compiler recognize the type
val ser = getSerializer(opCode).asInstanceOf[ValueSerializer[v.type]]
if (collectSerInfo) {
val scope = serializerInfo.get(opCode) match {
case None =>
constantSerializer.serialize(c, w)
}
case _ =>
val opCode = v.opCode
// help compiler recognize the type
val ser = getSerializer(opCode).asInstanceOf[ValueSerializer[v.type]]
if (collectSerInfo) {
val scope = serializerInfo.get(opCode) match {
case None =>
val newScope = SerScope(opCode, mutable.ArrayBuffer.empty)
serializerInfo += (opCode -> newScope)
println(s"Added: ${ser.opDesc}")
newScope
case Some(scope) => scope
}
w.put(opCode)

scopeStack ::= scope
ser.serialize(v, w)
scopeStack = scopeStack.tail
} else {
w.put(opCode)
ser.serialize(v, w)
val newScope = SerScope(opCode, mutable.ArrayBuffer.empty)
serializerInfo += (opCode -> newScope)
println(s"Added: ${ser.opDesc}")
newScope
case Some(scope) => scope
}
w.put(opCode)

scopeStack ::= scope
ser.serialize(v, w)
scopeStack = scopeStack.tail
} else {
w.put(opCode)
ser.serialize(v, w)
}
}
w.level = depth
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.ergoplatform.{ErgoBoxCandidate, Inputs, Outputs}
import org.scalacheck.Gen
import scalan.util.BenchmarkUtil
import scorex.util.ByteArrayBuilder
import scorex.util.serialization.{Reader, VLQByteBufferReader, VLQByteBufferWriter}
import scorex.util.serialization.{Reader, VLQByteBufferReader, VLQByteBufferWriter, Writer}
import sigmastate.Values.{BlockValue, GetVarInt, IntConstant, SValue, SigmaBoolean, SigmaPropValue, Tuple, ValDef, ValUse}
import sigmastate._
import sigmastate.eval.Extensions._
Expand Down Expand Up @@ -46,8 +46,8 @@ class DeserializationResilience extends SerializationSpecification
private def writer(maxTreeDepth: Int): SigmaByteWriter = {
val b = new ByteArrayBuilder()
val writer = new VLQByteBufferWriter(b)
val r = new SigmaByteWriter(writer, None, maxTreeDepth)
r
val w = new SigmaByteWriter(writer, None, maxTreeDepth)
w
}

@tailrec
Expand Down Expand Up @@ -196,7 +196,66 @@ class DeserializationResilience extends SerializationSpecification
val depth = stackTrace.count { se =>
(se.getClassName == ValueSerializer.getClass.getName && se.getMethodName == "deserialize") ||
(se.getClassName == DataSerializer.getClass.getName && se.getMethodName == "deserialize") ||
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "parse")
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "parse") ||
(se.getClassName == TypeSerializer.getClass.getName && se.getMethodName == "deserialize")
}
callDepthsBuilder += depth
}
}
(levels, callDepthsBuilder.result())
}

private def traceWriterCallDepth(expr: SValue): (IndexedSeq[Int], IndexedSeq[Int]) = {
class LoggingSigmaByteWriter(w: Writer) extends SigmaByteWriter(w, None, SigmaSerializer.MaxTreeDepth) {
val levels: mutable.ArrayBuilder[Int] = mutable.ArrayBuilder.make[Int]()
override def level_=(v: Int): Unit = {
if (v >= super.level) {
// going deeper (depth is increasing), save new depth to account added depth level by the caller
levels += v
} else {
// going up (depth is decreasing), save previous depth to account added depth level for the caller
levels += super.level
}
super.level_=(v)
}
}

class ProbeException extends Exception

class ThrowingSigmaByteWriter(w: Writer, levels: IndexedSeq[Int], throwOnNthLevelCall: Int)
extends SigmaByteWriter(w, None, SigmaSerializer.MaxTreeDepth)
{
private var levelCall: Int = 0
override def level_=(v: Int): Unit = {
if (throwOnNthLevelCall == levelCall) throw new ProbeException()
levelCall += 1
super.level_=(v)
}
}

val loggingW = new LoggingSigmaByteWriter(new VLQByteBufferWriter(new ByteArrayBuilder()))
val bytes = ValueSerializer.serialize(expr, loggingW)
val levels = loggingW.levels.result()
levels.nonEmpty shouldBe true

val callDepthsBuilder = mutable.ArrayBuilder.make[Int]()
levels.zipWithIndex.foreach { case (_, levelIndex) =>
val throwingW = new ThrowingSigmaByteWriter(
new VLQByteBufferWriter(new ByteArrayBuilder()),
levels,
throwOnNthLevelCall = levelIndex
)
try {
val _ = ValueSerializer.serialize(expr, throwingW)
} catch {
case e: Exception =>
e.isInstanceOf[ProbeException] shouldBe true
val stackTrace = e.getStackTrace
val depth = stackTrace.count { se =>
(se.getClassName == ValueSerializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == DataSerializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == SigmaBoolean.serializer.getClass.getName && se.getMethodName == "serialize") ||
(se.getClassName == TypeSerializer.getClass.getName && se.getMethodName == "serialize")
}
callDepthsBuilder += depth
}
Expand Down Expand Up @@ -227,6 +286,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("writer.level is updated in ValueSerializer.serialize") {
val expr = SizeOf(Outputs)
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("max recursive call depth is checked in reader.level for ValueSerializer calls") {
val expr = SizeOf(Outputs)
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -246,6 +312,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 2, 1)
}

property("writer.level is updated in DataSerializer.serialize") {
val expr = IntConstant(1)
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 2, 2, 1)
}

property("max recursive call depth is checked in reader.level for DataSerializer calls") {
val expr = IntConstant(1)
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -265,6 +338,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 4, 4, 3, 2, 1)
}

property("writer.level is updated in SigmaBoolean.serializer.serialize") {
val expr = CAND(Seq(proveDlogGen.sample.get, proveDHTGen.sample.get))
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 2, 2, 3, 4, 4, 4, 4, 3, 2, 1)
}

property("max recursive call depth is checked in reader.level for SigmaBoolean.serializer calls") {
val expr = CAND(Seq(proveDlogGen.sample.get, proveDHTGen.sample.get))
an[DeserializeCallDepthExceeded] should be thrownBy
Expand All @@ -284,6 +364,13 @@ class DeserializationResilience extends SerializationSpecification
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 3, 3, 4, 4, 3, 2, 2, 3, 3, 2, 1)
}

property("writer.level is updated in TypeSerializer") {
val expr = Tuple(Tuple(IntConstant(1), IntConstant(1)), IntConstant(1))
val (callDepths, levels) = traceWriterCallDepth(expr)
callDepths shouldEqual levels
callDepths shouldEqual IndexedSeq(1, 2, 3, 4, 4, 4, 4, 3, 3, 4, 4, 4, 4, 3, 2, 2, 3, 3, 3, 3, 2, 1)
}

property("max recursive call depth is checked in reader.level for TypeSerializer") {
val expr = Tuple(Tuple(IntConstant(1), IntConstant(1)), IntConstant(1))
an[DeserializeCallDepthExceeded] should be thrownBy
Expand Down

0 comments on commit 595ef21

Please sign in to comment.