Skip to content

Commit

Permalink
Add Source.collect (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Sep 4, 2024
1 parent 575dc64 commit 4d8ad10
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
31 changes: 30 additions & 1 deletion core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ trait SourceOps[+T] { outer: Source[T] =>
}

/** Creates a view of this source, where the results of [[receive]] will be transformed on the consumer's thread using the given function
* `f`. If the function is not defined at an element, the element will be skipped.
* `f`. If the function is not defined at an element, the element will be skipped. For an eager, asynchronous version, see [[collect]].
*
* The same logic applies to receive clauses created using this source, which can be used in [[select]].
*
Expand Down Expand Up @@ -288,6 +288,35 @@ trait SourceOps[+T] { outer: Source[T] =>
}
c

/** Applies the given mapping function `f` to each element received from this source, for which the function is defined, and sends the
* results to the returned channel. If `f` is not defined at an element, the element will be skipped.
*
* Errors from this channel are propagated to the returned channel. Any exceptions that occur when invoking `f` are propagated as errors
* to the returned channel as well.
*
* Must be run within a scope, as a child fork is created, which receives from this source and sends the mapped values to the resulting
* one.
*
* For a lazily-evaluated version, see [[collectAsView]].
*
* @param f
* The mapping function.
* @return
* A source, onto which results of the mapping function will be sent.
*/
def collect[U](f: PartialFunction[T, U])(using Ox, StageCapacity): Source[U] =
val c2 = StageCapacity.newChannel[U]
forkPropagate(c2) {
repeatWhile {
receiveOrClosed() match
case ChannelClosed.Done => c2.doneOrClosed(); false
case ChannelClosed.Error(r) => c2.errorOrClosed(r); false
case t: T @unchecked if f.isDefinedAt(t) => c2.send(f(t)); true
case _ => true // f is not defined at t, skipping
}
}
c2

def take(n: Int)(using Ox, StageCapacity): Source[T] = transform(_.take(n))

/** Transform the source so that it returns elements as long as predicate `f` is satisfied (returns `true`). If `includeFirstFailing` is
Expand Down
20 changes: 20 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsCollectTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsCollectTest extends AnyFlatSpec with Matchers:
behavior of "Source.collect"

it should "collect over a source" in {
supervised {
val c = Source.fromValues(1 to 10: _*)

val s = c.collect {
case i if i % 2 == 0 => i * 10
}

s.toList shouldBe (2 to 10 by 2).map(_ * 10)
}
}

0 comments on commit 4d8ad10

Please sign in to comment.