Skip to content

Commit

Permalink
Merge pull request #7200 from RasaHQ/fix_all_retrieval_intents
Browse files Browse the repository at this point in the history
Fix `all_retrieval_intents` key in response selector output
  • Loading branch information
rasabot authored Nov 9, 2020
2 parents 7275f33 + e9e3ab7 commit 5fd4f71
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog/7200.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug because of which only one retrieval intent was present in `all_retrieval_intent` key of the output of `ResponseSelector` even if there were multiple retrieval intents present in the training data.
6 changes: 5 additions & 1 deletion rasa/nlu/selectors/response_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,12 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
"""Prepares data for training.
Performs sanity checks on training data, extracts encodings for labels.
Args:
training_data: training data to preprocessed.
"""
# Collect all retrieval intents present in the data before filtering
self.all_retrieval_intents = list(training_data.retrieval_intents)

if self.retrieval_intent:
training_data = training_data.filter_training_examples(
Expand All @@ -321,7 +326,6 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
)

self.responses = training_data.responses
self.all_retrieval_intents = list(training_data.retrieval_intents)

if not label_id_index_mapping:
# no labels are present to train
Expand Down
30 changes: 30 additions & 0 deletions tests/nlu/selectors/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
CHECKPOINT_MODEL,
)
from rasa.nlu.selectors.response_selector import ResponseSelector
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.nlu.training_data.training_data import TrainingData


@pytest.mark.parametrize(
Expand Down Expand Up @@ -95,6 +97,34 @@ def test_train_selector(pipeline, component_builder, tmpdir):
assert rank.get("intent_response_key") is not None


def test_preprocess_selector_multiple_retrieval_intents():

# use some available data
training_data = rasa.shared.nlu.training_data.loading.load_data(
"data/examples/rasa/demo-rasa.md"
)
training_data_responses = rasa.shared.nlu.training_data.loading.load_data(
"data/examples/rasa/demo-rasa-responses.md"
)
training_data_extra_intent = TrainingData(
[
Message.build(
text="Is it possible to detect the version?", intent="faq/q1"
),
Message.build(text="How can I get a new virtual env", intent="faq/q2"),
]
)
training_data = training_data.merge(training_data_responses).merge(
training_data_extra_intent
)

response_selector = ResponseSelector()

response_selector.preprocess_train_data(training_data)

assert sorted(response_selector.all_retrieval_intents) == ["chitchat", "faq"]


@pytest.mark.parametrize(
"use_text_as_label, label_values",
[
Expand Down

0 comments on commit 5fd4f71

Please sign in to comment.