-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathlm_featurizer.py
911 lines (756 loc) · 35.7 KB
/
lm_featurizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
from __future__ import annotations
import os
import numpy as np
import logging
from typing import Any, Text, List, Dict, Tuple, Type
import tensorflow as tf
from rasa.engine.graph import ExecutionContext, GraphComponent
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
from rasa.nlu.tokenizers.tokenizer import Token, Tokenizer
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
from rasa.nlu.constants import (
DENSE_FEATURIZABLE_ATTRIBUTES,
SEQUENCE_FEATURES,
SENTENCE_FEATURES,
NO_LENGTH_RESTRICTION,
NUMBER_OF_SUB_TOKENS,
TOKENS_NAMES,
)
from rasa.shared.nlu.constants import TEXT, ACTION_TEXT
from rasa.utils import train_utils
logger = logging.getLogger(__name__)
# Disable `tokenizers` warning (https://github.com/huggingface/transformers/issues/5486)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_MODEL_NAME = "bert" # used if neither `model_name` or `model_weights` are set
DEFAULT_MODEL_WEIGHTS = {
"bert": "rasa/LaBSE",
"gpt": "openai-gpt",
"gpt2": "gpt2",
"xlnet": "xlnet-base-cased",
"distilbert": "distilbert-base-uncased",
"roberta": "roberta-base",
}
CLS_TOKEN = "[CLS]"
# Those models are known to break the component as of transformers 4.13.0,
# due to the tokenizer specific cleanup not working, the AutoConfig not providing
# the necessary data, or since the models expect different input than just text.
INCOMPATIBLE_MODELS = [
"Bartphor",
"Bertweet",
"BlenderbotSmall",
"ByT5",
"CTRL",
"LayoutLMv2",
"LayoutXLM",
"MBart50",
"Perceiver",
"Phobert",
"Reformer",
"Speech2Text",
"Speech2Text2",
"Tapas",
"TransfoXL",
"Wav2Vec2CTC",
"Wav2Vec2",
"XLMProphetNet",
"T5",
"BertGeneration",
"LED",
"Canine",
"CLIP",
]
def get_model_weights(config: Dict[str, Any]) -> str:
"""Gets the model weights from the configuration.
In case no model weights are specified, but the model name is from the supported
list in `DEFAULT_MODEL_WEIGHTS`, the default model weights are used. Otherwise a
KeyError is raised.
"""
model_name = config.get("model_name", DEFAULT_MODEL_NAME)
model_weights = config.get("model_weights")
if model_weights is None:
if model_name in DEFAULT_MODEL_WEIGHTS:
model_weights = DEFAULT_MODEL_WEIGHTS[model_name]
logger.info(
f"Model weights not specified. Will choose default model "
f"weights: {model_weights}"
)
else:
raise KeyError(
f"No model_weights specified and there is no default weights"
f" available for the provided model_name {model_name}. Please"
f" specify model_weights explicitly."
)
return model_weights
@DefaultV1Recipe.register(
DefaultV1Recipe.ComponentType.MESSAGE_FEATURIZER, is_trainable=False
)
class LanguageModelFeaturizer(DenseFeaturizer, GraphComponent):
"""A featurizer that uses transformer-based language models.
This component loads a pre-trained language model
from the Transformers library (https://github.com/huggingface/transformers)
including BERT, GPT, GPT-2, xlnet, distilbert, and roberta.
It also tokenizes and featurizes the featurizable dense attributes of
each message.
"""
@classmethod
def required_components(cls) -> List[Type]:
"""Components that should be included in the pipeline before this component."""
return [Tokenizer]
def __init__(
self, config: Dict[Text, Any], execution_context: ExecutionContext
) -> None:
"""Initializes the featurizer with the model in the config."""
super(LanguageModelFeaturizer, self).__init__(
execution_context.node_name, config
)
self._load_model_metadata()
self._load_model_instance()
@staticmethod
def get_default_config() -> Dict[Text, Any]:
"""Returns LanguageModelFeaturizer's default config."""
return {
**DenseFeaturizer.get_default_config(),
# name of the language model to load.
"model_name": "bert",
# Pre-Trained weights to be loaded(string)
"model_weights": None,
# an optional path to a specific directory to download
# and cache the pre-trained model weights.
"cache_dir": None,
# allows to skip model loading for unit tests
# (e.g. if only tokenization is tested)
"load_model": True,
}
@classmethod
def validate_config(cls, config: Dict[Text, Any]) -> None:
"""Validates the configuration."""
pass
@classmethod
def create(
cls,
config: Dict[Text, Any],
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
) -> LanguageModelFeaturizer:
"""Creates a LanguageModelFeaturizer.
Loads the model specified in the config.
"""
return cls(config, execution_context)
@staticmethod
def required_packages() -> List[Text]:
"""Returns the extra python dependencies required."""
return ["transformers"]
def _load_model_metadata(self) -> None:
"""Loads the metadata for the specified model and set them as properties.
This includes the model name, model weights, cache directory and the
maximum sequence length the model can handle.
"""
# Note: these imports are only done locally, since otherwise `sys.modules`
# contains all HuggingFace classes, including the ones depending on `torch`,
# which leads to issues in unit tests, particularly with `freezegun` making use
# of the module cache and throwing a ModuleNotFound error for `torch`.
from transformers import AutoConfig
self.model_weights = get_model_weights(self._config)
self.cache_dir = self._config["cache_dir"]
model_config = AutoConfig.from_pretrained(self.model_weights)
model_name = type(model_config).__name__.replace("Config", "")
if model_name in INCOMPATIBLE_MODELS:
raise ValueError(
f"You tried using a {model_name} model, which is not compatible with "
f"`LanguageModelFeaturizer`. Please consult the documentation under"
f"https://rasa.com/docs/rasa/components/#languagemodelfeaturizer on "
f"which models are supported."
)
self.max_model_sequence_length = model_config.max_position_embeddings
def _load_model_instance(self) -> None:
"""Tries to load the model instance.
Model loading should be skipped in unit tests.
See unit tests for examples.
"""
# Note: these imports are only done locally, since otherwise `sys.modules`
# contains all HuggingFace classes, including the ones depending on `torch`,
# which leads to issues in unit tests, particularly with `freezegun` making use
# of the module cache and throwing a ModuleNotFound error for `torch`.
from transformers import AutoTokenizer, TFAutoModel
logger.debug(f"Loading Tokenizer and Model for {self.model_weights}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_weights, cache_dir=self.cache_dir
)
if self._config["load_model"]:
self.model = TFAutoModel.from_pretrained( # type: ignore
self.model_weights, cache_dir=self.cache_dir
)
# Use a universal pad token since all transformer architectures do not have a
# consistent token. Instead of pad_token_id we use unk_token_id because
# pad_token_id is not set for all architectures. We can't add a new token as
# well since vocabulary resizing is not yet supported for TF classes.
# Also, this does not hurt the model predictions since we use an attention mask
# while feeding input.
self.pad_token_id = self.tokenizer.unk_token_id
# Get the special token ids that are added by the tokenizer (e.g. [CLS]) so
# they can be removed from the output.
# Remove the UNK token since we're interested in the tokens that are added
# additionally to the existing tokens, whereas the UNK token can represent an
# original token and should therefore not be filtered from the output.
self.special_token_ids = self.tokenizer.all_special_ids.copy()
if self.tokenizer.unk_token_id in self.special_token_ids:
self.special_token_ids.remove(self.tokenizer.unk_token_id)
def _lm_tokenize(self, text: Text) -> Tuple[List[int], List[Text]]:
"""Passes the text through the tokenizer of the language model.
Args:
text: Text to be tokenized.
Returns: List of token ids and token strings.
"""
split_token_ids = self.tokenizer.encode(text, add_special_tokens=False)
split_token_strings = self.tokenizer.convert_ids_to_tokens(split_token_ids)
return split_token_ids, split_token_strings
def _add_lm_specific_special_tokens(
self, token_ids: List[List[int]]
) -> List[List[int]]:
"""Adds the language and model-specific tokens used during training.
Args:
token_ids: List of token ids for each example in the batch.
Returns: Augmented list of token ids for each example in the batch.
"""
augmented_tokens = [
self.tokenizer.build_inputs_with_special_tokens(example_token_ids)
for example_token_ids in token_ids
]
return augmented_tokens
def _lm_specific_token_cleanup(
self, split_token_ids: List[int], token_strings: List[Text]
) -> Tuple[List[int], List[Text]]:
"""Cleans up special chars added by tokenizers of language models.
Many language models add a special char in front/back of (some) words. We clean
up those chars as they are not needed once the features are already computed.
Args:
split_token_ids: List of token ids received as output from the language
model specific tokenizer.
token_strings: List of token strings received as output from the language
model specific tokenizer.
Returns: Cleaned up token ids and token strings.
"""
# NOTE: We're using `PretrainedTokenizer.convert_tokens_to_string` to remove
# delimiter pre- and suffixes, such as `##` (BERT), `Ġ` (GPT2), or `▁` (XLNET).
# This function expects a list of tokens and builds a string by
# concatenating and removing the delimiters. The additional empty string here
# is necessary since for BERT style tokenizers, the `##` signifies a sub-token
# and the filtering only takes effect if there is a starting sub-token before.
# For the other tokenizers, the delimiters signify whitespace.
token_ids_string = [
(id_, self.tokenizer.convert_tokens_to_string(["", token]))
for id_, token in zip(split_token_ids, token_strings)
]
token_ids_string = [(id_, token) for id_, token in token_ids_string if token]
# return as individual token ids and token strings
token_ids: List[int]
token_ids, token_strings = zip(*token_ids_string)
return token_ids, token_strings
@staticmethod
def _post_process_sequence_embeddings(
sequence_embeddings: np.ndarray,
special_tokens_mask: np.ndarray,
cls_token_idxs: List[Any],
) -> Tuple[np.ndarray, np.ndarray]:
"""Computes sentence and sequence level representations for relevant tokens.
Args:
sequence_embeddings: Sequence level dense features received as output from
language model.
special_tokens_mask: A boolean mask signifying the special tokens added by
the tokenizer.
cls_token_idxs: A list with the index of the [CLS] token, if present,
otherwise containing `None` for an example.
Returns: Sentence and sequence level representations.
"""
sentence_embeddings = []
post_processed_sequence_embeddings = []
for example_embedding, example_mask, example_cls_token_idx in zip(
sequence_embeddings, special_tokens_mask, cls_token_idxs
):
# The mask gets inverted, so the embeddings of special tokens are discarded
example_post_processed_embedding = example_embedding[
~np.array(example_mask)
]
# Use embedding of [CLS] token for BERT style architectures
if example_cls_token_idx is not None:
example_sentence_embedding = example_embedding[example_cls_token_idx]
else:
example_sentence_embedding = np.mean(
example_post_processed_embedding, axis=0
)
sentence_embeddings.append(example_sentence_embedding)
post_processed_sequence_embeddings.append(example_post_processed_embedding)
return (
np.array(sentence_embeddings),
np.array(post_processed_sequence_embeddings),
)
def _tokenize_example(
self, message: Message, attribute: Text
) -> Tuple[List[Token], List[int]]:
"""Tokenizes a single message example.
Many language models add a special char in front of (some) words and split
words into sub-words. To ensure the entity start and end values matches the
token values, use the tokens produced by the Tokenizer component. If
individual tokens are split up into multiple tokens, we add this information
to the respected token.
Args:
message: Single message object to be processed.
attribute: Property of message to be processed, one of ``TEXT`` or
``RESPONSE``.
Returns: List of token strings and token ids for the corresponding
attribute of the message.
"""
tokens_in = message.get(TOKENS_NAMES[attribute])
tokens_out = []
token_ids_out = []
for token in tokens_in:
# use lm specific tokenizer to further tokenize the text
split_token_ids, split_token_strings = self._lm_tokenize(token.text)
if not split_token_ids:
# fix the situation that `token.text` only contains whitespace or other
# special characters, which cause `split_token_ids` and
# `split_token_strings` be empty, finally cause
# `self._lm_specific_token_cleanup()` to raise an exception
continue
(split_token_ids, split_token_strings) = self._lm_specific_token_cleanup(
split_token_ids, split_token_strings
)
token_ids_out += split_token_ids
token.set(NUMBER_OF_SUB_TOKENS, len(split_token_strings))
tokens_out.append(token)
return tokens_out, token_ids_out
def _get_token_ids_for_batch(
self, batch_examples: List[Message], attribute: Text
) -> Tuple[List[List[Token]], List[List[int]]]:
"""Computes token ids and token strings for each example in batch.
A token id is the id of that token in the vocabulary of the language model.
Args:
batch_examples: Batch of message objects for which tokens need to be
computed.
attribute: Property of message to be processed, one of ``TEXT`` or
``RESPONSE``.
Returns: List of token strings and token ids for each example in the batch.
"""
batch_token_ids = []
batch_tokens = []
for example in batch_examples:
example_tokens, example_token_ids = self._tokenize_example(
example, attribute
)
batch_tokens.append(example_tokens)
batch_token_ids.append(example_token_ids)
return batch_tokens, batch_token_ids
@staticmethod
def _compute_attention_mask(
actual_sequence_lengths: List[int], max_input_sequence_length: int
) -> np.ndarray:
"""Computes a mask for padding tokens.
This mask will be used by the language model so that it does not attend to
padding tokens.
Args:
actual_sequence_lengths: List of length of each example without any
padding.
max_input_sequence_length: Maximum length of a sequence that will be
present in the input batch. This is
after taking into consideration the maximum input sequence the model
can handle. Hence it can never be
greater than self.max_model_sequence_length in case the model
applies length restriction.
Returns: Computed attention mask, 0 for padding and 1 for non-padding
tokens.
"""
attention_mask = []
for actual_sequence_length in actual_sequence_lengths:
# add 1s for present tokens, fill up the remaining space up to max
# sequence length with 0s (non-existing tokens)
padded_sequence = [1] * min(
actual_sequence_length, max_input_sequence_length
) + [0] * (
max_input_sequence_length
- min(actual_sequence_length, max_input_sequence_length)
)
attention_mask.append(padded_sequence)
attention_mask = np.array(attention_mask).astype(np.float32)
return attention_mask
def _compute_special_tokens_mask(
self, batch_token_ids: List[List[int]]
) -> Tuple[np.ndarray, List[Any]]:
"""Computes a mask for the special tokens added by the tokenizer.
This mask will be used to filter out the special tokens before creating a
sequence embedding.
Returns: Computed mask, 1 for special tokens, 0 for normal tokens.
List of indices of the [CLS] token, if present, otherwise None.
"""
special_tokens_mask = []
cls_token_idxs = []
for token_ids in batch_token_ids:
mask = [id_ in self.special_token_ids for id_ in token_ids]
# Truncate the mask to the maximum sequence length of the model
if (
self.max_model_sequence_length != NO_LENGTH_RESTRICTION
and len(mask) > self.max_model_sequence_length
):
mask = mask[: self.max_model_sequence_length]
special_tokens_mask.append(mask)
# Keep the CLS token position for BERT style architectures to build
# a sentence representation
if CLS_TOKEN in self.tokenizer.all_special_tokens:
try:
cls_idx = token_ids.index(self.tokenizer.cls_token_id)
except ValueError:
cls_idx = None
else:
cls_idx = None
cls_token_idxs.append(cls_idx)
return np.array(special_tokens_mask), cls_token_idxs
def _extract_sequence_lengths(
self, batch_token_ids: List[List[int]]
) -> Tuple[List[int], int]:
"""Extracts the sequence length for each example and maximum sequence length.
Args:
batch_token_ids: List of token ids for each example in the batch.
Returns:
Tuple consisting of: the actual sequence lengths for each example,
and the maximum input sequence length (taking into account the
maximum sequence length that the model can handle.
"""
# Compute max length across examples
max_input_sequence_length = 0
actual_sequence_lengths = []
for example_token_ids in batch_token_ids:
sequence_length = len(example_token_ids)
actual_sequence_lengths.append(sequence_length)
max_input_sequence_length = max(
max_input_sequence_length, len(example_token_ids)
)
# Take into account the maximum sequence length the model can handle
max_input_sequence_length = (
max_input_sequence_length
if self.max_model_sequence_length == NO_LENGTH_RESTRICTION
else min(max_input_sequence_length, self.max_model_sequence_length)
)
return actual_sequence_lengths, max_input_sequence_length
def _add_padding_to_batch(
self, batch_token_ids: List[List[int]], max_sequence_length_model: int
) -> List[List[int]]:
"""Adds padding so that all examples in the batch are of the same length.
Args:
batch_token_ids: Batch of examples where each example is a non-padded list
of token ids.
max_sequence_length_model: Maximum length of any input sequence in the batch
to be fed to the model.
Returns:
Padded batch with all examples of the same length.
"""
padded_token_ids = []
# Add padding according to max_sequence_length
# Some models don't contain pad token, we use unknown token as padding token.
# This doesn't affect the computation since we compute an attention mask
# anyways.
for example_token_ids in batch_token_ids:
# Truncate any longer sequences so that they can be fed to the model
if len(example_token_ids) > max_sequence_length_model:
example_token_ids = example_token_ids[:max_sequence_length_model]
padded_token_ids.append(
example_token_ids
+ [self.pad_token_id]
* (max_sequence_length_model - len(example_token_ids))
)
return padded_token_ids
@staticmethod
def _extract_nonpadded_embeddings(
embeddings: np.ndarray, actual_sequence_lengths: List[int]
) -> np.ndarray:
"""Extracts embeddings for actual tokens.
Use pre-computed non-padded lengths of each example to extract embeddings
for non-padding tokens.
Args:
embeddings: sequence level representations for each example of the batch.
actual_sequence_lengths: non-padded lengths of each example of the batch.
Returns:
Sequence level embeddings for only non-padding tokens of the batch.
"""
nonpadded_sequence_embeddings = []
for index, embedding in enumerate(embeddings):
unmasked_embedding = embedding[: actual_sequence_lengths[index]]
nonpadded_sequence_embeddings.append(unmasked_embedding)
return np.array(nonpadded_sequence_embeddings)
def _compute_batch_sequence_features(
self, batch_attention_mask: np.ndarray, padded_token_ids: List[List[int]]
) -> np.ndarray:
"""Feeds the padded batch to the language model.
Args:
batch_attention_mask: Mask of 0s and 1s which indicate whether the token
is a padding token or not.
padded_token_ids: Batch of token ids for each example. The batch is padded
and hence can be fed at once.
Returns:
Sequence level representations from the language model.
"""
model_outputs = self.model(
tf.convert_to_tensor(padded_token_ids),
attention_mask=tf.convert_to_tensor(batch_attention_mask),
)
# sequence hidden states is always the first output from all models
sequence_hidden_states = model_outputs[0]
sequence_hidden_states = sequence_hidden_states.numpy()
return sequence_hidden_states
def _validate_sequence_lengths(
self,
actual_sequence_lengths: List[int],
batch_examples: List[Message],
attribute: Text,
inference_mode: bool = False,
) -> None:
"""Validates sequence length.
Checks if sequence lengths of inputs are less than
the max sequence length the model can handle.
This method should throw an error during training, and log a debug
message during inference if any of the input examples have a length
greater than maximum sequence length allowed.
Args:
actual_sequence_lengths: original sequence length of all inputs
batch_examples: all message instances in the batch
attribute: attribute of message object to be processed
inference_mode: whether this is during training or inference
"""
if self.max_model_sequence_length == NO_LENGTH_RESTRICTION:
# There is no restriction on sequence length from the model
return
for sequence_length, example in zip(actual_sequence_lengths, batch_examples):
if sequence_length > self.max_model_sequence_length:
if not inference_mode:
raise RuntimeError(
f"The sequence length of '{example.get(attribute)[:20]}...' "
f"is too long({sequence_length} tokens) for the "
f"model chosen {self.model_weights} which has a maximum "
f"sequence length of {self.max_model_sequence_length} tokens. "
f"Either shorten the message or use a model which has no "
f"restriction on input sequence length like XLNet."
)
logger.debug(
f"The sequence length of '{example.get(attribute)[:20]}...' "
f"is too long({sequence_length} tokens) for the "
f"model chosen {self.model_weights} which has a maximum "
f"sequence length of {self.max_model_sequence_length} tokens. "
f"Downstream model predictions may be affected because of this."
)
def _add_extra_padding(
self, sequence_embeddings: np.ndarray, actual_sequence_lengths: List[int]
) -> np.ndarray:
"""Adds extra zero padding to match the original sequence length.
This is only done if the input was truncated during the batch
preparation of input for the model.
Args:
sequence_embeddings: Embeddings returned from the model
actual_sequence_lengths: original sequence length of all inputs
Returns:
Modified sequence embeddings with padding if necessary
"""
if self.max_model_sequence_length == NO_LENGTH_RESTRICTION:
# No extra padding needed because there wouldn't have been any
# truncation in the first place
return sequence_embeddings
reshaped_sequence_embeddings = []
for index, embedding in enumerate(sequence_embeddings):
embedding_size = embedding.shape[-1]
if actual_sequence_lengths[index] > self.max_model_sequence_length:
embedding = np.concatenate(
[
embedding,
np.zeros(
(
actual_sequence_lengths[index]
- self.max_model_sequence_length,
embedding_size,
),
dtype=np.float32,
),
]
)
reshaped_sequence_embeddings.append(embedding)
return np.array(reshaped_sequence_embeddings)
def _get_model_features_for_batch(
self,
batch_token_ids: List[List[int]],
batch_tokens: List[List[Token]],
batch_examples: List[Message],
attribute: Text,
inference_mode: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""Computes dense features of each example in the batch.
We first add the special tokens corresponding to each language model. Next, we
add appropriate padding and compute a mask for that padding so that it doesn't
affect the feature computation. The padded batch is next fed to the language
model and token level embeddings are computed. Using the pre-computed mask,
embeddings for non-padding tokens are extracted and subsequently sentence
level embeddings are computed.
Args:
batch_token_ids: List of token ids of each example in the batch.
batch_tokens: List of token objects for each example in the batch.
batch_examples: List of examples in the batch.
attribute: attribute of the Message object to be processed.
inference_mode: Whether the call is during training or during inference.
Returns:
Sentence and token level dense representations.
"""
# Let's first add tokenizer specific special tokens to all examples
batch_token_ids_augmented = self._add_lm_specific_special_tokens(
batch_token_ids
)
# Keep a mask of the special tokens that got added
special_tokens_mask, cls_token_idxs = self._compute_special_tokens_mask(
batch_token_ids_augmented
)
# Compute sequence lengths for all examples
(
actual_sequence_lengths,
max_input_sequence_length,
) = self._extract_sequence_lengths(batch_token_ids_augmented)
# Validate that all sequences can be processed based on their sequence
# lengths and the maximum sequence length the model can handle
self._validate_sequence_lengths(
actual_sequence_lengths, batch_examples, attribute, inference_mode
)
# Add padding so that whole batch can be fed to the model
padded_token_ids = self._add_padding_to_batch(
batch_token_ids_augmented, max_input_sequence_length
)
# Compute attention mask based on actual_sequence_length
batch_attention_mask = self._compute_attention_mask(
actual_sequence_lengths, max_input_sequence_length
)
# Get token level features from the model
sequence_hidden_states = self._compute_batch_sequence_features(
batch_attention_mask, padded_token_ids
)
# Extract features for only non-padding tokens
sequence_nonpadded_embeddings = self._extract_nonpadded_embeddings(
sequence_hidden_states, actual_sequence_lengths
)
# Extract sentence level and post-processed features
(
sentence_embeddings,
sequence_embeddings,
) = self._post_process_sequence_embeddings(
sequence_nonpadded_embeddings, special_tokens_mask, cls_token_idxs
)
# Pad zeros for examples which were truncated in inference mode.
# This is intentionally done after sentence embeddings have been
# extracted so that they are not affected
sequence_embeddings = self._add_extra_padding(
sequence_embeddings, actual_sequence_lengths
)
# shape of matrix for all sequence embeddings
batch_dim = len(sequence_embeddings)
seq_dim = max(e.shape[0] for e in sequence_embeddings)
feature_dim = sequence_embeddings[0].shape[1]
shape = (batch_dim, seq_dim, feature_dim)
# align features with tokens so that we have just one vector per token
# (don't include sub-tokens)
sequence_embeddings = train_utils.align_token_features(
batch_tokens, sequence_embeddings, shape
)
# sequence_embeddings is a padded numpy array
# remove the padding, keep just the non-zero vectors
sequence_final_embeddings = []
for embeddings, tokens in zip(sequence_embeddings, batch_tokens):
sequence_final_embeddings.append(embeddings[: len(tokens)])
sequence_final_embeddings = np.array(sequence_final_embeddings)
return sentence_embeddings, sequence_final_embeddings
def _get_docs_for_batch(
self,
batch_examples: List[Message],
attribute: Text,
inference_mode: bool = False,
) -> List[Dict[Text, Any]]:
"""Computes language model docs for all examples in the batch.
Args:
batch_examples: Batch of message objects for which language model docs
need to be computed.
attribute: Property of message to be processed, one of ``TEXT`` or
``RESPONSE``.
inference_mode: Whether the call is during inference or during training.
Returns:
List of language model docs for each message in batch.
"""
batch_tokens, batch_token_ids = self._get_token_ids_for_batch(
batch_examples, attribute
)
(
batch_sentence_features,
batch_sequence_features,
) = self._get_model_features_for_batch(
batch_token_ids, batch_tokens, batch_examples, attribute, inference_mode
)
# A doc consists of
# {'sequence_features': ..., 'sentence_features': ...}
batch_docs = []
for index in range(len(batch_examples)):
doc = {
SEQUENCE_FEATURES: batch_sequence_features[index],
SENTENCE_FEATURES: np.reshape(batch_sentence_features[index], (1, -1)),
}
batch_docs.append(doc)
return batch_docs
def process_training_data(self, training_data: TrainingData) -> TrainingData:
"""Computes tokens and dense features for each message in training data.
Args:
training_data: NLU training data to be tokenized and featurized
"""
batch_size = 64
for attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
non_empty_examples = list(
filter(lambda x: x.get(attribute), training_data.training_examples)
)
batch_start_index = 0
while batch_start_index < len(non_empty_examples):
batch_end_index = min(
batch_start_index + batch_size, len(non_empty_examples)
)
# Collect batch examples
batch_messages = non_empty_examples[batch_start_index:batch_end_index]
# Construct a doc with relevant features
# extracted(tokens, dense_features)
batch_docs = self._get_docs_for_batch(batch_messages, attribute)
for index, ex in enumerate(batch_messages):
self._set_lm_features(batch_docs[index], ex, attribute)
batch_start_index += batch_size
return training_data
def process(self, messages: List[Message]) -> List[Message]:
"""Processes messages by computing tokens and dense features."""
for message in messages:
self._process_message(message)
return messages
def _process_message(self, message: Message) -> Message:
"""Processes a message by computing tokens and dense features."""
# processing featurizers operates only on TEXT and ACTION_TEXT attributes,
# because all other attributes are labels which are featurized during
# training and their features are stored by the model itself.
for attribute in {TEXT, ACTION_TEXT}:
if message.get(attribute):
self._set_lm_features(
self._get_docs_for_batch(
[message], attribute=attribute, inference_mode=True
)[0],
message,
attribute,
)
return message
def _set_lm_features(
self, doc: Dict[Text, Any], message: Message, attribute: Text = TEXT
) -> None:
"""Adds the precomputed word vectors to the messages features."""
sequence_features = doc[SEQUENCE_FEATURES]
sentence_features = doc[SENTENCE_FEATURES]
self.add_features_to_message(
sequence=sequence_features,
sentence=sentence_features,
attribute=attribute,
message=message,
)