Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Locking Mechanism to Reminders Handler #8001

Merged
merged 5 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/8001.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed bug where the conversation does not lock before handling a reminder event.
1 change: 1 addition & 0 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def create_processor(
self.policy_ensemble,
self.domain,
self.tracker_store,
self.lock_store,
self.nlg,
action_endpoint=self.action_endpoint,
message_preprocessor=preprocessor,
Expand Down
37 changes: 21 additions & 16 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
UTTER_PREFIX,
)
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.lock_store import LockStore
from rasa.core.policies.ensemble import PolicyEnsemble
import rasa.core.tracker_store
import rasa.shared.core.trackers
Expand All @@ -63,6 +64,7 @@ def __init__(
policy_ensemble: PolicyEnsemble,
domain: Domain,
tracker_store: rasa.core.tracker_store.TrackerStore,
lock_store: LockStore,
generator: NaturalLanguageGenerator,
action_endpoint: Optional[EndpointConfig] = None,
max_number_of_predictions: int = MAX_NUMBER_OF_PREDICTIONS,
Expand All @@ -74,6 +76,7 @@ def __init__(
self.policy_ensemble = policy_ensemble
self.domain = domain
self.tracker_store = tracker_store
self.lock_store = lock_store
self.max_number_of_predictions = max_number_of_predictions
self.message_preprocessor = message_preprocessor
self.on_circuit_break = on_circuit_break
Expand Down Expand Up @@ -418,24 +421,26 @@ async def handle_reminder(
output_channel: OutputChannel,
) -> None:
"""Handle a reminder that is triggered asynchronously."""

tracker = await self.fetch_tracker_and_update_session(sender_id, output_channel)

if (
reminder_event.kill_on_user_message
and self._has_message_after_reminder(tracker, reminder_event)
or not self._is_reminder_still_valid(tracker, reminder_event)
):
logger.debug(
f"Canceled reminder because it is outdated ({reminder_event})."
)
else:
intent = reminder_event.intent
entities = reminder_event.entities or {}
await self.trigger_external_user_uttered(
intent, entities, tracker, output_channel
async with self.lock_store.lock(sender_id):
tracker = await self.fetch_tracker_and_update_session(
sender_id, output_channel
)

if (
reminder_event.kill_on_user_message
and self._has_message_after_reminder(tracker, reminder_event)
or not self._is_reminder_still_valid(tracker, reminder_event)
):
logger.debug(
f"Canceled reminder because it is outdated ({reminder_event})."
)
else:
intent = reminder_event.intent
entities = reminder_event.entities or {}
await self.trigger_external_user_uttered(
intent, entities, tracker, output_channel
)

async def trigger_external_user_uttered(
self,
intent_name: Text,
Expand Down
2 changes: 2 additions & 0 deletions tests/core/actions/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rasa.core.policies.policy import PolicyPrediction
from rasa.core.processor import MessageProcessor
from rasa.core.tracker_store import InMemoryTrackerStore
from rasa.core.lock_store import InMemoryLockStore
from rasa.core.actions import action
from rasa.core.actions.action import ActionExecutionRejection
from rasa.shared.core.constants import ACTION_LISTEN_NAME, REQUESTED_SLOT
Expand Down Expand Up @@ -144,6 +145,7 @@ async def test_switch_forms_with_same_slot(default_agent: Agent):
default_agent.policy_ensemble,
domain,
InMemoryTrackerStore(domain),
InMemoryLockStore(),
TemplatedNaturalLanguageGenerator(domain.templates),
)

Expand Down
3 changes: 3 additions & 0 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from rasa.core.processor import MessageProcessor
from rasa.shared.core.slots import Slot
from rasa.core.tracker_store import InMemoryTrackerStore, MongoTrackerStore
from rasa.core.lock_store import LockStore, InMemoryLockStore
from rasa.shared.core.trackers import DialogueStateTracker

DEFAULT_DOMAIN_PATH_WITH_SLOTS = "data/test_domains/default_with_slots.yml"
Expand Down Expand Up @@ -142,11 +143,13 @@ def default_channel() -> OutputChannel:
@pytest.fixture
async def default_processor(default_agent: Agent) -> MessageProcessor:
tracker_store = InMemoryTrackerStore(default_agent.domain)
lock_store = InMemoryLockStore()
return MessageProcessor(
default_agent.interpreter,
default_agent.policy_ensemble,
default_agent.domain,
tracker_store,
lock_store,
TemplatedNaturalLanguageGenerator(default_agent.domain.templates),
)

Expand Down
45 changes: 42 additions & 3 deletions tests/core/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
import json
from _pytest.monkeypatch import MonkeyPatch
from _pytest.logging import LogCaptureFixture
from aioresponses import aioresponses
from typing import Optional, Text, List, Callable, Type, Any, Tuple
from unittest.mock import patch, Mock
Expand Down Expand Up @@ -50,6 +51,7 @@
from rasa.core.processor import MessageProcessor
from rasa.shared.core.slots import Slot, AnySlot
from rasa.core.tracker_store import InMemoryTrackerStore
from rasa.core.lock_store import InMemoryLockStore
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.nlu.constants import INTENT_NAME_KEY
from rasa.utils.endpoints import EndpointConfig
Expand Down Expand Up @@ -134,7 +136,9 @@ async def test_http_parsing():

inter = RasaNLUHttpInterpreter(endpoint_config=endpoint)
try:
await MessageProcessor(inter, None, None, None, None).parse_message(message)
await MessageProcessor(inter, None, None, None, None, None).parse_message(
message
)
except KeyError:
pass # logger looks for intent and entities, so we except

Expand Down Expand Up @@ -204,6 +208,29 @@ async def test_reminder_scheduled(
)


async def test_reminder_lock(
default_channel: CollectingOutputChannel,
default_processor: MessageProcessor,
caplog: LogCaptureFixture,
):
caplog.clear()
with caplog.at_level(logging.DEBUG):
sender_id = uuid.uuid4().hex

reminder = ReminderScheduled("remind", datetime.datetime.now())
tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

tracker.update(UserUttered("test"))
tracker.update(ActionExecuted("action_schedule_reminder"))
tracker.update(reminder)

default_processor.tracker_store.save(tracker)

await default_processor.handle_reminder(reminder, sender_id, default_channel)

assert f"Deleted lock for conversation '{sender_id}'." in caplog.text


async def test_trigger_external_latest_input_channel(
default_channel: CollectingOutputChannel, default_processor: MessageProcessor
):
Expand Down Expand Up @@ -853,7 +880,12 @@ def predict_action_probabilities(
domain = Domain.empty()

processor = MessageProcessor(
test_interpreter, ensemble, domain, InMemoryTrackerStore(domain), Mock()
test_interpreter,
ensemble,
domain,
InMemoryTrackerStore(domain),
InMemoryLockStore(),
Mock(),
)

# This should not raise
Expand Down Expand Up @@ -883,7 +915,12 @@ def test_get_next_action_probabilities_pass_policy_predictions_without_interpret
domain = Domain.empty()

processor = MessageProcessor(
interpreter, ensemble, domain, InMemoryTrackerStore(domain), Mock()
interpreter,
ensemble,
domain,
InMemoryTrackerStore(domain),
InMemoryLockStore(),
Mock(),
)

with pytest.warns(DeprecationWarning):
Expand Down Expand Up @@ -1173,11 +1210,13 @@ def probabilities_using_best_policy(
return PolicyPrediction.for_action_name(domain, ACTION_LISTEN_NAME)

tracker_store = InMemoryTrackerStore(domain)
lock_store = InMemoryLockStore()
processor = MessageProcessor(
RegexInterpreter(),
ConstantEnsemble(),
domain,
tracker_store,
lock_store,
NaturalLanguageGenerator.create(None, domain),
)

Expand Down