Skip to content

Commit

Permalink
Merge pull request #569 from scala/marissa/cve-2022-36944
Browse files Browse the repository at this point in the history
Fix CVE-2022-36944 for `LazyList`
  • Loading branch information
lrytz authored Nov 25, 2022
2 parents 53b8c17 + 366d7a1 commit 7030af3
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import scala.collection.generic.{
SeqFactory
}
import scala.collection.immutable.{LinearSeq, NumericRange}
import scala.collection.mutable.{ArrayBuffer, Builder, StringBuilder}
import scala.collection.mutable.{Builder, StringBuilder}
import scala.language.implicitConversions

/** This class implements an immutable linked list that evaluates elements
Expand Down Expand Up @@ -516,10 +516,6 @@ final class LazyList[+A] private (private[this] var lazyState: () => LazyList.St
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))
} else super.++:(prefix)(bf)

private def prependedAllToLL[B >: A](prefix: Traversable[B]): LazyList[B] =
if (knownIsEmpty) LazyList.from(prefix)
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))

/** @inheritdoc
*
* $preservesLaziness
Expand Down Expand Up @@ -1512,14 +1508,17 @@ object LazyList extends SeqFactory[LazyList] {

private[this] def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
val init = new ArrayBuffer[A]
val init = new mutable.ListBuffer[A]
var initRead = false
while (!initRead) in.readObject match {
case SerializeEnd => initRead = true
case a => init += a.asInstanceOf[A]
}
val tail = in.readObject().asInstanceOf[LazyList[A]]
coll = tail.prependedAllToLL(init)
// scala/scala#10118: caution that no code path can evaluate `tail.state`
// before the resulting LazyList is returned
val it = init.toList.iterator
coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state))
}

private[this] def readResolve(): Any = coll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,86 @@ class LazyListGCTest {
def tapEach_takeRight_headOption_allowsGC(): Unit = {
assertLazyListOpAllowsGC(_.tapEach(_).takeRight(2).headOption, _ => ())
}

@Test
def serialization(): Unit =
if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.12"))) {
import java.io._

def serialize(obj: AnyRef): Array[Byte] = {
val buffer = new ByteArrayOutputStream
val out = new ObjectOutputStream(buffer)
out.writeObject(obj)
buffer.toByteArray
}

def deserialize(a: Array[Byte]): AnyRef = {
val in = new ObjectInputStream(new ByteArrayInputStream(a))
in.readObject
}

def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]

val l = LazyList.from(10)

val ld1 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld1.take(10).toList)

l.tail.head
val ld2 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld2.take(10).toList)

LazyListGCTest.serializationForceCount = 0
val u = LazyList
.from(10)
.map(x => {
LazyListGCTest.serializationForceCount += 1; x
})

def printDiff(): Unit = {
val a = serialize(u)
classOf[LazyList[_]]
.getDeclaredField("scala$collection$compat$immutable$LazyList$$stateEvaluated")
.setBoolean(u, true)
val b = serialize(u)
val i = a.zip(b).indexWhere(p => p._1 != p._2)
println("difference: ")
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
}

// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
// printDiff()

val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97,
118, 97, 46)
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97,
118, 97, 46)

assertEquals(LazyListGCTest.serializationForceCount, 0)

u.head
assertEquals(LazyListGCTest.serializationForceCount, 1)

val data = serialize(u)
var i = data.indexOfSlice(from)
to.foreach(x => {
data(i) = x; i += 1
})

val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]

// this check failed before scala/scala#10118, deserialization triggered evaluation
assertEquals(LazyListGCTest.serializationForceCount, 1)

ud1.tail.head
assertEquals(LazyListGCTest.serializationForceCount, 2)

u.tail.head
assertEquals(LazyListGCTest.serializationForceCount, 3)
}
}

object LazyListGCTest {
var serializationForceCount = 0
}

0 comments on commit 7030af3

Please sign in to comment.