diff --git a/changelog/9766.removal.md b/changelog/9766.removal.md new file mode 100644 index 000000000000..006c5419033a --- /dev/null +++ b/changelog/9766.removal.md @@ -0,0 +1,2 @@ +`rasa.core.agent.Agent.visualize` was removed. Please use `rasa visualize` or +`rasa.core.visualize.visualize` instead. diff --git a/rasa/cli/arguments/visualize.py b/rasa/cli/arguments/visualize.py index 360c6959d4bb..2f7853e00b17 100644 --- a/rasa/cli/arguments/visualize.py +++ b/rasa/cli/arguments/visualize.py @@ -1,7 +1,6 @@ import argparse from rasa.cli.arguments.default_arguments import ( - add_config_param, add_domain_param, add_stories_param, add_out_param, @@ -10,9 +9,9 @@ def set_visualize_stories_arguments(parser: argparse.ArgumentParser) -> None: + """Sets the CLI arguments for `rasa data visualize.""" add_domain_param(parser) add_stories_param(parser) - add_config_param(parser) add_out_param( parser, diff --git a/rasa/cli/visualize.py b/rasa/cli/visualize.py index 3da0c33b55fb..7774673fba1f 100644 --- a/rasa/cli/visualize.py +++ b/rasa/cli/visualize.py @@ -37,8 +37,6 @@ def visualize_stories(args: argparse.Namespace) -> None: if args.nlu is None and os.path.exists(DEFAULT_DATA_PATH): args.nlu = rasa.shared.data.get_nlu_directory(DEFAULT_DATA_PATH) - rasa.utils.common.run_in_loop( - rasa.core.visualize.visualize( - args.config, args.domain, args.stories, args.nlu, args.out, args.max_history - ) + rasa.core.visualize.visualize( + args.domain, args.stories, args.nlu, args.out, args.max_history ) diff --git a/rasa/core/agent.py b/rasa/core/agent.py index fd89bd78ab79..c100f0cb09df 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -36,7 +36,6 @@ from rasa.nlu.utils import is_url from rasa.shared.exceptions import RasaException import rasa.shared.utils.io -from rasa.shared.nlu.training_data.training_data import TrainingData from rasa.utils.endpoints import EndpointConfig from rasa.core.tracker_store import TrackerStore @@ -499,36 +498,6 @@ def _set_fingerprint(self, fingerprint: Optional[Text] = None) -> None: else: self.fingerprint = uuid.uuid4().hex - async def visualize( - self, - resource_name: Text, - output_file: Text, - max_history: Optional[int] = None, - nlu_training_data: Optional[TrainingData] = None, - should_merge_nodes: bool = True, - fontsize: int = 12, - ) -> None: - """Visualize the loaded training data from the resource.""" - # TODO: This needs to be fixed to not use the interpreter - from rasa.shared.core.training_data.visualization import visualize_stories - from rasa.shared.core.training_data import loading - - # if the user doesn't provide a max history, we will use the - # largest value from any policy - max_history = max_history or self._max_history() - - story_steps = loading.load_data_from_resource(resource_name, self.domain) - await visualize_stories( - story_steps, - self.domain, - output_file, - max_history, - self.interpreter, - nlu_training_data, - should_merge_nodes, - fontsize, - ) - @staticmethod def _create_tracker_store( store: Optional[TrackerStore], domain: Domain diff --git a/rasa/core/visualize.py b/rasa/core/visualize.py index 85ca6601a76a..ce8fa339e8e7 100644 --- a/rasa/core/visualize.py +++ b/rasa/core/visualize.py @@ -3,35 +3,34 @@ from typing import Text from rasa import telemetry +from rasa.shared.core.training_data import loading from rasa.shared.utils.cli import print_error - -from rasa.shared.core.domain import InvalidDomain +from rasa.shared.core.domain import InvalidDomain, Domain logger = logging.getLogger(__name__) -async def visualize( - config_path: Text, +def visualize( domain_path: Text, stories_path: Text, nlu_data_path: Text, output_path: Text, max_history: int, ) -> None: - from rasa.core.agent import Agent - from rasa.core import config + """Visualizes stories as graph. - try: - policies = config.load(config_path) - except Exception as e: - print_error( - f"Could not load config due to: '{e}'. To specify a valid config file use " - f"the '--config' argument." - ) - return + Args: + domain_path: Path to the domain file. + stories_path: Path to the stories files. + nlu_data_path: Path to the NLU training data which can be used to interpolate + intents with actual examples in the graph. + output_path: Path where the created graph should be persisted. + max_history: Max history to use for the story visualization. + """ + import rasa.shared.core.training_data.visualization try: - agent = Agent(domain=domain_path, policies=policies) + domain = Domain.load(domain_path) except InvalidDomain as e: print_error( f"Could not load domain due to: '{e}'. To specify a valid domain path use " @@ -53,8 +52,14 @@ async def visualize( logger.info("Starting to visualize stories...") telemetry.track_visualization() - await agent.visualize( - stories_path, output_path, max_history, nlu_training_data=nlu_training_data + + story_steps = loading.load_data_from_resource(stories_path, domain) + rasa.shared.core.training_data.visualization.visualize_stories( + story_steps, + domain, + output_path, + max_history, + nlu_training_data=nlu_training_data, ) full_output_path = "file://{}".format(os.path.abspath(output_path)) diff --git a/rasa/shared/core/training_data/visualization.py b/rasa/shared/core/training_data/visualization.py index cad605dcab62..99dd98b21ba1 100644 --- a/rasa/shared/core/training_data/visualization.py +++ b/rasa/shared/core/training_data/visualization.py @@ -4,6 +4,7 @@ from typing import Any, Text, List, Dict, Optional, TYPE_CHECKING, Set import rasa.shared.utils.io +from rasa.shared.constants import INTENT_MESSAGE_PREFIX from rasa.shared.core.constants import ACTION_LISTEN_NAME from rasa.shared.core.domain import Domain from rasa.shared.core.events import UserUttered, ActionExecuted, Event @@ -14,7 +15,6 @@ INTENT, TEXT, ENTITY_ATTRIBUTE_TYPE, - ENTITIES, INTENT_NAME_KEY, ) @@ -58,25 +58,15 @@ def _contains_same_entity(entities: Dict[Text, Any], e: Dict[Text, Any]) -> bool ) != e.get(ENTITY_ATTRIBUTE_VALUE) def message_for_data(self, structured_info: Dict[Text, Any]) -> Any: - """Find a data sample with the same intent and entities. - - Given the parsed data from a message (intent and entities) finds a - message in the data that has the same intent and entities.""" - + """Find a data sample with the same intent.""" if structured_info.get(INTENT) is not None: intent_name = structured_info.get(INTENT, {}).get(INTENT_NAME_KEY) usable_examples = self.mapping.get(intent_name, [])[:] random.shuffle(usable_examples) - for example in usable_examples: - entities = { - e.get(ENTITY_ATTRIBUTE_TYPE): e.get(ENTITY_ATTRIBUTE_VALUE) - for e in example.get(ENTITIES, []) - } - for e in structured_info.get(ENTITIES, []): - if self._contains_same_entity(entities, e): - break - else: - return example.get(TEXT) + + if usable_examples: + return usable_examples[0].get(TEXT) + return structured_info.get(TEXT) @@ -265,13 +255,12 @@ def _merge_equivalent_nodes(graph: "networkx.MultiDiGraph", max_history: int) -> graph.remove_node(j) -async def _replace_edge_labels_with_nodes( - graph: "networkx.MultiDiGraph", - next_id: int, - interpreter, - nlu_training_data: "TrainingData", +def _replace_edge_labels_with_nodes( + graph: "networkx.MultiDiGraph", next_id: int, nlu_training_data: "TrainingData", ) -> None: - """User messages are created as edge labels. This removes the labels and + """Replaces edge labels with nodes. + + User messages are created as edge labels. This removes the labels and creates nodes instead. The algorithms (e.g. merging) are simpler if the user messages are labels @@ -287,11 +276,14 @@ async def _replace_edge_labels_with_nodes( edges = list(graph.edges(keys=True, data=True)) for s, e, k, d in edges: if k != EDGE_NONE_LABEL: - if message_generator and d.get("label", k) is not None: - parsed_info = await interpreter.parse(d.get("label", k)) + label = d.get("label", k) + + if message_generator: + parsed_info = {TEXT: label} + if label.startswith(INTENT_MESSAGE_PREFIX): + parsed_info[INTENT] = {INTENT_NAME_KEY: label[1:]} + label = message_generator.message_for_data(parsed_info) - else: - label = d.get("label", k) next_id += 1 graph.remove_edge(s, e, k) graph.add_node( @@ -411,12 +403,11 @@ def _add_message_edge( ) -async def visualize_neighborhood( +def visualize_neighborhood( current: Optional[List[Event]], event_sequences: List[List[Event]], output_file: Optional[Text] = None, max_history: int = 2, - interpreter=None, nlu_training_data: Optional["TrainingData"] = None, should_merge_nodes: bool = True, max_distance: int = 1, @@ -446,10 +437,8 @@ async def visualize_neighborhood( idx -= 1 break if isinstance(el, UserUttered): - if not el.intent: - message = await interpreter.parse(el.text) - else: - message = el.parse_data + message = el.parse_data + message[TEXT] = f"{INTENT_MESSAGE_PREFIX}{el.intent_name}" elif ( isinstance(el, ActionExecuted) and el.action_name != ACTION_LISTEN_NAME ): @@ -507,9 +496,7 @@ async def visualize_neighborhood( if should_merge_nodes: _merge_equivalent_nodes(graph, max_history) - await _replace_edge_labels_with_nodes( - graph, next_node_idx, interpreter, nlu_training_data - ) + _replace_edge_labels_with_nodes(graph, next_node_idx, nlu_training_data) _remove_auxiliary_nodes(graph, special_node_idx) @@ -538,12 +525,11 @@ def _remove_auxiliary_nodes( ps.add(pred) -async def visualize_stories( +def visualize_stories( story_steps: List[StoryStep], domain: Domain, output_file: Optional[Text], max_history: int, - interpreter=None, # TODO: Fix this to use processor: nlu_training_data: Optional["TrainingData"] = None, should_merge_nodes: bool = True, fontsize: int = 12, @@ -587,12 +573,11 @@ async def visualize_stories( completed_trackers = g.generate() event_sequences = [t.events for t in completed_trackers] - graph = await visualize_neighborhood( + graph = visualize_neighborhood( None, event_sequences, output_file, max_history, - interpreter, nlu_training_data, should_merge_nodes, max_distance=1, diff --git a/tests/cli/test_rasa_visualize.py b/tests/cli/test_rasa_visualize.py index 7b3e99cda5c1..27dd46fc2831 100644 --- a/tests/cli/test_rasa_visualize.py +++ b/tests/cli/test_rasa_visualize.py @@ -6,8 +6,7 @@ def test_visualize_help(run: Callable[..., RunResult]): output = run("visualize", "--help") help_text = """usage: rasa visualize [-h] [-v] [-vv] [--quiet] [-d DOMAIN] [-s STORIES] - [-c CONFIG] [--out OUT] [--max-history MAX_HISTORY] - [-u NLU]""" + [--out OUT] [--max-history MAX_HISTORY] [-u NLU]""" lines = help_text.split("\n") # expected help text lines should appear somewhere in the output diff --git a/tests/core/test_training.py b/tests/core/test_training.py index 65a6fe28290a..d92191302cda 100644 --- a/tests/core/test_training.py +++ b/tests/core/test_training.py @@ -10,7 +10,6 @@ from rasa.shared.core.domain import Domain from rasa.core.policies.ted_policy import TEDPolicy -from rasa.shared.core.training_data.visualization import visualize_stories import rasa.model_training import rasa.shared.utils.io @@ -22,40 +21,6 @@ def test_load_training_data_reader_not_found_throws(tmp_path: Path, domain: Doma training.load_data(str(tmp_path), domain) -async def test_story_visualization(domain: Domain, tmp_path: Path): - import rasa.shared.core.training_data.loading as core_loading - - story_steps = core_loading.load_data_from_resource( - "data/test_yaml_stories/stories.yml", domain - ) - out_file = str(tmp_path / "graph.html") - generated_graph = await visualize_stories( - story_steps, - domain, - output_file=out_file, - max_history=3, - should_merge_nodes=False, - ) - - assert len(generated_graph.nodes()) == 51 - - assert len(generated_graph.edges()) == 56 - - -async def test_story_visualization_with_merging(domain: Domain): - import rasa.shared.core.training_data.loading as core_loading - - story_steps = core_loading.load_data_from_resource( - "data/test_yaml_stories/stories.yml", domain - ) - generated_graph = await visualize_stories( - story_steps, domain, output_file=None, max_history=3, should_merge_nodes=True, - ) - assert 15 < len(generated_graph.nodes()) < 33 - - assert 20 < len(generated_graph.edges()) < 33 - - def test_training_script_with_restart_stories(tmp_path: Path, domain_path: Text): model_file = rasa.model_training.train_core( domain_path, diff --git a/tests/shared/core/training_data/test_visualization.py b/tests/shared/core/training_data/test_visualization.py index d3125442494a..569398494ad5 100644 --- a/tests/shared/core/training_data/test_visualization.py +++ b/tests/shared/core/training_data/test_visualization.py @@ -1,10 +1,14 @@ from pathlib import Path +from typing import Text import rasa.shared.utils.io from rasa.shared.core.domain import Domain from rasa.shared.core.events import ActionExecuted, SlotSet, UserUttered from rasa.shared.core.training_data import visualization import rasa.utils.io +from rasa.shared.nlu.constants import TEXT, INTENT +from rasa.shared.nlu.training_data.message import Message +from rasa.shared.nlu.training_data.training_data import TrainingData def test_style_transfer(): @@ -77,7 +81,7 @@ def test_common_action_prefix_unequal(): assert num_common == 0 -async def test_graph_persistence(domain: Domain, tmp_path: Path): +def test_graph_persistence(domain: Domain, tmp_path: Path): from os.path import isfile from networkx.drawing import nx_pydot import rasa.shared.core.training_data.loading as core_loading @@ -86,7 +90,7 @@ async def test_graph_persistence(domain: Domain, tmp_path: Path): "data/test_yaml_stories/stories.yml", domain ) out_file = str(tmp_path / "graph.html") - generated_graph = await visualization.visualize_stories( + generated_graph = visualization.visualize_stories( story_steps, domain, output_file=out_file, @@ -104,7 +108,7 @@ async def test_graph_persistence(domain: Domain, tmp_path: Path): assert "graph = `{}`".format(generated_graph.to_string()) in content -async def test_merge_nodes(domain: Domain, tmp_path: Path): +def test_merge_nodes(domain: Domain, tmp_path: Path): from os.path import isfile import rasa.shared.core.training_data.loading as core_loading @@ -112,7 +116,7 @@ async def test_merge_nodes(domain: Domain, tmp_path: Path): "data/test_yaml_stories/stories.yml", domain ) out_file = str(tmp_path / "graph.html") - await visualization.visualize_stories( + visualization.visualize_stories( story_steps, domain, output_file=out_file, @@ -120,3 +124,67 @@ async def test_merge_nodes(domain: Domain, tmp_path: Path): should_merge_nodes=True, ) assert isfile(out_file) + + +def test_story_visualization(domain: Domain, tmp_path: Path): + import rasa.shared.core.training_data.loading as core_loading + + story_steps = core_loading.load_data_from_resource( + "data/test_yaml_stories/stories.yml", domain + ) + out_file = tmp_path / "graph.html" + generated_graph = visualization.visualize_stories( + story_steps, + domain, + output_file=str(out_file), + max_history=3, + should_merge_nodes=False, + ) + + assert str(None) not in out_file.read_text() + assert "/affirm" in out_file.read_text() + assert len(generated_graph.nodes()) == 51 + assert len(generated_graph.edges()) == 56 + + +def test_story_visualization_with_training_data( + domain: Domain, tmp_path: Path, nlu_data_path: Text +): + import rasa.shared.core.training_data.loading as core_loading + + story_steps = core_loading.load_data_from_resource( + "data/test_yaml_stories/stories.yml", domain + ) + out_file = tmp_path / "graph.html" + test_text = "test text" + test_intent = "affirm" + generated_graph = visualization.visualize_stories( + story_steps, + domain, + output_file=str(out_file), + max_history=3, + should_merge_nodes=False, + nlu_training_data=TrainingData( + [Message({TEXT: test_text, INTENT: test_intent})] + ), + ) + + assert test_text in out_file.read_text() + assert test_intent not in out_file.read_text() + + assert len(generated_graph.nodes()) == 51 + assert len(generated_graph.edges()) == 56 + + +def test_story_visualization_with_merging(domain: Domain): + import rasa.shared.core.training_data.loading as core_loading + + story_steps = core_loading.load_data_from_resource( + "data/test_yaml_stories/stories.yml", domain + ) + generated_graph = visualization.visualize_stories( + story_steps, domain, output_file=None, max_history=3, should_merge_nodes=True, + ) + assert 15 < len(generated_graph.nodes()) < 33 + + assert 20 < len(generated_graph.edges()) < 33