Skip to content

Commit

Permalink
Merge pull request #9173 from RasaHQ/9135-core-test
Browse files Browse the repository at this point in the history
Remove `processor` replicated logic in `rasa.core.test`
  • Loading branch information
ancalita authored Jul 22, 2021
2 parents ac1197b + b8b6d2d commit 7c2206d
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 47 deletions.
2 changes: 2 additions & 0 deletions changelog/9135.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Remove `MessageProcessor` logic when determining whether to predict another action in `rasa.core.test` module.
Adapt `MessageProcessor.predict_next_action()` method to raise `ActionLimitReached` exception instead.
64 changes: 41 additions & 23 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
import rasa.core.utils
from rasa.core.policies.policy import PolicyPrediction
from rasa.exceptions import ActionLimitReached
from rasa.shared.core.constants import (
USER_INTENT_RESTART,
ACTION_LISTEN_NAME,
Expand All @@ -33,6 +34,7 @@
ReminderScheduled,
SlotSet,
UserUttered,
ActionExecuted,
)
from rasa.shared.core.slots import Slot
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
Expand Down Expand Up @@ -374,8 +376,23 @@ def predict_next_action(
"""Predicts the next action the bot should take after seeing x.
This should be overwritten by more advanced policies to use
ML to predict the action. Returns the index of the next action.
ML to predict the action.
Returns:
The index of the next action and prediction of the policy.
Raises:
ActionLimitReached if the limit of actions to predict has been reached.
"""
should_predict_another_action = self.should_predict_another_action(
tracker.latest_action_name
)

if self.is_action_limit_reached(tracker, should_predict_another_action):
raise ActionLimitReached(
"The limit of actions to predict has been reached."
)

prediction = self._get_next_action_probabilities(tracker)

action = rasa.core.actions.action.action_for_index(
Expand Down Expand Up @@ -623,18 +640,27 @@ def _should_handle_message(tracker: DialogueStateTracker) -> bool:
)

def is_action_limit_reached(
self, num_predicted_actions: int, should_predict_another_action: bool
self, tracker: DialogueStateTracker, should_predict_another_action: bool,
) -> bool:
"""Check whether the maximum number of predictions has been met.
Args:
num_predicted_actions: Number of predicted actions.
tracker: instance of DialogueStateTracker.
should_predict_another_action: Whether the last executed action allows
for more actions to be predicted or not.
Returns:
`True` if the limit of actions to predict has been reached.
"""
reversed_events = list(tracker.events)[::-1]
num_predicted_actions = 0

for e in reversed_events:
if isinstance(e, ActionExecuted):
if e.action_name in (ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME):
break
num_predicted_actions += 1

return (
num_predicted_actions >= self.max_number_of_predictions
and should_predict_another_action
Expand All @@ -645,33 +671,25 @@ async def _predict_and_execute_next_action(
) -> None:
# keep taking actions decided by the policy until it chooses to 'listen'
should_predict_another_action = True
num_predicted_actions = 0

# action loop. predicts actions until we hit action listen
while (
should_predict_another_action
and self._should_handle_message(tracker)
and num_predicted_actions < self.max_number_of_predictions
):
while should_predict_another_action and self._should_handle_message(tracker):
# this actually just calls the policy's method by the same name
action, prediction = self.predict_next_action(tracker)
try:
action, prediction = self.predict_next_action(tracker)
except ActionLimitReached:
logger.warning(
"Circuit breaker tripped. Stopped predicting "
f"more actions for sender '{tracker.sender_id}'."
)
if self.on_circuit_break:
# call a registered callback
self.on_circuit_break(tracker, output_channel, self.nlg)
break

should_predict_another_action = await self._run_action(
action, tracker, output_channel, self.nlg, prediction
)
num_predicted_actions += 1

if self.is_action_limit_reached(
num_predicted_actions, should_predict_another_action
):
# circuit breaker was tripped
logger.warning(
"Circuit breaker tripped. Stopped predicting "
f"more actions for sender '{tracker.sender_id}'."
)
if self.on_circuit_break:
# call a registered callback
self.on_circuit_break(tracker, output_channel, self.nlg)

@staticmethod
def should_predict_another_action(action_name: Text) -> bool:
Expand Down
38 changes: 14 additions & 24 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from rasa.shared.importers.importer import TrainingDataImporter
from rasa.shared.utils.io import DEFAULT_ENCODING
from rasa.utils.tensorflow.constants import QUERY_INTENT_KEY, SEVERITY_KEY
from rasa.exceptions import ActionLimitReached

if typing.TYPE_CHECKING:
from rasa.core.agent import Agent
Expand Down Expand Up @@ -637,7 +638,6 @@ def _collect_action_executed_predictions(
partial_tracker: DialogueStateTracker,
event: ActionExecuted,
fail_on_prediction_errors: bool,
circuit_breaker_tripped: bool,
) -> Tuple[EvaluationStore, PolicyPrediction, Optional[EntityEvaluationResult]]:

action_executed_eval_store = EvaluationStore()
Expand All @@ -649,13 +649,13 @@ def _collect_action_executed_predictions(
policy_entity_result = None
prev_action_unlikely_intent = False

if circuit_breaker_tripped:
prediction = PolicyPrediction([], policy_name=None)
predicted_action = "circuit breaker tripped"
else:
try:
predicted_action, prediction, policy_entity_result = _run_action_prediction(
processor, partial_tracker, expected_action
)
except ActionLimitReached:
prediction = PolicyPrediction([], policy_name=None)
predicted_action = "circuit breaker tripped"

predicted_action_unlikely_intent = predicted_action == ACTION_UNLIKELY_INTENT_NAME
if predicted_action_unlikely_intent and predicted_action != expected_action:
Expand All @@ -671,9 +671,14 @@ def _collect_action_executed_predictions(
)
)
prev_action_unlikely_intent = True
predicted_action, prediction, policy_entity_result = _run_action_prediction(
processor, partial_tracker, expected_action
)

try:
predicted_action, prediction, policy_entity_result = _run_action_prediction(
processor, partial_tracker, expected_action
)
except ActionLimitReached:
prediction = PolicyPrediction([], policy_name=None)
predicted_action = "circuit breaker tripped"

action_executed_eval_store.add_to_store(
action_predictions=[predicted_action], action_targets=[expected_action]
Expand Down Expand Up @@ -761,25 +766,16 @@ async def _predict_tracker_actions(
)

tracker_actions = []
should_predict_another_action = True
num_predicted_actions = 0
policy_entity_results = []

for event in events[1:]:
if isinstance(event, ActionExecuted):
circuit_breaker_tripped = processor.is_action_limit_reached(
num_predicted_actions, should_predict_another_action
)
(
action_executed_result,
prediction,
entity_result,
) = _collect_action_executed_predictions(
processor,
partial_tracker,
event,
fail_on_prediction_errors,
circuit_breaker_tripped,
processor, partial_tracker, event, fail_on_prediction_errors,
)

if entity_result:
Expand All @@ -795,10 +791,6 @@ async def _predict_tracker_actions(
"confidence": prediction.max_confidence,
}
)
should_predict_another_action = processor.should_predict_another_action(
action_executed_result.action_predictions[0]
)
num_predicted_actions += 1

elif use_e2e and isinstance(event, UserUttered):
# This means that user utterance didn't have a user message, only intent,
Expand All @@ -818,8 +810,6 @@ async def _predict_tracker_actions(
tracker_eval_store.merge_store(user_uttered_result)
else:
partial_tracker.update(event)
if isinstance(event, UserUttered):
num_predicted_actions = 0

return tracker_eval_store, partial_tracker, tracker_actions, policy_entity_results

Expand Down
4 changes: 4 additions & 0 deletions rasa/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ def __init__(self, timestamp: float) -> None:
def __str__(self) -> Text:
"""Returns string representation of exception."""
return str(self.timestamp)


class ActionLimitReached(RasaException):
"""Raised when predicted action limit is reached."""
31 changes: 31 additions & 0 deletions tests/core/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
UserMessage,
OutputChannel,
)
from rasa.exceptions import ActionLimitReached
from rasa.shared.core.domain import SessionConfig, Domain, KEY_ACTIONS
from rasa.shared.core.events import (
ActionExecuted,
Expand Down Expand Up @@ -1297,3 +1298,33 @@ def test_predict_next_action_with_hidden_rules():
action, prediction = processor.predict_next_action(tracker)
assert isinstance(action, ActionListen)
assert not prediction.hide_rule_turn


def test_predict_next_action_raises_limit_reached_exception(domain: Domain):
interpreter = RegexInterpreter()
ensemble = SimplePolicyEnsemble(policies=[RulePolicy(), MemoizationPolicy()])
tracker_store = InMemoryTrackerStore(domain)
lock_store = InMemoryLockStore()

processor = MessageProcessor(
interpreter,
ensemble,
domain,
tracker_store,
lock_store,
TemplatedNaturalLanguageGenerator(domain.responses),
max_number_of_predictions=1,
)

tracker = DialogueStateTracker.from_events(
"test",
evts=[
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered("Hi!"),
ActionExecuted("test_action"),
],
)
tracker.set_latest_action({"action_name": "test_action"})

with pytest.raises(ActionLimitReached):
processor.predict_next_action(tracker)

0 comments on commit 7c2206d

Please sign in to comment.