Skip to content

Commit

Permalink
Merge pull request #7965 from RasaHQ/config-file-schema-validation
Browse files Browse the repository at this point in the history
validate model config
  • Loading branch information
wochinge authored Feb 18, 2021
2 parents bb6c348 + 0f47b4e commit 2574c46
Show file tree
Hide file tree
Showing 19 changed files with 237 additions and 68 deletions.
1 change: 1 addition & 0 deletions changelog/7893.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Model configuration files are now validated whether they match the expected schema.
2 changes: 1 addition & 1 deletion rasa/cli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def _migrate_model_config(args: argparse.Namespace) -> None:
def _get_configuration(path: Path) -> Dict:
config = {}
try:
config = rasa.shared.utils.io.read_config_file(path)
config = rasa.shared.utils.io.read_model_configuration(path)
except Exception:
rasa.shared.utils.cli.print_error_and_exit(
f"'{path}' is not a path to a valid model configuration. "
Expand Down
15 changes: 0 additions & 15 deletions rasa/cli/default_config.yml

This file was deleted.

2 changes: 1 addition & 1 deletion rasa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load(config_file: Optional[Union[Text, Dict]]) -> List["Policy"]:

config_data = {}
if isinstance(config_file, str) and os.path.isfile(config_file):
config_data = rasa.shared.utils.io.read_config_file(config_file)
config_data = rasa.shared.utils.io.read_model_configuration(config_file)
elif isinstance(config_file, Dict):
config_data = config_file

Expand Down
2 changes: 1 addition & 1 deletion rasa/nlu/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def load(
config = DEFAULT_CONFIG_PATH

if config is not None:
file_config = rasa.shared.utils.io.read_config_file(config)
file_config = rasa.shared.utils.io.read_model_configuration(config)

return _load_from_dict(file_config, **kwargs)

Expand Down
5 changes: 3 additions & 2 deletions rasa/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
PACKAGE_NAME = "rasa"
NEXT_MAJOR_VERSION_FOR_DEPRECATIONS = "3.0.0"

CONFIG_SCHEMA_FILE = "shared/nlu/training_data/schemas/config.yml"
MODEL_CONFIG_SCHEMA_FILE = "shared/utils/schemas/model_config.yml"
CONFIG_SCHEMA_FILE = "shared/utils/schemas/config.yml"
RESPONSES_SCHEMA_FILE = "shared/nlu/training_data/schemas/responses.yml"
SCHEMA_EXTENSIONS_FILE = "shared/utils/pykwalify_extensions.py"
LATEST_TRAINING_DATA_FORMAT_VERSION = "2.0"

DOMAIN_SCHEMA_FILE = "utils/schemas/domain.yml"
DOMAIN_SCHEMA_FILE = "shared/utils/schemas/domain.yml"

DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES = 60
DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
KEY_RULE_FOR_CONVERSATION_START = "conversation_start"


CORE_SCHEMA_FILE = "utils/schemas/stories.yml"
CORE_SCHEMA_FILE = "shared/utils/schemas/stories.yml"
DEFAULT_VALUE_TEXT_SLOTS = "filled"
DEFAULT_VALUE_LIST_SLOTS = [DEFAULT_VALUE_TEXT_SLOTS]

Expand Down
2 changes: 1 addition & 1 deletion rasa/shared/importers/multi_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
training_data_paths: Optional[Union[List[Text], Text]] = None,
project_directory: Optional[Text] = None,
):
self.config = rasa.shared.utils.io.read_config_file(config_file)
self.config = rasa.shared.utils.io.read_model_configuration(config_file)
if domain_path:
self._domain_paths = [domain_path]
else:
Expand Down
8 changes: 0 additions & 8 deletions rasa/shared/nlu/training_data/schemas/config.yml

This file was deleted.

2 changes: 1 addition & 1 deletion rasa/shared/nlu/training_data/schemas/nlu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ mapping:
type: "str"
examples: *examples_anchor
responses:
# see rasa/utils/schemas.yml
# see rasa/shared/nlu/training_data/schemas/responses.yml
include: responses
regex;(.*):
type: "any"
70 changes: 56 additions & 14 deletions rasa/shared/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
DEFAULT_LOG_LEVEL,
ENV_LOG_LEVEL,
NEXT_MAJOR_VERSION_FOR_DEPRECATIONS,
CONFIG_SCHEMA_FILE,
MODEL_CONFIG_SCHEMA_FILE,
)
from rasa.shared.exceptions import (
FileIOException,
FileNotFoundException,
YamlSyntaxException,
)
import rasa.shared.utils.validation

DEFAULT_ENCODING = "utf-8"
YAML_VERSION = (1, 2)
Expand Down Expand Up @@ -518,29 +521,68 @@ def raise_deprecation_warning(
raise_warning(message, FutureWarning, docs, **kwargs)


def read_validated_yaml(filename: Union[Text, Path], schema: Text) -> Any:
"""Validates YAML file content and returns parsed content.
Args:
filename: The path to the file which should be read.
schema: The path to the schema file which should be used for validating the
file content.
Returns:
The parsed file content.
Raises:
YamlValidationException: In case the model configuration doesn't match the
expected schema.
"""
content = read_file(filename)

rasa.shared.utils.validation.validate_yaml_schema(content, schema)
return read_yaml(content)


def read_config_file(filename: Union[Path, Text]) -> Dict[Text, Any]:
"""Parses a yaml configuration file. Content needs to be a dictionary
"""Parses a yaml configuration file. Content needs to be a dictionary.
Args:
filename: The path to the file which should be read.
Raises:
YamlValidationException: In case file content is not a `Dict`.
Returns:
Parsed config file.
"""
content = read_yaml_file(filename)
return read_validated_yaml(filename, CONFIG_SCHEMA_FILE)

if content is None:
return {}
elif isinstance(content, dict):
return content
else:
raise YamlSyntaxException(
filename,
ValueError(
f"Tried to load configuration file '{filename}'. "
f"Expected a key value mapping but found a {type(content).__name__}"
),
)

def read_model_configuration(filename: Union[Path, Text]) -> Dict[Text, Any]:
"""Parses a model configuration file.
Args:
filename: The path to the file which should be read.
Raises:
YamlValidationException: In case the model configuration doesn't match the
expected schema.
Returns:
Parsed config file.
"""
return read_validated_yaml(filename, MODEL_CONFIG_SCHEMA_FILE)


def is_subdirectory(path: Text, potential_parent_directory: Text) -> bool:
"""Checks if `path` is a subdirectory of `potential_parent_directory`.
Args:
path: Path to a file or directory.
potential_parent_directory: Potential parent directory.
Returns:
`True` if `path` is a subdirectory of `potential_parent_directory`.
"""
if path is None or potential_parent_directory is None:
return False

Expand Down
2 changes: 2 additions & 0 deletions rasa/shared/utils/schemas/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
allowempty: True
type: map
File renamed without changes.
33 changes: 33 additions & 0 deletions rasa/shared/utils/schemas/model_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
allowempty: True
mapping:
version:
type: "str"
required: False
allowempty: False
language:
type: "str"
required: False
pipeline:
type: "seq"
required: False
sequence:
- type: "map"
# Only validate required items but do not validate each potential config param
# for the the components
allowempty: True
mapping:
name:
type: str
required: True
policies:
type: "seq"
required: False
sequence:
- type: "map"
# Only validate required items but do not validate each potential config param
# for the the policies
allowempty: True
mapping:
name:
type: str
required: True
File renamed without changes.
12 changes: 2 additions & 10 deletions tests/cli/test_rasa_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,12 @@ def test_train_core_compare(run_in_simple_project: Callable[..., RunResult]):
temp_dir = os.getcwd()

rasa.shared.utils.io.write_yaml(
{
"language": "en",
"pipeline": "supervised_embeddings",
"policies": [{"name": "MemoizationPolicy"}],
},
{"language": "en", "policies": [{"name": "MemoizationPolicy"}],},
"config_1.yml",
)

rasa.shared.utils.io.write_yaml(
{
"language": "en",
"pipeline": "supervised_embeddings",
"policies": [{"name": "MemoizationPolicy"}],
},
{"language": "en", "policies": [{"name": "MemoizationPolicy"}],},
"config_2.yml",
)

Expand Down
8 changes: 2 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from rasa.shared.exceptions import RasaException

DEFAULT_CONFIG_PATH = "rasa/cli/default_config.yml"
DEFAULT_CONFIG_PATH = "rasa/shared/importers/default_config.yml"

DEFAULT_NLU_DATA = "examples/moodbot/data/nlu.yml"

Expand Down Expand Up @@ -287,11 +287,7 @@ async def trained_core_model(

@pytest.fixture(scope="session")
async def trained_nlu_model(
trained_async: Callable,
default_domain_path: Text,
default_config: List[Policy],
default_nlu_data: Text,
default_stories_file: Text,
trained_async: Callable, default_domain_path: Text, default_nlu_data: Text,
) -> Text:
trained_nlu_model_path = await trained_async(
domain=default_domain_path,
Expand Down
Loading

0 comments on commit 2574c46

Please sign in to comment.