Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use actual event names for featurization instead of constants #8669

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rasa/architecture_prototype/config_to_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _core_config_to_train_graph_schema(
"uses": StoryToTrainingDataConverter,
"fn": "convert_for_training",
"config": {},
"needs": {"story_graph": "load_stories"},
"needs": {"story_graph": "load_stories", "domain": "load_domain"},
"persistor": False,
}
nlu_train_graph_schema, nlu_outs = _nlu_config_to_train_graph_schema(
Expand Down Expand Up @@ -353,14 +353,14 @@ def _core_config_to_predict_graph_schema(
"uses": StoryToTrainingDataConverter,
"fn": "convert_for_inference",
"config": {},
"needs": {"tracker": "add_parsed_nlu_message"},
"needs": {"tracker": "add_parsed_nlu_message",},
"persistor": False,
},
"create_e2e_lookup": {
"uses": MessageToE2EFeatureConverter,
"fn": "convert",
"config": {},
"needs": {"messages": nlu_e2e_out,},
"needs": {"messages": nlu_e2e_out},
"persistor": False,
},
**nlu_e2e_predict_graph_schema,
Expand Down
30 changes: 25 additions & 5 deletions rasa/architecture_prototype/graph_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rasa.architecture_prototype.interfaces import ComponentPersistorInterface
from rasa.core.channels import CollectingOutputChannel, UserMessage
from rasa.shared.constants import DEFAULT_DATA_PATH, DEFAULT_DOMAIN_PATH
from rasa.shared.core.constants import DEFAULT_ACTION_NAMES
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import ActionExecuted, UserUttered, Event
from rasa.shared.core.generator import TrackerWithCachedStates
Expand Down Expand Up @@ -106,10 +107,28 @@ def generate(


class StoryToTrainingDataConverter(GraphComponent):
def convert_for_training(self, story_graph: StoryGraph) -> TrainingData:
messages = []
def convert_for_training(
self, story_graph: StoryGraph, domain: Domain
) -> TrainingData:
user_actions = [
ActionExecuted(user_action) for user_action in domain.user_actions
]
end_to_end_actions = [
ActionExecuted(action_text=action_text)
for action_text in domain.action_texts
]
default_actions = [
ActionExecuted(default_action) for default_action in DEFAULT_ACTION_NAMES
]
user_events = []
for step in story_graph.story_steps:
messages += self._convert_tracker_to_messages(step.events)
user_events += [
event for event in step.events if isinstance(event, UserUttered)
]

messages = self._convert_tracker_to_messages(
user_actions + end_to_end_actions + default_actions + user_events
)

# Workaround: add at least one end to end message to initialize
# the `CountVectorizer` for e2e. Alternatives: Store information or simply config
Expand Down Expand Up @@ -148,6 +167,7 @@ def convert_for_inference(self, tracker: DialogueStateTracker) -> List[Message]:

class MessageToE2EFeatureConverter(GraphComponent):
"""Collects featurised messages for use by an e2e policy."""

def convert(
self, messages: Union[TrainingData, List[Message]]
) -> Dict[Text, Message]:
Expand All @@ -156,8 +176,8 @@ def convert(
additional_features = {}
for message in messages:
key = next(
k
for k in message.data.keys()
v
for k, v in message.data.items()
if k in {ACTION_NAME, ACTION_TEXT, INTENT, TEXT}
)
additional_features[key] = message
Expand Down
30 changes: 10 additions & 20 deletions rasa/core/featurizers/single_state_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def _create_entity_tag_specs(
]

def prepare_for_training(
self,
domain: Domain,
bilou_tagging: bool = False,
self, domain: Domain, bilou_tagging: bool = False,
) -> None:
"""Gets necessary information for featurization from domain.

Expand Down Expand Up @@ -221,10 +219,12 @@ def _extract_state_features(
e2e_features: Optional[Dict[Text, Message]] = None,
) -> Dict[Text, List["Features"]]:
key = next(
k for k in sub_state.keys() if k in {ACTION_NAME, ACTION_TEXT, INTENT, TEXT}
v
for k, v in sub_state.items()
if k in {ACTION_NAME, ACTION_TEXT, INTENT, TEXT}
)
# TODO: We need a fallback for unexpected user texts during prediction time
parsed_message = e2e_features[key]
assert parsed_message

# remove entities from possible attributes
attributes = set(
Expand All @@ -247,9 +247,7 @@ def _extract_state_features(
return output

def encode_state(
self,
state: State,
e2e_features: Optional[Dict[Text, Message]] = None,
self, state: State, e2e_features: Optional[Dict[Text, Message]] = None,
) -> Dict[Text, List["Features"]]:
"""Encode the given state.

Expand Down Expand Up @@ -289,9 +287,7 @@ def encode_state(
return state_features

def encode_entities(
self,
entity_data: Dict[Text, Any],
bilou_tagging: bool = False,
self, entity_data: Dict[Text, Any], bilou_tagging: bool = False,
) -> Dict[Text, List["Features"]]:
"""Encode the given entity data.

Expand Down Expand Up @@ -335,9 +331,7 @@ def encode_entities(
}

def _encode_action(
self,
action: Text,
e2e_features: Optional[Dict[Text, Message]] = None,
self, action: Text, e2e_features: Optional[Dict[Text, Message]] = None,
) -> Dict[Text, List["Features"]]:
if action in self.action_texts:
action_as_sub_state = {ACTION_TEXT: action}
Expand All @@ -349,9 +343,7 @@ def _encode_action(
)

def encode_all_actions(
self,
domain: Domain,
e2e_features: Optional[Dict[Text, Message]] = None,
self, domain: Domain, e2e_features: Optional[Dict[Text, Message]] = None,
) -> List[Dict[Text, List["Features"]]]:
"""Encode all action from the domain.

Expand Down Expand Up @@ -381,9 +373,7 @@ def __init__(self) -> None:
)

def _extract_state_features(
self,
sub_state: SubState,
sparse: bool = False,
self, sub_state: SubState, sparse: bool = False,
) -> Dict[Text, List["Features"]]:
# create a special method that doesn't use passed interpreter
name_attribute = self._get_name_attribute(set(sub_state.keys()))
Expand Down
2 changes: 1 addition & 1 deletion tests/architecture_prototype/graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"uses": StoryToTrainingDataConverter,
"fn": "convert_for_training",
"config": {},
"needs": {"story_graph": "load_stories"},
"needs": {"story_graph": "load_stories", "domain": "load_domain"},
"persistor": False,
},
"process_core_WhitespaceTokenizer_0": {
Expand Down
Binary file modified tests/architecture_prototype/train_graph_schema.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.