Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizers don't split words into sub-words #5756

Merged
merged 16 commits into from
May 5, 2020
5 changes: 5 additions & 0 deletions changelog/5756.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
To avoid the problem of our entity extractors predicting entity labels for just a part of the words,
we introduced a cleaning method after the prediction was done.
We should avoid the incorrect prediction in the first place.
To achieve this we will not tokenize words into sub-words anymore.
We take the mean feature vectors of the sub-words as the feature vector of the word.
1 change: 0 additions & 1 deletion rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,6 @@ def _predict_entities(
)

entities = self.add_extractor_name(entities)
entities = self.clean_up_entities(message, entities)
entities = message.get(ENTITIES, []) + entities

return entities
Expand Down
1 change: 1 addition & 0 deletions rasa/nlu/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

CLS_TOKEN = "__CLS__"
POSITION_OF_CLS_TOKEN = -1
NUMBER_OF_SUB_TOKENS = "number_of_sub_tokens"

MESSAGE_ATTRIBUTES = [TEXT, INTENT, RESPONSE]

Expand Down
1 change: 0 additions & 1 deletion rasa/nlu/extractors/crf_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def _update_crf_order(self, training_data: TrainingData):
def process(self, message: Message, **kwargs: Any) -> None:
entities = self.extract_entities(message)
entities = self.add_extractor_name(entities)
entities = self.clean_up_entities(message, entities)
message.set(ENTITIES, message.get(ENTITIES, []) + entities, add_to_output=True)

def extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
Expand Down
1 change: 0 additions & 1 deletion rasa/nlu/extractors/duckling_http_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def process(self, message: Message, **kwargs: Any) -> None:
)

extracted = self.add_extractor_name(extracted)
extracted = self.clean_up_entities(message, extracted)
message.set(ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True)

@classmethod
Expand Down
261 changes: 1 addition & 260 deletions rasa/nlu/extractors/extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Text, Tuple, Optional, Union
from typing import Any, Dict, List, Text, Tuple, Optional

from rasa.constants import DOCS_URL_TRAINING_DATA_NLU
from rasa.nlu.training_data import TrainingData
Expand Down Expand Up @@ -41,265 +41,6 @@ def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:

return entity

def clean_up_entities(
self, message: Message, entities: List[Dict[Text, Any]], keep: bool = True
) -> List[Dict[Text, Any]]:
"""
Check if multiple entity labels are assigned to one word or if an entity label
is assigned to just a part of a word or if an entity label covers multiple
words, but one word just partly.

This might happen if you are using a tokenizer that splits up words into
sub-words and different entity labels are assigned to the individual sub-words.
If multiple entity labels are assigned to one word, we keep the entity label
with the highest confidence as entity label for that word. If just a part
of the word is annotated, that entity label is taken for the complete word.
If you set 'keep' to 'False', all entity labels for the word will be removed.

Args:
message: message object
entities: list of entities
keep:
If set to 'True', the entity label with the highest confidence is kept
if multiple entity labels are assigned to one word. If set to 'False'
all entity labels for that word will be removed.

Returns:
Updated entities.
"""
misaligned_entities = self._get_misaligned_entities(
message.get(TOKENS_NAMES[TEXT]), entities
)

entity_indices_to_remove = set()

for misaligned_entity in misaligned_entities:
# entity indices involved in the misalignment
entity_indices = misaligned_entity["entity_indices"]

if not keep:
entity_indices_to_remove.update(entity_indices)
continue

idx = self._entity_index_to_keep(entities, entity_indices)

if idx is None or idx not in entity_indices:
entity_indices_to_remove.update(entity_indices)
else:
# keep just one entity
entity_indices.remove(idx)
entity_indices_to_remove.update(entity_indices)

# update that entity to cover the complete word(s)
entities[idx][ENTITY_ATTRIBUTE_START] = misaligned_entity[
ENTITY_ATTRIBUTE_START
]
entities[idx][ENTITY_ATTRIBUTE_END] = misaligned_entity[
ENTITY_ATTRIBUTE_END
]
entities[idx][ENTITY_ATTRIBUTE_VALUE] = message.text[
misaligned_entity[ENTITY_ATTRIBUTE_START] : misaligned_entity[
ENTITY_ATTRIBUTE_END
]
]

# sort indices to remove entries at the end of the list first
# to avoid index out of range errors
for idx in sorted(entity_indices_to_remove, reverse=True):
entities.remove(entities[idx])

return entities

def _get_misaligned_entities(
self, tokens: List[Token], entities: List[Dict[Text, Any]]
) -> List[Dict[Text, Any]]:
"""Identify entities and tokens that are misaligned.

Misaligned entities are those that apply only to a part of a word, i.e.
sub-word.

Args:
tokens: list of tokens
entities: list of detected entities by the entity extractor

Returns:
Misaligned entities including the start and end position
of the final entity in the text and entity indices that are part of this
misalignment.
"""
if not tokens:
return []

# group tokens: one token cluster corresponds to one word
token_clusters = self._token_clusters(tokens)

# added for tests, should only happen if tokens are not set or len(tokens) == 1
if not token_clusters:
return []

misaligned_entities = []
for entity_idx, entity in enumerate(entities):
# get all tokens that are covered/touched by the entity
entity_tokens = self._tokens_of_entity(entity, token_clusters)

if len(entity_tokens) == 1:
# entity covers exactly one word
continue

# get start and end position of complete word
# needed to update the final entity later
start_position = entity_tokens[0].start
end_position = entity_tokens[-1].end

# check if an entity was already found that covers the exact same word(s)
_idx = self._misaligned_entity_index(
misaligned_entities, start_position, end_position
)

if _idx is None:
misaligned_entities.append(
{
ENTITY_ATTRIBUTE_START: start_position,
ENTITY_ATTRIBUTE_END: end_position,
"entity_indices": [entity_idx],
}
)
else:
# pytype: disable=attribute-error
misaligned_entities[_idx]["entity_indices"].append(entity_idx)
# pytype: enable=attribute-error

return misaligned_entities

@staticmethod
def _misaligned_entity_index(
word_entity_cluster: List[Dict[Text, Union[int, List[int]]]],
start_position: int,
end_position: int,
) -> Optional[int]:
"""Get index of matching misaligned entity.

Args:
word_entity_cluster: word entity cluster
start_position: start position
end_position: end position

Returns:
Index of the misaligned entity that matches the provided start and end
position.
"""
for idx, cluster in enumerate(word_entity_cluster):
if (
cluster[ENTITY_ATTRIBUTE_START] == start_position
and cluster[ENTITY_ATTRIBUTE_END] == end_position
):
return idx
return None

@staticmethod
def _tokens_of_entity(
entity: Dict[Text, Any], token_clusters: List[List[Token]]
) -> List[Token]:
"""Get all tokens of token clusters that are covered by the entity.

The entity can cover them completely or just partly.

Args:
entity: the entity
token_clusters: list of token clusters

Returns:
Token clusters that belong to the provided entity.

"""
entity_tokens = []
for token_cluster in token_clusters:
entity_starts_inside_cluster = (
token_cluster[0].start
<= entity[ENTITY_ATTRIBUTE_START]
<= token_cluster[-1].end
)
entity_ends_inside_cluster = (
token_cluster[0].start
<= entity[ENTITY_ATTRIBUTE_END]
<= token_cluster[-1].end
)

if entity_starts_inside_cluster or entity_ends_inside_cluster:
entity_tokens += token_cluster
return entity_tokens

@staticmethod
def _token_clusters(tokens: List[Token]) -> List[List[Token]]:
"""Build clusters of tokens that belong to one word.

Args:
tokens: list of tokens

Returns:
Token clusters.

"""
# token cluster = list of token indices that belong to one word
token_index_clusters = []

# start at 1 in order to check if current token and previous token belong
# to the same word
for token_idx in range(1, len(tokens)):
previous_token_idx = token_idx - 1
# two tokens belong to the same word if there is no other character
# between them
if tokens[token_idx].start == tokens[previous_token_idx].end:
# a word was split into multiple tokens
token_cluster_already_exists = (
token_index_clusters
and token_index_clusters[-1][-1] == previous_token_idx
)
if token_cluster_already_exists:
token_index_clusters[-1].append(token_idx)
else:
token_index_clusters.append([previous_token_idx, token_idx])
else:
# the token corresponds to a single word
if token_idx == 1:
token_index_clusters.append([previous_token_idx])
token_index_clusters.append([token_idx])

return [[tokens[idx] for idx in cluster] for cluster in token_index_clusters]

@staticmethod
def _entity_index_to_keep(
entities: List[Dict[Text, Any]], entity_indices: List[int]
) -> Optional[int]:
"""
Determine the entity index to keep.

If we just have one entity index, i.e. candidate, we return the index of that
candidate. If we have multiple candidates, we return the index of the entity
value with the highest confidence score. If no confidence score is present,
no entity label will be kept.

Args:
entities: the full list of entities
entity_indices: the entity indices to consider

Returns: the idx of the entity to keep
"""
if len(entity_indices) == 1:
return entity_indices[0]

confidences = [
entities[idx][ENTITY_ATTRIBUTE_CONFIDENCE_TYPE]
for idx in entity_indices
if ENTITY_ATTRIBUTE_CONFIDENCE_TYPE in entities[idx]
]

# we don't have confidence values for all entity labels
if len(confidences) != len(entity_indices):
return None

return confidences.index(max(confidences))

@staticmethod
def filter_irrelevant_entities(extracted: list, requested_dimensions: set) -> list:
"""Only return dimensions the user configured"""
Expand Down
1 change: 0 additions & 1 deletion rasa/nlu/extractors/mitie_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def process(self, message: Message, **kwargs: Any) -> None:
mitie_feature_extractor,
)
extracted = self.add_extractor_name(ents)
extracted = self.clean_up_entities(message, extracted)
message.set(ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True)

@classmethod
Expand Down
1 change: 0 additions & 1 deletion rasa/nlu/extractors/spacy_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def process(self, message: Message, **kwargs: Any) -> None:
spacy_nlp = kwargs.get("spacy_nlp", None)
doc = spacy_nlp(message.text)
all_extracted = self.add_extractor_name(self.extract_entities(doc))
all_extracted = self.clean_up_entities(message, all_extracted)
dimensions = self.component_config["dimensions"]
extracted = SpacyEntityExtractor.filter_irrelevant_entities(
all_extracted, dimensions
Expand Down
24 changes: 11 additions & 13 deletions rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from rasa.nlu.tokenizers.convert_tokenizer import ConveRTTokenizer
from rasa.nlu.config import RasaNLUModelConfig
from rasa.nlu.training_data import Message, TrainingData
from rasa.nlu.constants import (
TEXT,
TOKENS_NAMES,
DENSE_FEATURE_NAMES,
DENSE_FEATURIZABLE_ATTRIBUTES,
)
from rasa.nlu.constants import TEXT, DENSE_FEATURE_NAMES, DENSE_FEATURIZABLE_ATTRIBUTES
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -79,25 +74,28 @@ def _compute_sequence_encodings(
self, batch_examples: List[Message], attribute: Text = TEXT
) -> Tuple[np.ndarray, List[int]]:
list_of_tokens = [
example.get(TOKENS_NAMES[attribute]) for example in batch_examples
train_utils.tokens_without_cls(example, attribute)
for example in batch_examples
]

# remove CLS token from list of tokens
list_of_tokens = [sent_tokens[:-1] for sent_tokens in list_of_tokens]

number_of_tokens_in_sentence = [
len(sent_tokens) for sent_tokens in list_of_tokens
]

# join the tokens to get a clean text to ensure the sequence length of
# the returned embeddings from ConveRT matches the length of the tokens
# (including sub-tokens)
tokenized_texts = self._tokens_to_text(list_of_tokens)
token_features = self._sequence_encoding_of_text(tokenized_texts)

return (
self._sequence_encoding_of_text(tokenized_texts),
number_of_tokens_in_sentence,
# ConveRT might split up tokens into sub-tokens
# take the mean of the sub-token vectors and use that as the token vector
token_features = train_utils.align_token_features(
list_of_tokens, token_features
)

return token_features, number_of_tokens_in_sentence

def _combine_encodings(
self,
sentence_encodings: np.ndarray,
Expand Down
1 change: 1 addition & 0 deletions rasa/nlu/featurizers/featurizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy.sparse
from typing import Any, Text, Union, Optional

from rasa.nlu.training_data import Message
from rasa.nlu.components import Component
from rasa.nlu.constants import SPARSE_FEATURE_NAMES, DENSE_FEATURE_NAMES, TEXT
Expand Down
Loading