diff --git a/rasa/core/agent.py b/rasa/core/agent.py index 8c8e4164d8a0..28e5873e7d9f 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -411,6 +411,8 @@ def load( model_server: Optional[EndpointConfig] = None, remote_storage: Optional[Text] = None, path_to_model_archive: Optional[Text] = None, + new_config: Optional[Dict] = None, + finetuning_epoch_fraction: float = 1.0, ) -> "Agent": """Load a persisted model from the passed path.""" try: @@ -441,7 +443,15 @@ def load( if core_model: domain = Domain.load(os.path.join(core_model, DEFAULT_DOMAIN_PATH)) - ensemble = PolicyEnsemble.load(core_model) if core_model else None + ensemble = ( + PolicyEnsemble.load( + core_model, + new_config=new_config, + finetuning_epoch_fraction=finetuning_epoch_fraction, + ) + if core_model + else None + ) # ensures the domain hasn't changed between test and train domain.compare_with_specification(core_model) diff --git a/rasa/core/policies/ensemble.py b/rasa/core/policies/ensemble.py index e7b72d8778c1..dfbce47c3b9e 100644 --- a/rasa/core/policies/ensemble.py +++ b/rasa/core/policies/ensemble.py @@ -1,12 +1,13 @@ import importlib import json import logging +import math import os import sys from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import Text, Optional, Any, List, Dict, Tuple, NamedTuple, Union +from typing import Text, Optional, Any, List, Dict, Tuple, Type, Union import rasa.core import rasa.core.training.training @@ -41,6 +42,7 @@ from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.generator import TrackerWithCachedStates from rasa.core import registry +from rasa.utils.tensorflow.constants import EPOCHS logger = logging.getLogger(__name__) @@ -302,10 +304,29 @@ def _ensure_loaded_policy(cls, policy, policy_cls, policy_name: Text): "".format(policy_name) ) - @classmethod - def load(cls, path: Union[Text, Path]) -> "PolicyEnsemble": - """Loads policy and domain specification from storage""" + @staticmethod + def _get_updated_epochs( + policy_cls: Type[Policy], + config_for_policy: Dict[Text, Any], + finetuning_epoch_fraction: float, + ) -> Optional[int]: + if EPOCHS in config_for_policy: + epochs = config_for_policy[EPOCHS] + else: + try: + epochs = policy_cls.defaults[EPOCHS] + except (KeyError, AttributeError): + return None + return math.ceil(epochs * finetuning_epoch_fraction) + @classmethod + def load( + cls, + path: Union[Text, Path], + new_config: Optional[Dict] = None, + finetuning_epoch_fraction: float = 1.0, + ) -> "PolicyEnsemble": + """Loads policy and domain specification from disk.""" metadata = cls.load_metadata(path) cls.ensure_model_compatibility(metadata) policies = [] @@ -313,9 +334,38 @@ def load(cls, path: Union[Text, Path]) -> "PolicyEnsemble": policy_cls = registry.policy_from_module_path(policy_name) dir_name = f"policy_{i}_{policy_cls.__name__}" policy_path = os.path.join(path, dir_name) - policy = policy_cls.load(policy_path) + + context = {} + if new_config: + context["should_finetune"] = True + + config_for_policy = new_config["policies"][i] + epochs = cls._get_updated_epochs( + policy_cls, config_for_policy, finetuning_epoch_fraction + ) + if epochs: + context["epoch_override"] = epochs + + if "kwargs" not in rasa.shared.utils.common.arguments_of(policy_cls.load): + if context: + raise UnsupportedDialogueModelError( + f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not " + f"accept `**kwargs`. Attempting to pass {context} to the " + f"policy. `**kwargs` should be added to all policies by " + f"Rasa Open Source 3.0.0." + ) + else: + rasa.shared.utils.io.raise_deprecation_warning( + f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not " + f"accept `**kwargs`. `**kwargs` are required for contextual " + f"information e.g. the flag `should_finetune`.", + warn_until_version="3.0.0", + ) + + policy = policy_cls.load(policy_path, **context) cls._ensure_loaded_policy(policy, policy_cls, policy_name) policies.append(policy) + ensemble_cls = rasa.shared.utils.common.class_from_module_path( metadata["ensemble_name"] ) diff --git a/rasa/core/policies/fallback.py b/rasa/core/policies/fallback.py index aeef4f834fe4..9ddb96b4ff84 100644 --- a/rasa/core/policies/fallback.py +++ b/rasa/core/policies/fallback.py @@ -39,10 +39,12 @@ def __init__( ambiguity_threshold: float = DEFAULT_NLU_FALLBACK_AMBIGUITY_THRESHOLD, core_threshold: float = DEFAULT_CORE_FALLBACK_THRESHOLD, fallback_action_name: Text = ACTION_DEFAULT_FALLBACK_NAME, + **kwargs: Any, ) -> None: """Create a new Fallback policy. Args: + priority: Fallback policy priority. core_threshold: if NLU confidence threshold is met, predict fallback action with confidence `core_threshold`. If this is the highest confidence in the ensemble, @@ -54,7 +56,7 @@ def __init__( between confidences of the top two predictions fallback_action_name: name of the action to execute as a fallback """ - super().__init__(priority=priority) + super().__init__(priority=priority, **kwargs) self.nlu_threshold = nlu_threshold self.ambiguity_threshold = ambiguity_threshold diff --git a/rasa/core/policies/form_policy.py b/rasa/core/policies/form_policy.py index ba75d8e20421..5fcb2f7a2db1 100644 --- a/rasa/core/policies/form_policy.py +++ b/rasa/core/policies/form_policy.py @@ -35,12 +35,17 @@ def __init__( featurizer: Optional[TrackerFeaturizer] = None, priority: int = FORM_POLICY_PRIORITY, lookup: Optional[Dict] = None, + **kwargs: Any, ) -> None: # max history is set to 2 in order to capture # previous meaningful action before action listen super().__init__( - featurizer=featurizer, priority=priority, max_history=2, lookup=lookup + featurizer=featurizer, + priority=priority, + max_history=2, + lookup=lookup, + **kwargs, ) rasa.shared.utils.io.raise_deprecation_warning( diff --git a/rasa/core/policies/mapping_policy.py b/rasa/core/policies/mapping_policy.py index 3bfa83c60d42..fac4492a0df7 100644 --- a/rasa/core/policies/mapping_policy.py +++ b/rasa/core/policies/mapping_policy.py @@ -43,10 +43,9 @@ class MappingPolicy(Policy): def _standard_featurizer() -> None: return None - def __init__(self, priority: int = MAPPING_POLICY_PRIORITY) -> None: + def __init__(self, priority: int = MAPPING_POLICY_PRIORITY, **kwargs: Any) -> None: """Create a new Mapping policy.""" - - super().__init__(priority=priority) + super().__init__(priority=priority, **kwargs) rasa.shared.utils.io.raise_deprecation_warning( f"'{MappingPolicy.__name__}' is deprecated and will be removed in " diff --git a/rasa/core/policies/memoization.py b/rasa/core/policies/memoization.py index 55876c4a07af..849678e3f4a1 100644 --- a/rasa/core/policies/memoization.py +++ b/rasa/core/policies/memoization.py @@ -71,6 +71,7 @@ def __init__( priority: int = MEMOIZATION_POLICY_PRIORITY, max_history: Optional[int] = MAX_HISTORY_NOT_SET, lookup: Optional[Dict] = None, + **kwargs: Any, ) -> None: """Initialize the policy. @@ -81,7 +82,6 @@ def __init__( lookup: a dictionary that stores featurized tracker states and predicted actions for them """ - if max_history == MAX_HISTORY_NOT_SET: max_history = OLD_DEFAULT_MAX_HISTORY # old default value rasa.shared.utils.io.raise_warning( @@ -97,7 +97,7 @@ def __init__( if not featurizer: featurizer = self._standard_featurizer(max_history) - super().__init__(featurizer, priority) + super().__init__(featurizer, priority, **kwargs) self.max_history = self.featurizer.max_history self.lookup = lookup if lookup is not None else {} diff --git a/rasa/core/policies/policy.py b/rasa/core/policies/policy.py index 8b54d84b9184..1d9305006dc9 100644 --- a/rasa/core/policies/policy.py +++ b/rasa/core/policies/policy.py @@ -16,6 +16,8 @@ TYPE_CHECKING, ) import numpy as np + +from rasa.core.exceptions import UnsupportedDialogueModelError from rasa.shared.core.events import Event import rasa.shared.utils.common @@ -34,6 +36,7 @@ from rasa.core.constants import DEFAULT_POLICY_PRIORITY from rasa.shared.core.constants import USER, SLOTS, PREVIOUS_ACTION, ACTIVE_LOOP from rasa.shared.nlu.constants import ENTITIES, INTENT, TEXT, ACTION_TEXT, ACTION_NAME +from rasa.utils.tensorflow.constants import EPOCHS if TYPE_CHECKING: from rasa.shared.nlu.training_data.features import Features @@ -110,12 +113,15 @@ def __init__( self, featurizer: Optional[TrackerFeaturizer] = None, priority: int = DEFAULT_POLICY_PRIORITY, + **kwargs: Any, ) -> None: + """Constructs a new Policy object.""" self.__featurizer = self._create_featurizer(featurizer) self.priority = priority @property def featurizer(self): + """Returns the policy's featurizer.""" return self.__featurizer @staticmethod @@ -272,7 +278,7 @@ def persist(self, path: Union[Text, Path]) -> None: rasa.shared.utils.io.dump_obj_as_json_to_file(file, self._metadata()) @classmethod - def load(cls, path: Union[Text, Path]) -> "Policy": + def load(cls, path: Union[Text, Path], **kwargs: Any) -> "Policy": """Loads a policy from path. Args: @@ -290,6 +296,25 @@ def load(cls, path: Union[Text, Path]) -> "Policy": featurizer = TrackerFeaturizer.load(path) data["featurizer"] = featurizer + data.update(kwargs) + + constructor_args = rasa.shared.utils.common.arguments_of(cls) + if "kwargs" not in constructor_args: + if set(data.keys()).issubset(set(constructor_args)): + rasa.shared.utils.io.raise_deprecation_warning( + f"`{cls.__name__}.__init__` does not accept `**kwargs` " + f"This is required for contextual information e.g. the flag " + f"`should_finetune`.", + warn_until_version="3.0.0", + ) + else: + raise UnsupportedDialogueModelError( + f"`{cls.__name__}.__init__` does not accept `**kwargs`. " + f"Attempting to pass {data} to the policy. " + f"This argument should be added to all policies by " + f"Rasa Open Source 3.0.0." + ) + return cls(**data) logger.info( diff --git a/rasa/core/policies/rule_policy.py b/rasa/core/policies/rule_policy.py index c13086ec44af..033ace294237 100644 --- a/rasa/core/policies/rule_policy.py +++ b/rasa/core/policies/rule_policy.py @@ -108,6 +108,7 @@ def __init__( enable_fallback_prediction: bool = True, restrict_rules: bool = True, check_for_contradictions: bool = True, + **kwargs: Any, ) -> None: """Create a `RulePolicy` object. @@ -124,6 +125,10 @@ def __init__( if no rule matched. enable_fallback_prediction: If `True` `core_fallback_action_name` is predicted in case no rule matched. + restrict_rules: If `True` rules are restricted to contain a maximum of 1 + user message. This is used to avoid that users build a state machine + using the rules. + check_for_contradictions: Check for contradictions. """ self._core_fallback_threshold = core_fallback_threshold self._fallback_action_name = core_fallback_action_name @@ -136,7 +141,11 @@ def __init__( # max history is set to `None` in order to capture any lengths of rule stories super().__init__( - featurizer=featurizer, priority=priority, max_history=None, lookup=lookup + featurizer=featurizer, + priority=priority, + max_history=None, + lookup=lookup, + **kwargs, ) @classmethod diff --git a/rasa/core/policies/sklearn_policy.py b/rasa/core/policies/sklearn_policy.py index ef616fcdc208..3c1340d3e4e2 100644 --- a/rasa/core/policies/sklearn_policy.py +++ b/rasa/core/policies/sklearn_policy.py @@ -27,7 +27,7 @@ from sklearn.preprocessing import LabelEncoder from rasa.shared.nlu.constants import ACTION_TEXT, TEXT from rasa.shared.nlu.training_data.features import Features -from rasa.utils.tensorflow.constants import SENTENCE +from rasa.utils.tensorflow.constants import EPOCHS, SENTENCE from rasa.utils.tensorflow.model_data import Data # noinspection PyProtectedMember @@ -72,6 +72,8 @@ def __init__( Args: featurizer: Featurizer used to convert the training data into vector format. + priority: Policy priority + max_history: Maximum history of the dialogs. model: The sklearn model or model pipeline. param_grid: If *param_grid* is not None and *cv* is given, a grid search on the given *param_grid* is performed @@ -85,7 +87,6 @@ def __init__( shuffle: Whether to shuffle training data. zero_state_features: Contains default feature values for attributes """ - if featurizer: if not isinstance(featurizer, MaxHistoryTrackerFeaturizer): raise TypeError( @@ -104,7 +105,7 @@ def __init__( ) featurizer = self._standard_featurizer(max_history) - super().__init__(featurizer, priority) + super().__init__(featurizer, priority, **kwargs) self.model = model or self._default_model() self.cv = cv @@ -302,7 +303,10 @@ def persist(self, path: Union[Text, Path]) -> None: ) @classmethod - def load(cls, path: Union[Text, Path]) -> Policy: + def load( + cls, path: Union[Text, Path], should_finetune: bool = False, **kwargs: Any + ) -> Policy: + """See the docstring for `Policy.load`.""" filename = Path(path) / "sklearn_model.pkl" zero_features_filename = Path(path) / "zero_state_features.pkl" if not Path(path).exists(): @@ -325,6 +329,7 @@ def load(cls, path: Union[Text, Path]) -> Policy: featurizer=featurizer, priority=meta["priority"], zero_state_features=zero_state_features, + should_finetune=should_finetune, ) state = io_utils.pickle_load(filename) diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index 7a7b60352327..16bc37cd32d6 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -220,11 +220,10 @@ def __init__( **kwargs: Any, ) -> None: """Declare instance variables with default values.""" - if not featurizer: featurizer = self._standard_featurizer(max_history) - super().__init__(featurizer, priority) + super().__init__(featurizer, priority, **kwargs) if isinstance(featurizer, FullDialogueTrackerFeaturizer): self.is_full_dialogue_featurizer_used = True else: @@ -437,8 +436,9 @@ def persist(self, path: Union[Text, Path]) -> None: ) @classmethod - def load(cls, path: Union[Text, Path]) -> "TEDPolicy": + def load(cls, path: Union[Text, Path], **kwargs: Any) -> "TEDPolicy": """Loads a policy from the storage. + **Needs to load its featurizer** """ model_path = Path(path) @@ -500,6 +500,10 @@ def load(cls, path: Union[Text, Path]) -> "TEDPolicy": ) model.build_for_predict(predict_data_example) + meta["should_finetune"] = kwargs.get("should_finetune", False) + if "epoch_override" in kwargs: + meta[EPOCHS] = kwargs["epoch_override"] + return cls( featurizer=featurizer, priority=priority, diff --git a/rasa/core/policies/two_stage_fallback.py b/rasa/core/policies/two_stage_fallback.py index 92b2a5b43445..ed3db549dc6f 100644 --- a/rasa/core/policies/two_stage_fallback.py +++ b/rasa/core/policies/two_stage_fallback.py @@ -62,10 +62,12 @@ def __init__( fallback_core_action_name: Text = ACTION_DEFAULT_FALLBACK_NAME, fallback_nlu_action_name: Text = ACTION_DEFAULT_FALLBACK_NAME, deny_suggestion_intent_name: Text = USER_INTENT_OUT_OF_SCOPE, + **kwargs: Any, ) -> None: """Create a new Two-stage Fallback policy. Args: + priority: The fallback policy priority. nlu_threshold: minimum threshold for NLU confidence. If intent prediction confidence is lower than this, predict fallback action with confidence 1.0. @@ -88,6 +90,7 @@ def __init__( ambiguity_threshold, core_threshold, fallback_core_action_name, + **kwargs, ) self.fallback_nlu_action_name = fallback_nlu_action_name diff --git a/rasa/core/train.py b/rasa/core/train.py index c3c109f2039e..44137f6bbd00 100644 --- a/rasa/core/train.py +++ b/rasa/core/train.py @@ -31,7 +31,6 @@ async def train( exclusion_percentage: Optional[int] = None, additional_arguments: Optional[Dict] = None, model_to_finetune: Optional["Agent"] = None, - finetuning_epoch_fraction: float = 1.0, ) -> "Agent": from rasa.core import config, utils from rasa.core.utils import AvailableEndpoints @@ -66,6 +65,8 @@ async def train( training_data = await agent.load_data( training_resource, exclusion_percentage=exclusion_percentage, **data_load_args ) + if model_to_finetune: + agent.policy_ensemble = model_to_finetune.policy_ensemble agent.train(training_data, **additional_arguments) agent.persist(output_path) diff --git a/rasa/train.py b/rasa/train.py index 48626e85ac26..6bd70e8210cd 100644 --- a/rasa/train.py +++ b/rasa/train.py @@ -446,7 +446,11 @@ async def _train_core_with_validated_data( ) if model_to_finetune: - model_to_finetune = _core_model_for_finetuning(model_to_finetune) + model_to_finetune = _core_model_for_finetuning( + model_to_finetune, + new_config=config, + finetuning_epoch_fraction=finetuning_epoch_fraction, + ) if not model_to_finetune: rasa.shared.utils.cli.print_warning( @@ -468,7 +472,6 @@ async def _train_core_with_validated_data( additional_arguments=additional_arguments, interpreter=interpreter, model_to_finetune=model_to_finetune, - finetuning_epoch_fraction=finetuning_epoch_fraction, ) rasa.shared.utils.cli.print_color( "Core model training completed.", color=rasa.shared.utils.io.bcolors.OKBLUE @@ -488,21 +491,26 @@ async def _train_core_with_validated_data( return _train_path -def _core_model_for_finetuning(model_to_finetune: Text) -> Optional[Agent]: +def _core_model_for_finetuning( + model_to_finetune: Text, + new_config: Optional[Dict] = None, + finetuning_epoch_fraction: float = 1.0, +) -> Optional[Agent]: path_to_archive = model.get_model_for_finetuning(model_to_finetune) if not path_to_archive: return None with model.unpack_model(path_to_archive) as unpacked: - try: - agent = Agent.load(unpacked) - # Agent might be empty if no underlying Core model was found. - if agent.domain is not None and agent.policy_ensemble is not None: - return agent - except Exception: - # Anything might go wrong. In that case we skip model finetuning. - pass - return None + agent = Agent.load( + unpacked, + new_config=new_config, + finetuning_epoch_fraction=finetuning_epoch_fraction, + ) + # Agent might be empty if no underlying Core model was found. + if agent.domain is not None and agent.policy_ensemble is not None: + return agent + + return None def train_nlu( diff --git a/tests/core/conftest.py b/tests/core/conftest.py index b230a4170b35..095342e288de 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -72,8 +72,8 @@ def as_feature(self): # noinspection PyAbstractClass,PyUnusedLocal,PyMissingConstructor class ExamplePolicy(Policy): - def __init__(self, example_arg): - super(ExamplePolicy, self).__init__() + def __init__(self, *args, **kwargs): + super(ExamplePolicy, self).__init__(*args, **kwargs) class MockedMongoTrackerStore(MongoTrackerStore): diff --git a/tests/core/test_ensemble.py b/tests/core/test_ensemble.py index 7f90f2591e93..050ad361c45f 100644 --- a/tests/core/test_ensemble.py +++ b/tests/core/test_ensemble.py @@ -1,9 +1,13 @@ from pathlib import Path -from typing import List, Any, Text, Optional +from typing import List, Any, Text, Optional, Union +from unittest.mock import Mock +from _pytest.capture import CaptureFixture +from _pytest.monkeypatch import MonkeyPatch import pytest import copy +from rasa.core.exceptions import UnsupportedDialogueModelError from rasa.core.policies.memoization import MemoizationPolicy, AugmentedMemoizationPolicy from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter, RegexInterpreter @@ -37,7 +41,7 @@ class WorkingPolicy(Policy): @classmethod - def load(cls, _) -> Policy: + def load(cls, *args: Any, **kwargs: Any) -> Policy: return WorkingPolicy() def persist(self, _) -> None: @@ -74,6 +78,37 @@ def test_policy_loading_simple(tmp_path: Path): assert original_policy_ensemble.policies == loaded_policy_ensemble.policies +class PolicyWithoutLoadKwargs(Policy): + @classmethod + def load(cls, path: Union[Text, Path]) -> Policy: + return PolicyWithoutLoadKwargs() + + def persist(self, _) -> None: + pass + + +def test_policy_loading_no_kwargs_with_context(tmp_path: Path): + original_policy_ensemble = PolicyEnsemble([PolicyWithoutLoadKwargs()]) + original_policy_ensemble.train([], None, RegexInterpreter()) + original_policy_ensemble.persist(str(tmp_path)) + + with pytest.raises(UnsupportedDialogueModelError) as execinfo: + PolicyEnsemble.load(str(tmp_path), new_config={"policies": [{}]}) + assert "`PolicyWithoutLoadKwargs.load` does not accept `**kwargs`" in str( + execinfo.value + ) + + +def test_policy_loading_no_kwargs_with_no_context( + tmp_path: Path, capsys: CaptureFixture +): + original_policy_ensemble = PolicyEnsemble([PolicyWithoutLoadKwargs()]) + original_policy_ensemble.train([], None, RegexInterpreter()) + original_policy_ensemble.persist(str(tmp_path)) + with pytest.warns(FutureWarning): + PolicyEnsemble.load(str(tmp_path)) + + class ConstantPolicy(Policy): def __init__( self, @@ -83,8 +118,9 @@ def __init__( is_end_to_end_prediction: bool = False, events: Optional[List[Event]] = None, optional_events: Optional[List[Event]] = None, + **kwargs: Any, ) -> None: - super().__init__(priority=priority) + super().__init__(priority=priority, **kwargs) self.predict_index = predict_index self.confidence = confidence self.is_end_to_end_prediction = is_end_to_end_prediction @@ -92,7 +128,7 @@ def __init__( self.optional_events = optional_events or [] @classmethod - def load(cls, _) -> Policy: + def load(cls, args, **kwargs) -> Policy: pass def persist(self, _) -> None: @@ -304,7 +340,7 @@ def test_fallback_wins_over_mapping(): class LoadReturnsNonePolicy(Policy): @classmethod - def load(cls, _) -> None: + def load(cls, *args, **kwargs) -> None: return None def persist(self, _) -> None: @@ -340,7 +376,7 @@ def test_policy_loading_load_returns_none(tmp_path: Path): class LoadReturnsWrongTypePolicy(Policy): @classmethod - def load(cls, _) -> Text: + def load(cls, *args, **kwargs) -> Text: return "" def persist(self, _) -> None: diff --git a/tests/core/test_policies.py b/tests/core/test_policies.py index b1a65f7e5d0f..fe00581ab250 100644 --- a/tests/core/test_policies.py +++ b/tests/core/test_policies.py @@ -7,8 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch from rasa.core.channels import OutputChannel +from rasa.core.exceptions import UnsupportedDialogueModelError from rasa.core.nlg import NaturalLanguageGenerator from rasa.shared.core.generator import TrackerWithCachedStates +import rasa.shared.utils.io from rasa.core import training import rasa.core.actions.action @@ -1229,3 +1231,34 @@ def test_get_training_trackers_for_policy( def test_deprecation_warnings_for_old_rule_like_policies(policy: Type[Policy]): with pytest.warns(FutureWarning): policy(None) + + +class PolicyWithoutInitKwargs(Policy): + def __init__(self, *args: Any) -> None: + pass + + def persist(self, _) -> None: + pass + + @classmethod + def _metadata_filename(cls) -> Text: + return "no_finetune_policy" + + +def test_loading_policy_with_no_constructor_kwargs(tmp_path: Path): + rasa.shared.utils.io.write_text_file( + "{}", tmp_path / PolicyWithoutInitKwargs._metadata_filename() + ) + with pytest.raises(UnsupportedDialogueModelError) as execinfo: + PolicyWithoutInitKwargs.load(str(tmp_path), should_finetune=True) + assert "`PolicyWithoutInitKwargs.__init__` does not accept `**kwargs`." in str( + execinfo.value + ) + + +def test_loading_policy_with_no_constructor_kwargs_but_required_args(tmp_path: Path): + rasa.shared.utils.io.write_text_file( + "{}", tmp_path / PolicyWithoutInitKwargs._metadata_filename() + ) + with pytest.warns(FutureWarning): + PolicyWithoutInitKwargs.load(str(tmp_path)) diff --git a/tests/test_train.py b/tests/test_train.py index 53e091b2c4e4..cb6ac099ca2f 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -3,12 +3,13 @@ import os from pathlib import Path from typing import Text, Dict, Any -from unittest.mock import Mock +from unittest.mock import Mock, create_autospec import pytest from _pytest.capture import CaptureFixture from _pytest.monkeypatch import MonkeyPatch +from rasa.core.policies.ted_policy import TEDPolicy import rasa.model import rasa.core import rasa.nlu @@ -399,32 +400,84 @@ def test_model_finetuning_core( tmp_path: Path, monkeypatch: MonkeyPatch, default_domain_path: Text, - default_stories_file: Text, - default_stack_config: Text, - trained_rasa_model: Text, + default_nlu_data: Text, + trained_moodbot_path: Text, use_latest_model: bool, ): mocked_core_training = AsyncMock() monkeypatch.setattr(rasa.core, rasa.core.train.__name__, mocked_core_training) + mock_agent_load = Mock(wraps=Agent.load) + monkeypatch.setattr(Agent, "load", mock_agent_load) + (tmp_path / "models").mkdir() output = str(tmp_path / "models") if use_latest_model: - trained_rasa_model = str(Path(trained_rasa_model).parent) + trained_moodbot_path = str(Path(trained_moodbot_path).parent) + + # Typically models will be fine-tuned with a smaller number of epochs than training + # from scratch. + # Fine-tuning will use the number of epochs in the new config. + old_config = rasa.shared.utils.io.read_yaml_file("examples/moodbot/config.yml") + old_config["policies"][0]["epochs"] = 20 + new_config_path = tmp_path / "new_config.yml" + rasa.shared.utils.io.write_yaml(old_config, new_config_path) train_core( - default_domain_path, - default_stack_config, - default_stories_file, + "examples/moodbot/domain.yml", + str(new_config_path), + "examples/moodbot/data/stories.yml", output=output, - model_to_finetune=trained_rasa_model, - finetuning_epoch_fraction=1, + model_to_finetune=trained_moodbot_path, + finetuning_epoch_fraction=0.5, ) mocked_core_training.assert_called_once() _, kwargs = mocked_core_training.call_args - assert isinstance(kwargs["model_to_finetune"], Agent) + model_to_finetune = kwargs["model_to_finetune"] + assert isinstance(model_to_finetune, Agent) + + ted = model_to_finetune.policy_ensemble.policies[0] + assert ted.config[EPOCHS] == 10 + assert ted.config["should_finetune"] is True + + +def test_model_finetuning_core_with_default_epochs( + tmp_path: Path, + monkeypatch: MonkeyPatch, + default_domain_path: Text, + default_nlu_data: Text, + trained_moodbot_path: Text, +): + mocked_core_training = AsyncMock() + monkeypatch.setattr(rasa.core, rasa.core.train.__name__, mocked_core_training) + + (tmp_path / "models").mkdir() + output = str(tmp_path / "models") + + # Providing a new config with no epochs will mean the default amount are used + # and then scaled by `finetuning_epoch_fraction`. + old_config = rasa.shared.utils.io.read_yaml_file("examples/moodbot/config.yml") + del old_config["policies"][0]["epochs"] + new_config_path = tmp_path / "new_config.yml" + rasa.shared.utils.io.write_yaml(old_config, new_config_path) + + train_core( + "examples/moodbot/domain.yml", + str(new_config_path), + "examples/moodbot/data/stories.yml", + output=output, + model_to_finetune=trained_moodbot_path, + finetuning_epoch_fraction=2, + ) + + mocked_core_training.assert_called_once() + _, kwargs = mocked_core_training.call_args + model_to_finetune = kwargs["model_to_finetune"] + + ted = model_to_finetune.policy_ensemble.policies[0] + assert ted.config[EPOCHS] == TEDPolicy.defaults[EPOCHS] * 2 @pytest.mark.parametrize("use_latest_model", [True, False])