Skip to content

Commit

Permalink
Merge pull request #6419 from davidezanella/fix-EvaluationStore-predi…
Browse files Browse the repository at this point in the history
…cted-entities

Fix the EvaluationStore to match prediction and target entities
  • Loading branch information
degiz authored Sep 2, 2020
2 parents fca3951 + fa8b564 commit be72c78
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 16 deletions.
1 change: 1 addition & 0 deletions changelog/6419.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug in the `serialise` method of the `EvaluationStore` class which resulted in a wrong end-to-end evaluation of the predicted entities.
111 changes: 95 additions & 16 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
INTENT,
)
from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY
from rasa.core.utils import pad_lists_to_size
from rasa.core.events import ActionExecuted, UserUttered
from rasa.core.trackers import DialogueStateTracker
from rasa.nlu.training_data.formats.readerwriter import TrainingDataWriter
Expand Down Expand Up @@ -106,28 +105,108 @@ def has_prediction_target_mismatch(self) -> bool:
or self.action_predictions != self.action_targets
)

@staticmethod
def _compare_entities(
entity_predictions: List[Dict[Text, Any]],
entity_targets: List[Dict[Text, Any]],
i_pred: int,
i_target: int,
) -> int:
"""
Compare the current predicted and target entities and decide which one
comes first. If the predicted entity comes first it returns -1,
while it returns 1 if the target entity comes first.
If target and predicted are aligned it returns 0
"""
pred = None
target = None
if i_pred < len(entity_predictions):
pred = entity_predictions[i_pred]
if i_target < len(entity_targets):
target = entity_targets[i_target]
if target and pred:
# Check which entity has the lower "start" value
if pred.get("start") < target.get("start"):
return -1
elif target.get("start") < pred.get("start"):
return 1
else:
# Since both have the same "start" values,
# check which one has the lower "end" value
if pred.get("end") < target.get("end"):
return -1
elif target.get("end") < pred.get("end"):
return 1
else:
# The entities have the same "start" and "end" values
return 0
return 1 if target else -1

@staticmethod
def _generate_entity_training_data(entity: Dict[Text, Any]) -> Text:
return TrainingDataWriter.generate_entity(entity.get("text"), entity)

def serialise(self) -> Tuple[List[Text], List[Text]]:
"""Turn targets and predictions to lists of equal size for sklearn."""

targets = (
self.action_targets
+ self.intent_targets
+ [
TrainingDataWriter.generate_entity(gold.get("text"), gold)
for gold in self.entity_targets
]
texts = sorted(
list(
set(
[e.get("text") for e in self.entity_targets]
+ [e.get("text") for e in self.entity_predictions]
)
)
)

aligned_entity_targets = []
aligned_entity_predictions = []

for text in texts:
# sort the entities of this sentence to compare them directly
entity_targets = sorted(
filter(lambda x: x.get("text") == text, self.entity_targets),
key=lambda x: x.get("start"),
)
entity_predictions = sorted(
filter(lambda x: x.get("text") == text, self.entity_predictions),
key=lambda x: x.get("start"),
)

i_pred, i_target = 0, 0

while i_pred < len(entity_predictions) or i_target < len(entity_targets):
cmp = self._compare_entities(
entity_predictions, entity_targets, i_pred, i_target
)
if cmp == -1: # predicted comes first
aligned_entity_predictions.append(
self._generate_entity_training_data(entity_predictions[i_pred])
)
aligned_entity_targets.append("None")
i_pred += 1
elif cmp == 1: # target entity comes first
aligned_entity_targets.append(
self._generate_entity_training_data(entity_targets[i_target])
)
aligned_entity_predictions.append("None")
i_target += 1
else: # target and predicted entity are aligned
aligned_entity_predictions.append(
self._generate_entity_training_data(entity_predictions[i_pred])
)
aligned_entity_targets.append(
self._generate_entity_training_data(entity_targets[i_target])
)
i_pred += 1
i_target += 1

targets = self.action_targets + self.intent_targets + aligned_entity_targets

predictions = (
self.action_predictions
+ self.intent_predictions
+ [
TrainingDataWriter.generate_entity(predicted.get("text"), predicted)
for predicted in self.entity_predictions
]
+ aligned_entity_predictions
)

# sklearn does not cope with lists of unequal size, nor None values
return pad_lists_to_size(targets, predictions, padding_value="None")
return targets, predictions


class WronglyPredictedAction(ActionExecuted):
Expand Down
148 changes: 148 additions & 0 deletions tests/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,154 @@ def test_log_failed_stories(tmp_path: Path):
assert len(dump.split("\n")) == 1


@pytest.mark.parametrize(
"entity_predictions,entity_targets",
[
(
[{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"}],
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
),
(
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
),
(
[
{"text": "hi, how are you", "start": 0, "end": 2, "entity": "bb"},
{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},
],
[{"text": "hi, how are you", "start": 4, "end": 7, "entity": "aa"},],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
),
(
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
}
],
[
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 0,
"end": 5,
"entity": "person",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 22,
"end": 28,
"entity": "city",
},
{
"text": "Tanja is currently in Munich, but she lives in Berlin",
"start": 47,
"end": 53,
"entity": "city",
},
],
),
],
)
def test_evaluation_store_serialise(entity_predictions, entity_targets):
from rasa.nlu.training_data.formats.readerwriter import TrainingDataWriter

store = EvaluationStore(
entity_predictions=entity_predictions, entity_targets=entity_targets
)

targets, predictions = store.serialise()

assert len(targets) == len(predictions)

i_pred = 0
i_target = 0
for i, prediction in enumerate(predictions):
target = targets[i]
if prediction != "None" and target != "None":
predicted = entity_predictions[i_pred]
assert prediction == TrainingDataWriter.generate_entity(
predicted.get("text"), predicted
)
assert predicted.get("start") == entity_targets[i_target].get("start")
assert predicted.get("end") == entity_targets[i_target].get("end")

if prediction != "None":
i_pred += 1
if target != "None":
i_target += 1


async def test_test_does_not_use_rules(tmp_path: Path, default_agent: Agent):
from rasa.core.test import _generate_trackers

Expand Down

0 comments on commit be72c78

Please sign in to comment.