diff --git a/changelog/8623.bugfix.md b/changelog/8623.bugfix.md new file mode 100644 index 000000000000..2dda63050e22 --- /dev/null +++ b/changelog/8623.bugfix.md @@ -0,0 +1,5 @@ +When there are multiple entities in a user message, they will get sorted when creating a +representation of the current dialogue state. + +Previously, the ordering was random, leading to inconsistent state representations. This +would sometimes lead to memoization policies failing to recall a memorised action. \ No newline at end of file diff --git a/rasa/shared/core/domain.py b/rasa/shared/core/domain.py index 72378196c59a..2cd2bb0dd491 100644 --- a/rasa/shared/core/domain.py +++ b/rasa/shared/core/domain.py @@ -675,7 +675,9 @@ def _collect_overridden_default_intents( list(intent.keys())[0] if isinstance(intent, dict) else intent for intent in intents } - return sorted(intent_names & set(rasa.shared.core.constants.DEFAULT_INTENTS)) + return sorted( + intent_names.intersection(set(rasa.shared.core.constants.DEFAULT_INTENTS)) + ) @staticmethod def _initialize_forms( @@ -1047,6 +1049,11 @@ def input_states(self) -> List[Text]: ) def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]: + """Gets the names of all entities that are present and wanted in the message. + + Wherever an entity has a role or group specified as well, an additional role- + or group-specific entity name is added. + """ intent_name = latest_message.intent.get( rasa.shared.nlu.constants.INTENT_NAME_KEY ) @@ -1057,33 +1064,36 @@ def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]: # groups get featurized. We concatenate the entity label with the role/group # label using a special separator to make sure that the resulting label is # unique (as you can have the same role/group label for different entities). - entity_names = ( - set(entity["entity"] for entity in entities if "entity" in entity.keys()) - | set( - f"{entity['entity']}" - f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['role']}" - for entity in entities - if "entity" in entity.keys() and "role" in entity.keys() - ) - | set( - f"{entity['entity']}" - f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['group']}" - for entity in entities - if "entity" in entity.keys() and "group" in entity.keys() - ) + entity_names_basic = set( + entity["entity"] for entity in entities if "entity" in entity.keys() + ) + entity_names_roles = set( + f"{entity['entity']}" + f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['role']}" + for entity in entities + if "entity" in entity.keys() and "role" in entity.keys() ) + entity_names_groups = set( + f"{entity['entity']}" + f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['group']}" + for entity in entities + if "entity" in entity.keys() and "group" in entity.keys() + ) + entity_names = entity_names_basic.union(entity_names_roles, entity_names_groups) # the USED_ENTITIES_KEY of an intent also contains the entity labels and the # concatenated entity labels with their corresponding roles and groups labels wanted_entities = set(intent_config.get(USED_ENTITIES_KEY, entity_names)) - return entity_names & wanted_entities + return entity_names.intersection(wanted_entities) def _get_user_sub_state( self, tracker: "DialogueStateTracker" ) -> Dict[Text, Union[Text, Tuple[Text]]]: - """Turn latest UserUttered event into a substate containing intent, - text and set entities if present + """Turns latest UserUttered event into a substate. + + The substate will contain intent, text, and entities (if any are present). + Args: tracker: dialog state tracker containing the dialog so far Returns: @@ -1097,15 +1107,19 @@ def _get_user_sub_state( sub_state = latest_message.as_sub_state() - # filter entities based on intent config - # sub_state will be transformed to frozenset therefore we need to - # convert the set to the tuple - # sub_state is transformed to frozenset because we will later hash it - # for deduplication + # Filter entities based on intent config. We need to convert the set into a + # tuple because sub_state will be later transformed into a frozenset (so it can + # be hashed for deduplication). entities = tuple( - self._get_featurized_entities(latest_message) - & set(sub_state.get(rasa.shared.nlu.constants.ENTITIES, ())) + self._get_featurized_entities(latest_message).intersection( + set(sub_state.get(rasa.shared.nlu.constants.ENTITIES, ())) + ) ) + # Sort entities so that any derived state representation is consistent across + # runs and invariant to the order in which the entities for an utterance are + # listed in data files. + entities = tuple(sorted(entities)) + if entities: sub_state[rasa.shared.nlu.constants.ENTITIES] = entities else: @@ -1180,16 +1194,17 @@ def _clean_state(state: State) -> State: if sub_state } - def get_active_states( + def get_active_state( self, tracker: "DialogueStateTracker", omit_unset_slots: bool = False, ) -> State: - """Returns a bag of active states from the tracker state. + """Given a dialogue tracker, makes a representation of current dialogue state. Args: tracker: dialog state tracker containing the dialog so far omit_unset_slots: If `True` do not include the initial values of slots. - Returns `State` containing all active states. + Returns: + A representation of the dialogue's current state. """ state = { rasa.shared.core.constants.USER: self._get_user_sub_state(tracker), @@ -1286,7 +1301,7 @@ def states_for_tracker_history( if turn_was_hidden: continue - state = self.get_active_states(tr, omit_unset_slots=omit_unset_slots) + state = self.get_active_state(tr, omit_unset_slots=omit_unset_slots) if ignore_rule_only_turns: # clean state from only rule features diff --git a/rasa/shared/core/generator.py b/rasa/shared/core/generator.py index 910d8fe68358..937a1479f359 100644 --- a/rasa/shared/core/generator.py +++ b/rasa/shared/core/generator.py @@ -184,7 +184,7 @@ def _append_current_state(self) -> None: if self._states_for_hashing is None: self._states_for_hashing = self.past_states_for_hashing(self.domain) else: - state = self.domain.get_active_states(self) + state = self.domain.get_active_state(self) frozen_state = self.freeze_current_state(state) self._states_for_hashing.append(frozen_state) diff --git a/tests/shared/core/test_domain.py b/tests/shared/core/test_domain.py index 7c0ac2fafcb2..65815d16c8ea 100644 --- a/tests/shared/core/test_domain.py +++ b/tests/shared/core/test_domain.py @@ -1,6 +1,7 @@ import copy import json from pathlib import Path +import random from typing import Dict, List, Text, Any, Union, Set, Optional import pytest @@ -35,6 +36,8 @@ Domain, KEY_FORMS, KEY_E2E_ACTIONS, + KEY_INTENTS, + KEY_ENTITIES, ) from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.events import ActionExecuted, SlotSet, UserUttered @@ -1110,6 +1113,48 @@ def test_get_featurized_entities(): assert featurized_entities == {"GPE", f"GPE{ENTITY_LABEL_SEPARATOR}destination"} +def test_featurized_entities_ordered_consistently(): + """Check that entities get ordered -- needed for consistent state representations. + + Previously, no ordering was applied to entities, but they were ordered implicitly + due to how python sets work -- a set of all entity names was internally created, + which was ordered by the hashes of the entity names. Now, entities are sorted alpha- + betically. Since even sorting based on randomised hashing can produce alphabetical + ordering once in a while, we here check with a large number of entities, pushing to + ~0 the probability of correctly sorting the elements just by accident, without + actually doing proper sorting. + """ + # Create a sorted list of entity names from 'a' to 'z', and two randomly shuffled + # copies. + entity_names_sorted = [chr(i) for i in range(ord("a"), ord("z") + 1)] + entity_names_shuffled1 = entity_names_sorted.copy() + random.shuffle(entity_names_shuffled1) + entity_names_shuffled2 = entity_names_sorted.copy() + random.shuffle(entity_names_shuffled2) + + domain = Domain.from_dict( + {KEY_INTENTS: ["inform"], KEY_ENTITIES: entity_names_shuffled1} + ) + + tracker = DialogueStateTracker.from_events( + "story123", + [ + UserUttered( + text="hey there", + intent={"name": "inform", "confidence": 1.0}, + entities=[ + {"entity": e, "value": e.upper()} for e in entity_names_shuffled2 + ], + ) + ], + ) + state = domain.get_active_state(tracker) + + # Whatever order the entities were listed in, they should get sorted alphabetically + # so the states' representations are consistent and entity-order-agnostic. + assert state["user"]["entities"] == tuple(entity_names_sorted) + + @pytest.mark.parametrize( "domain_as_dict", [