Skip to content

Commit

Permalink
Merge pull request #5446 from RasaHQ/append-events-without-session-start
Browse files Browse the repository at this point in the history
No session start when appending events via API
  • Loading branch information
wochinge authored Mar 20, 2020
2 parents 85ff081 + 5f9453d commit a39bc32
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 33 deletions.
6 changes: 6 additions & 0 deletions changelog/5446.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
The endpoint ``PUT /conversations/<conversation_id>/tracker/events`` no longer
adds session start events (to learn more about conversation sessions, please
see :ref:`session_config`) in addition to the events which were sent in the request
payload. To achieve the old behavior send a
``GET /conversations/<conversation_id>/tracker``
request before appending events.
28 changes: 21 additions & 7 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,34 @@ async def get_tracker_with_session_start(
Tracker for `sender_id` if available, `None` otherwise.
"""

tracker = self._get_tracker(sender_id)
tracker = self.get_tracker(sender_id)
if not tracker:
return None

await self._update_tracker_session(tracker, output_channel)

return tracker

def get_tracker(self, conversation_id: Text) -> Optional[DialogueStateTracker]:
"""Get the tracker for a conversation.
In contrast to `get_tracker_with_session_start` this does not add any
`action_session_start` or `session_start` events at the beginning of a
conversation.
Args:
conversation_id: The ID of the conversation for which the history should be
retrieved.
Returns:
Tracker for the conversation. Creates an empty tracker in case it's a new
conversation.
"""
conversation_id = conversation_id or UserMessage.DEFAULT_SENDER_ID
return self.tracker_store.get_or_create_tracker(
conversation_id, append_action_listen=False
)

async def log_message(
self, message: UserMessage, should_save_tracker: bool = True
) -> Optional[DialogueStateTracker]:
Expand Down Expand Up @@ -713,12 +733,6 @@ def _has_session_expired(self, tracker: DialogueStateTracker) -> bool:

return has_expired

def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]:
sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID
return self.tracker_store.get_or_create_tracker(
sender_id, append_action_listen=False
)

def _save_tracker(self, tracker: DialogueStateTracker) -> None:
self.tracker_store.save(tracker)

Expand Down
31 changes: 14 additions & 17 deletions rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,25 @@ def event_verbosity_parameter(

async def get_tracker(
processor: "MessageProcessor", conversation_id: Text
) -> Optional[DialogueStateTracker]:
) -> DialogueStateTracker:
"""Get tracker object from `MessageProcessor`."""
tracker = await processor.get_tracker_with_session_start(conversation_id)
_validate_tracker(tracker, conversation_id)

# `_validate_tracker` ensures we can't return `None` so `Optional` is not needed
return tracker # pytype: disable=bad-return-type


def _validate_tracker(
tracker: Optional[DialogueStateTracker], conversation_id: Text
) -> None:
if not tracker:
raise ErrorResponse(
409,
"Conflict",
f"Could not retrieve tracker with id '{conversation_id}'. Most likely "
f"Could not retrieve tracker with ID '{conversation_id}'. Most likely "
f"because there is no domain set on the agent.",
)
return tracker


def validate_request_body(request: Request, error_message: Text):
Expand Down Expand Up @@ -479,12 +487,10 @@ async def append_events(request: Request, conversation_id: Text):

try:
async with app.agent.lock_store.lock(conversation_id):
tracker = await get_tracker(
app.agent.create_processor(), conversation_id
)
processor = app.agent.create_processor()
tracker = processor.get_tracker(conversation_id)
_validate_tracker(tracker, conversation_id)

# Get events after tracker initialization to ensure that generated
# timestamps are after potential session events.
events = _get_events_from_request_body(request)

for event in events:
Expand Down Expand Up @@ -590,15 +596,6 @@ async def execute_action(request: Request, conversation_id: Text):
{"parameter": "name", "in": "body"},
)

# Deprecation warning
raise_warning(
"Triggering actions via the execute endpoint is deprecated. "
"Trigger an intent via the "
"`/conversations/<conversation_id>/trigger_intent` "
"endpoint instead.",
FutureWarning,
)

policy = request_params.get("policy", None)
confidence = request_params.get("confidence", None)
verbosity = event_verbosity_parameter(request, EventVerbosity.AFTER_RESTART)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,14 +627,15 @@ def test_requesting_non_existent_tracker(rasa_app: SanicTestClient):


@pytest.mark.parametrize("event", test_events)
def test_pushing_event(rasa_app, event):
def test_pushing_event(rasa_app: SanicTestClient, event: Event):
sender_id = str(uuid.uuid1())
conversation = f"/conversations/{sender_id}"

serialized_event = event.as_dict()
# Remove timestamp so that a new one is assigned on the server
serialized_event.pop("timestamp")

time_before_adding_events = time.time()
_, response = rasa_app.post(
f"{conversation}/tracker/events",
json=serialized_event,
Expand All @@ -647,17 +648,17 @@ def test_pushing_event(rasa_app, event):
tracker = tracker_response.json
assert tracker is not None

assert len(tracker.get("events")) == 4
assert len(tracker.get("events")) == 1

evt = tracker.get("events")[3]
evt = tracker.get("events")[0]
deserialised_event = Event.from_parameters(evt)
assert deserialised_event == event
assert deserialised_event.timestamp > tracker.get("events")[2]["timestamp"]
assert deserialised_event.timestamp > time_before_adding_events


def test_push_multiple_events(rasa_app: SanicTestClient):
cid = str(uuid.uuid1())
conversation = f"/conversations/{cid}"
conversation_id = str(uuid.uuid1())
conversation = f"/conversations/{conversation_id}"

events = [e.as_dict() for e in test_events]
_, response = rasa_app.post(
Expand All @@ -668,13 +669,12 @@ def test_push_multiple_events(rasa_app: SanicTestClient):
assert response.json is not None
assert response.status == 200

_, tracker_response = rasa_app.get(f"/conversations/{cid}/tracker")
_, tracker_response = rasa_app.get(f"/conversations/{conversation_id}/tracker")
tracker = tracker_response.json
assert tracker is not None

# there is also an `ACTION_LISTEN` event at the start
assert len(tracker.get("events")) == len(test_events) + 3
assert tracker.get("events")[3:] == events
assert tracker.get("events") == events


def test_put_tracker(rasa_app: SanicTestClient):
Expand Down

0 comments on commit a39bc32

Please sign in to comment.