diff --git a/changelog/7292.improvement.md b/changelog/7292.improvement.md new file mode 100644 index 000000000000..b789979b8637 --- /dev/null +++ b/changelog/7292.improvement.md @@ -0,0 +1,12 @@ +Improve error handling and Sentry tracking: +- Raise `MarkdownException` when training data in Markdown format cannot be read. +- Raise `InvalidEntityFormatException` error instead of `json.JSONDecodeError` when entity format is in valid + in training data. +- Gracefully handle empty sections in endpoint config files. +- Introduce `ConnectionException` error and raise it when `TrackerStore` and `EventBroker` + cannot connect to 3rd party services, instead of raising exceptions from 3rd party libraries. +- Improve `rasa.shared.utils.common.class_from_module_path` function by making sure it always returns a class. + The function currently raises a deprecation warning if it detects an anomaly. +- Ignore `MemoryError` and `asyncio.CancelledError` in Sentry. +- `rasa.shared.utils.validation.validate_training_data` now raises a `SchemaValidationError` when validation fails + (this error inherits `jsonschema.ValidationError`, ensuring backwards compatibility). diff --git a/data/test_endpoints/example_endpoints.yml b/data/test_endpoints/example_endpoints.yml index 436cd1a15344..5ca07a8dcab5 100644 --- a/data/test_endpoints/example_endpoints.yml +++ b/data/test_endpoints/example_endpoints.yml @@ -23,3 +23,4 @@ tracker_store: #db: rasa #user: username #password: password +empty: diff --git a/rasa/core/brokers/broker.py b/rasa/core/brokers/broker.py index 01d5195e3a35..7a6dddccab43 100644 --- a/rasa/core/brokers/broker.py +++ b/rasa/core/brokers/broker.py @@ -5,6 +5,7 @@ import rasa.shared.utils.common import rasa.shared.utils.io +from rasa.shared.exceptions import ConnectionException from rasa.utils.endpoints import EndpointConfig logger = logging.getLogger(__name__) @@ -22,7 +23,16 @@ async def create( if isinstance(obj, EventBroker): return obj - return await _create_from_endpoint_config(obj, loop) + import aio_pika.exceptions + import sqlalchemy.exc + + try: + return await _create_from_endpoint_config(obj, loop) + except ( + sqlalchemy.exc.OperationalError, + aio_pika.exceptions.AMQPConnectionError, + ) as error: + raise ConnectionException("Cannot connect to event broker.") from error @classmethod async def from_endpoint_config( diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 879e2fa424ed..60ef27286ee2 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -21,7 +21,6 @@ ) from boto3.dynamodb.conditions import Key -from botocore.exceptions import ClientError import rasa.core.utils as core_utils import rasa.shared.utils.cli @@ -42,6 +41,7 @@ DialogueStateTracker, EventVerbosity, ) +from rasa.shared.exceptions import ConnectionException from rasa.shared.nlu.constants import INTENT_NAME_KEY from rasa.utils.endpoints import EndpointConfig import sqlalchemy as sa @@ -65,7 +65,7 @@ class TrackerStore: - """Class to hold all of the TrackerStore classes""" + """Represents common behavior and interface for all `TrackerStore`s.""" def __init__( self, @@ -115,7 +115,18 @@ def create( if isinstance(obj, TrackerStore): return obj - return _create_from_endpoint_config(obj, domain, event_broker) + from botocore.exceptions import BotoCoreError + import pymongo.errors + import sqlalchemy.exc + + try: + return _create_from_endpoint_config(obj, domain, event_broker) + except ( + BotoCoreError, + pymongo.errors.ConnectionFailure, + sqlalchemy.exc.OperationalError, + ) as error: + raise ConnectionException("Cannot connect to tracker store.") from error def get_or_create_tracker( self, @@ -407,7 +418,7 @@ def __init__( def get_or_create_table( self, table_name: Text ) -> "boto3.resources.factory.dynamodb.Table": - """Returns table or creates one if the table name is not in the table list""" + """Returns table or creates one if the table name is not in the table list.""" import boto3 dynamo = boto3.resource("dynamodb", region_name=self.region) @@ -442,7 +453,9 @@ def get_or_create_table( return table def save(self, tracker): - """Saves the current conversation state""" + """Saves the current conversation state.""" + from botocore.exceptions import ClientError + if self.event_broker: self.stream_events(tracker) serialized = self.serialise_tracker(tracker) @@ -475,7 +488,7 @@ def _retrieve_latest_session_date(self, sender_id: Text) -> Optional[int]: return dialogues[0].get("session_date") def serialise_tracker(self, tracker: "DialogueStateTracker") -> Dict: - """Serializes the tracker, returns object with decimal types""" + """Serializes the tracker, returns object with decimal types.""" d = tracker.as_dialogue().as_dict() d.update( {"sender_id": tracker.sender_id,} @@ -505,7 +518,7 @@ def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]: ) def keys(self) -> Iterable[Text]: - """Returns sender_ids of the DynamoTrackerStore""" + """Returns sender_ids of the `DynamoTrackerStore`.""" return [ i["sender_id"] for i in self.db.scan(ProjectionExpression="sender_id")["Items"] @@ -513,8 +526,7 @@ def keys(self) -> Iterable[Text]: class MongoTrackerStore(TrackerStore): - """ - Stores conversation history in Mongo + """Stores conversation history in Mongo. Property methods: conversations: returns the current conversation @@ -552,11 +564,11 @@ def __init__( @property def conversations(self): - """Returns the current conversation""" + """Returns the current conversation.""" return self.db[self.collection] def _ensure_indices(self): - """Create an index on the sender_id""" + """Create an index on the sender_id.""" self.conversations.create_index("sender_id") @staticmethod @@ -569,7 +581,7 @@ def _current_tracker_state_without_events(tracker: DialogueStateTracker) -> Dict return state def save(self, tracker, timeout=None): - """Saves the current conversation state""" + """Saves the current conversation state.""" if self.event_broker: self.stream_events(tracker) @@ -684,7 +696,7 @@ def retrieve_full_tracker( ) def keys(self) -> Iterable[Text]: - """Returns sender_ids of the Mongo Tracker Store""" + """Returns sender_ids of the Mongo Tracker Store.""" return [c["sender_id"] for c in self.conversations.find()] diff --git a/rasa/shared/exceptions.py b/rasa/shared/exceptions.py index 8a2eb43e3c48..f01164a5c77a 100644 --- a/rasa/shared/exceptions.py +++ b/rasa/shared/exceptions.py @@ -1,8 +1,15 @@ +import json from typing import Optional, Text +import jsonschema + class RasaException(Exception): - """Base exception class for all errors raised by Rasa Open Source.""" + """Base exception class for all errors raised by Rasa Open Source. + + These exceptions results from invalid use cases and will be reported + to the users, but will be ignored in telemetry. + """ class RasaCoreException(RasaException): @@ -17,6 +24,10 @@ class InvalidParameterException(RasaException, ValueError): """Raised when an invalid parameter is used.""" +class MarkdownException(RasaException, ValueError): + """Raised if there is an error reading Markdown.""" + + class YamlException(RasaException): """Raised if there is an error reading yaml.""" @@ -77,3 +88,26 @@ class InvalidConfigException(ValueError, RasaException): class UnsupportedFeatureException(RasaCoreException): """Raised if a requested feature is not supported.""" + + +class SchemaValidationError(RasaException, jsonschema.ValidationError): + """Raised if schema validation via `jsonschema` failed.""" + + +class InvalidEntityFormatException(RasaException, json.JSONDecodeError): + """Raised if the format of an entity is invalid.""" + + @classmethod + def create_from( + cls, other: json.JSONDecodeError, msg: Text + ) -> "InvalidEntityFormatException": + """Create an instance of `InvalidEntityFormatException` from a `JSONDecodeError`.""" + return cls(msg, other.doc, other.pos) + + +class ConnectionException(RasaException): + """Raised when a connection to a 3rd party service fails. + + It's used by our broker and tracker store classes, when + they can't connect to services like postgres, dynamoDB, mongo. + """ diff --git a/rasa/shared/nlu/training_data/entities_parser.py b/rasa/shared/nlu/training_data/entities_parser.py index e2916cee54ee..fcff405234ae 100644 --- a/rasa/shared/nlu/training_data/entities_parser.py +++ b/rasa/shared/nlu/training_data/entities_parser.py @@ -4,6 +4,7 @@ import rasa.shared.nlu.training_data.util from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU +from rasa.shared.exceptions import InvalidEntityFormatException from rasa.shared.nlu.constants import ( ENTITY_ATTRIBUTE_VALUE, ENTITY_ATTRIBUTE_TYPE, @@ -11,7 +12,7 @@ ENTITY_ATTRIBUTE_ROLE, ) from rasa.shared.nlu.training_data.message import Message -import rasa.shared.utils.io + GROUP_ENTITY_VALUE = "value" GROUP_ENTITY_TYPE = "entity" @@ -127,8 +128,8 @@ def get_validated_dict(json_str: Text) -> Dict[Text, Text]: json_str: The entity dict as string without "{}". Raises: - ValidationError if validation of entity dict fails. - JSONDecodeError if provided entity dict is not valid json. + SchemaValidationError if validation of parsed entity fails. + InvalidEntityFormatException if provided entity is not valid json. Returns: Deserialized and validated `json_str`. @@ -141,12 +142,11 @@ def get_validated_dict(json_str: Text) -> Dict[Text, Text]: try: data = json.loads(f"{{{json_str}}}") except JSONDecodeError as e: - rasa.shared.utils.io.raise_warning( - f"Incorrect training data format ('{{{json_str}}}'). Make sure your " - f"data is valid.", - docs=DOCS_URL_TRAINING_DATA_NLU, - ) - raise e + raise InvalidEntityFormatException.create_from( + e, + f"Incorrect training data format ('{{{json_str}}}'). " + f"More info at {DOCS_URL_TRAINING_DATA_NLU}", + ) from e validation_utils.validate_training_data(data, schema.entity_dict_schema()) diff --git a/rasa/shared/nlu/training_data/formats/markdown.py b/rasa/shared/nlu/training_data/formats/markdown.py index 195202586cf7..3b95259a91f6 100644 --- a/rasa/shared/nlu/training_data/formats/markdown.py +++ b/rasa/shared/nlu/training_data/formats/markdown.py @@ -9,6 +9,7 @@ LEGACY_DOCS_BASE_URL, DOCS_URL_MIGRATION_GUIDE_MD_DEPRECATION, ) +from rasa.shared.exceptions import MarkdownException from rasa.shared.nlu.constants import TEXT from rasa.shared.nlu.training_data.formats.readerwriter import ( TrainingDataReader, @@ -135,47 +136,10 @@ def _parse_item(self, line: Text) -> None: self.current_title, item, self.lookup_tables ) - @staticmethod - def _get_validated_dict(json_str: Text) -> Dict[Text, Text]: - """Converts the provided json_str to a valid dict containing the entity - attributes. - - Users can specify entity roles, synonyms, groups for an entity in a dict, e.g. - [LA]{"entity": "city", "role": "to", "value": "Los Angeles"} - - Args: - json_str: the entity dict as string without "{}" - - Raises: - ValidationError if validation of entity dict fails. - JSONDecodeError if provided entity dict is not valid json. - - Returns: - a proper python dict - """ - import json - import rasa.shared.utils.validation as validation_utils - import rasa.shared.nlu.training_data.schemas.data_schema as schema - - # add {} as they are not part of the regex - try: - data = json.loads(f"{{{json_str}}}") - except JSONDecodeError as e: - rasa.shared.utils.io.raise_warning( - f"Incorrect training data format ('{{{json_str}}}'), make sure your " - f"data is valid. For more information about the format visit " - f"{LEGACY_DOCS_BASE_URL}/nlu/training-data-format/." - ) - raise e - - validation_utils.validate_training_data(data, schema.entity_dict_schema()) - - return data - def _set_current_section(self, section: Text, title: Text) -> None: """Update parsing mode.""" if section not in AVAILABLE_SECTIONS: - raise ValueError( + raise MarkdownException( "Found markdown section '{}' which is not " "in the allowed sections '{}'." "".format(section, "', '".join(AVAILABLE_SECTIONS)) diff --git a/rasa/shared/utils/common.py b/rasa/shared/utils/common.py index 21ae589d3053..f451811de2b9 100644 --- a/rasa/shared/utils/common.py +++ b/rasa/shared/utils/common.py @@ -5,6 +5,10 @@ import logging from typing import Text, Dict, Optional, Any, List, Callable, Collection +import rasa.shared.utils.io +from rasa.shared.constants import NEXT_MAJOR_VERSION_FOR_DEPRECATIONS + + logger = logging.getLogger(__name__) @@ -13,24 +17,42 @@ def class_from_module_path( ) -> Any: """Given the module name and path of a class, tries to retrieve the class. - The loaded class can be used to instantiate new objects.""" - # load the module, will raise ImportError if module cannot be loaded + The loaded class can be used to instantiate new objects. + + Args: + module_path: either an absolute path to a Python class, + or the name of the class in the local / global scope. + lookup_path: a path where to load the class from, if it cannot + be found in the local / global scope. + + Returns: + a Python class + + Raises: + ImportError, in case the Python class cannot be found. + """ + klass = None if "." in module_path: module_name, _, class_name = module_path.rpartition(".") m = importlib.import_module(module_name) - # get the class, will raise AttributeError if class cannot be found - return getattr(m, class_name) - else: - module = globals().get(module_path, locals().get(module_path)) - if module is not None: - return module - - if lookup_path: - # last resort: try to import the class from the lookup path - m = importlib.import_module(lookup_path) - return getattr(m, module_path) - else: - raise ImportError(f"Cannot retrieve class from path {module_path}.") + klass = getattr(m, class_name, None) + elif lookup_path: + # try to import the class from the lookup path + m = importlib.import_module(lookup_path) + klass = getattr(m, module_path, None) + + if klass is None: + raise ImportError(f"Cannot retrieve class from path {module_path}.") + + if not inspect.isclass(klass): + rasa.shared.utils.io.raise_deprecation_warning( + f"`class_from_module_path()` is expected to return a class, " + f"but {module_path} is not one. " + f"This warning will be converted " + f"into an exception in {NEXT_MAJOR_VERSION_FOR_DEPRECATIONS}." + ) + + return klass def all_subclasses(cls: Any) -> List[Any]: diff --git a/rasa/shared/utils/validation.py b/rasa/shared/utils/validation.py index 2c35ae5e703e..b5e52b4fbb9d 100644 --- a/rasa/shared/utils/validation.py +++ b/rasa/shared/utils/validation.py @@ -9,7 +9,11 @@ from ruamel.yaml.constructor import DuplicateKeyError import rasa.shared -from rasa.shared.exceptions import YamlException, YamlSyntaxException +from rasa.shared.exceptions import ( + YamlException, + YamlSyntaxException, + SchemaValidationError, +) import rasa.shared.utils.io from rasa.shared.constants import ( DOCS_URL_TRAINING_DATA, @@ -176,7 +180,7 @@ def validate_training_data(json_data: Dict[Text, Any], schema: Dict[Text, Any]) schema: the schema Raises: - ValidationError if validation fails. + SchemaValidationError if validation fails. """ from jsonschema import validate from jsonschema import ValidationError @@ -189,7 +193,7 @@ def validate_training_data(json_data: Dict[Text, Any], schema: Dict[Text, Any]) f"is valid. For more information about the format visit " f"{DOCS_URL_TRAINING_DATA}." ) - raise e + raise SchemaValidationError.create_from(e) from e def validate_training_data_format_version( diff --git a/rasa/telemetry.py b/rasa/telemetry.py index 1983c6c957df..06bcba98488a 100644 --- a/rasa/telemetry.py +++ b/rasa/telemetry.py @@ -631,7 +631,15 @@ def initialize_error_reporting() -> None: ], send_default_pii=False, # activate PII filter server_name=telemetry_id or "UNKNOWN", - ignore_errors=[KeyboardInterrupt, RasaException, NotImplementedError], + ignore_errors=[ + # std lib errors + KeyboardInterrupt, # user hit the interrupt key (Ctrl+C) + MemoryError, # machine is running out of memory + NotImplementedError, # user is using a feature that is not implemented + asyncio.CancelledError, # an async operation has been cancelled by the user + # expected Rasa errors + RasaException, + ], in_app_include=["rasa"], # only submit errors in this package with_locals=False, # don't submit local variables release=f"rasa-{rasa.__version__}", diff --git a/rasa/utils/endpoints.py b/rasa/utils/endpoints.py index e86e351ec170..5c6b049c49fd 100644 --- a/rasa/utils/endpoints.py +++ b/rasa/utils/endpoints.py @@ -25,10 +25,10 @@ def read_endpoint_config( try: content = rasa.shared.utils.io.read_config_file(filename) - if endpoint_type in content: - return EndpointConfig.from_dict(content[endpoint_type]) - else: + if content.get(endpoint_type) is None: return None + + return EndpointConfig.from_dict(content[endpoint_type]) except FileNotFoundError: logger.error( "Failed to read endpoint configuration " diff --git a/tests/core/test_broker.py b/tests/core/test_broker.py index 5996386195f2..5bab4cbebb71 100644 --- a/tests/core/test_broker.py +++ b/tests/core/test_broker.py @@ -2,12 +2,12 @@ import logging import textwrap from asyncio.events import AbstractEventLoop +from pathlib import Path +from typing import Union, Text, List, Optional, Type, Dict, Any +import aio_pika.exceptions import kafka import pytest - -from pathlib import Path -from typing import Union, Text, List, Optional, Type, Dict, Any from _pytest.logging import LogCaptureFixture from _pytest.monkeypatch import MonkeyPatch @@ -22,6 +22,7 @@ from rasa.core.brokers.pika import PikaEventBroker, DEFAULT_QUEUE_NAME from rasa.core.brokers.sql import SQLEventBroker from rasa.shared.core.events import Event, Restarted, SlotSet, UserUttered +from rasa.shared.exceptions import ConnectionException from rasa.utils.endpoints import EndpointConfig, read_endpoint_config TEST_EVENTS = [ @@ -190,14 +191,15 @@ async def test_file_broker_properly_logs_newlines(tmp_path: Path): assert recovered == [event_with_newline] -def test_load_custom_broker_name(tmp_path: Path): +async def test_load_custom_broker_name(tmp_path: Path): config = EndpointConfig( **{ "type": "rasa.core.brokers.file.FileEventBroker", "path": str(tmp_path / "rasa_event.log"), } ) - assert EventBroker.create(config) + broker = await EventBroker.create(config) + assert broker class CustomEventBrokerWithoutAsync(EventBroker): @@ -301,3 +303,40 @@ def test_warning_if_unsupported_ssl_env_variables(monkeypatch: MonkeyPatch): with pytest.warns(UserWarning): pika._create_rabbitmq_ssl_options() + + +async def test_pika_connection_error(monkeypatch: MonkeyPatch): + # patch PikaEventBroker to raise an AMQP connection error + async def connect(self) -> None: + raise aio_pika.exceptions.ProbableAuthenticationError("Oups") + + monkeypatch.setattr(PikaEventBroker, "connect", connect) + cfg = EndpointConfig.from_dict( + { + "type": "pika", + "url": "localhost", + "username": "username", + "password": "password", + "queues": ["queue-1"], + "connection_attempts": 1, + "retry_delay_in_seconds": 0, + } + ) + with pytest.raises(ConnectionException): + await EventBroker.create(cfg) + + +async def test_sql_connection_error(monkeypatch: MonkeyPatch): + cfg = EndpointConfig.from_dict( + { + "type": "sql", + "dialect": "postgresql", + "url": "0.0.0.0", + "port": 42, + "db": "boom", + "username": "user", + "password": "pw", + } + ) + with pytest.raises(ConnectionException): + await EventBroker.create(cfg) diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index 7a5bf58e73f7..a59d2aa0931c 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -31,6 +31,7 @@ BotUttered, Event, ) +from rasa.shared.exceptions import ConnectionException from rasa.core.tracker_store import ( TrackerStore, InMemoryTrackerStore, @@ -896,3 +897,13 @@ def test_current_state_without_events(default_domain: Domain): def test_login_db_with_no_postgresql(tmp_path: Path): with pytest.warns(UserWarning): SQLTrackerStore(db=str(tmp_path / "rasa.db"), login_db="other") + + +@pytest.mark.parametrize( + "config", [{"type": "mongod", "url": "mongodb://0.0.0.0:42",}, {"type": "dynamo",}], +) +def test_tracker_store_connection_error(config: Dict, default_domain: Domain): + store = EndpointConfig.from_dict(config) + + with pytest.raises(ConnectionException): + TrackerStore.create(store, default_domain) diff --git a/tests/shared/nlu/training_data/formats/test_markdown.py b/tests/shared/nlu/training_data/formats/test_markdown.py index 582a16bf5438..ae2a7774e643 100644 --- a/tests/shared/nlu/training_data/formats/test_markdown.py +++ b/tests/shared/nlu/training_data/formats/test_markdown.py @@ -2,13 +2,14 @@ import pytest +from rasa.shared.exceptions import MarkdownException +from rasa.shared.nlu.training_data.formats import RasaReader +from rasa.shared.nlu.training_data.formats import MarkdownReader, MarkdownWriter from rasa.shared.nlu.training_data.loading import load_data from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor from rasa.nlu.extractors.duckling_entity_extractor import DucklingEntityExtractor from rasa.nlu.extractors.mitie_entity_extractor import MitieEntityExtractor from rasa.nlu.extractors.spacy_entity_extractor import SpacyEntityExtractor -from rasa.shared.nlu.training_data.formats import RasaReader -from rasa.shared.nlu.training_data.formats import MarkdownReader, MarkdownWriter @pytest.mark.parametrize( @@ -75,7 +76,7 @@ def test_markdown_empty_section(): def test_markdown_not_existing_section(): - with pytest.raises(ValueError): + with pytest.raises(MarkdownException): load_data("data/test/markdown_single_sections/not_existing_section.md") diff --git a/tests/shared/nlu/training_data/test_entities_parser.py b/tests/shared/nlu/training_data/test_entities_parser.py index c4df476b0a56..6499c132aa20 100644 --- a/tests/shared/nlu/training_data/test_entities_parser.py +++ b/tests/shared/nlu/training_data/test_entities_parser.py @@ -3,6 +3,7 @@ import pytest import rasa.shared.nlu.training_data.entities_parser as entities_parser +from rasa.shared.exceptions import InvalidEntityFormatException, SchemaValidationError from rasa.shared.nlu.constants import TEXT @@ -137,3 +138,19 @@ def test_parse_training_example_with_entities(): assert message.get("entities") == [ {"start": 10, "end": 16, "value": "Berlin", "entity": "city"} ] + + +def test_markdown_entity_regex_error_handling_not_json(): + with pytest.raises(InvalidEntityFormatException): + entities_parser.find_entities_in_training_example( + # JSON syntax error: missing closing " for `role` + 'I want to fly from [Berlin]{"entity": "city", "role: "from"}' + ) + + +def test_markdown_entity_regex_error_handling_wrong_schema(): + with pytest.raises(SchemaValidationError): + entities_parser.find_entities_in_training_example( + # Schema error: "entiti" instead of "entity" + 'I want to fly from [Berlin]{"entiti": "city", "role": "from"}' + ) diff --git a/tests/shared/utils/test_common.py b/tests/shared/utils/test_common.py index fac28ab71b3f..ba9e5df8575d 100644 --- a/tests/shared/utils/test_common.py +++ b/tests/shared/utils/test_common.py @@ -1,9 +1,11 @@ import asyncio -from typing import Collection, List, Text +from typing import Any, Collection, List, Optional, Text from unittest.mock import Mock import pytest +from _pytest.recwarn import WarningsRecorder +import rasa.shared.core.domain import rasa.shared.utils.common @@ -140,3 +142,50 @@ async def my_function(): await asyncio.sleep(0) await my_function() + + +@pytest.mark.parametrize( + "module_path, lookup_path, outcome", + [ + ("rasa.shared.core.domain.Domain", None, "Domain"), + # lookup_path + ("Event", "rasa.shared.core.events", "Event"), + ], +) +def test_class_from_module_path( + module_path: Text, lookup_path: Optional[Text], outcome: Text +): + klass = rasa.shared.utils.common.class_from_module_path(module_path, lookup_path) + assert isinstance(klass, object) + assert klass.__name__ == outcome + + +@pytest.mark.parametrize( + "module_path, lookup_path", + [ + ("rasa.shared.core.domain.FunkyDomain", None), + ("FunkyDomain", None), + ("FunkyDomain", "rasa.shared.core.domain"), + ], +) +def test_class_from_module_path_not_found( + module_path: Text, lookup_path: Optional[Text] +): + with pytest.raises(ImportError): + rasa.shared.utils.common.class_from_module_path(module_path, lookup_path) + + +@pytest.mark.parametrize( + "module_path, result, outcome", + [ + ("rasa.shared.core.domain.Domain", rasa.shared.core.domain.Domain, True), + ("rasa.shared.core.domain.logger", rasa.shared.core.domain.logger, False), + ], +) +def test_class_from_module_path_ensure_class( + module_path: Text, outcome: bool, result: Any, recwarn: WarningsRecorder +): + klass = rasa.shared.utils.common.class_from_module_path(module_path) + assert klass is result + + assert bool(len(recwarn)) is not outcome diff --git a/tests/shared/utils/test_validation.py b/tests/shared/utils/test_validation.py index 543b4cb2daca..89cf95996af7 100644 --- a/tests/shared/utils/test_validation.py +++ b/tests/shared/utils/test_validation.py @@ -1,9 +1,8 @@ import pytest -from jsonschema import ValidationError from pep440_version_utils import Version -from rasa.shared.exceptions import YamlException +from rasa.shared.exceptions import YamlException, SchemaValidationError import rasa.shared.utils.io import rasa.shared.utils.validation as validation_utils import rasa.utils.io as io_utils @@ -71,7 +70,7 @@ def test_example_training_data_is_valid(): ], ) def test_validate_training_data_is_throwing_exceptions(invalid_data): - with pytest.raises(ValidationError): + with pytest.raises(SchemaValidationError): validation_utils.validate_training_data( invalid_data, schema.rasa_nlu_data_schema() ) @@ -123,7 +122,7 @@ def test_url_data_format(): ], ) def test_validate_entity_dict_is_throwing_exceptions(invalid_data): - with pytest.raises(ValidationError): + with pytest.raises(SchemaValidationError): validation_utils.validate_training_data( invalid_data, schema.entity_dict_schema() ) diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index 8fde713699f3..9dee2455d856 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -1,9 +1,11 @@ import logging +from typing import Text import pytest from aioresponses import aioresponses from tests.utilities import latest_request, json_of_latest_request +from tests.core.conftest import DEFAULT_ENDPOINTS_FILE import rasa.utils.endpoints as endpoint_utils @@ -120,3 +122,25 @@ async def test_request_non_json_response(): response = await endpoint.request("post", subpath="test") assert not response + + +@pytest.mark.parametrize( + "filename, endpoint_type", [(DEFAULT_ENDPOINTS_FILE, "tracker_store"),], +) +def test_read_endpoint_config(filename: Text, endpoint_type: Text): + conf = endpoint_utils.read_endpoint_config(filename, endpoint_type) + assert isinstance(conf, endpoint_utils.EndpointConfig) + + +@pytest.mark.parametrize( + "filename, endpoint_type", + [ + ("", "tracker_store"), + (DEFAULT_ENDPOINTS_FILE, "stuff"), + (DEFAULT_ENDPOINTS_FILE, "empty"), + ("/unknown/path.yml", "tracker_store"), + ], +) +def test_read_endpoint_config_not_found(filename: Text, endpoint_type: Text): + conf = endpoint_utils.read_endpoint_config(filename, endpoint_type) + assert conf is None