Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified evaluation to use seqeval package #14

Merged
merged 3 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
src/massive/**/__pycache__/*
*~
*.pyc
*.swp
1 change: 1 addition & 0 deletions conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ dependencies:
- scikit-learn==1.0.1
- scipy==1.7.3
- sentencepiece==0.1.96
- seqeval==1.2.2
- six==1.16.0
- sklearn==0.0
- tabulate==0.8.9
Expand Down
122 changes: 65 additions & 57 deletions src/massive/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from math import sqrt
import numpy as np
import os
from seqeval.metrics import f1_score
import sklearn.metrics as sklm
import torch
from transformers import (
Expand Down Expand Up @@ -374,9 +375,57 @@ def compute_metrics(p):

return compute_metrics

def convert_to_bio(seq_tags, outside='Other', labels_merge=None):
"""
Converts a sequence of tags into BIO format. EX:

['city', 'city', 'Other', 'country', -100, 'Other']
to
['B-city', 'I-city', 'O', 'B-country', 'I-country', 'O']
where outside = 'Other' and labels_merge = [-100]

:param seq_tags: the sequence of tags that should be converted
:type seq_tags: list
:param outside: The label(s) to put outside (ignore). Default: 'Other'
:type outside: str or list
:param labels_merge: The labels to merge leftward (i.e. for tokenized inputs)
:type labels_merge: str or list
:return: a BIO-tagged sequence
:rtype: list
"""

seq_tags = [str(x) for x in seq_tags]

outside = [outside] if type(outside) != list else outside
outside = [str(x) for x in outside]

if labels_merge:
labels_merge = [labels_merge] if type(labels_merge) != list else labels_merge
labels_merge = [str(x) for x in labels_merge]
else:
labels_merge = []

bio_tagged = []
prev_tag = None
for tag in seq_tags:
if tag in outside:
bio_tagged.append('O')
prev_tag = tag
continue
if tag != prev_tag and tag not in labels_merge:
bio_tagged.append('B-' + tag)
prev_tag = tag
continue
if tag == prev_tag or tag in labels_merge:
if prev_tag in outside:
bio_tagged.append('O')
else:
bio_tagged.append('I-' + prev_tag)

return bio_tagged

def eval_preds(pred_intents=None, lab_intents=None, pred_slots=None, lab_slots=None,
eval_metrics='all', labels_ignore='Other', labels_merge=None, pad='Other',
slot_level_combination=True):
eval_metrics='all', labels_ignore='Other', labels_merge=None, pad='Other'):
"""
Function to evaluate the predictions from a model

Expand All @@ -397,18 +446,8 @@ def eval_preds(pred_intents=None, lab_intents=None, pred_slots=None, lab_slots=N
:type labels_merge: str or list
:param pad: The value to use when padding slot predictions to match the length of ground truth
:type pad: str
:param slot_level_combination: Whether to merge adjacent tokens with the same slot label
:type slot_level_combination: bool
"""

# convert to correct types
labels_ignore = [labels_ignore] if type(labels_ignore) != list else labels_ignore
labels_ignore = [str(x) for x in labels_ignore]
if labels_merge:
labels_merge = [labels_merge] if type(labels_merge) != list else labels_merge
labels_merge = [str(x) for x in labels_merge]
else:
labels_merge = []
results = {}

# Check lengths
Expand All @@ -424,69 +463,38 @@ def eval_preds(pred_intents=None, lab_intents=None, pred_slots=None, lab_slots=N
results['intent_acc_stderr'] = sqrt(intent_acc*(1-intent_acc)/len(pred_intents))

if lab_slots is not None and pred_slots is not None:
pruned_slot_labels, pruned_slot_preds = [], []
bio_slot_labels, bio_slot_preds = [], []
for lab, pred in zip(lab_slots, pred_slots):

# Pad or truncate prediction as needed using `pad` arg
if type(pred) == list:
pred = pred[:len(lab)] + [pad]*(len(lab) - len(pred))

if slot_level_combination:
# for each prediction and label, we want to combine tokens with same slot into
# a single slot. So we'll make a string for each and concatenate with commas
new_lab, new_pred = [], []
prev_lab = ''

in_merge = False
for i in range(len(lab)):
if str(lab[i]) in labels_ignore:
prev_lab = str(lab[i])
in_merge = False
elif str(lab[i]) in labels_merge:
if i != 0 and not in_merge:
prev_lab = str(lab[i-1])
in_merge = True
# Combine slots
elif str(lab[i]) == prev_lab:
new_lab[-1] = new_lab[-1] + ',' + str(lab[i])
new_pred[-1] = new_pred[-1] + ',' + str(pred[i])
prev_lab = str(lab[i])
in_merge = False
else:
new_lab.append(str(lab[i]))
new_pred.append(str(pred[i]))
prev_lab = str(lab[i])
in_merge = False

pred = new_pred
lab = new_lab

pruned_slot_labels.append(lab)
pruned_slot_preds.append(pred)
# convert to BIO
bio_slot_labels.append(
convert_to_bio(lab, outside=labels_ignore, labels_merge=labels_merge)
)
bio_slot_preds.append(
convert_to_bio(pred, outside=labels_ignore, labels_merge=labels_merge)
)

if ('slot_micro_f1' in eval_metrics) or ('all' in eval_metrics):

# Flatten list of lists to a single list
flat_pruned_slot_labels = [item for sublist in pruned_slot_labels for item in sublist]
flat_pruned_slot_preds = [item for sublist in pruned_slot_preds for item in sublist]

# Calculate globally micro averaged slot f1 (~0.2 seconds)
smf1 = sklm.f1_score(flat_pruned_slot_labels,
flat_pruned_slot_preds,
average='micro',
zero_division=0)
# from seqeval
smf1 = f1_score(bio_slot_labels, bio_slot_preds)
results['slot_micro_f1'] = smf1
# Assuming normal distribution. Multiply by z (from "z table") to get confidence int
results['slot_micro_f1_stderr'] = sqrt(smf1*(1-smf1)/len(flat_pruned_slot_preds))
total_slots = sum([len(x) for x in bio_slot_preds])
results['slot_micro_f1_stderr'] = sqrt(smf1*(1-smf1)/total_slots)

if ('ex_match_acc' in eval_metrics) or ('all' in eval_metrics):
# calculate exact match accuracy (~0.01 seconds)
matches = 0
denom = 0
for p_int, p_slot, l_int, l_slot in zip(pred_intents,
pruned_slot_preds,
bio_slot_preds,
lab_intents,
pruned_slot_labels):
bio_slot_labels):

if (p_int == l_int) and (p_slot == l_slot):
matches += 1
Expand Down
Binary file not shown.
Binary file not shown.
77 changes: 71 additions & 6 deletions test/test_eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

import pytest
from massive.utils.training_utils import eval_preds
from massive.utils.training_utils import convert_to_bio, eval_preds

cases = [
# ---------------- Slot F1 ---------------
Expand All @@ -42,14 +42,15 @@
),
(
# Test padding
# 2 TP and 1 FN = 2 / (2 + (1 + 0) / 2)
None,
None,
[['X', 'X']],
[['X', 'X', 'Y']],
[['X', 'X', 'Y', 'Other']],
[['X', 'X', 'Y', 'Other', 'Y']],
'Other',
None,
'slot_micro_f1',
{'slot_micro_f1': 0.5}
{'slot_micro_f1': 0.8}
),
(
# Test truncation
Expand All @@ -73,16 +74,50 @@
'slot_micro_f1',
{'slot_micro_f1': 0.5}
),
(
# Test prediction too long
None,
None,
[['X', 'X', 'X', 'Y', 'Y']],
[['X', 'X', 'Other', 'Y', 'Y']],
'Other',
None,
'slot_micro_f1',
{'slot_micro_f1': 0.5}
),
(
# Test prediction too short
None,
None,
[['X', 'Other', 'Other', 'Y', 'Y']],
[['X', 'X', 'Other', 'Y', 'Y']],
'Other',
None,
'slot_micro_f1',
{'slot_micro_f1': 0.5}
),
(
# Test prediction number mismatch
# 1 FN for Y and 1 TP for X = 1 / (1 + (0 + 1) / 2)
None,
None,
[['Other'], ['X']],
[['Y'], ['X']],
'Other',
None,
'slot_micro_f1',
{'slot_micro_f1': 0.67}
),
(
# Test -100 merging
None,
None,
[[50, -100, 50, -100, -100, 20, 20, 10, 10, -100, 20]],
[[50, -100, 50, -100, -100, 20, 0, 20, 0, -100, 20]],
[[50, -100, 50, -100, -100, 20, 0, 10, 0, -100, 20]],
[0],
[-100],
'slot_micro_f1',
{'slot_micro_f1': 0.75}
{'slot_micro_f1': 0.5}
),

# ------------- Exact match acc and intent acc ----------
Expand All @@ -107,6 +142,17 @@
'all',
{'ex_match_acc': 0.25, 'intent_acc': 0.75}
),
(
# Test prediction too long
['A'],
['A'],
[['X', 'X', 'X']],
[['X', 'X', 'Other']],
'Other',
None,
'all',
{'ex_match_acc': 0}
)
]

@pytest.mark.parametrize(
Expand All @@ -131,3 +177,22 @@ def test_eval_preds(
for key in out:
assert key in results
assert round(out[key], 2) == round(results[key], 2)

bio_cases = [
(
['city', 'city', 'Other', 'country', -100, 'Other'],
'Other',
-100,
['B-city', 'I-city', 'O', 'B-country', 'I-country', 'O']
),
(
[1, 1, 3, 3, 9, 4],
[3],
[9],
['B-1', 'I-1', 'O', 'O', 'O', 'B-4']
)
]

@pytest.mark.parametrize('seq_tags, outside, labels_merge, out', bio_cases)
def test_convert_to_bio(seq_tags, outside, labels_merge, out):
assert convert_to_bio(seq_tags, outside, labels_merge) == out