Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the EvaluationStore to match prediction and target entities #6419

Merged
Merged
1 change: 1 addition & 0 deletions changelog/6419.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug in the `serialise` method of the `EvaluationStore` class which resulted in a wrong end-to-end evaluation of the predicted entities.
111 changes: 95 additions & 16 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
INTENT,
)
from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY
from rasa.core.utils import pad_lists_to_size
from rasa.core.events import ActionExecuted, UserUttered
from rasa.core.trackers import DialogueStateTracker
from rasa.nlu.training_data.formats.readerwriter import TrainingDataWriter
Expand Down Expand Up @@ -106,28 +105,108 @@ def has_prediction_target_mismatch(self) -> bool:
or self.action_predictions != self.action_targets
)

@staticmethod
def _compare_entities(
entity_predictions: List[Dict[Text, Any]],
entity_targets: List[Dict[Text, Any]],
i_pred: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Rename to index_prediction.

i_target: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Rename to index_target.

) -> int:
"""
Compare the current predicted and target entities and decide which one
comes first. If the predicted entity comes first it returns -1,
while it returns 1 if the target entity comes first.
If target and predicted are aligned it returns 0
"""
pred = None
target = None
if i_pred < len(entity_predictions):
pred = entity_predictions[i_pred]
if i_target < len(entity_targets):
target = entity_targets[i_target]
if target and pred:
# Check which entity has the lower "start" value
if pred.get("start") < target.get("start"):
return -1
elif target.get("start") < pred.get("start"):
return 1
else:
# Since both have the same "start" values,
# check which one has the lower "end" value
if pred.get("end") < target.get("end"):
return -1
elif target.get("end") < pred.get("end"):
return 1
else:
# The entities have the same "start" and "end" values
return 0
return 1 if target else -1

@staticmethod
def _generate_entity_training_data(entity: Dict[Text, Any]) -> Text:
return TrainingDataWriter.generate_entity(entity.get("text"), entity)

def serialise(self) -> Tuple[List[Text], List[Text]]:
davidezanella marked this conversation as resolved.
Show resolved Hide resolved
"""Turn targets and predictions to lists of equal size for sklearn."""

targets = (
self.action_targets
+ self.intent_targets
+ [
TrainingDataWriter.generate_entity(gold.get("text"), gold)
for gold in self.entity_targets
]
texts = sorted(
list(
set(
[e.get("text") for e in self.entity_targets]
+ [e.get("text") for e in self.entity_predictions]
)
)
)

aligned_entity_targets = []
aligned_entity_predictions = []

for text in texts:
# sort the entities of this sentence to compare them directly
entity_targets = sorted(
filter(lambda x: x.get("text") == text, self.entity_targets),
key=lambda x: x.get("start"),
)
entity_predictions = sorted(
filter(lambda x: x.get("text") == text, self.entity_predictions),
key=lambda x: x.get("start"),
)

i_pred, i_target = 0, 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: We try to not use abbreviations in names. I would rename this to index_prediction and index_target.


while i_pred < len(entity_predictions) or i_target < len(entity_targets):
cmp = self._compare_entities(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: I would rename this to comparison_result.

entity_predictions, entity_targets, i_pred, i_target
)
if cmp == -1: # predicted comes first
aligned_entity_predictions.append(
self._generate_entity_training_data(entity_predictions[i_pred])
)
aligned_entity_targets.append("None")
i_pred += 1
elif cmp == 1: # target entity comes first
aligned_entity_targets.append(
self._generate_entity_training_data(entity_targets[i_target])
)
aligned_entity_predictions.append("None")
i_target += 1
else: # target and predicted entity are aligned
aligned_entity_predictions.append(
self._generate_entity_training_data(entity_predictions[i_pred])
)
aligned_entity_targets.append(
self._generate_entity_training_data(entity_targets[i_target])
)
i_pred += 1
i_target += 1

targets = self.action_targets + self.intent_targets + aligned_entity_targets

predictions = (
self.action_predictions
+ self.intent_predictions
+ [
TrainingDataWriter.generate_entity(predicted.get("text"), predicted)
for predicted in self.entity_predictions
]
+ aligned_entity_predictions
)

# sklearn does not cope with lists of unequal size, nor None values
return pad_lists_to_size(targets, predictions, padding_value="None")
return targets, predictions


class WronglyPredictedAction(ActionExecuted):
Expand Down
148 changes: 148 additions & 0 deletions tests/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,154 @@ def test_log_failed_stories(tmp_path: Path):
assert len(dump.split("\n")) == 1


@pytest.mark.parametrize(
"entity_predictions,entity_targets",
[
(
[{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"}],
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
),
(
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
),
(
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
[{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
}
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
),
],
)
def test_evaluation_store_serialise(entity_predictions, entity_targets):
from rasa.nlu.training_data.formats.readerwriter import TrainingDataWriter

store = EvaluationStore(
entity_predictions=entity_predictions, entity_targets=entity_targets
)

targets, predictions = store.serialise()

assert len(targets) == len(predictions)

i_pred = 0
i_target = 0
for i, prediction in enumerate(predictions):
target = targets[i]
if prediction != "None" and target != "None":
predicted = entity_predictions[i_pred]
assert prediction == TrainingDataWriter.generate_entity(
predicted.get("text"), predicted
)
assert predicted.get("start") == entity_targets[i_target].get("start")
assert predicted.get("end") == entity_targets[i_target].get("end")

if prediction != "None":
i_pred += 1
if target != "None":
i_target += 1


async def test_test_does_not_use_rules(tmp_path: Path, default_agent: Agent):
from rasa.core.test import _generate_trackers

Expand Down