From c9fdad1dd173d1cb5111a3e5d66ec9218b00869f Mon Sep 17 00:00:00 2001 From: "Andy(Jingzhang)Chen" Date: Thu, 12 Sep 2024 19:55:42 +0800 Subject: [PATCH] add unit test protect ExtractEntityId can be shared safely (#1475) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add unit test protect ExtractEntityId can be shared safely Related with #1463 * chore: avoid the double evaluation of entityId in ClusterSharding (#1304) * chore: avoid the double evaluation of entityId in ClusterSharding * new cacheable partial function * optimized for review * fix the right type * Revert "chore: avoid the double evaluation of entityId in ClusterSharding (#1…" (#1464) This reverts commit b0e9886439bf216dace9e2979ea820521ddd2a63. * grammar fix * sort imports --------- Co-authored-by: PJ Fanning --- .../cluster/sharding/ShardRegionSpec.scala | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala b/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala index 095df727a85..95a7df4296b 100644 --- a/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala +++ b/cluster-sharding/src/test/scala/org/apache/pekko/cluster/sharding/ShardRegionSpec.scala @@ -21,9 +21,14 @@ import org.apache.pekko import pekko.actor.{ Actor, ActorLogging, ActorRef, ActorSystem, PoisonPill, Props } import pekko.cluster.{ Cluster, MemberStatus } import pekko.cluster.ClusterEvent.CurrentClusterState +import pekko.cluster.sharding.ShardRegion.MessageExtractor +import pekko.Done +import pekko.stream.scaladsl.{ Sink, Source } import pekko.testkit.{ DeadLettersFilter, PekkoSpec, TestProbe, WithLogCapturing } import pekko.testkit.TestEvent.Mute +import scala.concurrent.{ ExecutionContext, Future } + object ShardRegionSpec { val host = "127.0.0.1" val tempConfig = @@ -54,6 +59,7 @@ object ShardRegionSpec { val shardTypeName = "Caat" val numberOfShards = 3 + val largerShardNum = 20 val extractEntityId: ShardRegion.ExtractEntityId = { case msg: Int => (msg.toString, msg) @@ -66,11 +72,37 @@ object ShardRegionSpec { case _ => throw new IllegalArgumentException() } + val messageExtractor: MessageExtractor = new MessageExtractor { + override def entityId(message: Any): String = message match { + case msg: Int => msg.toString + case _ => throw new IllegalArgumentException() + } + + override def shardId(message: Any): String = message match { + case msg: Int => (msg % largerShardNum).toString + case _ => throw new IllegalArgumentException() + } + + override def entityMessage(message: Any): Any = message + } + class EntityActor extends Actor with ActorLogging { override def receive: Receive = { case msg => sender() ! msg } } + + class IDMatcherActor extends Actor with ActorLogging { + override def receive: Receive = { + case msg => + val selfEntityId = self.path.name + val msgEntityId = messageExtractor.entityId(msg) + if (selfEntityId != msgEntityId) { + throw new IllegalStateException(s"EntityId mismatch: $selfEntityId != $msgEntityId") + } + sender() ! msg + } + } } class ShardRegionSpec extends PekkoSpec(ShardRegionSpec.config) with WithLogCapturing { @@ -183,4 +215,43 @@ class ShardRegionSpec extends PekkoSpec(ShardRegionSpec.config) with WithLogCapt } } + "ExtractEntityId" must { + "be safely able to share multiple shards" in { + implicit val ec: ExecutionContext = system.dispatcher + + Cluster(sysA).join(Cluster(sysA).selfAddress) // coordinator on A + awaitAssert(Cluster(sysA).selfMember.status shouldEqual MemberStatus.Up, 1.second) + + within(10.seconds) { + awaitAssert { + Set(sysA).foreach { s => + Cluster(s).sendCurrentClusterState(testActor) + expectMsgType[CurrentClusterState].members.size shouldEqual 2 + } + } + } + + val shardTypeName = "Doog" + val region = ClusterSharding(sysA).start( + shardTypeName, + Props[IDMatcherActor](), + ClusterShardingSettings(system), + messageExtractor) + + val total = largerShardNum * 100 + val source = Source(1 to total) + + val flow = source.mapAsync(parallelism = largerShardNum) { i => + Future { + region.tell(i, p1.ref) + } + } + + val result = flow.runWith(Sink.ignore) + + result.futureValue shouldEqual Done + p1.receiveN(total, 10.seconds) + } + } + }