From b365fb321d1b487e035fd338bd17c933f6d422f1 Mon Sep 17 00:00:00 2001 From: Joseph Juzl Date: Tue, 1 Dec 2020 16:34:49 +0100 Subject: [PATCH] Add functionality to check if a model is fine-tunable --- rasa/model.py | 52 +++++++++++++++++++++++ tests/core/test_model.py | 90 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/rasa/model.py b/rasa/model.py index 5f09ad29cb2f..fd9061518dcf 100644 --- a/rasa/model.py +++ b/rasa/model.py @@ -10,6 +10,9 @@ from pathlib import Path from typing import Any, Text, Tuple, Union, Optional, List, Dict, NamedTuple +from packaging import version + +from rasa.constants import MINIMUM_COMPATIBLE_VERSION import rasa.shared.utils.io import rasa.utils.io from rasa.cli.utils import create_output_path @@ -42,6 +45,7 @@ FINGERPRINT_CONFIG_KEY = "config" FINGERPRINT_CONFIG_CORE_KEY = "core-config" FINGERPRINT_CONFIG_NLU_KEY = "nlu-config" +FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY = "config-without-epochs" FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY = "domain" FINGERPRINT_NLG_KEY = "nlg" FINGERPRINT_RASA_VERSION_KEY = "version" @@ -80,6 +84,14 @@ class Section(NamedTuple): ) SECTION_NLG = Section(name="NLG templates", relevant_keys=[FINGERPRINT_NLG_KEY]) +SECTION_FINE_TUNE = Section( + name="Fine-tune", + relevant_keys=[ + FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY, + FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY, + ], +) + class FingerprintComparisonResult: def __init__( @@ -327,6 +339,9 @@ async def model_fingerprint(file_importer: "TrainingDataImporter") -> Fingerprin FINGERPRINT_CONFIG_NLU_KEY: _get_fingerprint_of_config( config, include_keys=CONFIG_KEYS_NLU ), + FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY: _get_fingerprint_of_config_without_epochs( + config + ), FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY: domain.fingerprint(), FINGERPRINT_NLG_KEY: rasa.shared.utils.io.deep_container_fingerprint(responses), FINGERPRINT_PROJECT: project_fingerprint(), @@ -352,6 +367,23 @@ def _get_fingerprint_of_config( return rasa.shared.utils.io.deep_container_fingerprint(sub_config) +# TODO: either generalise with _get_fingerprint_of_config, or make nicer +def _get_fingerprint_of_config_without_epochs( + config: Optional[Dict[Text, Any]], +) -> Text: + if not config: + return "" + + copied_config = copy.deepcopy(config) + + for key in ["pipeline", "policies"]: + for p in copied_config[key]: + if "epochs" in p: + del p["epochs"] + + return rasa.shared.utils.io.deep_container_fingerprint(copied_config) + + def fingerprint_from_path(model_path: Text) -> Fingerprint: """Load a persisted fingerprint. @@ -468,6 +500,26 @@ def should_retrain( return fingerprint_comparison +def can_fine_tune(last_fingerprint: Fingerprint, new_fingerprint: Fingerprint) -> bool: + """Check which components of a model should be retrained. + + Args: + last_fingerprint: The fingerprint of the old model to potentially be fine-tuned. + new_fingerprint: The fingerprint of the new model. + + Returns: + `True` if the old model can be fine-tuned, `False` otherwise. + """ + fingerprint_changed = did_section_fingerprint_change( + last_fingerprint, new_fingerprint, SECTION_FINE_TUNE + ) + + old_model_above_min_version = version.parse( + last_fingerprint.get(FINGERPRINT_RASA_VERSION_KEY) + ) >= version.parse(MINIMUM_COMPATIBLE_VERSION) + return old_model_above_min_version and not fingerprint_changed + + def package_model( fingerprint: Fingerprint, output_directory: Text, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 20092001e8a8..b87333861216 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -9,6 +9,8 @@ import pytest +import rasa +import rasa.constants from rasa.shared.importers.importer import TrainingDataImporter from rasa.shared.importers.rasa import RasaFileImporter from rasa.shared.constants import ( @@ -22,6 +24,7 @@ from rasa import model from rasa.model import ( FINGERPRINT_CONFIG_KEY, + FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY, FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY, FINGERPRINT_NLG_KEY, FINGERPRINT_FILE_PATH, @@ -32,7 +35,9 @@ FINGERPRINT_CONFIG_CORE_KEY, FINGERPRINT_CONFIG_NLU_KEY, SECTION_CORE, + SECTION_FINE_TUNE, SECTION_NLU, + can_fine_tune, create_package_rasa, get_latest_model, get_model, @@ -102,6 +107,7 @@ def _fingerprint( config: Optional[Any] = None, config_nlu: Optional[Any] = None, config_core: Optional[Any] = None, + config_without_epochs: Optional[Any] = None, domain: Optional[Any] = None, nlg: Optional[Any] = None, stories: Optional[Any] = None, @@ -114,6 +120,9 @@ def _fingerprint( if config_core is not None else ["test"], FINGERPRINT_CONFIG_NLU_KEY: config_nlu if config_nlu is not None else ["test"], + FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY: config_without_epochs + if config_without_epochs + else ["test"], FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY: domain if domain is not None else ["test"], FINGERPRINT_NLG_KEY: nlg if nlg is not None else ["test"], FINGERPRINT_TRAINED_AT_KEY: time.time(), @@ -148,6 +157,7 @@ def test_persist_and_load_fingerprint(): (_fingerprint(nlg=["other"]), False), (_fingerprint(nlu=["test", "other"]), False), (_fingerprint(config_nlu=["other"]), False), + (_fingerprint(config_without_epochs=["other"]), False), ], ) def test_core_fingerprint_changed(fingerprint2, changed): @@ -168,6 +178,7 @@ def test_core_fingerprint_changed(fingerprint2, changed): (_fingerprint(nlg=["other"]), False), (_fingerprint(config_core=["other"]), False), (_fingerprint(stories=["other"]), False), + (_fingerprint(config_without_epochs=["other"]), False), ], ) def test_nlu_fingerprint_changed(fingerprint2, changed): @@ -178,6 +189,28 @@ def test_nlu_fingerprint_changed(fingerprint2, changed): ) +@pytest.mark.parametrize( + "fingerprint2, changed", + [ + (_fingerprint(config=["other"]), True), + (_fingerprint(config_without_epochs=["other"]), True), + (_fingerprint(domain=["other"]), True), + (_fingerprint(rasa_version="100"), False), + (_fingerprint(config_core=["other"]), False), + (_fingerprint(config_nlu=["other"]), False), + (_fingerprint(nlu=["other"]), False), + (_fingerprint(nlg=["other"]), False), + (_fingerprint(stories=["other"]), False), + ], +) +def test_fine_tune_fingerprint_changed(fingerprint2, changed): + fingerprint1 = _fingerprint() + assert ( + did_section_fingerprint_change(fingerprint1, fingerprint2, SECTION_FINE_TUNE) + is changed + ) + + def _project_files( project: Text, config_file: Text = DEFAULT_CONFIG_PATH, @@ -233,6 +266,44 @@ async def test_fingerprinting_changed_response_text(project: Text): assert old_fingerprint[FINGERPRINT_NLG_KEY] != new_fingerprint[FINGERPRINT_NLG_KEY] +async def test_fingerprinting_changing_config_epochs(project: Text): + importer = _project_files(project) + + old_fingerprint = await model_fingerprint(importer) + config = await importer.get_config() + + for key in ["pipeline", "policies"]: + for p in config[key]: + if "epochs" in p: + p["epochs"] += 10 + + importer.get_config = asyncio.coroutine(lambda: config) + new_fingerprint = await model_fingerprint(importer) + + assert ( + old_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY] + == new_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY] + ) + assert ( + old_fingerprint[FINGERPRINT_CONFIG_CORE_KEY] + != new_fingerprint[FINGERPRINT_CONFIG_CORE_KEY] + ) + assert ( + old_fingerprint[FINGERPRINT_CONFIG_NLU_KEY] + != new_fingerprint[FINGERPRINT_CONFIG_NLU_KEY] + ) + + config["pipeline"].pop() + + importer.get_config = asyncio.coroutine(lambda: config) + new_fingerprint = await model_fingerprint(importer) + + assert ( + old_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY] + != new_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY] + ) + + async def test_fingerprinting_additional_action(project: Text): importer = _project_files(project) @@ -418,3 +489,22 @@ async def get_domain() -> Domain: actual = Domain.load(tmpdir / DEFAULT_CORE_SUBDIRECTORY_NAME / DEFAULT_DOMAIN_PATH) assert actual.is_empty() + + +@pytest.mark.parametrize( + "min_compatible_version, old_model_version, can_tune", + [("2.1.0", "2.1.0", True), ("2.0.0", "2.1.0", True), ("2.1.0", "2.0.0", False),], +) +async def test_can_fine_tune_min_version( + project: Text, monkeypatch, old_model_version, min_compatible_version, can_tune +): + importer = _project_files(project) + + monkeypatch.setattr( + rasa.constants, "MINIMUM_COMPATIBLE_VERSION", min_compatible_version + ) + monkeypatch.setattr(rasa, "__version__", old_model_version) + old_fingerprint = await model_fingerprint(importer) + new_fingerprint = await model_fingerprint(importer) + + assert can_fine_tune(old_fingerprint, new_fingerprint) == can_tune