Skip to content

Commit

Permalink
Merge pull request #8388 from RasaHQ/rasa_test_missing_intent_warning
Browse files Browse the repository at this point in the history
Fix for missing intent warnings when running rasa test
  • Loading branch information
b-quachtran authored May 18, 2021
2 parents 6b10890 + 51736f7 commit f1d119d
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog/8388.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed bug where missing intent warnings appear when running `rasa test`
3 changes: 2 additions & 1 deletion rasa/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DEFAULT_MODELS_PATH,
DEFAULT_DATA_PATH,
DEFAULT_RESULTS_PATH,
DEFAULT_DOMAIN_PATH,
)
import rasa.shared.utils.validation as validation_utils
import rasa.cli.utils
Expand Down Expand Up @@ -158,7 +159,7 @@ async def run_nlu_test_async(

data_path = rasa.cli.utils.get_validated_path(data_path, "nlu", DEFAULT_DATA_PATH)
test_data_importer = TrainingDataImporter.load_from_dict(
training_data_paths=[data_path]
training_data_paths=[data_path], domain_path=DEFAULT_DOMAIN_PATH,
)
nlu_data = await test_data_importer.get_nlu_data()

Expand Down
13 changes: 12 additions & 1 deletion rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,20 @@ async def _create_data_generator(
use_conversation_test_files: bool = False,
) -> "TrainingDataGenerator":
from rasa.shared.core.generator import TrainingDataGenerator
from rasa.shared.constants import DEFAULT_DOMAIN_PATH
from rasa.model import get_model_subdirectories

core_model = None
if agent.model_directory:
core_model, _ = get_model_subdirectories(agent.model_directory)

if core_model and os.path.exists(os.path.join(core_model, DEFAULT_DOMAIN_PATH)):
domain_path = os.path.join(core_model, DEFAULT_DOMAIN_PATH)
else:
domain_path = None

test_data_importer = TrainingDataImporter.load_from_dict(
training_data_paths=[resource_name]
training_data_paths=[resource_name], domain_path=domain_path
)
if use_conversation_test_files:
story_graph = await test_data_importer.get_conversation_tests()
Expand Down
3 changes: 2 additions & 1 deletion rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,13 +1439,14 @@ async def run_evaluation(
Returns: dictionary containing evaluation results
"""
import rasa.shared.nlu.training_data.loading
from rasa.shared.constants import DEFAULT_DOMAIN_PATH

# get the metadata config from the package data
interpreter = Interpreter.load(model_path, component_builder)

interpreter.pipeline = remove_pretrained_extractors(interpreter.pipeline)
test_data_importer = TrainingDataImporter.load_from_dict(
training_data_paths=[data_path]
training_data_paths=[data_path], domain_path=DEFAULT_DOMAIN_PATH,
)
test_data = await test_data_importer.get_nlu_data()

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ async def _trained_default_agent(
agent = Agent(
"data/test_domains/default_with_slots.yml",
policies=[AugmentedMemoizationPolicy(max_history=3), RulePolicy()],
model_directory=model_path,
)

training_data = await agent.load_data(stories_path)
Expand Down
16 changes: 16 additions & 0 deletions tests/core/test_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pathlib import Path
from typing import Text

import pytest

import rasa.core.test
from _pytest.capture import CaptureFixture
Expand All @@ -17,3 +20,16 @@ async def test_testing_warns_if_action_unknown(
assert "Test story" in output
assert "contains the bot utterance" in output
assert "which is not part of the training data / domain" in output


async def test_testing_does_not_warn_if_intent_in_domain(
default_agent: Agent, stories_path: Text,
):
with pytest.warns(UserWarning) as record:
await rasa.core.test.test(Path(stories_path), default_agent)

assert not any("Found intent" in r.message.args[0] for r in record)
assert all(
"in stories which is not part of the domain" not in r.message.args[0]
for r in record
)

0 comments on commit f1d119d

Please sign in to comment.