diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 42240b78b262..0e406f2f6313 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,11 +12,13 @@ This project adheres to `Semantic Versioning`_ starting with version 1.0. Added ----- -- log a warning if the ``Interpreter`` picks up an intent or an entity that does not exist in the domain file. +- log a warning if the ``Interpreter`` picks up an intent or an entity that does not + exist in the domain file. - added ``DynamoTrackerStore`` to support persistence of agents running on AWS - added docstrings for ``TrackerStore`` classes - added buttons and images to mattermost. -- `CRFEntityExtractor` updated to accept arbitrary token-level features like word vectors (issues/4214) +- `CRFEntityExtractor` updated to accept arbitrary token-level features like word + vectors (issues/4214) - `SpacyFeaturizer` updated to add `ner_features` for `CRFEntityExtractor` - Sanitizing incoming messages from slack to remove slack formatting like or and substitute it with original content @@ -33,6 +35,10 @@ Changed You can do so by overwriting the method ``get_metadata``. The return value of this method will be passed to the ``UserMessage`` object. - Tests can now be run in parallel +- Serialise ``DialogueStateTracker`` as json instead of pickle. **DEPRECATION warning**: + Deserialisation of pickled trackers will be deprecated in version 2.0. For now, + trackers are still loaded from pickle but will be dumped as json in any subsequent + save operations. Removed ------- diff --git a/rasa/core/conversation.py b/rasa/core/conversation.py index 072dad833ca8..8c8742243efa 100644 --- a/rasa/core/conversation.py +++ b/rasa/core/conversation.py @@ -1,15 +1,14 @@ -import typing -from typing import Dict, List, Text +from typing import Dict, List, Text, Any -if typing.TYPE_CHECKING: - from rasa.core.events import Event +from rasa.core.events import Event class Dialogue(object): """A dialogue comprises a list of Turn objects""" def __init__(self, name: Text, events: List["Event"]) -> None: - """This function initialises the dialogue with the dialogue name and the event list.""" + """This function initialises the dialogue with the dialogue name and the event + list.""" self.name = name self.events = events @@ -20,5 +19,23 @@ def __str__(self) -> Text: ) def as_dict(self) -> Dict: - """This function returns the dialogue as a dictionary to assist in serialization""" + """This function returns the dialogue as a dictionary to assist in + serialization.""" return {"events": [event.as_dict() for event in self.events], "name": self.name} + + @classmethod + def from_parameters(cls, parameters: Dict[Text, Any]) -> "Dialogue": + """Create `Dialogue` from parameters. + + Args: + parameters: Serialised dialogue, should contain keys 'name' and 'events'. + + Returns: + Deserialised `Dialogue`. + + """ + + return cls( + parameters.get("name"), + [Event.from_parameters(evt) for evt in parameters.get("events")], + ) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index d591af269956..ea65db072d15 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -6,16 +6,16 @@ import typing from datetime import datetime, timezone from typing import Iterator, Optional, Text, Iterable, Union, Dict + import itertools +from boto3.dynamodb.conditions import Key # noinspection PyPep8Naming from time import sleep -import boto3 -from boto3.dynamodb.conditions import Key - from rasa.core.actions.action import ACTION_LISTEN_NAME from rasa.core.brokers.event_channel import EventChannel +from rasa.core.conversation import Dialogue from rasa.core.domain import Domain from rasa.core.trackers import ActionExecuted, DialogueStateTracker, EventVerbosity from rasa.core.utils import replace_floats_with_decimals @@ -26,6 +26,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session + import boto3 logger = logging.getLogger(__name__) @@ -156,21 +157,45 @@ def keys(self) -> Iterable[Text]: raise NotImplementedError() @staticmethod - def serialise_tracker(tracker): - """Serializes the tracker, returns representation of the tracker""" + def serialise_tracker(tracker: DialogueStateTracker) -> Text: + """Serializes the tracker, returns representation of the tracker.""" dialogue = tracker.as_dialogue() - return pickle.dumps(dialogue) - def deserialise_tracker(self, sender_id, _json) -> Optional[DialogueStateTracker]: - """Deserializes the tracker and returns it""" - dialogue = pickle.loads(_json) + return json.dumps(dialogue.as_dict()) + + @staticmethod + def _deserialise_dialogue_from_pickle( + sender_id: Text, serialised_tracker: bytes + ) -> Dialogue: + + logger.warning( + f"DEPRECATION warning: Found pickled tracker for " + f"conversation ID '{sender_id}'. Deserialisation of pickled " + f"trackers will be deprecated in version 2.0. Rasa will perform any " + f"future save operations of this tracker using json serialisation." + ) + return pickle.loads(serialised_tracker) + + def deserialise_tracker( + self, sender_id: Text, serialised_tracker: Union[Text, bytes] + ) -> Optional[DialogueStateTracker]: + """Deserializes the tracker and returns it.""" + tracker = self.init_tracker(sender_id) - if tracker: - tracker.recreate_from_dialogue(dialogue) - return tracker - else: + if not tracker: return None + try: + dialogue = Dialogue.from_parameters(json.loads(serialised_tracker)) + except UnicodeDecodeError: + dialogue = self._deserialise_dialogue_from_pickle( + sender_id, serialised_tracker + ) + + tracker.recreate_from_dialogue(dialogue) + + return tracker + class InMemoryTrackerStore(TrackerStore): """Stores conversation history in memory""" @@ -274,6 +299,8 @@ def __init__( table_name: The name of the DynamoDb table, does not need to be present a priori. event_broker: """ + import boto3 + self.client = boto3.client("dynamodb", region_name=region) self.region = region self.table_name = table_name @@ -284,6 +311,8 @@ def get_or_create_table( self, table_name: Text ) -> "boto3.resources.factory.dynamodb.Table": """Returns table or creates one if the table name is not in the table list""" + import boto3 + dynamo = boto3.resource("dynamodb", region_name=self.region) if self.table_name not in self.client.list_tables()["TableNames"]: table = dynamo.create_table( @@ -323,7 +352,8 @@ def serialise_tracker(self, tracker: "DialogueStateTracker") -> Dict: def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]: """Create a tracker from all previously stored events.""" - # Retrieve dialogues for a sender_id in reverse chronological order based on the session_date sort key + # Retrieve dialogues for a sender_id in reverse chronological order based on + # the session_date sort key dialogues = self.db.query( KeyConditionExpression=Key("sender_id").eq(sender_id), Limit=1, diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 8aa535b5c4bf..72d09d2d819b 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -10,6 +10,7 @@ parse_last_positional_argument_as_model_path, get_validated_path, ) +from tests.conftest import assert_log_emitted @pytest.mark.parametrize( @@ -84,3 +85,25 @@ def test_validate_with_invalid_directory_if_default_is_valid(caplog: LogCaptureF def test_print_error_and_exit(): with pytest.raises(SystemExit): rasa.cli.utils.print_error_and_exit("") + + +def test_logging_capture(caplog: LogCaptureFixture): + logger = logging.getLogger(__name__) + + # make a random INFO log and ensure it passes decorator + info_text = "SOME INFO" + logger.info(info_text) + with assert_log_emitted(caplog, logger.name, logging.INFO, info_text): + pass + + +def test_logging_capture_failure(caplog: LogCaptureFixture): + logger = logging.getLogger(__name__) + + # make a random INFO log + logger.info("SOME INFO") + + # test for string in log that wasn't emitted + with pytest.raises(AssertionError): + with assert_log_emitted(caplog, logger.name, logging.INFO, "NONONO"): + pass diff --git a/tests/conftest.py b/tests/conftest.py index 70a50783a0d6..8a1bb709b2ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import logging +from contextlib import contextmanager from typing import Text, List import pytest @@ -218,5 +219,51 @@ async def rasa_server_without_api() -> Sanic: def get_test_client(server): test_client = server.test_client test_client.port = None - return test_client + + +@contextmanager +def assert_log_emitted( + _caplog: LogCaptureFixture, logger_name: Text, log_level: int, text: Text = None +) -> None: + """Context manager testing whether a logging message has been emitted. + + Provides a context in which an assertion is made about a logging message. + Raises an `AssertionError` if the log isn't emitted as expected. + + Example usage: + + ``` + with assert_log_emitted(caplog, LOGGER_NAME, LOGGING_LEVEL, TEXT): + + ``` + + Args: + _caplog: `LogCaptureFixture` used to capture logs. + logger_name: Name of the logger being examined. + log_level: Log level to be tested. + text: Logging message to be tested (optional). If left blank, assertion is made + only about `log_level` and `logger_name`. + + Yields: + `None` + + """ + + yield + + record_tuples = _caplog.record_tuples + + if not any( + ( + record[0] == logger_name + and record[1] == log_level + and (text in record[2] if text else True) + ) + for record in record_tuples + ): + raise AssertionError( + f"Did not detect expected logging output.\nExpected output is (logger " + f"name, log level, text): ({logger_name}, {log_level}, {text})\n" + f"Instead found records:\n{record_tuples}" + ) diff --git a/tests/core/test_dialogues.py b/tests/core/test_dialogues.py index bc8ba5b07f81..79e46e341981 100644 --- a/tests/core/test_dialogues.py +++ b/tests/core/test_dialogues.py @@ -4,10 +4,11 @@ import pytest import rasa.utils.io +from rasa.core.conversation import Dialogue from rasa.core.domain import Domain from rasa.core.tracker_store import InMemoryTrackerStore -from tests.core.utilities import tracker_from_dialogue_file from tests.core.conftest import TEST_DIALOGUES, EXAMPLE_DOMAINS +from tests.core.utilities import tracker_from_dialogue_file @pytest.mark.parametrize("filename", TEST_DIALOGUES) @@ -36,3 +37,12 @@ def test_tracker_restaurant(): tracker = tracker_from_dialogue_file(filename, domain) assert tracker.get_slot("price") == "lo" assert tracker.get_slot("name") is None # slot doesn't exist! + + +def test_dialogue_from_parameters(): + domain = Domain.load("examples/restaurantbot/domain.yml") + filename = "data/test_dialogues/restaurantbot.json" + tracker = tracker_from_dialogue_file(filename, domain) + serialised_dialogue = InMemoryTrackerStore.serialise_tracker(tracker) + deserialised_dialogue = Dialogue.from_parameters(json.loads(serialised_dialogue)) + assert tracker.as_dialogue().as_dict() == deserialised_dialogue.as_dict() diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index b14ccca0415c..a4829f0ad18e 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -1,6 +1,9 @@ +import logging import tempfile +from typing import Tuple import pytest +from _pytest.logging import LogCaptureFixture from moto import mock_dynamodb2 from rasa.core.channels.channel import UserMessage @@ -13,8 +16,10 @@ SQLTrackerStore, DynamoTrackerStore, ) - +import rasa.core.tracker_store +from rasa.core.trackers import DialogueStateTracker from rasa.utils.endpoints import EndpointConfig, read_endpoint_config +from tests.conftest import assert_log_emitted from tests.core.conftest import DEFAULT_ENDPOINTS_FILE domain = Domain.load("data/test_domains/default.yml") @@ -139,16 +144,24 @@ def test_tracker_store_from_invalid_string(default_domain): assert isinstance(tracker_store, InMemoryTrackerStore) -def test_tracker_serialisation(): - slot_key = "location" - slot_val = "Easter Island" +def _tracker_store_and_tracker_with_slot_set() -> Tuple[ + InMemoryTrackerStore, DialogueStateTracker +]: + # returns an InMemoryTrackerStore containing a tracker with a slot set - store = InMemoryTrackerStore(domain) + slot_key = "cuisine" + slot_val = "French" + store = InMemoryTrackerStore(domain) tracker = store.get_or_create_tracker(UserMessage.DEFAULT_SENDER_ID) ev = SlotSet(slot_key, slot_val) tracker.update(ev) + return store, tracker + + +def test_tracker_serialisation(): + store, tracker = _tracker_store_and_tracker_with_slot_set() serialised = store.serialise_tracker(tracker) assert tracker == store.deserialise_tracker( @@ -156,6 +169,29 @@ def test_tracker_serialisation(): ) +def test_deprecated_pickle_deserialisation(caplog: LogCaptureFixture): + def pickle_serialise_tracker(_tracker): + # mocked version of TrackerStore.serialise_tracker() that uses + # the deprecated pickle serialisation + import pickle + + dialogue = _tracker.as_dialogue() + + return pickle.dumps(dialogue) + + store, tracker = _tracker_store_and_tracker_with_slot_set() + + serialised = pickle_serialise_tracker(tracker) + + # deprecation warning should be emitted + with assert_log_emitted( + caplog, rasa.core.tracker_store.logger.name, logging.WARNING, "DEPRECATION" + ): + assert tracker == store.deserialise_tracker( + UserMessage.DEFAULT_SENDER_ID, serialised + ) + + @pytest.mark.parametrize( "full_url", [