Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the EvaluationStore to match prediction and target entities #6419

Merged
Merged
109 changes: 93 additions & 16 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
ENTITY_ATTRIBUTE_TYPE,
)
from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY
from rasa.core.utils import pad_lists_to_size
from rasa.core.events import ActionExecuted, UserUttered
from rasa.nlu.training_data.formats.markdown import MarkdownWriter
from rasa.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -105,26 +104,104 @@ def has_prediction_target_mismatch(self) -> bool:

def serialise(self) -> Tuple[List[Text], List[Text]]:
davidezanella marked this conversation as resolved.
Show resolved Hide resolved
"""Turn targets and predictions to lists of equal size for sklearn."""
# sklearn does not cope with lists of unequal size, nor None values
davidezanella marked this conversation as resolved.
Show resolved Hide resolved

targets = (
self.action_targets
+ self.intent_targets
+ [
TrainingDataWriter.generate_entity(gold.get("text"), gold)
for gold in self.entity_targets
]
texts = sorted(
list(
set(
[e.get("text") for e in self.entity_targets]
+ [e.get("text") for e in self.entity_predictions]
)
)
)

entity_targets_fixed = []
entity_predictions_fixed = []
davidezanella marked this conversation as resolved.
Show resolved Hide resolved

for text in texts:
# sort the entities of this sentence to compare them directly
entity_targets = sorted(
filter(lambda x: x.get("text") == text, self.entity_targets),
key=lambda x: x.get("start"),
)
entity_predictions = sorted(
filter(lambda x: x.get("text") == text, self.entity_predictions),
key=lambda x: x.get("start"),
)

i_pred = 0
i_target = 0

def compare_entities():
davidezanella marked this conversation as resolved.
Show resolved Hide resolved
"""
Compare the current predicted and target entities and decide which one
comes first. If the predicted entity comes first it returns -1,
while it returns 1 if the target entity comes first.
If target and predicted are aligned it returns 0
"""
pred = None
target = None
if i_pred < len(entity_predictions):
pred = entity_predictions[i_pred]
if i_target < len(entity_targets):
target = entity_targets[i_target]
if target and pred:
if pred.get("start") < target.get("start"):
return -1
elif target.get("start") < pred.get("start"):
return 1
else:
if pred.get("end") < target.get("end"):
return -1
elif target.get("end") < pred.get("end"):
return 1
else:
return 0
return 1 if target else -1

while i_pred < len(entity_predictions) or i_target < len(entity_targets):
cmp = compare_entities()
if cmp == -1: # predicted comes first
entity_predictions_fixed.append(
TrainingDataWriter.generate_entity(
entity_predictions[i_pred].get("text"),
entity_predictions[i_pred],
)
davidezanella marked this conversation as resolved.
Show resolved Hide resolved
)
entity_targets_fixed.append("None")
i_pred += 1
elif cmp == 1: # target entity comes first
entity_targets_fixed.append(
TrainingDataWriter.generate_entity(
entity_targets[i_target].get("text"),
entity_targets[i_target],
)
)
entity_predictions_fixed.append("None")
i_target += 1
else: # target and predicted entity are aligned
entity_predictions_fixed.append(
TrainingDataWriter.generate_entity(
entity_predictions[i_pred].get("text"),
entity_predictions[i_pred],
)
)
entity_targets_fixed.append(
TrainingDataWriter.generate_entity(
entity_targets[i_target].get("text"),
entity_targets[i_target],
)
)
i_pred += 1
i_target += 1

targets = self.action_targets + self.intent_targets + entity_targets_fixed

predictions = (
self.action_predictions
+ self.intent_predictions
+ [
TrainingDataWriter.generate_entity(predicted.get("text"), predicted)
for predicted in self.entity_predictions
]
self.action_predictions + self.intent_predictions + entity_predictions_fixed
)

# sklearn does not cope with lists of unequal size, nor None values
return pad_lists_to_size(targets, predictions, padding_value="None")
return targets, predictions


class WronglyPredictedAction(ActionExecuted):
Expand Down
149 changes: 149 additions & 0 deletions tests/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,152 @@ async def test_e2e_warning_if_no_nlu_model(
agent_load.assert_called_once()
_, _, kwargs = agent_load.mock_calls[0]
assert isinstance(kwargs["interpreter"], RegexInterpreter)


@pytest.mark.parametrize(
"entity_predictions,entity_targets",
[
(
[{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"}],
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
),
(
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
),
(
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
[{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
}
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
),
],
)
def test_evaluation_store_serialise(entity_predictions, entity_targets):
from rasa.core.test import EvaluationStore
from rasa.nlu.training_data.formats.readerwriter import TrainingDataWriter

store = EvaluationStore(
entity_predictions=entity_predictions, entity_targets=entity_targets
)

targets, predictions = store.serialise()

assert len(targets) == len(predictions)

i_pred = 0
i_target = 0
for i, prediction in enumerate(predictions):
target = targets[i]
if prediction != "None" and target != "None":
predicted = entity_predictions[i_pred]
assert prediction == TrainingDataWriter.generate_entity(
predicted.get("text"), predicted
)
assert predicted.get("start") == entity_targets[i_target].get("start")
assert predicted.get("end") == entity_targets[i_target].get("end")

if prediction != "None":
i_pred += 1
if target != "None":
i_target += 1