Skip to content

Commit

Permalink
Merge pull request #7136 from RasaHQ/end-to-end-policy-predictions
Browse files Browse the repository at this point in the history
Policy Predictions for End-to-End
  • Loading branch information
rasabot authored Nov 10, 2020
2 parents df7a5b9 + 2cc4ac8 commit a5dcad8
Show file tree
Hide file tree
Showing 24 changed files with 1,192 additions and 471 deletions.
3 changes: 3 additions & 0 deletions changelog/7136.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[Policies](policies.mdx) can now return obligatory and optional events as part of their
prediction. Obligatory events are always applied to the current conversation tracker.
Optional events are only applied to the conversation tracker in case the policy wins.
26 changes: 26 additions & 0 deletions changelog/7136.removal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
The [`Policy`](policies.mdx) interface was changed to return a `PolicyPrediction` object when
`predict_action_probabilities` is called. Returning a list of probabilities directly
is deprecated and support for this will be removed in Rasa Open Source 3.0.

You can adapt your custom policy by wrapping your probabilities in a `PolicyPrediction`
object:

```python
from rasa.core.policies.policy import Policy, PolicyPrediction
# ... other imports

def predict_action_probabilities(
self,
tracker: DialogueStateTracker,
domain: Domain,
interpreter: NaturalLanguageInterpreter,
**kwargs: Any,
) -> PolicyPrediction:
probabilities = ... # an action prediction of your policy
return PolicyPrediction(probabilities, "policy_name", policy_priority=self.priority)
```

The same change was applied to the `PolicyEnsemble` interface. Instead of returning
a tuple of action probabilities and policy name, it is now returning a
`PolicyPrediction` object. Support for the old `PolicyEnsemble` interface will be
removed in Rasa Open Source 3.0.
12 changes: 7 additions & 5 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.policies.ensemble import PolicyEnsemble, SimplePolicyEnsemble
from rasa.core.policies.memoization import MemoizationPolicy
from rasa.core.policies.policy import Policy
from rasa.core.policies.policy import Policy, PolicyPrediction
from rasa.core.processor import MessageProcessor
from rasa.core.tracker_store import (
FailSafeTrackerStore,
Expand Down Expand Up @@ -548,14 +548,16 @@ async def execute_action(
sender_id: Text,
action: Text,
output_channel: OutputChannel,
policy: Text,
confidence: float,
policy: Optional[Text],
confidence: Optional[float],
) -> Optional[DialogueStateTracker]:
"""Handle a single message."""

processor = self.create_processor()
prediction = PolicyPrediction.for_action_name(
self.domain, action, policy, confidence or 0.0
)
return await processor.execute_action(
sender_id, action, output_channel, self.nlg, policy, confidence
sender_id, action, output_channel, self.nlg, prediction
)

async def trigger_intent(
Expand Down
166 changes: 103 additions & 63 deletions rasa/core/policies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import rasa.core
import rasa.core.training.training
from rasa.core.constants import FALLBACK_POLICY_PRIORITY
from rasa.shared.exceptions import RasaException
import rasa.shared.utils.common
import rasa.shared.utils.io
Expand All @@ -33,7 +34,7 @@
from rasa.core.exceptions import UnsupportedDialogueModelError
from rasa.core.featurizers.tracker_featurizers import MaxHistoryTrackerFeaturizer
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter, RegexInterpreter
from rasa.core.policies.policy import Policy, SupportedData
from rasa.core.policies.policy import Policy, SupportedData, PolicyPrediction
from rasa.core.policies.fallback import FallbackPolicy
from rasa.core.policies.memoization import MemoizationPolicy, AugmentedMemoizationPolicy
from rasa.core.policies.rule_policy import RulePolicy
Expand All @@ -55,7 +56,7 @@ def __init__(
self.policies = policies
self.date_trained = None

self.action_fingerprints = action_fingerprints
self.action_fingerprints = action_fingerprints or []

self._check_priorities()
self._check_for_important_policies()
Expand Down Expand Up @@ -203,7 +204,7 @@ def probabilities_using_best_policy(
domain: Domain,
interpreter: NaturalLanguageInterpreter,
**kwargs: Any,
) -> Tuple[List[float], Optional[Text]]:
) -> PolicyPrediction:
raise NotImplementedError

def _max_histories(self) -> List[Optional[int]]:
Expand Down Expand Up @@ -449,22 +450,36 @@ def _check_if_rule_policy_used_with_rule_like_policies(
)


class Prediction(NamedTuple):
"""Stores the probabilities and the priority of the prediction."""

probabilities: List[float]
priority: int


class SimplePolicyEnsemble(PolicyEnsemble):
"""Default implementation of a `Policy` ensemble."""

@staticmethod
def is_not_memo_policy(
policy_name: Text, max_confidence: Optional[float] = None
def is_not_in_training_data(
policy_name: Optional[Text], max_confidence: Optional[float] = None
) -> bool:
is_memo = policy_name.endswith("_" + MemoizationPolicy.__name__)
is_augmented = policy_name.endswith("_" + AugmentedMemoizationPolicy.__name__)
"""Checks if the prediction is by a policy which memoized the training data.
Args:
policy_name: The name of the policy.
max_confidence: The max confidence of the policy's prediction.
Returns: `True` if it's a `MemoizationPolicy`, `False` otherwise.
"""
if not policy_name:
return True

memorizing_policies = [
RulePolicy.__name__,
MemoizationPolicy.__name__,
AugmentedMemoizationPolicy.__name__,
]
is_memorized = any(
policy_name.endswith(f"_{memoizing_policy}")
for memoizing_policy in memorizing_policies
)

# also check if confidence is 0, than it cannot be count as prediction
return not (is_memo or is_augmented) or max_confidence == 0.0
return not is_memorized or max_confidence == 0.0

@staticmethod
def _is_not_mapping_policy(
Expand All @@ -483,22 +498,19 @@ def _is_form_policy(policy_name: Text) -> bool:
return policy_name.endswith("_" + FormPolicy.__name__)

def _pick_best_policy(
self, predictions: Dict[Text, Prediction]
) -> Tuple[List[float], Optional[Text]]:
self, predictions: Dict[Text, PolicyPrediction]
) -> PolicyPrediction:
"""Picks the best policy prediction based on probabilities and policy priority.
Args:
predictions: the dictionary containing policy name as keys
and predictions as values
Returns:
best_probabilities: the list of probabilities for the next actions
best_policy_name: the name of the picked policy
The best prediction.
"""

best_confidence = (-1, -1)
best_policy_name = None

# form and mapping policies are special:
# form should be above fallback
# mapping should be below fallback
Expand All @@ -507,9 +519,19 @@ def _pick_best_policy(

form_confidence = None
form_policy_name = None
# End-to-end predictions overrule all other predictions.
use_only_end_to_end = any(
prediction.is_end_to_end_prediction for prediction in predictions.values()
)
policy_events = []

for policy_name, prediction in predictions.items():
confidence = (max(prediction.probabilities), prediction.priority)
policy_events += prediction.events

if prediction.is_end_to_end_prediction != use_only_end_to_end:
continue

confidence = (prediction.max_confidence, prediction.policy_priority)
if self._is_form_policy(policy_name):
# store form prediction separately
form_confidence = confidence
Expand All @@ -526,14 +548,24 @@ def _pick_best_policy(
if form_confidence > best_confidence:
best_policy_name = form_policy_name

return predictions[best_policy_name].probabilities, best_policy_name
best_prediction = predictions[best_policy_name]

policy_events += best_prediction.optional_events

return PolicyPrediction(
best_prediction.probabilities,
best_policy_name,
best_prediction.policy_priority,
policy_events,
is_end_to_end_prediction=best_prediction.is_end_to_end_prediction,
)

def _best_policy_prediction(
self,
tracker: DialogueStateTracker,
domain: Domain,
interpreter: NaturalLanguageInterpreter,
) -> Tuple[List[float], Optional[Text]]:
) -> PolicyPrediction:
"""Finds the best policy prediction.
Args:
Expand All @@ -543,8 +575,7 @@ def _best_policy_prediction(
additional features.
Returns:
probabilities: the list of probabilities for the next actions
policy_name: the name of the picked policy
The winning policy prediction.
"""
# find rejected action before running the policies
# because some of them might add events
Expand Down Expand Up @@ -588,16 +619,17 @@ def _get_prediction(
tracker: DialogueStateTracker,
domain: Domain,
interpreter: NaturalLanguageInterpreter,
) -> Prediction:
) -> PolicyPrediction:
number_of_arguments_in_rasa_1_0 = 2
arguments = rasa.shared.utils.common.arguments_of(
policy.predict_action_probabilities
)

if (
len(arguments) > number_of_arguments_in_rasa_1_0
and "interpreter" in arguments
):
probabilities = policy.predict_action_probabilities(
prediction = policy.predict_action_probabilities(
tracker, domain, interpreter
)
else:
Expand All @@ -608,58 +640,69 @@ def _get_prediction(
"adapt your custom `Policy` implementation.",
category=DeprecationWarning,
)
probabilities = policy.predict_action_probabilities(
prediction = policy.predict_action_probabilities(
tracker, domain, RegexInterpreter()
)

return Prediction(probabilities, policy.priority)
if isinstance(prediction, list):
rasa.shared.utils.io.raise_deprecation_warning(
f"The function `predict_action_probabilities` of "
f"the `{Policy.__name__}` interface was changed to return "
f"a `{PolicyPrediction.__name__}` object. Please make sure to "
f"adapt your custom `{Policy.__name__}` implementation. Support for "
f"returning a list of floats will be removed in Rasa Open Source 3.0.0"
)
prediction = PolicyPrediction(
prediction, policy.__class__.__name__, policy_priority=policy.priority
)

return prediction

def _fallback_after_listen(
self, domain: Domain, probabilities: List[float], policy_name: Text
) -> Tuple[List[float], Text]:
self, domain: Domain, prediction: PolicyPrediction
) -> PolicyPrediction:
"""Triggers fallback if `action_listen` is predicted after a user utterance.
This is done on the condition that:
- a fallback policy is present,
- there was just a user message and the predicted
action is action_listen by a policy
other than the MemoizationPolicy
- we received a user message and the predicted action is `action_listen`
by a policy other than the `MemoizationPolicy` or one of its subclasses.
Args:
domain: the :class:`rasa.shared.core.domain.Domain`
probabilities: the list of probabilities for the next actions
policy_name: the name of the picked policy
prediction: The winning prediction.
Returns:
probabilities: the list of probabilities for the next actions
policy_name: the name of the picked policy
The prediction for the next action.
"""

fallback_idx_policy = [
(i, p) for i, p in enumerate(self.policies) if isinstance(p, FallbackPolicy)
]

if fallback_idx_policy:
fallback_idx, fallback_policy = fallback_idx_policy[0]
if not fallback_idx_policy:
return prediction

logger.debug(
f"Action 'action_listen' was predicted after "
f"a user message using {policy_name}. Predicting "
f"fallback action: {fallback_policy.fallback_action_name}"
)
fallback_idx, fallback_policy = fallback_idx_policy[0]

probabilities = fallback_policy.fallback_scores(domain)
policy_name = f"policy_{fallback_idx}_{type(fallback_policy).__name__}"
logger.debug(
f"Action '{ACTION_LISTEN_NAME}' was predicted after "
f"a user message using {prediction.policy_name}. Predicting "
f"fallback action: {fallback_policy.fallback_action_name}"
)

return probabilities, policy_name
return PolicyPrediction(
fallback_policy.fallback_scores(domain),
f"policy_{fallback_idx}_{type(fallback_policy).__name__}",
FALLBACK_POLICY_PRIORITY,
)

def probabilities_using_best_policy(
self,
tracker: DialogueStateTracker,
domain: Domain,
interpreter: NaturalLanguageInterpreter,
**kwargs: Any,
) -> Tuple[List[float], Optional[Text]]:
) -> PolicyPrediction:
"""Predicts the next action the bot should take after seeing the tracker.
Picks the best policy prediction based on probabilities and policy priority.
Expand All @@ -675,24 +718,21 @@ def probabilities_using_best_policy(
best_probabilities: the list of probabilities for the next actions
best_policy_name: the name of the picked policy
"""

probabilities, policy_name = self._best_policy_prediction(
tracker, domain, interpreter
)
winning_prediction = self._best_policy_prediction(tracker, domain, interpreter)

if (
tracker.latest_action_name == ACTION_LISTEN_NAME
and probabilities is not None
and probabilities.index(max(probabilities))
and winning_prediction.probabilities is not None
and winning_prediction.max_confidence_index
== domain.index_for_action(ACTION_LISTEN_NAME)
and self.is_not_memo_policy(policy_name, max(probabilities))
):
probabilities, policy_name = self._fallback_after_listen(
domain, probabilities, policy_name
and self.is_not_in_training_data(
winning_prediction.policy_name, winning_prediction.max_confidence
)
):
winning_prediction = self._fallback_after_listen(domain, winning_prediction)

logger.debug(f"Predicted next action using {policy_name}")
return probabilities, policy_name
logger.debug(f"Predicted next action using {winning_prediction.policy_name}.")
return winning_prediction


def _check_policy_for_forms_available(
Expand Down
Loading

0 comments on commit a5dcad8

Please sign in to comment.