Skip to content

Commit

Permalink
Merge pull request #6355 from RasaHQ/fallback-classifier-disambiguate
Browse files Browse the repository at this point in the history
add `ambiguity_threshold` param to `FallbackClassifier`
  • Loading branch information
rasabot authored Aug 10, 2020
2 parents d600261 + 6196ba7 commit 3e781a9
Show file tree
Hide file tree
Showing 23 changed files with 299 additions and 116 deletions.
5 changes: 3 additions & 2 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OPEN_UTTERANCE_PREDICTION_KEY,
RESPONSE_SELECTOR_PROPERTY_NAME,
INTENT_RANKING_KEY,
INTENT_NAME_KEY,
)

from rasa.core.events import (
Expand Down Expand Up @@ -722,14 +723,14 @@ async def run(
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
intent_to_affirm = tracker.latest_message.intent.get("name")
intent_to_affirm = tracker.latest_message.intent.get(INTENT_NAME_KEY)

intent_ranking = tracker.latest_message.intent.get(INTENT_RANKING_KEY, [])
if (
intent_to_affirm == DEFAULT_NLU_FALLBACK_INTENT_NAME
and len(intent_ranking) > 1
):
intent_to_affirm = intent_ranking[1]["name"]
intent_to_affirm = intent_ranking[1][INTENT_NAME_KEY]

affirmation_message = f"Did you mean '{intent_to_affirm}'?"

Expand Down
11 changes: 6 additions & 5 deletions rasa/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ruamel.yaml import YAMLError

import rasa.core.constants
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import (
raise_warning,
lazy_property,
Expand Down Expand Up @@ -675,7 +676,7 @@ def get_parsing_states(self, tracker: "DialogueStateTracker") -> Dict[Text, floa
if not latest_message:
return state_dict

intent_name = latest_message.intent.get("name")
intent_name = latest_message.intent.get(INTENT_NAME_KEY)

if intent_name:
for entity_name in self._get_featurized_entities(latest_message):
Expand All @@ -699,18 +700,18 @@ def get_parsing_states(self, tracker: "DialogueStateTracker") -> Dict[Text, floa

if "intent_ranking" in latest_message.parse_data:
for intent in latest_message.parse_data["intent_ranking"]:
if intent.get("name"):
intent_id = "intent_{}".format(intent["name"])
if intent.get(INTENT_NAME_KEY):
intent_id = "intent_{}".format(intent[INTENT_NAME_KEY])
state_dict[intent_id] = intent["confidence"]

elif intent_name:
intent_id = "intent_{}".format(latest_message.intent["name"])
intent_id = "intent_{}".format(latest_message.intent[INTENT_NAME_KEY])
state_dict[intent_id] = latest_message.intent.get("confidence", 1.0)

return state_dict

def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]:
intent_name = latest_message.intent.get("name")
intent_name = latest_message.intent.get(INTENT_NAME_KEY)
intent_config = self.intent_config(intent_name)
entities = latest_message.entities
entity_names = {
Expand Down
17 changes: 11 additions & 6 deletions rasa/core/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
EXTERNAL_MESSAGE_PREFIX,
ACTION_NAME_SENDER_ID_CONNECTOR_STR,
)
from rasa.nlu.constants import INTENT_NAME_KEY

if typing.TYPE_CHECKING:
from rasa.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -258,7 +259,11 @@ def _from_parse_data(

def __hash__(self) -> int:
return hash(
(self.text, self.intent.get("name"), jsonpickle.encode(self.entities))
(
self.text,
self.intent.get(INTENT_NAME_KEY),
jsonpickle.encode(self.entities),
)
)

def __eq__(self, other) -> bool:
Expand All @@ -267,11 +272,11 @@ def __eq__(self, other) -> bool:
else:
return (
self.text,
self.intent.get("name"),
self.intent.get(INTENT_NAME_KEY),
[jsonpickle.encode(ent) for ent in self.entities],
) == (
other.text,
other.intent.get("name"),
other.intent.get(INTENT_NAME_KEY),
[jsonpickle.encode(ent) for ent in other.entities],
)

Expand Down Expand Up @@ -324,11 +329,11 @@ def as_story_string(self, e2e: bool = False) -> Text:
ent_string = ""

parse_string = "{intent}{entities}".format(
intent=self.intent.get("name", ""), entities=ent_string
intent=self.intent.get(INTENT_NAME_KEY, ""), entities=ent_string
)
if e2e:
message = md_format_message(self.text, self.intent, self.entities)
return "{}: {}".format(self.intent.get("name"), message)
return "{}: {}".format(self.intent.get(INTENT_NAME_KEY), message)
else:
return parse_string
else:
Expand All @@ -344,7 +349,7 @@ def create_external(
) -> "UserUttered":
return UserUttered(
text=f"{EXTERNAL_MESSAGE_PREFIX}{intent_name}",
intent={"name": intent_name},
intent={INTENT_NAME_KEY: intent_name},
metadata={IS_EXTERNAL: True},
entities=entity_list or [],
)
Expand Down
7 changes: 4 additions & 3 deletions rasa/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rasa.core import constants
from rasa.core.trackers import DialogueStateTracker
from rasa.core.constants import INTENT_MESSAGE_PREFIX
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import raise_warning, class_from_module_path
from rasa.utils.endpoints import EndpointConfig

Expand Down Expand Up @@ -171,8 +172,8 @@ def synchronous_parse(

return {
"text": message_text,
"intent": {"name": intent, "confidence": confidence},
"intent_ranking": [{"name": intent, "confidence": confidence}],
"intent": {INTENT_NAME_KEY: intent, "confidence": confidence},
"intent_ranking": [{INTENT_NAME_KEY: intent, "confidence": confidence}],
"entities": entities,
}

Expand All @@ -195,7 +196,7 @@ async def parse(
Return a default value if the parsing of the text failed."""

default_return = {
"intent": {"name": "", "confidence": 0.0},
"intent": {INTENT_NAME_KEY: "", "confidence": 0.0},
"entities": [],
"text": "",
}
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/policies/mapping_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from rasa.constants import DOCS_URL_POLICIES, DOCS_URL_MIGRATION_GUIDE
import rasa.utils.io
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils import common as common_utils

from rasa.core.actions.action import (
Expand Down Expand Up @@ -108,7 +109,7 @@ def predict_action_probabilities(

result = self._default_predictions(domain)

intent = tracker.latest_message.intent.get("name")
intent = tracker.latest_message.intent.get(INTENT_NAME_KEY)
if intent == USER_INTENT_RESTART:
action = ACTION_RESTART_NAME
elif intent == USER_INTENT_BACK:
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/policies/two_stage_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from rasa.core.policies.policy import confidence_scores_for
from rasa.core.trackers import DialogueStateTracker
from rasa.core.constants import FALLBACK_POLICY_PRIORITY
from rasa.nlu.constants import INTENT_NAME_KEY

if typing.TYPE_CHECKING:
from rasa.core.policies.ensemble import PolicyEnsemble
Expand Down Expand Up @@ -121,7 +122,7 @@ def predict_action_probabilities(
"""Predicts the next action if NLU confidence is low."""

nlu_data = tracker.latest_message.parse_data
last_intent_name = nlu_data["intent"].get("name", None)
last_intent_name = nlu_data["intent"].get(INTENT_NAME_KEY, None)
should_nlu_fallback = self.should_nlu_fallback(
nlu_data, tracker.latest_action_name
)
Expand Down
5 changes: 3 additions & 2 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from rasa.core.policies.ensemble import PolicyEnsemble
from rasa.core.tracker_store import TrackerStore
from rasa.core.trackers import DialogueStateTracker, EventVerbosity
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import raise_warning
from rasa.utils.endpoints import EndpointConfig

Expand Down Expand Up @@ -436,7 +437,7 @@ def _check_for_unseen_features(self, parse_data: Dict[Text, Any]) -> None:
if not self.domain or self.domain.is_empty():
return

intent = parse_data["intent"]["name"]
intent = parse_data["intent"][INTENT_NAME_KEY]
if intent:
known_intents = self.domain.intents + DEFAULT_INTENTS
if intent not in known_intents:
Expand Down Expand Up @@ -520,7 +521,7 @@ async def _handle_message_with_tracker(
def _should_handle_message(tracker: DialogueStateTracker):
return (
not tracker.is_paused()
or tracker.latest_message.intent.get("name") == USER_INTENT_RESTART
or tracker.latest_message.intent.get(INTENT_NAME_KEY) == USER_INTENT_RESTART
)

def is_action_limit_reached(
Expand Down
5 changes: 4 additions & 1 deletion rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from rasa.core.events import SessionStarted
from rasa.core.trackers import ActionExecuted, DialogueStateTracker, EventVerbosity
import rasa.cli.utils as rasa_cli_utils
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import class_from_module_path, raise_warning, arguments_of
from rasa.utils.endpoints import EndpointConfig
import sqlalchemy as sa
Expand Down Expand Up @@ -910,7 +911,9 @@ def save(self, tracker: DialogueStateTracker) -> None:

for event in events:
data = event.as_dict()
intent = data.get("parse_data", {}).get("intent", {}).get("name")
intent = (
data.get("parse_data", {}).get("intent", {}).get(INTENT_NAME_KEY)
)
action = data.get("name")
timestamp = data.get("timestamp")

Expand Down
35 changes: 23 additions & 12 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aiohttp import ClientError
from colorclass import Color

from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.nlu.training_data.loading import MARKDOWN, RASA
from sanic import Sanic, response
from sanic.exceptions import NotFound
Expand Down Expand Up @@ -322,12 +323,19 @@ def _selection_choices_from_intent_prediction(
) -> List[Dict[Text, Any]]:
""""Given a list of ML predictions create a UI choice list."""

sorted_intents = sorted(predictions, key=lambda k: (-k["confidence"], k["name"]))
sorted_intents = sorted(
predictions, key=lambda k: (-k["confidence"], k[INTENT_NAME_KEY])
)

choices = []
for p in sorted_intents:
name_with_confidence = f'{p.get("confidence"):03.2f} {p.get("name"):40}'
choice = {"name": name_with_confidence, "value": p.get("name")}
name_with_confidence = (
f'{p.get("confidence"):03.2f} {p.get(INTENT_NAME_KEY):40}'
)
choice = {
INTENT_NAME_KEY: name_with_confidence,
"value": p.get(INTENT_NAME_KEY),
}
choices.append(choice)

return choices
Expand Down Expand Up @@ -416,15 +424,15 @@ async def _request_intent_from_user(

predictions = latest_message.get("parse_data", {}).get("intent_ranking", [])

predicted_intents = {p["name"] for p in predictions}
predicted_intents = {p[INTENT_NAME_KEY] for p in predictions}

for i in intents:
if i not in predicted_intents:
predictions.append({"name": i, "confidence": 0.0})
predictions.append({INTENT_NAME_KEY: i, "confidence": 0.0})

# convert intents to ui list and add <other> as a free text alternative
choices = [
{"name": "<create_new_intent>", "value": OTHER_INTENT}
{INTENT_NAME_KEY: "<create_new_intent>", "value": OTHER_INTENT}
] + _selection_choices_from_intent_prediction(predictions)

intent_name = await _request_selection_from_intents(
Expand All @@ -433,11 +441,12 @@ async def _request_intent_from_user(

if intent_name == OTHER_INTENT:
intent_name = await _request_free_text_intent(conversation_id, endpoint)
selected_intent = {"name": intent_name, "confidence": 1.0}
selected_intent = {INTENT_NAME_KEY: intent_name, "confidence": 1.0}
else:
# returns the selected intent with the original probability value
selected_intent = next(
(x for x in predictions if x["name"] == intent_name), {"name": None}
(x for x in predictions if x[INTENT_NAME_KEY] == intent_name),
{INTENT_NAME_KEY: None},
)

return selected_intent
Expand Down Expand Up @@ -479,7 +488,7 @@ def colored(txt: Text, color: Text) -> Text:

def format_user_msg(user_event: UserUttered, max_width: int) -> Text:
intent = user_event.intent or {}
intent_name = intent.get("name", "")
intent_name = intent.get(INTENT_NAME_KEY, "")
_confidence = intent.get("confidence", 1.0)
_md = _as_md_message(user_event.parse_data)

Expand Down Expand Up @@ -745,7 +754,9 @@ def _collect_messages(events: List[Dict[Text, Any]]) -> List[Message]:
if event.get("event") == UserUttered.type_name:
data = event.get("parse_data", {})
rasa_nlu_training_data_utils.remove_untrainable_entities_from(data)
msg = Message.build(data["text"], data["intent"]["name"], data["entities"])
msg = Message.build(
data["text"], data["intent"][INTENT_NAME_KEY], data["entities"]
)
messages.append(msg)
elif event.get("event") == UserUtteranceReverted.type_name and messages:
messages.pop() # user corrected the nlu, remove incorrect example
Expand Down Expand Up @@ -1117,7 +1128,7 @@ def _validate_user_regex(latest_message: Dict[Text, Any], intents: List[Text]) -
`/greet`. Return `True` if the intent is a known one."""

parse_data = latest_message.get("parse_data", {})
intent = parse_data.get("intent", {}).get("name")
intent = parse_data.get("intent", {}).get(INTENT_NAME_KEY)

if intent in intents:
return True
Expand All @@ -1134,7 +1145,7 @@ async def _validate_user_text(

parse_data = latest_message.get("parse_data", {})
text = _as_md_message(parse_data)
intent = parse_data.get("intent", {}).get("name")
intent = parse_data.get("intent", {}).get(INTENT_NAME_KEY)
entities = parse_data.get("entities", [])
if entities:
message = (
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/training/story_reader/markdown_story_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rasa.core.training.story_reader.story_reader import StoryReader
from rasa.core.training.structures import StoryStep, FORM_PREFIX
from rasa.data import MARKDOWN_FILE_EXTENSION
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import raise_warning

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -209,7 +210,7 @@ async def _parse_message(self, message: Text, line_num: int) -> UserUttered:
utterance = UserUttered(
message, parse_data.get("intent"), parse_data.get("entities"), parse_data
)
intent_name = utterance.intent.get("name")
intent_name = utterance.intent.get(INTENT_NAME_KEY)
if self.domain and intent_name not in self.domain.intents:
raise_warning(
f"Found unknown intent '{intent_name}' on line {line_num}. "
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/training/story_reader/yaml_story_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rasa.core.training.story_reader.story_reader import StoryReader
from rasa.core.training.structures import StoryStep
from rasa.data import YAML_FILE_EXTENSIONS
from rasa.nlu.constants import INTENT_NAME_KEY

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,7 +250,7 @@ def _parse_user_utterance(self, step: Dict[Text, Any]) -> None:
self.current_step_builder.add_user_messages([utterance])

def _validate_that_utterance_is_in_domain(self, utterance: UserUttered) -> None:
intent_name = utterance.intent.get("name")
intent_name = utterance.intent.get(INTENT_NAME_KEY)

if not self.domain:
logger.debug(
Expand Down
Loading

0 comments on commit 3e781a9

Please sign in to comment.