Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor interpreter usage in story visualization #9821

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/9766.removal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
`rasa.core.agent.Agent.visualize` was removed. Please use `rasa visualize` or
`rasa.core.visualize.visualize` instead.
3 changes: 1 addition & 2 deletions rasa/cli/arguments/visualize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse

from rasa.cli.arguments.default_arguments import (
add_config_param,
add_domain_param,
add_stories_param,
add_out_param,
Expand All @@ -10,9 +9,9 @@


def set_visualize_stories_arguments(parser: argparse.ArgumentParser) -> None:
"""Sets the CLI arguments for `rasa data visualize."""
add_domain_param(parser)
add_stories_param(parser)
add_config_param(parser)

add_out_param(
parser,
Expand Down
6 changes: 2 additions & 4 deletions rasa/cli/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def visualize_stories(args: argparse.Namespace) -> None:
if args.nlu is None and os.path.exists(DEFAULT_DATA_PATH):
args.nlu = rasa.shared.data.get_nlu_directory(DEFAULT_DATA_PATH)

rasa.utils.common.run_in_loop(
rasa.core.visualize.visualize(
args.config, args.domain, args.stories, args.nlu, args.out, args.max_history
)
rasa.core.visualize.visualize(
args.domain, args.stories, args.nlu, args.out, args.max_history
)
31 changes: 0 additions & 31 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from rasa.nlu.utils import is_url
from rasa.shared.exceptions import RasaException
import rasa.shared.utils.io
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.utils.endpoints import EndpointConfig

from rasa.core.tracker_store import TrackerStore
Expand Down Expand Up @@ -499,36 +498,6 @@ def _set_fingerprint(self, fingerprint: Optional[Text] = None) -> None:
else:
self.fingerprint = uuid.uuid4().hex

async def visualize(
self,
resource_name: Text,
output_file: Text,
max_history: Optional[int] = None,
nlu_training_data: Optional[TrainingData] = None,
should_merge_nodes: bool = True,
fontsize: int = 12,
) -> None:
"""Visualize the loaded training data from the resource."""
# TODO: This needs to be fixed to not use the interpreter
from rasa.shared.core.training_data.visualization import visualize_stories
from rasa.shared.core.training_data import loading

# if the user doesn't provide a max history, we will use the
# largest value from any policy
max_history = max_history or self._max_history()

story_steps = loading.load_data_from_resource(resource_name, self.domain)
await visualize_stories(
story_steps,
self.domain,
output_file,
max_history,
self.interpreter,
nlu_training_data,
should_merge_nodes,
fontsize,
)

@staticmethod
def _create_tracker_store(
store: Optional[TrackerStore], domain: Domain
Expand Down
39 changes: 22 additions & 17 deletions rasa/core/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,34 @@
from typing import Text

from rasa import telemetry
from rasa.shared.core.training_data import loading
from rasa.shared.utils.cli import print_error

from rasa.shared.core.domain import InvalidDomain
from rasa.shared.core.domain import InvalidDomain, Domain

logger = logging.getLogger(__name__)


async def visualize(
config_path: Text,
def visualize(
domain_path: Text,
stories_path: Text,
nlu_data_path: Text,
output_path: Text,
max_history: int,
) -> None:
from rasa.core.agent import Agent
from rasa.core import config
"""Visualizes stories as graph.

try:
policies = config.load(config_path)
except Exception as e:
print_error(
f"Could not load config due to: '{e}'. To specify a valid config file use "
f"the '--config' argument."
)
return
Args:
domain_path: Path to the domain file.
stories_path: Path to the stories files.
nlu_data_path: Path to the NLU training data which can be used to interpolate
intents with actual examples in the graph.
output_path: Path where the created graph should be persisted.
max_history: Max history to use for the story visualization.
"""
import rasa.shared.core.training_data.visualization

try:
agent = Agent(domain=domain_path, policies=policies)
domain = Domain.load(domain_path)
except InvalidDomain as e:
print_error(
f"Could not load domain due to: '{e}'. To specify a valid domain path use "
Expand All @@ -53,8 +52,14 @@ async def visualize(

logger.info("Starting to visualize stories...")
telemetry.track_visualization()
await agent.visualize(
stories_path, output_path, max_history, nlu_training_data=nlu_training_data

story_steps = loading.load_data_from_resource(stories_path, domain)
rasa.shared.core.training_data.visualization.visualize_stories(
story_steps,
domain,
output_path,
max_history,
nlu_training_data=nlu_training_data,
)

full_output_path = "file://{}".format(os.path.abspath(output_path))
Expand Down
63 changes: 24 additions & 39 deletions rasa/shared/core/training_data/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Text, List, Dict, Optional, TYPE_CHECKING, Set

import rasa.shared.utils.io
from rasa.shared.constants import INTENT_MESSAGE_PREFIX
from rasa.shared.core.constants import ACTION_LISTEN_NAME
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import UserUttered, ActionExecuted, Event
Expand All @@ -14,7 +15,6 @@
INTENT,
TEXT,
ENTITY_ATTRIBUTE_TYPE,
ENTITIES,
INTENT_NAME_KEY,
)

Expand Down Expand Up @@ -58,25 +58,15 @@ def _contains_same_entity(entities: Dict[Text, Any], e: Dict[Text, Any]) -> bool
) != e.get(ENTITY_ATTRIBUTE_VALUE)

def message_for_data(self, structured_info: Dict[Text, Any]) -> Any:
"""Find a data sample with the same intent and entities.

Given the parsed data from a message (intent and entities) finds a
message in the data that has the same intent and entities."""

"""Find a data sample with the same intent."""
if structured_info.get(INTENT) is not None:
intent_name = structured_info.get(INTENT, {}).get(INTENT_NAME_KEY)
usable_examples = self.mapping.get(intent_name, [])[:]
random.shuffle(usable_examples)
for example in usable_examples:
entities = {
e.get(ENTITY_ATTRIBUTE_TYPE): e.get(ENTITY_ATTRIBUTE_VALUE)
for e in example.get(ENTITIES, [])
}
for e in structured_info.get(ENTITIES, []):
if self._contains_same_entity(entities, e):
break
else:
return example.get(TEXT)

if usable_examples:
return usable_examples[0].get(TEXT)

return structured_info.get(TEXT)


Expand Down Expand Up @@ -265,13 +255,12 @@ def _merge_equivalent_nodes(graph: "networkx.MultiDiGraph", max_history: int) ->
graph.remove_node(j)


async def _replace_edge_labels_with_nodes(
graph: "networkx.MultiDiGraph",
next_id: int,
interpreter,
nlu_training_data: "TrainingData",
def _replace_edge_labels_with_nodes(
graph: "networkx.MultiDiGraph", next_id: int, nlu_training_data: "TrainingData",
) -> None:
"""User messages are created as edge labels. This removes the labels and
"""Replaces edge labels with nodes.

User messages are created as edge labels. This removes the labels and
creates nodes instead.

The algorithms (e.g. merging) are simpler if the user messages are labels
Expand All @@ -287,11 +276,14 @@ async def _replace_edge_labels_with_nodes(
edges = list(graph.edges(keys=True, data=True))
for s, e, k, d in edges:
if k != EDGE_NONE_LABEL:
if message_generator and d.get("label", k) is not None:
parsed_info = await interpreter.parse(d.get("label", k))
label = d.get("label", k)

if message_generator:
parsed_info = {TEXT: label}
if label.startswith(INTENT_MESSAGE_PREFIX):
parsed_info[INTENT] = {INTENT_NAME_KEY: label[1:]}
wochinge marked this conversation as resolved.
Show resolved Hide resolved

label = message_generator.message_for_data(parsed_info)
else:
label = d.get("label", k)
next_id += 1
graph.remove_edge(s, e, k)
graph.add_node(
Expand Down Expand Up @@ -411,12 +403,11 @@ def _add_message_edge(
)


async def visualize_neighborhood(
def visualize_neighborhood(
current: Optional[List[Event]],
event_sequences: List[List[Event]],
output_file: Optional[Text] = None,
max_history: int = 2,
interpreter=None,
nlu_training_data: Optional["TrainingData"] = None,
should_merge_nodes: bool = True,
max_distance: int = 1,
Expand Down Expand Up @@ -446,10 +437,8 @@ async def visualize_neighborhood(
idx -= 1
break
if isinstance(el, UserUttered):
if not el.intent:
message = await interpreter.parse(el.text)
else:
message = el.parse_data
message = el.parse_data
message[TEXT] = f"{INTENT_MESSAGE_PREFIX}{el.intent_name}"
elif (
isinstance(el, ActionExecuted) and el.action_name != ACTION_LISTEN_NAME
):
Expand Down Expand Up @@ -507,9 +496,7 @@ async def visualize_neighborhood(

if should_merge_nodes:
_merge_equivalent_nodes(graph, max_history)
await _replace_edge_labels_with_nodes(
graph, next_node_idx, interpreter, nlu_training_data
)
_replace_edge_labels_with_nodes(graph, next_node_idx, nlu_training_data)

_remove_auxiliary_nodes(graph, special_node_idx)

Expand Down Expand Up @@ -538,12 +525,11 @@ def _remove_auxiliary_nodes(
ps.add(pred)


async def visualize_stories(
def visualize_stories(
story_steps: List[StoryStep],
domain: Domain,
output_file: Optional[Text],
max_history: int,
interpreter=None, # TODO: Fix this to use processor:
nlu_training_data: Optional["TrainingData"] = None,
should_merge_nodes: bool = True,
fontsize: int = 12,
Expand Down Expand Up @@ -587,12 +573,11 @@ async def visualize_stories(
completed_trackers = g.generate()
event_sequences = [t.events for t in completed_trackers]

graph = await visualize_neighborhood(
graph = visualize_neighborhood(
None,
event_sequences,
output_file,
max_history,
interpreter,
nlu_training_data,
should_merge_nodes,
max_distance=1,
Expand Down
3 changes: 1 addition & 2 deletions tests/cli/test_rasa_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ def test_visualize_help(run: Callable[..., RunResult]):
output = run("visualize", "--help")

help_text = """usage: rasa visualize [-h] [-v] [-vv] [--quiet] [-d DOMAIN] [-s STORIES]
[-c CONFIG] [--out OUT] [--max-history MAX_HISTORY]
[-u NLU]"""
[--out OUT] [--max-history MAX_HISTORY] [-u NLU]"""

lines = help_text.split("\n")
# expected help text lines should appear somewhere in the output
Expand Down
35 changes: 0 additions & 35 deletions tests/core/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from rasa.shared.core.domain import Domain
from rasa.core.policies.ted_policy import TEDPolicy

from rasa.shared.core.training_data.visualization import visualize_stories
import rasa.model_training
import rasa.shared.utils.io

Expand All @@ -22,40 +21,6 @@ def test_load_training_data_reader_not_found_throws(tmp_path: Path, domain: Doma
training.load_data(str(tmp_path), domain)


async def test_story_visualization(domain: Domain, tmp_path: Path):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved them to the other visualization tests

import rasa.shared.core.training_data.loading as core_loading

story_steps = core_loading.load_data_from_resource(
"data/test_yaml_stories/stories.yml", domain
)
out_file = str(tmp_path / "graph.html")
generated_graph = await visualize_stories(
story_steps,
domain,
output_file=out_file,
max_history=3,
should_merge_nodes=False,
)

assert len(generated_graph.nodes()) == 51

assert len(generated_graph.edges()) == 56


async def test_story_visualization_with_merging(domain: Domain):
import rasa.shared.core.training_data.loading as core_loading

story_steps = core_loading.load_data_from_resource(
"data/test_yaml_stories/stories.yml", domain
)
generated_graph = await visualize_stories(
story_steps, domain, output_file=None, max_history=3, should_merge_nodes=True,
)
assert 15 < len(generated_graph.nodes()) < 33

assert 20 < len(generated_graph.edges()) < 33


def test_training_script_with_restart_stories(tmp_path: Path, domain_path: Text):
model_file = rasa.model_training.train_core(
domain_path,
Expand Down
Loading