Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make entity ordering in state representations consistent #8646

Merged
merged 13 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog/8623.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
When there are multiple entities in a user message, they will get sorted when creating a
representation of the current dialogue state.

Previously, the ordering was random, leading to inconsistent state representations. This
would sometimes lead to memoization policies failing to recall a memorised action.
73 changes: 44 additions & 29 deletions rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,9 @@ def _collect_overridden_default_intents(
list(intent.keys())[0] if isinstance(intent, dict) else intent
for intent in intents
}
return sorted(intent_names & set(rasa.shared.core.constants.DEFAULT_INTENTS))
return sorted(
intent_names.intersection(set(rasa.shared.core.constants.DEFAULT_INTENTS))
)

@staticmethod
def _initialize_forms(
Expand Down Expand Up @@ -1047,6 +1049,11 @@ def input_states(self) -> List[Text]:
)

def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]:
"""Gets the names of all entities that are present and wanted in the message.

Wherever an entity has a role or group specified as well, an additional role-
or group-specific entity name is added.
"""
intent_name = latest_message.intent.get(
rasa.shared.nlu.constants.INTENT_NAME_KEY
)
Expand All @@ -1057,33 +1064,36 @@ def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]:
# groups get featurized. We concatenate the entity label with the role/group
# label using a special separator to make sure that the resulting label is
# unique (as you can have the same role/group label for different entities).
entity_names = (
set(entity["entity"] for entity in entities if "entity" in entity.keys())
| set(
f"{entity['entity']}"
f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['role']}"
for entity in entities
if "entity" in entity.keys() and "role" in entity.keys()
)
| set(
f"{entity['entity']}"
f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['group']}"
for entity in entities
if "entity" in entity.keys() and "group" in entity.keys()
)
entity_names_basic = set(
entity["entity"] for entity in entities if "entity" in entity.keys()
)
entity_names_roles = set(
f"{entity['entity']}"
f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['role']}"
for entity in entities
if "entity" in entity.keys() and "role" in entity.keys()
)
entity_names_groups = set(
f"{entity['entity']}"
f"{rasa.shared.core.constants.ENTITY_LABEL_SEPARATOR}{entity['group']}"
for entity in entities
if "entity" in entity.keys() and "group" in entity.keys()
)
entity_names = entity_names_basic.union(entity_names_roles, entity_names_groups)

# the USED_ENTITIES_KEY of an intent also contains the entity labels and the
# concatenated entity labels with their corresponding roles and groups labels
wanted_entities = set(intent_config.get(USED_ENTITIES_KEY, entity_names))

return entity_names & wanted_entities
return entity_names.intersection(wanted_entities)

def _get_user_sub_state(
self, tracker: "DialogueStateTracker"
) -> Dict[Text, Union[Text, Tuple[Text]]]:
"""Turn latest UserUttered event into a substate containing intent,
text and set entities if present
"""Turns latest UserUttered event into a substate.

The substate will contain intent, text, and entities (if any are present).

Args:
tracker: dialog state tracker containing the dialog so far
Returns:
Expand All @@ -1097,15 +1107,19 @@ def _get_user_sub_state(

sub_state = latest_message.as_sub_state()

# filter entities based on intent config
# sub_state will be transformed to frozenset therefore we need to
# convert the set to the tuple
# sub_state is transformed to frozenset because we will later hash it
# for deduplication
# Filter entities based on intent config. We need to convert the set into a
# tuple because sub_state will be later transformed into a frozenset (so it can
# be hashed for deduplication).
entities = tuple(
self._get_featurized_entities(latest_message)
& set(sub_state.get(rasa.shared.nlu.constants.ENTITIES, ()))
self._get_featurized_entities(latest_message).intersection(
set(sub_state.get(rasa.shared.nlu.constants.ENTITIES, ()))
)
)
# Sort entities so that any derived state representation is consistent across
# runs and invariant to the order in which the entities for an utterance are
# listed in data files.
entities = tuple(sorted(entities))

if entities:
sub_state[rasa.shared.nlu.constants.ENTITIES] = entities
else:
Expand Down Expand Up @@ -1180,16 +1194,17 @@ def _clean_state(state: State) -> State:
if sub_state
}

def get_active_states(
def get_active_state(
self, tracker: "DialogueStateTracker", omit_unset_slots: bool = False,
) -> State:
"""Returns a bag of active states from the tracker state.
"""Given a dialogue tracker, makes a representation of current dialogue state.

Args:
tracker: dialog state tracker containing the dialog so far
omit_unset_slots: If `True` do not include the initial values of slots.

Returns `State` containing all active states.
Returns:
A representation of the dialogue's current state.
"""
state = {
rasa.shared.core.constants.USER: self._get_user_sub_state(tracker),
Expand Down Expand Up @@ -1286,7 +1301,7 @@ def states_for_tracker_history(
if turn_was_hidden:
continue

state = self.get_active_states(tr, omit_unset_slots=omit_unset_slots)
state = self.get_active_state(tr, omit_unset_slots=omit_unset_slots)

if ignore_rule_only_turns:
# clean state from only rule features
Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _append_current_state(self) -> None:
if self._states_for_hashing is None:
self._states_for_hashing = self.past_states_for_hashing(self.domain)
else:
state = self.domain.get_active_states(self)
state = self.domain.get_active_state(self)
frozen_state = self.freeze_current_state(state)
self._states_for_hashing.append(frozen_state)

Expand Down
45 changes: 45 additions & 0 deletions tests/shared/core/test_domain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json
from pathlib import Path
import random
from typing import Dict, List, Text, Any, Union, Set, Optional

import pytest
Expand Down Expand Up @@ -35,6 +36,8 @@
Domain,
KEY_FORMS,
KEY_E2E_ACTIONS,
KEY_INTENTS,
KEY_ENTITIES,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.events import ActionExecuted, SlotSet, UserUttered
Expand Down Expand Up @@ -1110,6 +1113,48 @@ def test_get_featurized_entities():
assert featurized_entities == {"GPE", f"GPE{ENTITY_LABEL_SEPARATOR}destination"}


def test_featurized_entities_ordered_consistently():
"""Check that entities get ordered -- needed for consistent state representations.

Previously, no ordering was applied to entities, but they were ordered implicitly
due to how python sets work -- a set of all entity names was internally created,
which was ordered by the hashes of the entity names. Now, entities are sorted alpha-
betically. Since even sorting based on randomised hashing can produce alphabetical
ordering once in a while, we here check with a large number of entities, pushing to
~0 the probability of correctly sorting the elements just by accident, without
actually doing proper sorting.
"""
# Create a sorted list of entity names from 'a' to 'z', and two randomly shuffled
# copies.
entity_names_sorted = [chr(i) for i in range(ord("a"), ord("z") + 1)]
entity_names_shuffled1 = entity_names_sorted.copy()
random.shuffle(entity_names_shuffled1)
entity_names_shuffled2 = entity_names_sorted.copy()
random.shuffle(entity_names_shuffled2)

domain = Domain.from_dict(
{KEY_INTENTS: ["inform"], KEY_ENTITIES: entity_names_shuffled1}
)

tracker = DialogueStateTracker.from_events(
"story123",
[
UserUttered(
text="hey there",
intent={"name": "inform", "confidence": 1.0},
entities=[
{"entity": e, "value": e.upper()} for e in entity_names_shuffled2
],
)
],
)
state = domain.get_active_state(tracker)

# Whatever order the entities were listed in, they should get sorted alphabetically
# so the states' representations are consistent and entity-order-agnostic.
assert state["user"]["entities"] == tuple(entity_names_sorted)


@pytest.mark.parametrize(
"domain_as_dict",
[
Expand Down