Skip to content

Commit

Permalink
feat: Add flatMapConcat with parallelism support.
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Jan 10, 2025
1 parent 11e9547 commit 4b92618
Show file tree
Hide file tree
Showing 9 changed files with 729 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package org.apache.pekko.stream
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit

import scala.concurrent.Await
import scala.concurrent.{ Await, Future }
import scala.concurrent.duration._

import com.typesafe.config.ConfigFactory
Expand Down Expand Up @@ -76,6 +76,16 @@ class FlatMapConcatBenchmark {
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def sourceDotSingleP1(): Unit = {
val latch = new CountDownLatch(1)

testSource.flatMapConcat(1, Source.single).runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def internalSingleSource(): Unit = {
Expand All @@ -88,6 +98,18 @@ class FlatMapConcatBenchmark {
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def internalSingleSourceP1(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(1, elem => new GraphStages.SingleSource(elem))
.runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def oneElementList(): Unit = {
Expand All @@ -98,6 +120,64 @@ class FlatMapConcatBenchmark {
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def oneElementListP1(): Unit = {
val latch = new CountDownLatch(1)

testSource.flatMapConcat(1, n => Source(n :: Nil)).runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def completedFuture(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(n => Source.future(Future.successful(n)))
.runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def completedFutureP1(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(1, n => Source.future(Future.successful(n)))
.runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def normalFuture(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(n => Source.future(Future(n)(system.dispatcher)))
.runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def normalFutureP1(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(1, n => Source.future(Future(n)(system.dispatcher)))
.runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def mapBaseline(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.stream.scaladsl

import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicInteger
import org.apache.pekko
import pekko.pattern.FutureTimeoutSupport
import pekko.NotUsed
import pekko.stream._
import pekko.stream.testkit.{ ScriptedTest, StreamSpec }
import pekko.stream.testkit.scaladsl.TestSink

import java.util.Collections
import scala.annotation.switch
import scala.concurrent.duration.DurationInt
import scala.concurrent.Future
import scala.util.control.NoStackTrace

class FlowFlatMapConcatParallelismSpec extends StreamSpec("""
pekko.stream.materializer.initial-input-buffer-size = 2
""") with ScriptedTest with FutureTimeoutSupport {
val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right)

class BoomException extends RuntimeException("BOOM~~") with NoStackTrace
"A flatMapConcat" must {

for (i <- 1 until 129) {
s"work with value presented sources with parallelism: $i" in {
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.successful(5)),
Source.lazyFuture(() => Future.successful(6)),
Source.future(after(1.millis)(Future.successful(7)))))
.flatMapConcat(i, identity)
.runWith(toSeq)
.futureValue should ===(1 to 7)
}
}

def generateRandomValuePresentedSources(nums: Int): (Int, List[Source[Int, NotUsed]]) = {
val seq = List.tabulate(nums) { _ =>
val random = ThreadLocalRandom.current().nextInt(1, 10)
(random: @switch) match {
case 1 => Source.single(1)
case 2 => Source(List(1))
case 3 => Source.fromJavaStream(() => Collections.singleton(1).stream())
case 4 => Source.future(Future.successful(1))
case 5 => Source.future(after(1.millis)(Future.successful(1)))
case _ => Source.empty[Int]
}
}
val sum = seq.filterNot(_.eq(Source.empty[Int])).size
(sum, seq)
}

def generateSequencedValuePresentedSources(nums: Int): (Int, List[Source[Int, NotUsed]]) = {
val seq = List.tabulate(nums) { index =>
val random = ThreadLocalRandom.current().nextInt(1, 6)
(random: @switch) match {
case 1 => Source.single(index)
case 2 => Source(List(index))
case 3 => Source.fromJavaStream(() => Collections.singleton(index).stream())
case 4 => Source.future(Future.successful(index))
case 5 => Source.future(after(1.millis)(Future.successful(index)))
case _ => throw new IllegalStateException("unexpected")
}
}
val sum = (0 until nums).sum
(sum, seq)
}

for (i <- 1 until 129) {
s"work with generated value presented sources with parallelism: $i " in {
val (sum, sources @ _) = generateRandomValuePresentedSources(100000)
Source(sources)
.flatMapConcat(i, identity(_)) // scala 2.12 can't infer the type of identity
.runWith(Sink.seq)
.map(_.sum)(pekko.dispatch.ExecutionContexts.parasitic)
.futureValue shouldBe sum
}
}

for (i <- 1 until 129) {
s"work with generated value sequenced sources with parallelism: $i " in {
val (sum, sources @ _) = generateSequencedValuePresentedSources(100000)
Source(sources)
.flatMapConcat(i, identity(_)) // scala 2.12 can't infer the type of identity
// check the order
.statefulMap(() => -1)((pre, current) => {
if (pre + 1 != current) {
throw new IllegalStateException(s"expected $pre + 1 == $current")
}
(current, current)
}, _ => None)
.runWith(Sink.seq)
.map(_.sum)(pekko.dispatch.ExecutionContexts.parasitic)
.futureValue shouldBe sum
}
}

"work with value presented failed sources" in {
val ex = new BoomException
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.failed(ex)),
Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.onErrorComplete[BoomException]()
.runWith(toSeq)
.futureValue should ===(1 to 4)
}

"work with value presented sources when demands slow" in {
val prob = Source(
List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.runWith(TestSink())

prob.request(1)
prob.expectNext(1)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(2, 3)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(4, 5)
prob.expectComplete()
}

val parallelism = ThreadLocalRandom.current().nextInt(4, 65)
s"can do pre materialization when parallelism > 1, parallelism is $parallelism" in {
val materializationCounter = new AtomicInteger(0)
val prob = Source(1 to (parallelism * 3))
.flatMapConcat(
parallelism,
value => {
Source
.lazySingle(() => {
materializationCounter.incrementAndGet()
value
})
.buffer(1, overflowStrategy = OverflowStrategy.backpressure)
})
.runWith(TestSink())

expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0

prob.request(1)
prob.expectNext(1.seconds, 1)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (parallelism + 1)
materializationCounter.set(0)

prob.request(2)
prob.expectNextN(List(2, 3))
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 2
materializationCounter.set(0)

prob.request(parallelism - 3)
prob.expectNextN(4 to parallelism)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (parallelism - 3)
materializationCounter.set(0)

prob.request(parallelism)
prob.expectNextN(parallelism + 1 to parallelism * 2)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe parallelism
materializationCounter.set(0)

prob.request(parallelism)
prob.expectNextN(parallelism * 2 + 1 to parallelism * 3)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0
prob.expectComplete()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ import pekko.stream.Attributes._
val mergePreferred = name("mergePreferred")
val mergePrioritized = name("mergePrioritized")
val flattenMerge = name("flattenMerge")
val flattenConcat = name("flattenConcat")
val recoverWith = name("recoverWith")
val onErrorComplete = name("onErrorComplete")
val broadcast = name("broadcast")
Expand Down
Loading

0 comments on commit 4b92618

Please sign in to comment.