diff --git a/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala b/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala index 095df727a85..95a7df4296b 100644 --- a/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala +++ b/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala @@ -21,9 +21,14 @@ import org.apache.pekko import pekko.actor.{ Actor, ActorLogging, ActorRef, ActorSystem, PoisonPill, Props } import pekko.cluster.{ Cluster, MemberStatus } import pekko.cluster.ClusterEvent.CurrentClusterState +import pekko.cluster.sharding.ShardRegion.MessageExtractor +import pekko.Done +import pekko.stream.scaladsl.{ Sink, Source } import pekko.testkit.{ DeadLettersFilter, PekkoSpec, TestProbe, WithLogCapturing } import pekko.testkit.TestEvent.Mute +import scala.concurrent.{ ExecutionContext, Future } + object ShardRegionSpec { val host = "127.0.0.1" val tempConfig = @@ -54,6 +59,7 @@ object ShardRegionSpec { val shardTypeName = "Caat" val numberOfShards = 3 + val largerShardNum = 20 val extractEntityId: ShardRegion.ExtractEntityId = { case msg: Int => (msg.toString, msg) @@ -66,11 +72,37 @@ object ShardRegionSpec { case _ => throw new IllegalArgumentException() } + val messageExtractor: MessageExtractor = new MessageExtractor { + override def entityId(message: Any): String = message match { + case msg: Int => msg.toString + case _ => throw new IllegalArgumentException() + } + + override def shardId(message: Any): String = message match { + case msg: Int => (msg % largerShardNum).toString + case _ => throw new IllegalArgumentException() + } + + override def entityMessage(message: Any): Any = message + } + class EntityActor extends Actor with ActorLogging { override def receive: Receive = { case msg => sender() ! msg } } + + class IDMatcherActor extends Actor with ActorLogging { + override def receive: Receive = { + case msg => + val selfEntityId = self.path.name + val msgEntityId = messageExtractor.entityId(msg) + if (selfEntityId != msgEntityId) { + throw new IllegalStateException(s"EntityId mismatch: $selfEntityId != $msgEntityId") + } + sender() ! msg + } + } } class ShardRegionSpec extends PekkoSpec(ShardRegionSpec.config) with WithLogCapturing { @@ -183,4 +215,43 @@ class ShardRegionSpec extends PekkoSpec(ShardRegionSpec.config) with WithLogCapt } } + "ExtractEntityId" must { + "be safely able to share multiple shards" in { + implicit val ec: ExecutionContext = system.dispatcher + + Cluster(sysA).join(Cluster(sysA).selfAddress) // coordinator on A + awaitAssert(Cluster(sysA).selfMember.status shouldEqual MemberStatus.Up, 1.second) + + within(10.seconds) { + awaitAssert { + Set(sysA).foreach { s => + Cluster(s).sendCurrentClusterState(testActor) + expectMsgType[CurrentClusterState].members.size shouldEqual 2 + } + } + } + + val shardTypeName = "Doog" + val region = ClusterSharding(sysA).start( + shardTypeName, + Props[IDMatcherActor](), + ClusterShardingSettings(system), + messageExtractor) + + val total = largerShardNum * 100 + val source = Source(1 to total) + + val flow = source.mapAsync(parallelism = largerShardNum) { i => + Future { + region.tell(i, p1.ref) + } + } + + val result = flow.runWith(Sink.ignore) + + result.futureValue shouldEqual Done + p1.receiveN(total, 10.seconds) + } + } + }