Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make rasa data validate check for duplicated intents, forms, responses and slots when using domains split between multiple files #10444

Merged
merged 17 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/10444.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Make `rasa data validate` check for duplicated intents, forms, responses
and slots when using domains split between multiple files.
38 changes: 38 additions & 0 deletions data/test_domains/test_domain_with_duplicates/domain1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
version: "3.0"

intents:
- greet
- goodbye
- affirm
- deny
- mood_great
- mood_unhappy
- bot_challenge

responses:
utter_greet:
- text: "Hey! How are you?"

utter_cheer_up:
- text: "Here is something to cheer you up:"
image: "https://i.imgur.com/nGF1K8f.jpg"

utter_did_that_help:
- text: "Did that help you?"

utter_goodbye:
- text: "Bye"

utter_iamabot:
- text: "I am a bot, powered by Rasa."

slots:
mood:
type: bool
mappings:
- type: from_entity
entity: some_slot

session_config:
session_expiration_time: 60
carry_over_slots_to_new_session: true
22 changes: 22 additions & 0 deletions data/test_domains/test_domain_with_duplicates/domain2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
version: "3.0"

intents:
- greet
- test

responses:
utter_greet:
- text: "Hey! How are you?"

utter_did_that_help:
- text: "Did that help you?"

utter_happy:
- text: "Great, carry on!"

slots:
mood:
type: bool
mappings:
- type: from_entity
entity: some_slot
alwx marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions rasa/cli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def validate_stories(args: argparse.Namespace) -> None:
def _validate_domain(validator: "Validator") -> bool:
return (
validator.verify_domain_validity()
and validator.verify_domain_duplicates()
and validator.verify_actions_in_stories_rules()
and validator.verify_forms_in_stories_rules()
and validator.verify_form_slots()
Expand Down
37 changes: 23 additions & 14 deletions rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,15 @@ def from_yaml(cls, yaml: Text, original_filename: Text = "") -> "Domain":
raise e

@classmethod
def from_dict(cls, data: Dict) -> "Domain":
def from_dict(
cls, data: Dict, duplicates: Optional[Dict[Text, List[Text]]] = None
) -> "Domain":
"""Deserializes and creates domain.

Args:
data: The serialized domain.
duplicates: The dictionary of duplicated intents, slots, forms, etc when the
alwx marked this conversation as resolved.
Show resolved Hide resolved
domain is built from multiple files.

Returns:
The instantiated `Domain` object.
Expand All @@ -216,6 +220,7 @@ def from_dict(cls, data: Dict) -> "Domain":
data.get(KEY_FORMS, {}),
data.get(KEY_E2E_ACTIONS, []),
session_config=session_config,
duplicates=duplicates,
**additional_arguments,
)

Expand Down Expand Up @@ -280,15 +285,10 @@ def merge_dicts(
def merge_lists(list1: List[Any], list2: List[Any]) -> List[Any]:
return sorted(list(set(list1 + list2)))

def merge_lists_of_dicts(
dict_list1: List[Dict],
dict_list2: List[Dict],
override_existing_values: bool = False,
) -> List[Dict]:
dict1 = {list(i.keys())[0]: i for i in dict_list1}
dict2 = {list(i.keys())[0]: i for i in dict_list2}
merged_dicts = merge_dicts(dict1, dict2, override_existing_values)
return list(merged_dicts.values())
def extract_duplicates(
dict1: Dict[Text, Any], dict2: Dict[Text, Any],
) -> List[Text]:
return [value for value in dict1.keys() if value in dict2.keys()]

if override:
config = domain_dict["config"]
Expand All @@ -298,9 +298,12 @@ def merge_lists_of_dicts(
if override or self.session_config == SessionConfig.default():
combined[SESSION_CONFIG_KEY] = domain_dict[SESSION_CONFIG_KEY]

combined[KEY_INTENTS] = merge_lists_of_dicts(
combined[KEY_INTENTS], domain_dict[KEY_INTENTS], override
)
duplicates: Dict[Text, List[Text]] = {}
alwx marked this conversation as resolved.
Show resolved Hide resolved

dict1 = {list(i.keys())[0]: i for i in combined[KEY_INTENTS]}
dict2 = {list(i.keys())[0]: i for i in domain_dict[KEY_INTENTS]}
duplicates[KEY_INTENTS] = extract_duplicates(dict1, dict2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we consider just raising an error/warning here so that we don't need to store the duplicates on the class?

Copy link
Contributor Author

@alwx alwx Dec 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to show this information only when running rasa data validate, as it's written in the description of this issue. We can update it to show the warning every time the domain is getting merged but in this case I don't see the point in doing any changes to rasa data validate and even link this issue to the validator which means the whole task will be different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @TyDunn can take a look at it and say which approach we prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the approach that @joejuzl describe happen when you run rasa data validate too?

Copy link
Contributor Author

@alwx alwx Dec 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TyDunn If we do it the way @joejuzl proposes then the warnings/exceptions will be raised every time the domain is getting merged (e.g. when just running rasa train or literally any other command that uses domain), not only when doing rasa data validate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context we do already have quite a lot of domain validation that happens just from loading it e.g. _check_domain_sanity in the __init__ looks for duplicates etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alwx I believe Joe is out for the rest of the year, and I am not going to be able to gain enough context atm. Can you make a judgement call here yourself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TyDunn I think it makes sense to keep this one as it as and maybe create a new issue because what Joe was talking about makes the scope much bigger

combined[KEY_INTENTS] = list(merge_dicts(dict1, dict2, override).values())

# remove existing forms from new actions
for form in combined[KEY_FORMS]:
Expand All @@ -311,12 +314,14 @@ def merge_lists_of_dicts(
combined[key] = merge_lists(combined[key], domain_dict[key])

for key in [KEY_FORMS, KEY_RESPONSES, KEY_SLOTS]:
duplicates[key] = extract_duplicates(combined[key], domain_dict[key])
combined[key] = merge_dicts(combined[key], domain_dict[key], override)

return self.__class__.from_dict(combined)
return self.__class__.from_dict(combined, duplicates)

@staticmethod
def collect_slots(slot_dict: Dict[Text, Any]) -> List[Slot]:
"""Collects a list of slots from a dictionary."""
slots = []
# make a copy to not alter the input dictionary
slot_dict = copy.deepcopy(slot_dict)
Expand Down Expand Up @@ -572,6 +577,7 @@ def __init__(
action_texts: Optional[List[Text]] = None,
store_entities_as_slots: bool = True,
session_config: SessionConfig = SessionConfig.default(),
duplicates: Optional[Dict[Text, List[Text]]] = None,
) -> None:
"""Creates a `Domain`.

Expand All @@ -588,6 +594,8 @@ def __init__(
events for entities if there are slots with the same name as the entity.
session_config: Configuration for conversation sessions. Conversations are
restarted at the end of a session.
duplicates: The dictionary of duplicated intents, slots, forms, etc when the
alwx marked this conversation as resolved.
Show resolved Hide resolved
domain is built from multiple files.
"""
self.entities, self.roles, self.groups = self.collect_entity_properties(
entities
Expand All @@ -608,6 +616,7 @@ def __init__(

self.action_texts = action_texts or []
self.session_config = session_config
self.duplicates = duplicates

self._custom_actions = action_names

Expand Down
34 changes: 33 additions & 1 deletion rasa/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
)
from rasa.shared.core import constants
from rasa.shared.core.constants import MAPPING_CONDITIONS, ACTIVE_LOOP
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import ActionExecuted, ActiveLoop
from rasa.shared.core.events import UserUttered
from rasa.shared.core.domain import (
KEY_INTENTS,
KEY_RESPONSES,
KEY_SLOTS,
KEY_FORMS,
Domain,
)
from rasa.shared.core.generator import TrainingDataGenerator
from rasa.shared.core.slot_mappings import SlotMapping
from rasa.shared.core.training_data.structures import StoryGraph
Expand Down Expand Up @@ -321,6 +327,32 @@ def verify_nlu(self, ignore_warnings: bool = True) -> bool:
stories_are_valid = self.verify_utterances_in_stories(ignore_warnings)
return intents_are_valid and stories_are_valid and there_is_no_duplication

def verify_domain_duplicates(self) -> bool:
"""Verifies that there are no duplicated dictionaries in multiple domain files.

Returns:
`True` if duplicates exist.
"""
logger.info("Checking duplicates across domain files...")

all_valid = True

if not self.domain.duplicates:
return True

for key in [KEY_INTENTS, KEY_FORMS, KEY_RESPONSES, KEY_SLOTS]:
duplicates = self.domain.duplicates.get(key)
if duplicates:
duplicates_str = ", ".join(duplicates)
rasa.shared.utils.io.raise_warning(
f"The following duplicated {key} has been found "
+ f"across multiple domain files: {duplicates_str}",
docs=DOCS_URL_DOMAINS,
)
all_valid = False

return all_valid

def verify_form_slots(self) -> bool:
"""Verifies that form slots match the slot mappings in domain."""
domain_slot_names = [slot.name for slot in self.domain.slots]
Expand Down
20 changes: 20 additions & 0 deletions tests/shared/core/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,26 @@ def test_domain_invalid_yml_in_folder():
Domain.from_directory("data/test_domains/test_domain_from_directory1/")


def test_domain_with_duplicates():
"""
Check if a domain with duplicated slots, responses and intents contains
a correct information in `duplicates` field.
"""
domain = Domain.from_directory("data/test_domains/test_domain_with_duplicates/")
assert domain.duplicates["slots"] == ["mood"]
assert domain.duplicates["responses"] == ["utter_greet", "utter_did_that_help"]
assert domain.duplicates["intents"] == ["greet"]


def test_domain_duplicates_when_one_domain_file():
alwx marked this conversation as resolved.
Show resolved Hide resolved
"""
Check if a domain with duplicated slots, responses and intents contains
a correct information in `duplicates` field.
"""
domain = Domain.from_file(path="data/test_domains/default.yml")
assert domain.duplicates is None


def test_domain_fingerprint_consistency_across_runs():
domain_yaml = """
version: "3.0"
Expand Down
76 changes: 75 additions & 1 deletion tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Text
from typing import Text, Any, Optional, List, Dict

import pytest
from _pytest.logging import LogCaptureFixture

from rasa.validator import Validator

from rasa.shared.importers.rasa import RasaFileImporter
from rasa.shared.importers.autoconfig import TrainingType
from rasa.shared.core.domain import Domain
from pathlib import Path


Expand Down Expand Up @@ -364,6 +366,78 @@ def test_verify_actions_in_rules_not_in_domain(tmp_path: Path, domain_path: Text
)


@pytest.mark.parametrize(
"duplicates,is_valid,warning_type,messages",
[
(None, True, None, []),
({}, True, None, []),
({"responses": []}, True, None, []),
(
{"responses": ["some_response"]},
False,
UserWarning,
[
"The following duplicated responses has been found across "
"multiple domain files: some_response"
],
),
(
{"slots": ["some_slot"]},
False,
UserWarning,
[
"The following duplicated slots has been found across "
"multiple domain files: some_slot"
],
),
(
{"forms": ["form1", "form2"]},
False,
UserWarning,
[
"The following duplicated forms has been found across "
"multiple domain files: form1, form2"
],
),
(
{"forms": ["form1", "form2"], "slots": []},
False,
UserWarning,
[
"The following duplicated forms has been found across "
"multiple domain files: form1, form2"
],
),
(
{"forms": ["form1", "form2"], "slots": ["slot1", "slot2", "slot3"]},
False,
UserWarning,
[
"The following duplicated forms has been found across "
"multiple domain files: form1, form2",
"The following duplicated slots has been found across "
"multiple domain files: slot1, slot2, slot3",
],
),
],
)
def test_verify_domain_with_duplicates(
duplicates: Optional[Dict[Text, List[Text]]],
is_valid: bool,
warning_type: Any,
messages: List[Text],
):
domain = Domain([], [], [], {}, [], {}, duplicates=duplicates)
validator = Validator(domain, None, None, None)

with pytest.warns(warning_type) as warning:
assert validator.verify_domain_duplicates() is is_valid

assert len(warning) == len(messages)
for i in range(len(messages)):
assert messages[i] in warning[i].message.args[0]


def test_verify_form_slots_invalid_domain(tmp_path: Path):
domain = tmp_path / "domain.yml"
domain.write_text(
Expand Down