diff --git a/build.sbt b/build.sbt index 57f13dab..3484f321 100644 --- a/build.sbt +++ b/build.sbt @@ -82,7 +82,7 @@ lazy val compat = new MultiScalaCrossProject( sharedSourceDir / "scala-2.11_2.12" } }, - versionPolicyIntention := Compatibility.BinaryAndSourceCompatible, + versionPolicyIntention := Compatibility.BinaryCompatible, mimaBinaryIssueFilters ++= { import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.core.ProblemFilters._ diff --git a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala index 86a737c9..ace79efb 100644 --- a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala +++ b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala @@ -12,7 +12,7 @@ package scala.collection.compat -import scala.annotation.nowarn +import scala.annotation.{nowarn, tailrec} import scala.collection.generic._ import scala.reflect.ClassTag import scala.collection.{ @@ -249,6 +249,10 @@ private[compat] trait PackageShared { fact.apply(source.toSeq: _*) } + implicit def toGenericCompanionExtensionMethods[CC[X] <: GenTraversable[X]]( + companion: GenericCompanion[CC] + ): GenericCompanionExtensionMethods[CC] = new GenericCompanionExtensionMethods[CC](companion) + implicit class MapFactoryExtensionMethods[CC[A, B] <: Map[A, B] with MapLike[A, B, CC[A, B]]]( private val fact: MapFactory[CC]) { def from[K, V](source: TraversableOnce[(K, V)]): CC[K, V] = @@ -663,3 +667,25 @@ class OptionCompanionExtensionMethods(private val fact: Option.type) extends Any @inline def unless[A](cond: Boolean)(a: => A): Option[A] = when(!cond)(a) } + +class GenericCompanionExtensionMethods[CC[X] <: GenTraversable[X]]( + private val companion: GenericCompanion[CC]) extends AnyVal { + def unfold[A, S](init: S)(f: S => Option[(A, S)])( + implicit cbf: CanBuildFrom[CC[A], A, CC[A]] + ): CC[A] = { + val builder = cbf() + + @tailrec + def loop(s1: S): Unit = { + f(s1) match { + case Some((a, s2)) => + builder += a + loop(s2) + case None => + } + } + + loop(init) + builder.result() + } +} diff --git a/compat/src/test/scala/test/scala/collection/CollectionTest.scala b/compat/src/test/scala/test/scala/collection/CollectionTest.scala index a3be5367..2bd6955b 100644 --- a/compat/src/test/scala/test/scala/collection/CollectionTest.scala +++ b/compat/src/test/scala/test/scala/collection/CollectionTest.scala @@ -193,4 +193,30 @@ class CollectionTest { assertEquals(List(3, 1, 2).distinctBy(_ % 2 == 0), List(3, 2)) assertEquals(List.empty[Int].distinctBy(_ % 2 == 0), List.empty) } + + @Test + def testUnfold(): Unit = { + def typed[A](x: A): Unit = () + + val list = List.unfold(1)(x => if (x <= 5) Some((x.toString, x + 1)) else None) + typed[List[String]](list) + assertEquals(list, List("1", "2", "3", "4", "5")) + + val vector = Vector.unfold(1)(x => if (x <= 100) Some((x, x * 3)) else None) + typed[Vector[Int]](vector) + assertEquals(vector, Vector(1, 3, 9, 27, 81)) + + val seq = collection.Seq.unfold(1L)(x => if (x <= 10L) Some(x, x + 2L) else None) + typed[collection.Seq[Long]](seq) + assertEquals(seq, collection.Seq(1L, 3L, 5L, 7L, 9L)) + + val iterable = Iterable.unfold(4)(x => if (x > 0) Some(("a" * x, x - 1)) else None) + typed[Iterable[String]](iterable) + assertEquals(iterable, Iterable("aaaa", "aaa", "aa", "a")) + + val arrayBuffer = + collection.mutable.ArrayBuffer.unfold(1)(x => if (x < 3) Some((x, x + 1)) else None) + typed[collection.mutable.ArrayBuffer[Int]](arrayBuffer) + assertEquals(arrayBuffer, collection.mutable.ArrayBuffer(1, 2)) + } }