Skip to content

Commit

Permalink
add unfold method
Browse files Browse the repository at this point in the history
  • Loading branch information
xuwei-k committed Jul 10, 2023
1 parent 65cf59e commit faa6777
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ lazy val compat = new MultiScalaCrossProject(
sharedSourceDir / "scala-2.11_2.12"
}
},
versionPolicyIntention := Compatibility.BinaryAndSourceCompatible,
versionPolicyIntention := Compatibility.BinaryCompatible,
mimaBinaryIssueFilters ++= {
import com.typesafe.tools.mima.core._
import com.typesafe.tools.mima.core.ProblemFilters._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

package scala.collection.compat

import scala.annotation.nowarn
import scala.annotation.{nowarn, tailrec}
import scala.collection.generic._
import scala.reflect.ClassTag
import scala.collection.{
Expand Down Expand Up @@ -249,6 +249,10 @@ private[compat] trait PackageShared {
fact.apply(source.toSeq: _*)
}

implicit def toGenericCompanionExtensionMethods[CC[X] <: GenTraversable[X]](
companion: GenericCompanion[CC]
): GenericCompanionExtensionMethods[CC] = new GenericCompanionExtensionMethods[CC](companion)

implicit class MapFactoryExtensionMethods[CC[A, B] <: Map[A, B] with MapLike[A, B, CC[A, B]]](
private val fact: MapFactory[CC]) {
def from[K, V](source: TraversableOnce[(K, V)]): CC[K, V] =
Expand Down Expand Up @@ -663,3 +667,25 @@ class OptionCompanionExtensionMethods(private val fact: Option.type) extends Any

@inline def unless[A](cond: Boolean)(a: => A): Option[A] = when(!cond)(a)
}

class GenericCompanionExtensionMethods[CC[X] <: GenTraversable[X]](
private val companion: GenericCompanion[CC]) extends AnyVal {
def unfold[A, S](init: S)(f: S => Option[(A, S)])(
implicit cbf: CanBuildFrom[CC[A], A, CC[A]]
): CC[A] = {
val builder = cbf()

@tailrec
def loop(s1: S): Unit = {
f(s1) match {
case Some((a, s2)) =>
builder += a
loop(s2)
case None =>
}
}

loop(init)
builder.result()
}
}
26 changes: 26 additions & 0 deletions compat/src/test/scala/test/scala/collection/CollectionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,30 @@ class CollectionTest {
assertEquals(List(3, 1, 2).distinctBy(_ % 2 == 0), List(3, 2))
assertEquals(List.empty[Int].distinctBy(_ % 2 == 0), List.empty)
}

@Test
def testUnfold(): Unit = {
def typed[A](x: A): Unit = ()

val list = List.unfold(1)(x => if (x <= 5) Some((x.toString, x + 1)) else None)
typed[List[String]](list)
assertEquals(list, List("1", "2", "3", "4", "5"))

val vector = Vector.unfold(1)(x => if (x <= 100) Some((x, x * 3)) else None)
typed[Vector[Int]](vector)
assertEquals(vector, Vector(1, 3, 9, 27, 81))

val seq = collection.Seq.unfold(1L)(x => if (x <= 10L) Some(x, x + 2L) else None)
typed[collection.Seq[Long]](seq)
assertEquals(seq, collection.Seq(1L, 3L, 5L, 7L, 9L))

val iterable = Iterable.unfold(4)(x => if (x > 0) Some(("a" * x, x - 1)) else None)
typed[Iterable[String]](iterable)
assertEquals(iterable, Iterable("aaaa", "aaa", "aa", "a"))

val arrayBuffer =
collection.mutable.ArrayBuffer.unfold(1)(x => if (x < 3) Some((x, x + 1)) else None)
typed[collection.mutable.ArrayBuffer[Int]](arrayBuffer)
assertEquals(arrayBuffer, collection.mutable.ArrayBuffer(1, 2))
}
}

0 comments on commit faa6777

Please sign in to comment.