Skip to content

Commit

Permalink
#7329 load models in finetune mode core (#7458)
Browse files Browse the repository at this point in the history
* Load core model in fine-tuning mode

* Core finetune loading test

* Test and PR comments

* Fallback to default epochs

* Test policy and ensemble fine-tuning exception cases

* Remove epoch_override from Policy.load

* use kwargs

* fix

* fix train tests

* More test fixes

* Apply suggestions from code review

Co-authored-by: Daksh Varshneya <[email protected]>

* remove unneeded sklearn epochs

* Apply suggestions from code review

Co-authored-by: Tobias Wochinger <[email protected]>

* PR comments for warning strings

* Add typing

* add back invalid model tests

* small comments

Co-authored-by: Daksh Varshneya <[email protected]>
Co-authored-by: Tobias Wochinger <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2020
1 parent ca76810 commit 241a075
Show file tree
Hide file tree
Showing 17 changed files with 297 additions and 54 deletions.
12 changes: 11 additions & 1 deletion rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def load(
model_server: Optional[EndpointConfig] = None,
remote_storage: Optional[Text] = None,
path_to_model_archive: Optional[Text] = None,
new_config: Optional[Dict] = None,
finetuning_epoch_fraction: float = 1.0,
) -> "Agent":
"""Load a persisted model from the passed path."""
try:
Expand Down Expand Up @@ -441,7 +443,15 @@ def load(

if core_model:
domain = Domain.load(os.path.join(core_model, DEFAULT_DOMAIN_PATH))
ensemble = PolicyEnsemble.load(core_model) if core_model else None
ensemble = (
PolicyEnsemble.load(
core_model,
new_config=new_config,
finetuning_epoch_fraction=finetuning_epoch_fraction,
)
if core_model
else None
)

# ensures the domain hasn't changed between test and train
domain.compare_with_specification(core_model)
Expand Down
60 changes: 55 additions & 5 deletions rasa/core/policies/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import importlib
import json
import logging
import math
import os
import sys
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Text, Optional, Any, List, Dict, Tuple, NamedTuple, Union
from typing import Text, Optional, Any, List, Dict, Tuple, Type, Union

import rasa.core
import rasa.core.training.training
Expand Down Expand Up @@ -41,6 +42,7 @@
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.generator import TrackerWithCachedStates
from rasa.core import registry
from rasa.utils.tensorflow.constants import EPOCHS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -302,20 +304,68 @@ def _ensure_loaded_policy(cls, policy, policy_cls, policy_name: Text):
"".format(policy_name)
)

@classmethod
def load(cls, path: Union[Text, Path]) -> "PolicyEnsemble":
"""Loads policy and domain specification from storage"""
@staticmethod
def _get_updated_epochs(
policy_cls: Type[Policy],
config_for_policy: Dict[Text, Any],
finetuning_epoch_fraction: float,
) -> Optional[int]:
if EPOCHS in config_for_policy:
epochs = config_for_policy[EPOCHS]
else:
try:
epochs = policy_cls.defaults[EPOCHS]
except (KeyError, AttributeError):
return None
return math.ceil(epochs * finetuning_epoch_fraction)

@classmethod
def load(
cls,
path: Union[Text, Path],
new_config: Optional[Dict] = None,
finetuning_epoch_fraction: float = 1.0,
) -> "PolicyEnsemble":
"""Loads policy and domain specification from disk."""
metadata = cls.load_metadata(path)
cls.ensure_model_compatibility(metadata)
policies = []
for i, policy_name in enumerate(metadata["policy_names"]):
policy_cls = registry.policy_from_module_path(policy_name)
dir_name = f"policy_{i}_{policy_cls.__name__}"
policy_path = os.path.join(path, dir_name)
policy = policy_cls.load(policy_path)

context = {}
if new_config:
context["should_finetune"] = True

config_for_policy = new_config["policies"][i]
epochs = cls._get_updated_epochs(
policy_cls, config_for_policy, finetuning_epoch_fraction
)
if epochs:
context["epoch_override"] = epochs

if "kwargs" not in rasa.shared.utils.common.arguments_of(policy_cls.load):
if context:
raise UnsupportedDialogueModelError(
f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not "
f"accept `**kwargs`. Attempting to pass {context} to the "
f"policy. `**kwargs` should be added to all policies by "
f"Rasa Open Source 3.0.0."
)
else:
rasa.shared.utils.io.raise_deprecation_warning(
f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not "
f"accept `**kwargs`. `**kwargs` are required for contextual "
f"information e.g. the flag `should_finetune`.",
warn_until_version="3.0.0",
)

policy = policy_cls.load(policy_path, **context)
cls._ensure_loaded_policy(policy, policy_cls, policy_name)
policies.append(policy)

ensemble_cls = rasa.shared.utils.common.class_from_module_path(
metadata["ensemble_name"]
)
Expand Down
4 changes: 3 additions & 1 deletion rasa/core/policies/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ def __init__(
ambiguity_threshold: float = DEFAULT_NLU_FALLBACK_AMBIGUITY_THRESHOLD,
core_threshold: float = DEFAULT_CORE_FALLBACK_THRESHOLD,
fallback_action_name: Text = ACTION_DEFAULT_FALLBACK_NAME,
**kwargs: Any,
) -> None:
"""Create a new Fallback policy.
Args:
priority: Fallback policy priority.
core_threshold: if NLU confidence threshold is met,
predict fallback action with confidence `core_threshold`.
If this is the highest confidence in the ensemble,
Expand All @@ -54,7 +56,7 @@ def __init__(
between confidences of the top two predictions
fallback_action_name: name of the action to execute as a fallback
"""
super().__init__(priority=priority)
super().__init__(priority=priority, **kwargs)

self.nlu_threshold = nlu_threshold
self.ambiguity_threshold = ambiguity_threshold
Expand Down
7 changes: 6 additions & 1 deletion rasa/core/policies/form_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ def __init__(
featurizer: Optional[TrackerFeaturizer] = None,
priority: int = FORM_POLICY_PRIORITY,
lookup: Optional[Dict] = None,
**kwargs: Any,
) -> None:

# max history is set to 2 in order to capture
# previous meaningful action before action listen
super().__init__(
featurizer=featurizer, priority=priority, max_history=2, lookup=lookup
featurizer=featurizer,
priority=priority,
max_history=2,
lookup=lookup,
**kwargs,
)

rasa.shared.utils.io.raise_deprecation_warning(
Expand Down
5 changes: 2 additions & 3 deletions rasa/core/policies/mapping_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ class MappingPolicy(Policy):
def _standard_featurizer() -> None:
return None

def __init__(self, priority: int = MAPPING_POLICY_PRIORITY) -> None:
def __init__(self, priority: int = MAPPING_POLICY_PRIORITY, **kwargs: Any) -> None:
"""Create a new Mapping policy."""

super().__init__(priority=priority)
super().__init__(priority=priority, **kwargs)

rasa.shared.utils.io.raise_deprecation_warning(
f"'{MappingPolicy.__name__}' is deprecated and will be removed in "
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/policies/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
priority: int = MEMOIZATION_POLICY_PRIORITY,
max_history: Optional[int] = MAX_HISTORY_NOT_SET,
lookup: Optional[Dict] = None,
**kwargs: Any,
) -> None:
"""Initialize the policy.
Expand All @@ -81,7 +82,6 @@ def __init__(
lookup: a dictionary that stores featurized tracker states and
predicted actions for them
"""

if max_history == MAX_HISTORY_NOT_SET:
max_history = OLD_DEFAULT_MAX_HISTORY # old default value
rasa.shared.utils.io.raise_warning(
Expand All @@ -97,7 +97,7 @@ def __init__(
if not featurizer:
featurizer = self._standard_featurizer(max_history)

super().__init__(featurizer, priority)
super().__init__(featurizer, priority, **kwargs)

self.max_history = self.featurizer.max_history
self.lookup = lookup if lookup is not None else {}
Expand Down
27 changes: 26 additions & 1 deletion rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
TYPE_CHECKING,
)
import numpy as np

from rasa.core.exceptions import UnsupportedDialogueModelError
from rasa.shared.core.events import Event

import rasa.shared.utils.common
Expand All @@ -34,6 +36,7 @@
from rasa.core.constants import DEFAULT_POLICY_PRIORITY
from rasa.shared.core.constants import USER, SLOTS, PREVIOUS_ACTION, ACTIVE_LOOP
from rasa.shared.nlu.constants import ENTITIES, INTENT, TEXT, ACTION_TEXT, ACTION_NAME
from rasa.utils.tensorflow.constants import EPOCHS

if TYPE_CHECKING:
from rasa.shared.nlu.training_data.features import Features
Expand Down Expand Up @@ -110,12 +113,15 @@ def __init__(
self,
featurizer: Optional[TrackerFeaturizer] = None,
priority: int = DEFAULT_POLICY_PRIORITY,
**kwargs: Any,
) -> None:
"""Constructs a new Policy object."""
self.__featurizer = self._create_featurizer(featurizer)
self.priority = priority

@property
def featurizer(self):
"""Returns the policy's featurizer."""
return self.__featurizer

@staticmethod
Expand Down Expand Up @@ -272,7 +278,7 @@ def persist(self, path: Union[Text, Path]) -> None:
rasa.shared.utils.io.dump_obj_as_json_to_file(file, self._metadata())

@classmethod
def load(cls, path: Union[Text, Path]) -> "Policy":
def load(cls, path: Union[Text, Path], **kwargs: Any) -> "Policy":
"""Loads a policy from path.
Args:
Expand All @@ -290,6 +296,25 @@ def load(cls, path: Union[Text, Path]) -> "Policy":
featurizer = TrackerFeaturizer.load(path)
data["featurizer"] = featurizer

data.update(kwargs)

constructor_args = rasa.shared.utils.common.arguments_of(cls)
if "kwargs" not in constructor_args:
if set(data.keys()).issubset(set(constructor_args)):
rasa.shared.utils.io.raise_deprecation_warning(
f"`{cls.__name__}.__init__` does not accept `**kwargs` "
f"This is required for contextual information e.g. the flag "
f"`should_finetune`.",
warn_until_version="3.0.0",
)
else:
raise UnsupportedDialogueModelError(
f"`{cls.__name__}.__init__` does not accept `**kwargs`. "
f"Attempting to pass {data} to the policy. "
f"This argument should be added to all policies by "
f"Rasa Open Source 3.0.0."
)

return cls(**data)

logger.info(
Expand Down
11 changes: 10 additions & 1 deletion rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
enable_fallback_prediction: bool = True,
restrict_rules: bool = True,
check_for_contradictions: bool = True,
**kwargs: Any,
) -> None:
"""Create a `RulePolicy` object.
Expand All @@ -124,6 +125,10 @@ def __init__(
if no rule matched.
enable_fallback_prediction: If `True` `core_fallback_action_name` is
predicted in case no rule matched.
restrict_rules: If `True` rules are restricted to contain a maximum of 1
user message. This is used to avoid that users build a state machine
using the rules.
check_for_contradictions: Check for contradictions.
"""
self._core_fallback_threshold = core_fallback_threshold
self._fallback_action_name = core_fallback_action_name
Expand All @@ -136,7 +141,11 @@ def __init__(

# max history is set to `None` in order to capture any lengths of rule stories
super().__init__(
featurizer=featurizer, priority=priority, max_history=None, lookup=lookup
featurizer=featurizer,
priority=priority,
max_history=None,
lookup=lookup,
**kwargs,
)

@classmethod
Expand Down
13 changes: 9 additions & 4 deletions rasa/core/policies/sklearn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sklearn.preprocessing import LabelEncoder
from rasa.shared.nlu.constants import ACTION_TEXT, TEXT
from rasa.shared.nlu.training_data.features import Features
from rasa.utils.tensorflow.constants import SENTENCE
from rasa.utils.tensorflow.constants import EPOCHS, SENTENCE
from rasa.utils.tensorflow.model_data import Data

# noinspection PyProtectedMember
Expand Down Expand Up @@ -72,6 +72,8 @@ def __init__(
Args:
featurizer: Featurizer used to convert the training data into
vector format.
priority: Policy priority
max_history: Maximum history of the dialogs.
model: The sklearn model or model pipeline.
param_grid: If *param_grid* is not None and *cv* is given,
a grid search on the given *param_grid* is performed
Expand All @@ -85,7 +87,6 @@ def __init__(
shuffle: Whether to shuffle training data.
zero_state_features: Contains default feature values for attributes
"""

if featurizer:
if not isinstance(featurizer, MaxHistoryTrackerFeaturizer):
raise TypeError(
Expand All @@ -104,7 +105,7 @@ def __init__(
)
featurizer = self._standard_featurizer(max_history)

super().__init__(featurizer, priority)
super().__init__(featurizer, priority, **kwargs)

self.model = model or self._default_model()
self.cv = cv
Expand Down Expand Up @@ -302,7 +303,10 @@ def persist(self, path: Union[Text, Path]) -> None:
)

@classmethod
def load(cls, path: Union[Text, Path]) -> Policy:
def load(
cls, path: Union[Text, Path], should_finetune: bool = False, **kwargs: Any
) -> Policy:
"""See the docstring for `Policy.load`."""
filename = Path(path) / "sklearn_model.pkl"
zero_features_filename = Path(path) / "zero_state_features.pkl"
if not Path(path).exists():
Expand All @@ -325,6 +329,7 @@ def load(cls, path: Union[Text, Path]) -> Policy:
featurizer=featurizer,
priority=meta["priority"],
zero_state_features=zero_state_features,
should_finetune=should_finetune,
)

state = io_utils.pickle_load(filename)
Expand Down
10 changes: 7 additions & 3 deletions rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,10 @@ def __init__(
**kwargs: Any,
) -> None:
"""Declare instance variables with default values."""

if not featurizer:
featurizer = self._standard_featurizer(max_history)

super().__init__(featurizer, priority)
super().__init__(featurizer, priority, **kwargs)
if isinstance(featurizer, FullDialogueTrackerFeaturizer):
self.is_full_dialogue_featurizer_used = True
else:
Expand Down Expand Up @@ -437,8 +436,9 @@ def persist(self, path: Union[Text, Path]) -> None:
)

@classmethod
def load(cls, path: Union[Text, Path]) -> "TEDPolicy":
def load(cls, path: Union[Text, Path], **kwargs: Any) -> "TEDPolicy":
"""Loads a policy from the storage.
**Needs to load its featurizer**
"""
model_path = Path(path)
Expand Down Expand Up @@ -500,6 +500,10 @@ def load(cls, path: Union[Text, Path]) -> "TEDPolicy":
)
model.build_for_predict(predict_data_example)

meta["should_finetune"] = kwargs.get("should_finetune", False)
if "epoch_override" in kwargs:
meta[EPOCHS] = kwargs["epoch_override"]

return cls(
featurizer=featurizer,
priority=priority,
Expand Down
Loading

0 comments on commit 241a075

Please sign in to comment.