diff --git a/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala b/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala index 026611912c7..525718d2277 100644 --- a/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala +++ b/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala @@ -20,6 +20,7 @@ import java.util.concurrent.CompletionStage import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future +import scala.runtime.AbstractPartialFunction import org.apache.pekko import pekko.actor.ActorRefProvider @@ -172,10 +173,16 @@ import pekko.util.JavaDurationConverters._ allocationStrategy: Option[ShardAllocationStrategy]): ActorRef[E] = { val extractorAdapter = new ExtractorAdapter(extractor) - val extractEntityId: ShardRegion.ExtractEntityId = { - // TODO is it possible to avoid the double evaluation of entityId - case message if extractorAdapter.entityId(message) != null => - (extractorAdapter.entityId(message), extractorAdapter.unwrapMessage(message)) + // !!!important is only applicable if you know that isDefinedAt(x) is always called before apply(x) (with the same x) + val extractEntityId: ShardRegion.ExtractEntityId = new AbstractPartialFunction[Any, (String, Any)] { + var cache: String = _ + + override def isDefinedAt(msg: Any): Boolean = { + cache = extractorAdapter.entityId(msg) + cache != null + } + + override def apply(x: Any): (String, Any) = (cache, extractorAdapter.unwrapMessage(x)) } val extractShardId: ShardRegion.ExtractShardId = { message => extractorAdapter.entityId(message) match { diff --git a/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala b/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala index 9278839da3d..835d2fcfbb2 100755 --- a/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala +++ b/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala @@ -19,6 +19,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.immutable import scala.concurrent.Await +import scala.runtime.AbstractPartialFunction import scala.util.control.NonFatal import org.apache.pekko @@ -429,15 +430,26 @@ class ClusterSharding(system: ExtendedActorSystem) extends Extension { typeName, _ => entityProps, settings, - extractEntityId = { - case msg if messageExtractor.entityId(msg) ne null => - (messageExtractor.entityId(msg), messageExtractor.entityMessage(msg)) - }, + extractEntityId = extractEntityIdFromExtractor(messageExtractor), extractShardId = msg => messageExtractor.shardId(msg), allocationStrategy = allocationStrategy, handOffStopMessage = handOffStopMessage) } + // !!!important is only applicable if you know that isDefinedAt(x) is always called before apply(x) (with the same x) + private def extractEntityIdFromExtractor( + messageExtractor: ShardRegion.MessageExtractor): ShardRegion.ExtractEntityId = + new AbstractPartialFunction[Any, (String, Any)] { + var cache: String = _ + + override def isDefinedAt(msg: Any): Boolean = { + cache = messageExtractor.entityId(msg) + cache != null + } + + override def apply(x: Any): (String, Any) = (cache, messageExtractor.entityMessage(x)) + } + /** * Java/Scala API: Register a named entity type by defining the [[pekko.actor.Props]] of the entity actor * and functions to extract entity and shard identifier from messages. The [[ShardRegion]] actor @@ -612,11 +624,12 @@ class ClusterSharding(system: ExtendedActorSystem) extends Extension { dataCenter: Optional[String], messageExtractor: ShardRegion.MessageExtractor): ActorRef = { - startProxy(typeName, Option(role.orElse(null)), Option(dataCenter.orElse(null)), - extractEntityId = { - case msg if messageExtractor.entityId(msg) ne null => - (messageExtractor.entityId(msg), messageExtractor.entityMessage(msg)) - }, extractShardId = msg => messageExtractor.shardId(msg)) + startProxy( + typeName, + Option(role.orElse(null)), + Option(dataCenter.orElse(null)), + extractEntityId = extractEntityIdFromExtractor(messageExtractor), + msg => messageExtractor.shardId(msg)) }