Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change url to host in custom tracker store #4808

Merged
merged 15 commits into from
Dec 4, 2019
2 changes: 2 additions & 0 deletions changelog/4734.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Changed ``url`` ``__init__()`` arguments for custom tracker stores to ``host`` to reflect the ``__init__`` arguments of
currently supported tracker stores. Note that in ``endpoints.yml``, these are still declared as ``url``.
5 changes: 3 additions & 2 deletions docs/api/tracker-stores.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ SQLTrackerStore
:Parameters:
- ``domain`` (default: ``None``): Domain object associated with this tracker store
- ``dialect`` (default: ``sqlite``): The dialect used to communicate with your SQL backend. Consult the `SQLAlchemy docs <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_ for available dialects.
- ``host`` (default: ``None``): URL of your SQL server
- ``url`` (default: ``None``): URL of your SQL server
- ``port`` (default: ``None``): Port of your SQL server
- ``db`` (default: ``rasa.db``): The path to the database to be used
- ``username`` (default: ``None``): The username which is used for authentication
Expand Down Expand Up @@ -106,6 +106,7 @@ RedisTrackerStore
- ``password`` (default: ``None``): Password used for authentication
(``None`` equals no authentication)
- ``record_exp`` (default: ``None``): Record expiry in seconds
- ``use_ssl`` (default: ``False``): whether or not to use SSL for transit encryption

MongoTrackerStore
~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -142,9 +143,9 @@ MongoTrackerStore
- ``db`` (default: ``rasa``): The database name which should be used
- ``username`` (default: ``0``): The username which is used for authentication
- ``password`` (default: ``None``): The password which is used for authentication
- ``auth_source`` (default: ``admin``): database name associated with the user’s credentials.
- ``collection`` (default: ``conversations``): The collection name which is
used to store the conversations
- ``auth_source`` (default: ``admin``): database name associated with the user’s credentials.

Custom Tracker Store
~~~~~~~~~~~~~~~~~~~~
Expand Down
48 changes: 27 additions & 21 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@
# noinspection PyPep8Naming
from time import sleep

from rasa.core import utils
from rasa.utils import common
from rasa.core.actions.action import ACTION_LISTEN_NAME
from rasa.core.brokers.event_channel import EventChannel
from rasa.core.conversation import Dialogue
from rasa.core.domain import Domain
from rasa.core.trackers import ActionExecuted, DialogueStateTracker, EventVerbosity
from rasa.core.utils import replace_floats_with_decimals
from rasa.utils.common import class_from_module_path
from rasa.utils.endpoints import EndpointConfig

if typing.TYPE_CHECKING:
from sqlalchemy.engine.url import URL
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
import boto3
import boto3.resources.factory.dynamodb.Table

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,14 +60,14 @@ def find_tracker_store(
logger.error(
f"Error when trying to connect to '{store.type}' "
f"tracker store. Using "
f"'{InMemoryTrackerStore.__name__}'' instead. "
f"'{InMemoryTrackerStore.__name__}' instead. "
f"The causing error was: {e}."
)

if not tracker_store:
tracker_store = InMemoryTrackerStore(domain, event_broker)

logger.debug("Connected to {}.".format(tracker_store.__class__.__name__))
logger.debug(f"Connected to {tracker_store.__class__.__name__}.")

return tracker_store

Expand Down Expand Up @@ -118,16 +118,26 @@ def load_tracker_from_module_string(
"""
custom_tracker = None
try:
custom_tracker = class_from_module_path(store.type)
custom_tracker = common.class_from_module_path(store.type)
except (AttributeError, ImportError):
warnings.warn(
f"Store type '{store.type}' not found. "
"Using InMemoryTrackerStore instead"
f"Using InMemoryTrackerStore instead."
)

if custom_tracker:
init_args = common.arguments_of(custom_tracker.__init__)
if "url" in init_args and "host" not in init_args:
warnings.warn(
"The `url` initialization argument for custom tracker stores is deprecated. Your "
"custom tracker store should take a `host` argument in ``__init__()`` instead.",
FutureWarning,
)
store.kwargs["url"] = store.url
else:
store.kwargs["host"] = store.url
return custom_tracker(
domain=domain, url=store.url, event_broker=event_broker, **store.kwargs
domain=domain, event_broker=event_broker, **store.kwargs,
)
else:
return InMemoryTrackerStore(domain)
Expand Down Expand Up @@ -172,10 +182,10 @@ def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
def stream_events(self, tracker: DialogueStateTracker) -> None:
"""Streams events to a message broker"""
offset = self.number_of_existing_events(tracker.sender_id)
evts = tracker.events
for evt in list(itertools.islice(evts, offset, len(evts))):
events = tracker.events
for event in list(itertools.islice(events, offset, len(events))):
body = {"sender_id": tracker.sender_id}
body.update(evt.as_dict())
body.update(event.as_dict())
self.event_broker.publish(body)

def number_of_existing_events(self, sender_id: Text) -> int:
Expand Down Expand Up @@ -278,7 +288,6 @@ def __init__(
record_exp: Optional[float] = None,
use_ssl: bool = False,
):

import redis

self.red = redis.StrictRedis(
Expand Down Expand Up @@ -381,7 +390,7 @@ def serialise_tracker(self, tracker: "DialogueStateTracker") -> Dict:
"session_date": int(datetime.now(tz=timezone.utc).timestamp()),
}
)
return replace_floats_with_decimals(d)
return utils.replace_floats_with_decimals(d)

def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
"""Create a tracker from all previously stored events."""
Expand Down Expand Up @@ -540,9 +549,7 @@ def __init__(
engine_url = self.get_db_url(
dialect, host, port, db, username, password, login_db, query
)
logger.debug(
"Attempting to connect to database via '{}'.".format(repr(engine_url))
)
logger.debug(f"Attempting to connect to database via '{engine_url}'.")

# Database might take a while to come up
while True:
Expand Down Expand Up @@ -711,9 +718,9 @@ def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
)
else:
logger.debug(
"Can't retrieve tracker matching "
"sender id '{}' from SQL storage. "
"Returning `None` instead.".format(sender_id)
f"Can't retrieve tracker matching "
f"sender id '{sender_id}' from SQL storage. "
f"Returning `None` instead."
)
return None

Expand Down Expand Up @@ -748,8 +755,7 @@ def save(self, tracker: DialogueStateTracker) -> None:
session.commit()

logger.debug(
"Tracker with sender_id '{}' "
"stored to database".format(tracker.sender_id)
f"Tracker with sender_id '{tracker.sender_id}' " f"stored to database."
)

def _additional_events(
Expand Down
33 changes: 27 additions & 6 deletions tests/core/test_tracker_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_find_tracker_store(default_domain: Domain, monkeypatch: MonkeyPatch):
)


class ExampleTrackerStore(RedisTrackerStore):
class URLExampleTrackerStore(RedisTrackerStore):
def __init__(self, domain, url, port, db, password, record_exp, event_broker=None):
super().__init__(
domain,
Expand All @@ -135,21 +135,41 @@ def __init__(self, domain, url, port, db, password, record_exp, event_broker=Non
)


def test_tracker_store_from_string(default_domain: Domain):
class HostExampleTrackerStore(RedisTrackerStore):
pass


def test_tracker_store_deprecated_url_argument_from_string(default_domain: Domain):
endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
store_config = read_endpoint_config(endpoints_path, "tracker_store")
store_config.type = "tests.core.test_tracker_stores.URLExampleTrackerStore"

with pytest.warns(FutureWarning):
tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

assert isinstance(tracker_store, URLExampleTrackerStore)


def test_tracker_store_with_host_argument_from_string(default_domain: Domain):
endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
store_config = read_endpoint_config(endpoints_path, "tracker_store")
store_config.type = "tests.core.test_tracker_stores.HostExampleTrackerStore"

with pytest.warns(None) as record:
tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)
assert len(record) == 0

assert isinstance(tracker_store, ExampleTrackerStore)
assert isinstance(tracker_store, HostExampleTrackerStore)


def test_tracker_store_from_invalid_module(default_domain: Domain):
endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
store_config = read_endpoint_config(endpoints_path, "tracker_store")
store_config.type = "a.module.which.cannot.be.found"

tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)
with pytest.warns(UserWarning):
tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

assert isinstance(tracker_store, InMemoryTrackerStore)

Expand All @@ -159,7 +179,8 @@ def test_tracker_store_from_invalid_string(default_domain: Domain):
store_config = read_endpoint_config(endpoints_path, "tracker_store")
store_config.type = "any string"

tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)
with pytest.warns(UserWarning):
tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

assert isinstance(tracker_store, InMemoryTrackerStore)

Expand Down