Skip to content

Commit

Permalink
Add functionality to check if a model is fine-tunable
Browse files Browse the repository at this point in the history
  • Loading branch information
joejuzl committed Dec 1, 2020
1 parent 5d5f41a commit b365fb3
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 0 deletions.
52 changes: 52 additions & 0 deletions rasa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from pathlib import Path
from typing import Any, Text, Tuple, Union, Optional, List, Dict, NamedTuple

from packaging import version

from rasa.constants import MINIMUM_COMPATIBLE_VERSION
import rasa.shared.utils.io
import rasa.utils.io
from rasa.cli.utils import create_output_path
Expand Down Expand Up @@ -42,6 +45,7 @@
FINGERPRINT_CONFIG_KEY = "config"
FINGERPRINT_CONFIG_CORE_KEY = "core-config"
FINGERPRINT_CONFIG_NLU_KEY = "nlu-config"
FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY = "config-without-epochs"
FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY = "domain"
FINGERPRINT_NLG_KEY = "nlg"
FINGERPRINT_RASA_VERSION_KEY = "version"
Expand Down Expand Up @@ -80,6 +84,14 @@ class Section(NamedTuple):
)
SECTION_NLG = Section(name="NLG templates", relevant_keys=[FINGERPRINT_NLG_KEY])

SECTION_FINE_TUNE = Section(
name="Fine-tune",
relevant_keys=[
FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY,
FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY,
],
)


class FingerprintComparisonResult:
def __init__(
Expand Down Expand Up @@ -327,6 +339,9 @@ async def model_fingerprint(file_importer: "TrainingDataImporter") -> Fingerprin
FINGERPRINT_CONFIG_NLU_KEY: _get_fingerprint_of_config(
config, include_keys=CONFIG_KEYS_NLU
),
FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY: _get_fingerprint_of_config_without_epochs(
config
),
FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY: domain.fingerprint(),
FINGERPRINT_NLG_KEY: rasa.shared.utils.io.deep_container_fingerprint(responses),
FINGERPRINT_PROJECT: project_fingerprint(),
Expand All @@ -352,6 +367,23 @@ def _get_fingerprint_of_config(
return rasa.shared.utils.io.deep_container_fingerprint(sub_config)


# TODO: either generalise with _get_fingerprint_of_config, or make nicer
def _get_fingerprint_of_config_without_epochs(
config: Optional[Dict[Text, Any]],
) -> Text:
if not config:
return ""

copied_config = copy.deepcopy(config)

for key in ["pipeline", "policies"]:
for p in copied_config[key]:
if "epochs" in p:
del p["epochs"]

return rasa.shared.utils.io.deep_container_fingerprint(copied_config)


def fingerprint_from_path(model_path: Text) -> Fingerprint:
"""Load a persisted fingerprint.
Expand Down Expand Up @@ -468,6 +500,26 @@ def should_retrain(
return fingerprint_comparison


def can_fine_tune(last_fingerprint: Fingerprint, new_fingerprint: Fingerprint) -> bool:
"""Check which components of a model should be retrained.
Args:
last_fingerprint: The fingerprint of the old model to potentially be fine-tuned.
new_fingerprint: The fingerprint of the new model.
Returns:
`True` if the old model can be fine-tuned, `False` otherwise.
"""
fingerprint_changed = did_section_fingerprint_change(
last_fingerprint, new_fingerprint, SECTION_FINE_TUNE
)

old_model_above_min_version = version.parse(
last_fingerprint.get(FINGERPRINT_RASA_VERSION_KEY)
) >= version.parse(MINIMUM_COMPATIBLE_VERSION)
return old_model_above_min_version and not fingerprint_changed


def package_model(
fingerprint: Fingerprint,
output_directory: Text,
Expand Down
90 changes: 90 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pytest

import rasa
import rasa.constants
from rasa.shared.importers.importer import TrainingDataImporter
from rasa.shared.importers.rasa import RasaFileImporter
from rasa.shared.constants import (
Expand All @@ -22,6 +24,7 @@
from rasa import model
from rasa.model import (
FINGERPRINT_CONFIG_KEY,
FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY,
FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY,
FINGERPRINT_NLG_KEY,
FINGERPRINT_FILE_PATH,
Expand All @@ -32,7 +35,9 @@
FINGERPRINT_CONFIG_CORE_KEY,
FINGERPRINT_CONFIG_NLU_KEY,
SECTION_CORE,
SECTION_FINE_TUNE,
SECTION_NLU,
can_fine_tune,
create_package_rasa,
get_latest_model,
get_model,
Expand Down Expand Up @@ -102,6 +107,7 @@ def _fingerprint(
config: Optional[Any] = None,
config_nlu: Optional[Any] = None,
config_core: Optional[Any] = None,
config_without_epochs: Optional[Any] = None,
domain: Optional[Any] = None,
nlg: Optional[Any] = None,
stories: Optional[Any] = None,
Expand All @@ -114,6 +120,9 @@ def _fingerprint(
if config_core is not None
else ["test"],
FINGERPRINT_CONFIG_NLU_KEY: config_nlu if config_nlu is not None else ["test"],
FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY: config_without_epochs
if config_without_epochs
else ["test"],
FINGERPRINT_DOMAIN_WITHOUT_NLG_KEY: domain if domain is not None else ["test"],
FINGERPRINT_NLG_KEY: nlg if nlg is not None else ["test"],
FINGERPRINT_TRAINED_AT_KEY: time.time(),
Expand Down Expand Up @@ -148,6 +157,7 @@ def test_persist_and_load_fingerprint():
(_fingerprint(nlg=["other"]), False),
(_fingerprint(nlu=["test", "other"]), False),
(_fingerprint(config_nlu=["other"]), False),
(_fingerprint(config_without_epochs=["other"]), False),
],
)
def test_core_fingerprint_changed(fingerprint2, changed):
Expand All @@ -168,6 +178,7 @@ def test_core_fingerprint_changed(fingerprint2, changed):
(_fingerprint(nlg=["other"]), False),
(_fingerprint(config_core=["other"]), False),
(_fingerprint(stories=["other"]), False),
(_fingerprint(config_without_epochs=["other"]), False),
],
)
def test_nlu_fingerprint_changed(fingerprint2, changed):
Expand All @@ -178,6 +189,28 @@ def test_nlu_fingerprint_changed(fingerprint2, changed):
)


@pytest.mark.parametrize(
"fingerprint2, changed",
[
(_fingerprint(config=["other"]), True),
(_fingerprint(config_without_epochs=["other"]), True),
(_fingerprint(domain=["other"]), True),
(_fingerprint(rasa_version="100"), False),
(_fingerprint(config_core=["other"]), False),
(_fingerprint(config_nlu=["other"]), False),
(_fingerprint(nlu=["other"]), False),
(_fingerprint(nlg=["other"]), False),
(_fingerprint(stories=["other"]), False),
],
)
def test_fine_tune_fingerprint_changed(fingerprint2, changed):
fingerprint1 = _fingerprint()
assert (
did_section_fingerprint_change(fingerprint1, fingerprint2, SECTION_FINE_TUNE)
is changed
)


def _project_files(
project: Text,
config_file: Text = DEFAULT_CONFIG_PATH,
Expand Down Expand Up @@ -233,6 +266,44 @@ async def test_fingerprinting_changed_response_text(project: Text):
assert old_fingerprint[FINGERPRINT_NLG_KEY] != new_fingerprint[FINGERPRINT_NLG_KEY]


async def test_fingerprinting_changing_config_epochs(project: Text):
importer = _project_files(project)

old_fingerprint = await model_fingerprint(importer)
config = await importer.get_config()

for key in ["pipeline", "policies"]:
for p in config[key]:
if "epochs" in p:
p["epochs"] += 10

importer.get_config = asyncio.coroutine(lambda: config)
new_fingerprint = await model_fingerprint(importer)

assert (
old_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY]
== new_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY]
)
assert (
old_fingerprint[FINGERPRINT_CONFIG_CORE_KEY]
!= new_fingerprint[FINGERPRINT_CONFIG_CORE_KEY]
)
assert (
old_fingerprint[FINGERPRINT_CONFIG_NLU_KEY]
!= new_fingerprint[FINGERPRINT_CONFIG_NLU_KEY]
)

config["pipeline"].pop()

importer.get_config = asyncio.coroutine(lambda: config)
new_fingerprint = await model_fingerprint(importer)

assert (
old_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY]
!= new_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY]
)


async def test_fingerprinting_additional_action(project: Text):
importer = _project_files(project)

Expand Down Expand Up @@ -418,3 +489,22 @@ async def get_domain() -> Domain:
actual = Domain.load(tmpdir / DEFAULT_CORE_SUBDIRECTORY_NAME / DEFAULT_DOMAIN_PATH)

assert actual.is_empty()


@pytest.mark.parametrize(
"min_compatible_version, old_model_version, can_tune",
[("2.1.0", "2.1.0", True), ("2.0.0", "2.1.0", True), ("2.1.0", "2.0.0", False),],
)
async def test_can_fine_tune_min_version(
project: Text, monkeypatch, old_model_version, min_compatible_version, can_tune
):
importer = _project_files(project)

monkeypatch.setattr(
rasa.constants, "MINIMUM_COMPATIBLE_VERSION", min_compatible_version
)
monkeypatch.setattr(rasa, "__version__", old_model_version)
old_fingerprint = await model_fingerprint(importer)
new_fingerprint = await model_fingerprint(importer)

assert can_fine_tune(old_fingerprint, new_fingerprint) == can_tune

0 comments on commit b365fb3

Please sign in to comment.