Skip to content

Commit

Permalink
new CLI: rasa data convert responses
Browse files Browse the repository at this point in the history
  • Loading branch information
m-vdb committed Nov 5, 2020
1 parent 9ef7bbe commit feab7b6
Showing 1 changed file with 136 additions and 4 deletions.
140 changes: 136 additions & 4 deletions rasa/cli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
from pathlib import Path
from typing import List, Text, Dict
from typing import Dict, List, Text, Tuple, TYPE_CHECKING

import rasa.shared.core.domain
from rasa import telemetry
Expand All @@ -17,19 +17,20 @@
DEFAULT_CONFIG_PATH,
DEFAULT_DOMAIN_PATH,
DOCS_URL_MIGRATION_GUIDE,
UTTER_PREFIX,
)
import rasa.shared.data
from rasa.shared.core.constants import (
USER_INTENT_OUT_OF_SCOPE,
ACTION_DEFAULT_FALLBACK_NAME,
)
from rasa.shared.core.events import ActionExecuted
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
YAMLStoryReader,
)
from rasa.shared.core.training_data.story_writer.yaml_story_writer import (
YAMLStoryWriter,
)
from rasa.shared.core.training_data.structures import StoryStep
from rasa.shared.importers.rasa import RasaFileImporter
import rasa.shared.nlu.training_data.loading
import rasa.shared.nlu.training_data.util
Expand All @@ -45,8 +46,13 @@
from rasa.core.policies.two_stage_fallback import TwoStageFallbackPolicy
from rasa.core.policies.mapping_policy import MappingPolicy

if TYPE_CHECKING:
from rasa.shared.core.training_data.structures import StoryStep

logger = logging.getLogger(__name__)

OBSOLETE_RESPOND_PREFIX = "respond_"


def add_subparser(
subparsers: SubParsersAction, parents: List[argparse.ArgumentParser]
Expand Down Expand Up @@ -131,6 +137,19 @@ def _add_data_convert_parsers(
"part of the migration. If the file doesn't exist, it will be created.",
)

convert_responses_parser = convert_subparsers.add_parser(
"responses",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parents=parents,
help=(
"Convert retrieval intent responses between Rasa Open Source versions. "
"Please also run `rasa data convert nlg` to convert responses to the right format."
),
)
convert_responses_parser.set_defaults(func=_migrate_responses)
default_arguments.add_stories_param(convert_responses_parser)
default_arguments.add_domain_param(convert_responses_parser)


def _add_data_split_parsers(
data_subparsers, parents: List[argparse.ArgumentParser]
Expand Down Expand Up @@ -324,6 +343,119 @@ def _convert_nlg_data(args: argparse.Namespace) -> None:
)


def _migrate_responses(args: argparse.Namespace) -> None:
"""Migrate retrieval intent responses to the new 2.0 format.
It does so modifying the stories and domain files.
"""
_migrate_responses_in_domain(args)
rasa.utils.common.run_in_loop(_migrate_responses_in_stories(args))
telemetry.track_data_convert(args.format, "responses")


def _migrate_responses_in_domain(args: argparse.Namespace):
"""Migrate retrieval intent responses to the new 2.0 format.
Before 2.0, retrieval intent responses needed to start
with `respond_`. Now, they need to start with `utter_`.
This function updates the domain file in place.
Args:
args: the CLI arguments
"""
domain_file = Path(args.domain)
domain = _get_domain(domain_file)

domain_dict = domain.cleaned_domain()
domain_dict["actions"] = [
_normalize_response_name(action) for action in domain_dict["actions"]
]

new_domain = Domain.from_dict(domain_dict)
new_domain.persist_clean(domain_file)


async def _migrate_responses_in_stories(args: argparse.Namespace):
"""Migrate retrieval intent responses to the new 2.0 format.
Before 2.0, retrieval intent responses needed to start
with `respond_`. Now, they need to start with `utter_`.
This function updates the story files in place.
Args:
args: the CLI arguments
"""
stories = await _load_stories_from_resource(args.stories)
for story_file, story_steps in stories:
for story_step in story_steps:
for event in story_step.events:
if isinstance(event, ActionExecuted):
event.action_name = _normalize_response_name(event.action_name)

YAMLStoryWriter().dump(story_file, story_steps)


def _normalize_response_name(action_name: Text) -> Text:
return (
f"{UTTER_PREFIX}{action_name[len(OBSOLETE_RESPOND_PREFIX):]}"
if action_name.starswith(OBSOLETE_RESPOND_PREFIX)
else action_name
)


async def _load_stories_from_resource(resource: Text) -> List[Tuple[Text, "StoryStep"]]:
"""Loads core training data from a resource (folder or file).
Args:
resource: Folder/File with core training data files.
Returns:
Story steps from the training data.
"""
story_files = rasa.shared.data.get_data_files(
resource, rasa.shared.data.is_story_file
)

stories = []

for story_file in story_files:

reader = _get_yaml_story_reader(story_file)

story_steps = reader.read_from_file(story_file)
stories.append((story_file, story_steps))

return stories


def _get_yaml_story_reader(filename: Text) -> YAMLStoryReader:
"""Get a `YAMLStoryReader` instance for a given file.
This function also validates that the file is a valid YAML story files,
and exits gracefully if not.
Args:
filename: the name of the story file
Returns:
an instance of YAMLStoryReader
"""
if YAMLStoryReader.is_stories_file(filename):
return YAMLStoryReader(source_name=filename)

if filename.endswith(".md"):
rasa.shared.utils.cli.print_error_and_exit(
f"File {filename} is not a valid YAML stories file. "
f"Please run `rasa data convert nlu` to convert your "
f"stories to the new YAML format."
)
else:
rasa.shared.utils.cli.print_error_and_exit(
f"File {filename} is not a valid YAML stories file."
)


async def _convert_to_yaml(
args: argparse.Namespace, converter: TrainingDataConverter
) -> None:
Expand Down Expand Up @@ -571,7 +703,7 @@ def _get_rules_path(path: Text) -> Path:
return rules_file


def _dump_rules(path: Path, new_rules: List[StoryStep]) -> None:
def _dump_rules(path: Path, new_rules: List["StoryStep"]) -> None:
existing_rules = []
if path.exists():
rules_reader = YAMLStoryReader()
Expand All @@ -593,7 +725,7 @@ def _backup(path: Path) -> None:
shutil.copy(path, backup_file)


def _print_success_message(new_rules: List[StoryStep], output_file: Path) -> None:
def _print_success_message(new_rules: List["StoryStep"], output_file: Path) -> None:
if len(new_rules) > 1:
suffix = "rule"
verb = "was"
Expand Down

0 comments on commit feab7b6

Please sign in to comment.