Skip to content

Commit

Permalink
Merge pull request #8162 from RasaHQ/train-model-tests
Browse files Browse the repository at this point in the history
Improve tests which train model
  • Loading branch information
alwx authored Mar 25, 2021
2 parents 8e5ec4d + ae76458 commit d6ee251
Show file tree
Hide file tree
Showing 15 changed files with 284 additions and 305 deletions.
1 change: 1 addition & 0 deletions changelog/8117.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve tests which train models. This is supposed to speed up the CI as model trainings take quite some time.
1 change: 0 additions & 1 deletion data/test_config/no_max_hist_config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
policies:
- name: MemoizationPolicy
- name: RulePolicy
- name: TEDPolicy
6 changes: 0 additions & 6 deletions data/test_config/ted_random_seed.yaml

This file was deleted.

1 change: 0 additions & 1 deletion data/test_selectors/nlu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ nlu:
- can I pay with credit card
- book flight
- Deliver a pizza
- How long will it take a letter to get from [new milford]{"entity": "city", "role": "from"} pa to [huntingdon]{"entity": "city", "role": "to"} pa
- how to implement faq
- What would you suggest where to go for a romance places
- Can you fyling him here please i love him so much
Expand Down
130 changes: 0 additions & 130 deletions rasa/core/restore.py

This file was deleted.

19 changes: 16 additions & 3 deletions tests/cli/test_rasa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def test_test_core_comparison(
def test_test_core_comparison_after_train(
run_in_simple_project: Callable[..., RunResult]
):
temp_dir = os.getcwd()

write_yaml(
{"language": "en", "policies": [{"name": "MemoizationPolicy"}]}, "config_1.yml"
)
Expand All @@ -175,9 +177,20 @@ def test_test_core_comparison_after_train(
"comparison_models",
)

assert os.path.exists("comparison_models")
assert os.path.exists("comparison_models/run_1")
assert os.path.exists("comparison_models/run_2")
import rasa.shared.utils.io

assert os.path.exists(os.path.join(temp_dir, "comparison_models"))
assert os.path.exists(os.path.join(temp_dir, "comparison_models", "run_1"))
assert os.path.exists(os.path.join(temp_dir, "comparison_models", "run_2"))
run_directories = rasa.shared.utils.io.list_subdirectories(
os.path.join(temp_dir, "comparison_models")
)
assert len(run_directories) == 2
model_files = rasa.shared.utils.io.list_files(
os.path.join(temp_dir, "comparison_models", run_directories[0])
)
assert len(model_files) == 4
assert model_files[0].endswith("tar.gz")

run_in_simple_project(
"test",
Expand Down
44 changes: 0 additions & 44 deletions tests/cli/test_rasa_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,50 +100,6 @@ def test_train_persist_nlu_data(run_in_simple_project: Callable[..., RunResult])
)


def test_train_core_compare(run_in_simple_project: Callable[..., RunResult]):
temp_dir = os.getcwd()

rasa.shared.utils.io.write_yaml(
{"language": "en", "policies": [{"name": "MemoizationPolicy"}],},
"config_1.yml",
)

rasa.shared.utils.io.write_yaml(
{"language": "en", "policies": [{"name": "MemoizationPolicy"}],},
"config_2.yml",
)

run_in_simple_project(
"train",
"core",
"-c",
"config_1.yml",
"config_2.yml",
"--stories",
"data/stories.yml",
"--out",
"core_comparison_results",
"--runs",
"2",
"--percentages",
"25",
"75",
"--augmentation",
"5",
)

assert os.path.exists(os.path.join(temp_dir, "core_comparison_results"))
run_directories = rasa.shared.utils.io.list_subdirectories(
os.path.join(temp_dir, "core_comparison_results")
)
assert len(run_directories) == 2
model_files = rasa.shared.utils.io.list_files(
os.path.join(temp_dir, "core_comparison_results", run_directories[0])
)
assert len(model_files) == 4
assert model_files[0].endswith("tar.gz")


def test_train_no_domain_exists(
run_in_simple_project: Callable[..., RunResult]
) -> None:
Expand Down
23 changes: 21 additions & 2 deletions tests/core/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from rasa.core.channels.channel import UserMessage
from rasa.shared.core.domain import InvalidDomain, Domain
from rasa.shared.constants import INTENT_MESSAGE_PREFIX
from rasa.core.policies.ensemble import PolicyEnsemble
from rasa.core.policies.memoization import AugmentedMemoizationPolicy
from rasa.core.policies.ensemble import PolicyEnsemble, SimplePolicyEnsemble
from rasa.core.policies.memoization import AugmentedMemoizationPolicy, MemoizationPolicy
from rasa.utils.endpoints import EndpointConfig


Expand Down Expand Up @@ -72,6 +72,25 @@ async def test_training_data_is_reproducible():
assert str(x.as_dialogue()) == str(same_training_data[i].as_dialogue())


async def test_agent_train(trained_rasa_model: Text):
domain = Domain.load("data/test_domains/default_with_slots.yml")
loaded = Agent.load(trained_rasa_model)

# test domain
assert loaded.domain.action_names_or_texts == domain.action_names_or_texts
assert loaded.domain.intents == domain.intents
assert loaded.domain.entities == domain.entities
assert loaded.domain.templates == domain.templates
assert [s.name for s in loaded.domain.slots] == [s.name for s in domain.slots]

# test policies
assert isinstance(loaded.policy_ensemble, SimplePolicyEnsemble)
assert [type(p) for p in loaded.policy_ensemble.policies] == [
MemoizationPolicy,
RulePolicy,
]


@pytest.mark.parametrize(
"text_message_data, expected",
[
Expand Down
31 changes: 0 additions & 31 deletions tests/core/test_restore.py

This file was deleted.

Loading

0 comments on commit d6ee251

Please sign in to comment.