Skip to content

Commit

Permalink
support but deprecate returning floats
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Oct 30, 2020
1 parent b116835 commit 955ad22
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
42 changes: 28 additions & 14 deletions rasa/core/policies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,25 +613,39 @@ def _get_prediction(
arguments = rasa.shared.utils.common.arguments_of(
policy.predict_action_probabilities
)

if (
len(arguments) > number_of_arguments_in_rasa_1_0
and "interpreter" in arguments
):
return policy.predict_action_probabilities(tracker, domain, interpreter)

# TODO: Deprecation warning if list of floats is returned
rasa.shared.utils.io.raise_warning(
"The function `predict_action_probabilities` of "
"the `Policy` interface was changed to support "
"additional parameters. Please make sure to "
"adapt your custom `Policy` implementation.",
category=DeprecationWarning,
)
probabilities = policy.predict_action_probabilities(
tracker, domain, RegexInterpreter()
)
prediction = policy.predict_action_probabilities(
tracker, domain, interpreter
)
else:
# TODO: Deprecation warning if list of floats is returned
rasa.shared.utils.io.raise_warning(
"The function `predict_action_probabilities` of "
"the `Policy` interface was changed to support "
"additional parameters. Please make sure to "
"adapt your custom `Policy` implementation.",
category=DeprecationWarning,
)
prediction = policy.predict_action_probabilities(
tracker, domain, RegexInterpreter()
)

if isinstance(prediction, list):
rasa.shared.utils.io.raise_warning(
f"The function `predict_action_probabilities` of "
f"the `Policy` interface was changed to return "
f"a `{PolicyPrediction.__name__}` object. Please make sure to "
"adapt your custom `Policy` implementation. Support for returning "
"for return a list of floats will be removed in Rasa Open Source 3.0.",
category=DeprecationWarning,
)
prediction = PolicyPrediction(prediction, policy_priority=policy.priority)

return probabilities
return prediction

def _fallback_after_listen(
self, domain: Domain, probabilities: List[float], policy_name: Text
Expand Down
29 changes: 29 additions & 0 deletions tests/core/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,32 @@ def test_prediction_applies_optional_policy_events(default_domain: Domain):
assert len(tracker.events) == len(optional_events) + len(must_have_events)
assert all(event in tracker.events for event in optional_events)
assert all(event in tracker.events for event in must_have_events)


def test_with_float_returning_policy(default_domain: Domain):
expected_index = 3

class OldPolicy(Policy):
def predict_action_probabilities(
self,
tracker: DialogueStateTracker,
domain: Domain,
interpreter: NaturalLanguageInterpreter,
**kwargs: Any,
) -> List[float]:
prediction = [0.0] * default_domain.num_actions
prediction[expected_index] = 3
return prediction

ensemble = SimplePolicyEnsemble(
[ConstantPolicy(priority=1, predict_index=1), OldPolicy(priority=1)]
)
tracker = DialogueStateTracker.from_events("test", evts=[])

with pytest.warns(DeprecationWarning):
prediction, winning_policy = ensemble.probabilities_using_best_policy(
tracker, default_domain, RegexInterpreter()
)

assert winning_policy == f"policy_1_{OldPolicy.__name__}"
assert prediction.index(max(prediction)) == expected_index

0 comments on commit 955ad22

Please sign in to comment.