-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
151 lines (115 loc) · 5.26 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch.nn as nn
import torch
import torch.nn.functional as F
from transformers import *
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.uniform_(m.bias)
class SpanEmbedder(nn.Module):
def __init__(self, config, device):
super(SpanEmbedder, self).__init__()
self.bert_hidden_size = config.bert_hidden_size
self.with_width_embedding = config.with_mention_width
self.use_head_attention = config.with_head_attention
self.device = device
self.dropout = config.dropout
self.padded_vector = torch.zeros(self.bert_hidden_size, device=device)
self.self_attention_layer = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(self.bert_hidden_size, config.hidden_layer),
# nn.Dropout(config['dropout']),
nn.ReLU(),
nn.Linear(config.hidden_layer, 1)
)
self.self_attention_layer.apply(init_weights)
self.width_feature = nn.Embedding(5, config.embedding_dimension)
def pad_continous_embeddings(self, continuous_embeddings):
max_length = max(len(v) for v in continuous_embeddings)
padded_tokens_embeddings = torch.stack(
[torch.cat((emb, self.padded_vector.repeat(max_length - len(emb), 1)))
for emb in continuous_embeddings]
)
masks = torch.stack(
[torch.cat(
(torch.ones(len(emb), device=self.device),
torch.zeros(max_length - len(emb), device=self.device)))
for emb in continuous_embeddings]
)
return padded_tokens_embeddings, masks
def forward(self, start_end, continuous_embeddings, width):
vector = start_end
if self.use_head_attention:
padded_tokens_embeddings, masks = self.pad_continous_embeddings(continuous_embeddings)
attention_scores = self.self_attention_layer(padded_tokens_embeddings).squeeze(-1)
attention_scores *= masks
attention_scores = torch.where(attention_scores != 0, attention_scores,
torch.tensor(-9e9, device=self.device))
attention_scores = F.softmax(attention_scores, dim=1)
weighted_sum = (attention_scores.unsqueeze(-1) * padded_tokens_embeddings).sum(dim=1)
vector = torch.cat((vector, weighted_sum), dim=1)
if self.with_width_embedding:
width = torch.clamp(width, max=4)
width_embedding = self.width_feature(width)
vector = torch.cat((vector, width_embedding), dim=1)
return vector
class SpanScorer(nn.Module):
def __init__(self, config):
super(SpanScorer, self).__init__()
self.input_layer = config.bert_hidden_size * 3
if config.with_mention_width:
self.input_layer += config.embedding_dimension
self.mlp = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(self.input_layer, config.hidden_layer),
# nn.Dropout(config['dropout']),
nn.ReLU(),
nn.Linear(config.hidden_layer, 1)
)
self.mlp.apply(init_weights)
def forward(self, span_embedding):
return self.mlp(span_embedding)
class SimplePairWiseClassifier(nn.Module):
def __init__(self, config):
super(SimplePairWiseClassifier, self).__init__()
self.input_layer = config.bert_hidden_size * 3 if config.with_head_attention else config.bert_hidden_size * 2
if config.with_mention_width:
self.input_layer += config.embedding_dimension
self.input_layer *= 3
self.hidden_layer = config.hidden_layer
self.pairwise_mlp = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(self.input_layer, self.hidden_layer),
nn.ReLU(),
nn.Linear(self.hidden_layer, self.hidden_layer),
nn.Dropout(config.dropout),
nn.ReLU(),
nn.Linear(self.hidden_layer, 1),
)
self.pairwise_mlp.apply(init_weights)
def forward(self, first, second):
return self.pairwise_mlp(torch.cat((first, second, first * second), dim=1))
class FullCrossEncoder(nn.Module):
def __init__(self, config, is_training=True):
super(FullCrossEncoder, self).__init__()
self.segment_size = config.segment_window * 2
self.tokenizer = RobertaTokenizer.from_pretrained(config.roberta_model)
self.tokenizer.add_tokens(['[START]', '[END]'])
if not is_training and config.pretrained_model:
self.model = AutoModel.from_pretrained(config.pretrained_model)
else:
self.model = RobertaModel.from_pretrained(config.roberta_model)
self.model.resize_token_embeddings(len(self.tokenizer))
self.hidden_size = self.model.config.hidden_size
# self.linear = nn.Linear(self.hidden_size, 1)
self.linear = nn.Sequential(
nn.Linear(self.hidden_size, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
self.linear.apply(init_weights)
def forward(self, input_ids, attention_mask):
output, _ = self.model(input_ids, attention_mask)
cls_vector = output[:, 0, :]
scores = self.linear(cls_vector)
return scores