-
Notifications
You must be signed in to change notification settings - Fork 2
/
collator.py
99 lines (88 loc) · 3.64 KB
/
collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Any, Optional, Union
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
@dataclass
class DataCollatorForAdvSeq2Seq:
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
adversarial = "adversarial_data" in features[0]
labels = (
[feature["labels"] for feature in features]
if "labels" in features[0].keys()
else None
)
if adversarial:
# Pad Adversarial Data Appropriately
adv_features = self.tokenizer.pad(
[
{
"input_ids": feature["adversarial_data"],
}
for feature in features
],
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if labels is not None:
max_label_length = max(len(l) for l in labels)
if self.pad_to_multiple_of is not None:
max_label_length = (
(max_label_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)
padding_side = self.tokenizer.padding_side
for feature in features:
if "adversarial_data" in feature:
del feature["adversarial_data"]
remainder = [self.label_pad_token_id] * (
max_label_length - len(feature["labels"])
)
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder
if padding_side == "right"
else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate(
[feature["labels"], remainder]
).astype(np.int64)
else:
feature["labels"] = np.concatenate(
[remainder, feature["labels"]]
).astype(np.int64)
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if adversarial:
features["adversarial_data"] = adv_features["input_ids"]
features["adversarial_mask"] = adv_features["attention_mask"]
# prepare decoder_input_ids
if (
labels is not None
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
labels=features["labels"]
)
features["decoder_input_ids"] = decoder_input_ids
return features