diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0b5eb738549d..8699690e27f0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,9 @@ Added Changed ------- +- Print info message when running Rasa X and a custom model server url was specified in ``endpoints.yml`` +- If a ``wait_time_between_pulls`` is configured for the model server in ``endpoints.yml``, + this will be used instead of the default one when running Rasa X Removed ------- diff --git a/rasa/cli/x.py b/rasa/cli/x.py index d2db80be0db3..dc2d05a97da0 100644 --- a/rasa/cli/x.py +++ b/rasa/cli/x.py @@ -2,6 +2,7 @@ import asyncio import importlib.util import logging +import warnings import os import signal import traceback @@ -112,12 +113,27 @@ def _overwrite_endpoints_for_local_x( from rasa.utils.endpoints import EndpointConfig import questionary + # Checking if endpoint.yml has existing url and wait time values set, if so give + # warning we are overwriting the endpoint.yml file. + custom_wait_time_pulls = endpoints.model.kwargs.get("wait_time_between_pulls") + custom_url = endpoints.model.url + default_rasax_model_server_url = ( + f"{rasa_x_url}/projects/default/models/tag/production" + ) + + if custom_url != default_rasax_model_server_url: + warnings.warn( + f"Ignoring url '{custom_url}' from 'endpoints.yml' and using " + f"'{default_rasax_model_server_url}' instead." + ) + endpoints.model = EndpointConfig( - f"{rasa_x_url}/projects/default/models/tags/production", + default_rasax_model_server_url, token=rasa_x_token, - wait_time_between_pulls=2, + wait_time_between_pulls=custom_wait_time_pulls or 2, ) + overwrite_existing_event_broker = False if endpoints.event_broker and not _is_correct_event_broker(endpoints.event_broker): cli_utils.print_error( "Rasa X currently only supports a SQLite event broker with path '{}' " @@ -132,9 +148,8 @@ def _overwrite_endpoints_for_local_x( if not overwrite_existing_event_broker: exit(0) - endpoints.event_broker = EndpointConfig( - type="sql", db=DEFAULT_EVENTS_DB, dialect="sqlite" - ) + if not endpoints.tracker_store or overwrite_existing_event_broker: + endpoints.event_broker = EndpointConfig(type="sql", db=DEFAULT_EVENTS_DB) def _is_correct_event_broker(event_broker: EndpointConfig) -> bool: diff --git a/tests/cli/test_rasa_x.py b/tests/cli/test_rasa_x.py index 7e5af395e69b..6e04f36ffee4 100644 --- a/tests/cli/test_rasa_x.py +++ b/tests/cli/test_rasa_x.py @@ -1,18 +1,18 @@ from pathlib import Path -from unittest.mock import Mock +import warnings -from typing import Callable, Dict, Text, Any import pytest +from typing import Callable, Dict from _pytest.pytester import RunResult -from _pytest.monkeypatch import MonkeyPatch -import questionary +from _pytest.logging import LogCaptureFixture + from aioresponses import aioresponses import rasa.utils.io as io_utils from rasa.cli import x -from rasa.core.utils import AvailableEndpoints from rasa.utils.endpoints import EndpointConfig +from rasa.core.utils import AvailableEndpoints def test_x_help(run: Callable[..., RunResult]): @@ -65,33 +65,6 @@ def test_prepare_credentials_if_already_valid(tmpdir: Path): assert actual == credentials -@pytest.mark.parametrize( - "event_broker", - [ - # Event broker was not configured. - {}, - # Event broker was explicitly configured to work with Rasa X in local mode. - {"type": "sql", "dialect": "sqlite", "db": x.DEFAULT_EVENTS_DB}, - # Event broker was configured but the values are not compatible for running Rasa - # X in local mode. - {"type": "sql", "dialect": "postgresql"}, - ], -) -def test_overwrite_endpoints_for_local_x( - event_broker: Dict[Text, Any], monkeypatch: MonkeyPatch -): - confirm = Mock() - confirm.return_value.ask.return_value = True - monkeypatch.setattr(questionary, "confirm", confirm) - - event_broker_config = EndpointConfig.from_dict(event_broker) - endpoints = AvailableEndpoints(event_broker=event_broker_config) - - x._overwrite_endpoints_for_local_x(endpoints, "test-token", "http://localhost:5002") - - assert x._is_correct_event_broker(endpoints.event_broker) - - def test_if_endpoint_config_is_valid_in_local_mode(): config = EndpointConfig(type="sql", dialect="sqlite", db=x.DEFAULT_EVENTS_DB) @@ -111,6 +84,42 @@ def test_if_endpoint_config_is_invalid_in_local_mode(kwargs: Dict): assert not x._is_correct_event_broker(config) +def test_overwrite_model_server_url(): + endpoint_config = EndpointConfig(url="http://testserver:5002/models/default@latest") + endpoints = AvailableEndpoints(model=endpoint_config) + with pytest.warns(UserWarning): + x._overwrite_endpoints_for_local_x(endpoints, "test", "http://localhost") + assert ( + endpoints.model.url == "http://localhost/projects/default/models/tag/production" + ) + + +def test_reuse_wait_time_between_pulls(): + test_wait_time = 5 + endpoint_config = EndpointConfig( + url="http://localhost:5002/models/default@latest", + wait_time_between_pulls=test_wait_time, + ) + endpoints = AvailableEndpoints(model=endpoint_config) + assert endpoints.model.kwargs["wait_time_between_pulls"] == test_wait_time + + +def test_default_wait_time_between_pulls(): + endpoint_config = EndpointConfig(url="http://localhost:5002/models/default@latest") + endpoints = AvailableEndpoints(model=endpoint_config) + x._overwrite_endpoints_for_local_x(endpoints, "test", "http://localhost") + assert endpoints.model.kwargs["wait_time_between_pulls"] == 2 + + +def test_default_model_server_url(): + endpoint_config = EndpointConfig() + endpoints = AvailableEndpoints(model=endpoint_config) + x._overwrite_endpoints_for_local_x(endpoints, "test", "http://localhost") + assert ( + endpoints.model.url == "http://localhost/projects/default/models/tag/production" + ) + + async def test_pull_runtime_config_from_server(): config_url = "http://example.com/api/config?token=token" credentials = "rasa: http://example.com:5002/api"