Skip to content

Commit

Permalink
apply policy events in processor
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Nov 4, 2020
1 parent 7d1ed27 commit 20e64bd
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
7 changes: 5 additions & 2 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.policies.ensemble import PolicyEnsemble, SimplePolicyEnsemble
from rasa.core.policies.memoization import MemoizationPolicy
from rasa.core.policies.policy import Policy
from rasa.core.policies.policy import Policy, PolicyPrediction
from rasa.core.processor import MessageProcessor
from rasa.core.tracker_store import (
FailSafeTrackerStore,
Expand Down Expand Up @@ -554,8 +554,11 @@ async def execute_action(
"""Handle a single message."""

processor = self.create_processor()
prediction = PolicyPrediction.for_action_name(
self.domain, action, policy, confidence
)
return await processor.execute_action(
sender_id, action, output_channel, self.nlg, policy, confidence
sender_id, action, output_channel, self.nlg, prediction
)

async def trigger_intent(
Expand Down
1 change: 0 additions & 1 deletion rasa/core/policies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,6 @@ def _pick_best_policy(

best_prediction = predictions[best_policy_name]

# Apply policy events to tracker
policy_events += best_prediction.optional_events

return PolicyPrediction(
Expand Down
5 changes: 4 additions & 1 deletion rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,10 @@ def __init__(

@staticmethod
def for_action_name(
domain: Domain, action_name: Text, policy_name: Text, confidence: float = 1
domain: Domain,
action_name: Text,
policy_name: Optional[Text] = None,
confidence: float = 1,
) -> "PolicyPrediction":
"""Create a prediction for a given action.
Expand Down
27 changes: 20 additions & 7 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ async def _update_tracker_session(
output_channel=output_channel,
nlg=self.nlg,
metadata=metadata,
prediction=PolicyPrediction.for_action_name(
self.domain, ACTION_SESSION_START_NAME
),
)

async def fetch_tracker_and_update_session(
Expand Down Expand Up @@ -284,16 +287,16 @@ async def execute_action(
action_name: Text,
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
policy: Text,
confidence: float,
prediction: PolicyPrediction,
) -> Optional[DialogueStateTracker]:

# we have a Tracker instance for each user
# which maintains conversation state
tracker = await self.fetch_tracker_and_update_session(sender_id, output_channel)

action = self._get_action(action_name)
await self._run_action(action, tracker, output_channel, nlg, policy, confidence)

await self._run_action(action, tracker, output_channel, nlg, prediction)

# save tracker state to continue conversation from this state
self._save_tracker(tracker)
Expand Down Expand Up @@ -699,6 +702,7 @@ async def _run_action(
# be passed to the SessionStart event. Otherwise the metadata will be lost.
if action.name() == ACTION_SESSION_START_NAME:
action.metadata = metadata
# TODO: Needs temporary events with policy events
events = await action.run(output_channel, nlg, tracker, self.domain)
except rasa.core.actions.action.ActionExecutionRejection:
events = [
Expand Down Expand Up @@ -770,23 +774,32 @@ def _log_action_on_tracker(
if events is None:
events = []

logger.debug(
f"Action '{action_name}' ended with events '{[e for e in events]}'."
)

self._warn_about_new_slots(tracker, action_name, events)

action_was_rejected_manually = any(
isinstance(event, ActionExecutionRejected) for event in events
)
if action_name is not None and not action_was_rejected_manually:
logger.debug(f"Policy prediction ended with events '{prediction.events}'.")

# Apply events from policy predictions
# TODO: Test this works and think about location
for e in prediction.events:
# this makes sure the events are ordered by timestamp -
# since the event objects are created somewhere else,
# the timestamp would indicate a time before the time
# of the action executed
e.timestamp = time.time()
tracker.update(e, self.domain)

# log the action and its produced events
tracker.update(
ActionExecuted(
action_name, prediction.policy_name, prediction.max_confidence
)
)

logger.debug(f"Action '{action_name}' ended with events '{events}'.")
for e in events:
# this makes sure the events are ordered by timestamp -
# since the event objects are created somewhere else,
Expand Down

0 comments on commit 20e64bd

Please sign in to comment.