Skip to content

Commit

Permalink
Merge pull request #4544 from RasaHQ/non-binary-tracker-serialisation
Browse files Browse the repository at this point in the history
Non-binary tracker serialisation
  • Loading branch information
ricwo authored Oct 9, 2019
2 parents e39a0d3 + 8e5cfa1 commit 39300da
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 29 deletions.
10 changes: 8 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ This project adheres to `Semantic Versioning`_ starting with version 1.0.

Added
-----
- log a warning if the ``Interpreter`` picks up an intent or an entity that does not exist in the domain file.
- log a warning if the ``Interpreter`` picks up an intent or an entity that does not
exist in the domain file.
- added ``DynamoTrackerStore`` to support persistence of agents running on AWS
- added docstrings for ``TrackerStore`` classes
- added buttons and images to mattermost.
- `CRFEntityExtractor` updated to accept arbitrary token-level features like word vectors (issues/4214)
- `CRFEntityExtractor` updated to accept arbitrary token-level features like word
vectors (issues/4214)
- `SpacyFeaturizer` updated to add `ner_features` for `CRFEntityExtractor`
- Sanitizing incoming messages from slack to remove slack formatting like <mailto:[email protected]|[email protected]>
or <http://url.com|url.com> and substitute it with original content
Expand All @@ -33,6 +35,10 @@ Changed
You can do so by overwriting the method ``get_metadata``. The return value of this
method will be passed to the ``UserMessage`` object.
- Tests can now be run in parallel
- Serialise ``DialogueStateTracker`` as json instead of pickle. **DEPRECATION warning**:
Deserialisation of pickled trackers will be deprecated in version 2.0. For now,
trackers are still loaded from pickle but will be dumped as json in any subsequent
save operations.

Removed
-------
Expand Down
29 changes: 23 additions & 6 deletions rasa/core/conversation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import typing
from typing import Dict, List, Text
from typing import Dict, List, Text, Any

if typing.TYPE_CHECKING:
from rasa.core.events import Event
from rasa.core.events import Event


class Dialogue(object):
"""A dialogue comprises a list of Turn objects"""

def __init__(self, name: Text, events: List["Event"]) -> None:
"""This function initialises the dialogue with the dialogue name and the event list."""
"""This function initialises the dialogue with the dialogue name and the event
list."""
self.name = name
self.events = events

Expand All @@ -20,5 +19,23 @@ def __str__(self) -> Text:
)

def as_dict(self) -> Dict:
"""This function returns the dialogue as a dictionary to assist in serialization"""
"""This function returns the dialogue as a dictionary to assist in
serialization."""
return {"events": [event.as_dict() for event in self.events], "name": self.name}

@classmethod
def from_parameters(cls, parameters: Dict[Text, Any]) -> "Dialogue":
"""Create `Dialogue` from parameters.
Args:
parameters: Serialised dialogue, should contain keys 'name' and 'events'.
Returns:
Deserialised `Dialogue`.
"""

return cls(
parameters.get("name"),
[Event.from_parameters(evt) for evt in parameters.get("events")],
)
58 changes: 44 additions & 14 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
import typing
from datetime import datetime, timezone
from typing import Iterator, Optional, Text, Iterable, Union, Dict

import itertools
from boto3.dynamodb.conditions import Key

# noinspection PyPep8Naming
from time import sleep

import boto3
from boto3.dynamodb.conditions import Key

from rasa.core.actions.action import ACTION_LISTEN_NAME
from rasa.core.brokers.event_channel import EventChannel
from rasa.core.conversation import Dialogue
from rasa.core.domain import Domain
from rasa.core.trackers import ActionExecuted, DialogueStateTracker, EventVerbosity
from rasa.core.utils import replace_floats_with_decimals
Expand All @@ -26,6 +26,7 @@
from sqlalchemy.engine.url import URL
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
import boto3

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,21 +157,45 @@ def keys(self) -> Iterable[Text]:
raise NotImplementedError()

@staticmethod
def serialise_tracker(tracker):
"""Serializes the tracker, returns representation of the tracker"""
def serialise_tracker(tracker: DialogueStateTracker) -> Text:
"""Serializes the tracker, returns representation of the tracker."""
dialogue = tracker.as_dialogue()
return pickle.dumps(dialogue)

def deserialise_tracker(self, sender_id, _json) -> Optional[DialogueStateTracker]:
"""Deserializes the tracker and returns it"""
dialogue = pickle.loads(_json)
return json.dumps(dialogue.as_dict())

@staticmethod
def _deserialise_dialogue_from_pickle(
sender_id: Text, serialised_tracker: bytes
) -> Dialogue:

logger.warning(
f"DEPRECATION warning: Found pickled tracker for "
f"conversation ID '{sender_id}'. Deserialisation of pickled "
f"trackers will be deprecated in version 2.0. Rasa will perform any "
f"future save operations of this tracker using json serialisation."
)
return pickle.loads(serialised_tracker)

def deserialise_tracker(
self, sender_id: Text, serialised_tracker: Union[Text, bytes]
) -> Optional[DialogueStateTracker]:
"""Deserializes the tracker and returns it."""

tracker = self.init_tracker(sender_id)
if tracker:
tracker.recreate_from_dialogue(dialogue)
return tracker
else:
if not tracker:
return None

try:
dialogue = Dialogue.from_parameters(json.loads(serialised_tracker))
except UnicodeDecodeError:
dialogue = self._deserialise_dialogue_from_pickle(
sender_id, serialised_tracker
)

tracker.recreate_from_dialogue(dialogue)

return tracker


class InMemoryTrackerStore(TrackerStore):
"""Stores conversation history in memory"""
Expand Down Expand Up @@ -274,6 +299,8 @@ def __init__(
table_name: The name of the DynamoDb table, does not need to be present a priori.
event_broker:
"""
import boto3

self.client = boto3.client("dynamodb", region_name=region)
self.region = region
self.table_name = table_name
Expand All @@ -284,6 +311,8 @@ def get_or_create_table(
self, table_name: Text
) -> "boto3.resources.factory.dynamodb.Table":
"""Returns table or creates one if the table name is not in the table list"""
import boto3

dynamo = boto3.resource("dynamodb", region_name=self.region)
if self.table_name not in self.client.list_tables()["TableNames"]:
table = dynamo.create_table(
Expand Down Expand Up @@ -323,7 +352,8 @@ def serialise_tracker(self, tracker: "DialogueStateTracker") -> Dict:
def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
"""Create a tracker from all previously stored events."""

# Retrieve dialogues for a sender_id in reverse chronological order based on the session_date sort key
# Retrieve dialogues for a sender_id in reverse chronological order based on
# the session_date sort key
dialogues = self.db.query(
KeyConditionExpression=Key("sender_id").eq(sender_id),
Limit=1,
Expand Down
23 changes: 23 additions & 0 deletions tests/cli/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
parse_last_positional_argument_as_model_path,
get_validated_path,
)
from tests.conftest import assert_log_emitted


@pytest.mark.parametrize(
Expand Down Expand Up @@ -84,3 +85,25 @@ def test_validate_with_invalid_directory_if_default_is_valid(caplog: LogCaptureF
def test_print_error_and_exit():
with pytest.raises(SystemExit):
rasa.cli.utils.print_error_and_exit("")


def test_logging_capture(caplog: LogCaptureFixture):
logger = logging.getLogger(__name__)

# make a random INFO log and ensure it passes decorator
info_text = "SOME INFO"
logger.info(info_text)
with assert_log_emitted(caplog, logger.name, logging.INFO, info_text):
pass


def test_logging_capture_failure(caplog: LogCaptureFixture):
logger = logging.getLogger(__name__)

# make a random INFO log
logger.info("SOME INFO")

# test for string in log that wasn't emitted
with pytest.raises(AssertionError):
with assert_log_emitted(caplog, logger.name, logging.INFO, "NONONO"):
pass
49 changes: 48 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import contextmanager
from typing import Text, List

import pytest
Expand Down Expand Up @@ -218,5 +219,51 @@ async def rasa_server_without_api() -> Sanic:
def get_test_client(server):
test_client = server.test_client
test_client.port = None

return test_client


@contextmanager
def assert_log_emitted(
_caplog: LogCaptureFixture, logger_name: Text, log_level: int, text: Text = None
) -> None:
"""Context manager testing whether a logging message has been emitted.
Provides a context in which an assertion is made about a logging message.
Raises an `AssertionError` if the log isn't emitted as expected.
Example usage:
```
with assert_log_emitted(caplog, LOGGER_NAME, LOGGING_LEVEL, TEXT):
<method supposed to emit TEXT at level LOGGING_LEVEL>
```
Args:
_caplog: `LogCaptureFixture` used to capture logs.
logger_name: Name of the logger being examined.
log_level: Log level to be tested.
text: Logging message to be tested (optional). If left blank, assertion is made
only about `log_level` and `logger_name`.
Yields:
`None`
"""

yield

record_tuples = _caplog.record_tuples

if not any(
(
record[0] == logger_name
and record[1] == log_level
and (text in record[2] if text else True)
)
for record in record_tuples
):
raise AssertionError(
f"Did not detect expected logging output.\nExpected output is (logger "
f"name, log level, text): ({logger_name}, {log_level}, {text})\n"
f"Instead found records:\n{record_tuples}"
)
12 changes: 11 additions & 1 deletion tests/core/test_dialogues.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytest

import rasa.utils.io
from rasa.core.conversation import Dialogue
from rasa.core.domain import Domain
from rasa.core.tracker_store import InMemoryTrackerStore
from tests.core.utilities import tracker_from_dialogue_file
from tests.core.conftest import TEST_DIALOGUES, EXAMPLE_DOMAINS
from tests.core.utilities import tracker_from_dialogue_file


@pytest.mark.parametrize("filename", TEST_DIALOGUES)
Expand Down Expand Up @@ -36,3 +37,12 @@ def test_tracker_restaurant():
tracker = tracker_from_dialogue_file(filename, domain)
assert tracker.get_slot("price") == "lo"
assert tracker.get_slot("name") is None # slot doesn't exist!


def test_dialogue_from_parameters():
domain = Domain.load("examples/restaurantbot/domain.yml")
filename = "data/test_dialogues/restaurantbot.json"
tracker = tracker_from_dialogue_file(filename, domain)
serialised_dialogue = InMemoryTrackerStore.serialise_tracker(tracker)
deserialised_dialogue = Dialogue.from_parameters(json.loads(serialised_dialogue))
assert tracker.as_dialogue().as_dict() == deserialised_dialogue.as_dict()
46 changes: 41 additions & 5 deletions tests/core/test_tracker_stores.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import tempfile
from typing import Tuple

import pytest
from _pytest.logging import LogCaptureFixture
from moto import mock_dynamodb2

from rasa.core.channels.channel import UserMessage
Expand All @@ -13,8 +16,10 @@
SQLTrackerStore,
DynamoTrackerStore,
)

import rasa.core.tracker_store
from rasa.core.trackers import DialogueStateTracker
from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
from tests.conftest import assert_log_emitted
from tests.core.conftest import DEFAULT_ENDPOINTS_FILE

domain = Domain.load("data/test_domains/default.yml")
Expand Down Expand Up @@ -139,23 +144,54 @@ def test_tracker_store_from_invalid_string(default_domain):
assert isinstance(tracker_store, InMemoryTrackerStore)


def test_tracker_serialisation():
slot_key = "location"
slot_val = "Easter Island"
def _tracker_store_and_tracker_with_slot_set() -> Tuple[
InMemoryTrackerStore, DialogueStateTracker
]:
# returns an InMemoryTrackerStore containing a tracker with a slot set

store = InMemoryTrackerStore(domain)
slot_key = "cuisine"
slot_val = "French"

store = InMemoryTrackerStore(domain)
tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID)
ev = SlotSet(slot_key, slot_val)
tracker.update(ev)

return store, tracker


def test_tracker_serialisation():
store, tracker = _tracker_store_and_tracker_with_slot_set()
serialised = store.serialise_tracker(tracker)

assert tracker == store.deserialise_tracker(
UserMessage.DEFAULT_SENDER_ID, serialised
)


def test_deprecated_pickle_deserialisation(caplog: LogCaptureFixture):
def pickle_serialise_tracker(_tracker):
# mocked version of TrackerStore.serialise_tracker() that uses
# the deprecated pickle serialisation
import pickle

dialogue = _tracker.as_dialogue()

return pickle.dumps(dialogue)

store, tracker = _tracker_store_and_tracker_with_slot_set()

serialised = pickle_serialise_tracker(tracker)

# deprecation warning should be emitted
with assert_log_emitted(
caplog, rasa.core.tracker_store.logger.name, logging.WARNING, "DEPRECATION"
):
assert tracker == store.deserialise_tracker(
UserMessage.DEFAULT_SENDER_ID, serialised
)


@pytest.mark.parametrize(
"full_url",
[
Expand Down

0 comments on commit 39300da

Please sign in to comment.