-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the EvaluationStore to match prediction and target entities #6419
Changes from all commits
c052e5b
cba8a5c
205a2ee
9c62564
72fba81
2e4641b
465a2f8
a790f87
25e3d48
b2b7301
a07b69f
f9f70bc
fa8b564
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Fixed a bug in the `serialise` method of the `EvaluationStore` class which resulted in a wrong end-to-end evaluation of the predicted entities. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,6 @@ | |
INTENT, | ||
) | ||
from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY | ||
from rasa.core.utils import pad_lists_to_size | ||
from rasa.core.events import ActionExecuted, UserUttered | ||
from rasa.core.trackers import DialogueStateTracker | ||
from rasa.nlu.training_data.formats.readerwriter import TrainingDataWriter | ||
|
@@ -106,28 +105,108 @@ def has_prediction_target_mismatch(self) -> bool: | |
or self.action_predictions != self.action_targets | ||
) | ||
|
||
@staticmethod | ||
def _compare_entities( | ||
entity_predictions: List[Dict[Text, Any]], | ||
entity_targets: List[Dict[Text, Any]], | ||
i_pred: int, | ||
i_target: int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: Rename to |
||
) -> int: | ||
""" | ||
Compare the current predicted and target entities and decide which one | ||
comes first. If the predicted entity comes first it returns -1, | ||
while it returns 1 if the target entity comes first. | ||
If target and predicted are aligned it returns 0 | ||
""" | ||
pred = None | ||
target = None | ||
if i_pred < len(entity_predictions): | ||
pred = entity_predictions[i_pred] | ||
if i_target < len(entity_targets): | ||
target = entity_targets[i_target] | ||
if target and pred: | ||
# Check which entity has the lower "start" value | ||
if pred.get("start") < target.get("start"): | ||
return -1 | ||
elif target.get("start") < pred.get("start"): | ||
return 1 | ||
else: | ||
# Since both have the same "start" values, | ||
# check which one has the lower "end" value | ||
if pred.get("end") < target.get("end"): | ||
return -1 | ||
elif target.get("end") < pred.get("end"): | ||
return 1 | ||
else: | ||
# The entities have the same "start" and "end" values | ||
return 0 | ||
return 1 if target else -1 | ||
|
||
@staticmethod | ||
def _generate_entity_training_data(entity: Dict[Text, Any]) -> Text: | ||
return TrainingDataWriter.generate_entity(entity.get("text"), entity) | ||
|
||
def serialise(self) -> Tuple[List[Text], List[Text]]: | ||
davidezanella marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Turn targets and predictions to lists of equal size for sklearn.""" | ||
|
||
targets = ( | ||
self.action_targets | ||
+ self.intent_targets | ||
+ [ | ||
TrainingDataWriter.generate_entity(gold.get("text"), gold) | ||
for gold in self.entity_targets | ||
] | ||
texts = sorted( | ||
list( | ||
set( | ||
[e.get("text") for e in self.entity_targets] | ||
+ [e.get("text") for e in self.entity_predictions] | ||
) | ||
) | ||
) | ||
|
||
aligned_entity_targets = [] | ||
aligned_entity_predictions = [] | ||
|
||
for text in texts: | ||
# sort the entities of this sentence to compare them directly | ||
entity_targets = sorted( | ||
filter(lambda x: x.get("text") == text, self.entity_targets), | ||
key=lambda x: x.get("start"), | ||
) | ||
entity_predictions = sorted( | ||
filter(lambda x: x.get("text") == text, self.entity_predictions), | ||
key=lambda x: x.get("start"), | ||
) | ||
|
||
i_pred, i_target = 0, 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: We try to not use abbreviations in names. I would rename this to |
||
|
||
while i_pred < len(entity_predictions) or i_target < len(entity_targets): | ||
cmp = self._compare_entities( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: I would rename this to |
||
entity_predictions, entity_targets, i_pred, i_target | ||
) | ||
if cmp == -1: # predicted comes first | ||
aligned_entity_predictions.append( | ||
self._generate_entity_training_data(entity_predictions[i_pred]) | ||
) | ||
aligned_entity_targets.append("None") | ||
i_pred += 1 | ||
elif cmp == 1: # target entity comes first | ||
aligned_entity_targets.append( | ||
self._generate_entity_training_data(entity_targets[i_target]) | ||
) | ||
aligned_entity_predictions.append("None") | ||
i_target += 1 | ||
else: # target and predicted entity are aligned | ||
aligned_entity_predictions.append( | ||
self._generate_entity_training_data(entity_predictions[i_pred]) | ||
) | ||
aligned_entity_targets.append( | ||
self._generate_entity_training_data(entity_targets[i_target]) | ||
) | ||
i_pred += 1 | ||
i_target += 1 | ||
|
||
targets = self.action_targets + self.intent_targets + aligned_entity_targets | ||
|
||
predictions = ( | ||
self.action_predictions | ||
+ self.intent_predictions | ||
+ [ | ||
TrainingDataWriter.generate_entity(predicted.get("text"), predicted) | ||
for predicted in self.entity_predictions | ||
] | ||
+ aligned_entity_predictions | ||
) | ||
|
||
# sklearn does not cope with lists of unequal size, nor None values | ||
return pad_lists_to_size(targets, predictions, padding_value="None") | ||
return targets, predictions | ||
|
||
|
||
class WronglyPredictedAction(ActionExecuted): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: Rename to
index_prediction
.