diff --git a/test_rasa_export.py b/test_rasa_export.py index 2a5e0380a919..c7a2fafef278 100644 --- a/test_rasa_export.py +++ b/test_rasa_export.py @@ -6,6 +6,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from _pytest.pytester import RunResult +from ruamel.yaml.scalarstring import SingleQuotedScalarString import rasa.core.utils as rasa_core_utils from rasa.cli import export @@ -65,7 +66,14 @@ def test_validate_timestamp_options_with_invalid_timestamps(): def test_get_event_broker_and_tracker_store_from_endpoint_config(tmp_path: Path): # write valid config to file endpoints_path = write_endpoint_config_to_yaml( - tmp_path, {"event_broker": {"type": "sql"}, "tracker_store": {"type": "sql"}} + tmp_path, + { + "event_broker": { + "type": "sql", + "db": str(tmp_path / "rasa.db").replace("\\", "\\\\"), + }, + "tracker_store": {"type": "sql"}, + }, ) available_endpoints = rasa_core_utils.read_endpoints_from_path(endpoints_path) diff --git a/test_rasa_test.py b/test_rasa_test.py index 394475806f7e..5884d6a94246 100644 --- a/test_rasa_test.py +++ b/test_rasa_test.py @@ -21,6 +21,14 @@ def test_test_core_no_plot(run_in_simple_project: Callable[..., RunResult]): def test_test(run_in_simple_project_with_model: Callable[..., RunResult]): + write_yaml( + { + "pipeline": "KeywordIntentClassifier", + "policies": [{"name": "MemoizationPolicy"}], + }, + "config2.yml", + ) + run_in_simple_project_with_model("test") assert os.path.exists("results") @@ -61,14 +69,15 @@ def test_test_nlu_cross_validation(run_in_simple_project: Callable[..., RunResul def test_test_nlu_comparison(run_in_simple_project: Callable[..., RunResult]): - copyfile("config.yml", "config-1.yml") + write_yaml({"pipeline": "KeywordIntentClassifier"}, "config.yml") + write_yaml({"pipeline": "KeywordIntentClassifier"}, "config2.yml") run_in_simple_project( "test", "nlu", "--config", "config.yml", - "config-1.yml", + "config2.yml", "--run", "2", "--percentages", @@ -123,8 +132,6 @@ def test_test_core_comparison_after_train( "--percentages", "25", "75", - "--augmentation", - "5", "--out", "comparison_models", ) diff --git a/test_utils.py b/test_utils.py index 6257da93f10f..239c27e6b1fa 100644 --- a/test_utils.py +++ b/test_utils.py @@ -72,30 +72,29 @@ def test_validate_invalid_path(): get_validated_path("test test test", "out", "default") -def test_validate_valid_path(): - tempdir = tempfile.mkdtemp() - - assert get_validated_path(tempdir, "out", "default") == tempdir +def test_validate_valid_path(tmp_path: pathlib.Path): + assert get_validated_path(str(tmp_path), "out", "default") == str(tmp_path) def test_validate_if_none_is_valid(): assert get_validated_path(None, "out", "default", True) is None -def test_validate_with_none_if_default_is_valid(caplog: LogCaptureFixture): - tempdir = tempfile.mkdtemp() - +def test_validate_with_none_if_default_is_valid( + caplog: LogCaptureFixture, tmp_path: pathlib.Path +): with caplog.at_level(logging.WARNING, rasa.cli.utils.logger.name): - assert get_validated_path(None, "out", tempdir) == tempdir + assert get_validated_path(None, "out", str(tmp_path)) == str(tmp_path) assert caplog.records == [] -def test_validate_with_invalid_directory_if_default_is_valid(): - tempdir = tempfile.mkdtemp() +def test_validate_with_invalid_directory_if_default_is_valid(tmp_path: pathlib.Path): invalid_directory = "gcfhvjkb" with pytest.warns(UserWarning) as record: - assert get_validated_path(invalid_directory, "out", tempdir) == tempdir + assert get_validated_path(invalid_directory, "out", str(tmp_path)) == str( + tmp_path + ) assert len(record) == 1 assert "does not seem to exist" in record[0].message.args[0]