Skip to content

Commit

Permalink
support layoutxlm re dygraph to static (#3325)
Browse files Browse the repository at this point in the history
* support layoutxlm re dygraph to static

* fix error
  • Loading branch information
WenmuZhou authored Sep 22, 2022
1 parent ad6fe24 commit b525401
Showing 1 changed file with 111 additions and 65 deletions.
176 changes: 111 additions & 65 deletions paddlenlp/transformers/layoutxlm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,71 +1310,115 @@ def __init__(self, hidden_size=768, hidden_dropout_prob=0.1):
self.loss_fct = CrossEntropyLoss()

def build_relation(self, relations, entities):
batch_size = len(relations)
new_relations = []
batch_size, max_seq_len = paddle.shape(entities)[:2]
new_relations = paddle.full(
shape=[batch_size, max_seq_len * max_seq_len, 3],
fill_value=-1,
dtype=relations.dtype)
for b in range(batch_size):
if len(entities[b]["start"]) <= 2:
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
all_possible_relations = set([
(i, j) for i in range(len(entities[b]["label"]))
for j in range(len(entities[b]["label"]))
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
])
if entities[b, 0, 0] <= 2:
entitie_new = paddle.full(shape=[512, 3],
fill_value=-1,
dtype=entities.dtype)
entitie_new[0, :] = 2
entitie_new[1:3, 0] = 0 # start
entitie_new[1:3, 1] = 1 # end
entitie_new[1:3, 2] = 0 # label
entities[b] = entitie_new
entitie_label = entities[b, 1:entities[b, 0, 2] + 1, 2]
all_possible_relations1 = paddle.arange(0,
entities[b, 0, 2],
dtype=entities.dtype)
all_possible_relations1 = all_possible_relations1[entitie_label ==
1]
all_possible_relations2 = paddle.arange(0,
entities[b, 0, 2],
dtype=entities.dtype)
all_possible_relations2 = all_possible_relations2[entitie_label ==
2]

all_possible_relations = paddle.stack(paddle.meshgrid(
all_possible_relations1, all_possible_relations2),
axis=2).reshape([-1, 2])
if len(all_possible_relations) == 0:
all_possible_relations = {(0, 1)}
if "head" in relations[b]:
positive_relations = set(
list(zip(relations[b]["head"], relations[b]["tail"])))
all_possible_relations = paddle.full_like(
all_possible_relations, fill_value=-1, dtype=entities.dtype)
all_possible_relations[0, 0] = 0
all_possible_relations[0, 1] = 1

relation_head = relations[b, 1:relations[b, 0, 0] + 1, 0]
relation_tail = relations[b, 1:relations[b, 0, 1] + 1, 1]
positive_relations = paddle.stack([relation_head, relation_tail],
axis=1)

all_possible_relations_repeat = all_possible_relations.unsqueeze(
axis=1).tile([1, len(positive_relations), 1])
positive_relations_repeat = positive_relations.unsqueeze(
axis=0).tile([len(all_possible_relations), 1, 1])
mask = paddle.all(
all_possible_relations_repeat == positive_relations_repeat,
axis=2)
negative_mask = paddle.any(mask, axis=1) == False
negative_relations = all_possible_relations[negative_mask]

positive_mask = paddle.any(mask, axis=0) == True
positive_relations = positive_relations[positive_mask]
if negative_mask.sum() > 0:
reordered_relations = paddle.concat(
[positive_relations, negative_relations])
else:
positive_relations = set()
negative_relations = all_possible_relations - positive_relations
positive_relations = set(
[i for i in positive_relations if i in all_possible_relations])
reordered_relations = list(positive_relations) + list(
negative_relations)
relation_per_doc = {
"head": [i[0] for i in reordered_relations],
"tail": [i[1] for i in reordered_relations],
"label": [1] * len(positive_relations) + [0] *
(len(reordered_relations) - len(positive_relations))
}
assert len(relation_per_doc["head"]) != 0
new_relations.append(relation_per_doc)
reordered_relations = positive_relations

relation_per_doc_label = paddle.zeros(
[len(reordered_relations), 1], dtype=reordered_relations.dtype)
relation_per_doc_label[:len(positive_relations)] = 1
relation_per_doc = paddle.concat(
[reordered_relations, relation_per_doc_label], axis=1)
assert len(relation_per_doc[:, 0]) != 0
new_relations[b, 0] = paddle.shape(relation_per_doc)[0].astype(
new_relations.dtype)
new_relations[b, 1:len(relation_per_doc) + 1] = relation_per_doc
# new_relations.append(relation_per_doc)
return new_relations, entities

def get_predicted_relations(self, logits, relations, entities):
pred_relations = []
for i, pred_label in enumerate(logits.argmax(-1)):
if pred_label != 1:
continue
rel = {}
rel["head_id"] = relations["head"][i]
rel["head"] = (entities["start"][rel["head_id"]],
entities["end"][rel["head_id"]])
rel["head_type"] = entities["label"][rel["head_id"]]

rel["tail_id"] = relations["tail"][i]
rel["tail"] = (entities["start"][rel["tail_id"]],
entities["end"][rel["tail_id"]])
rel["tail_type"] = entities["label"][rel["tail_id"]]
rel["type"] = 1
rel = paddle.full(shape=[7, 2],
fill_value=-1,
dtype=relations.dtype)
rel[0, 0] = relations[:, 0][i]
rel[1, 0] = entities[:, 0][relations[:, 0][i] + 1]
rel[1, 1] = entities[:, 1][relations[:, 0][i] + 1]
rel[2, 0] = entities[:, 2][relations[:, 0][i] + 1]
rel[3, 0] = relations[:, 1][i]
rel[4, 0] = entities[:, 0][relations[:, 1][i] + 1]
rel[4, 1] = entities[:, 1][relations[:, 1][i] + 1]
rel[5, 0] = entities[:, 2][relations[:, 1][i] + 1]
rel[6, 0] = 1
pred_relations.append(rel)
return pred_relations

def forward(self, hidden_states, entities, relations):
batch_size, max_n_words, context_dim = hidden_states.shape
batch_size, max_length, _ = paddle.shape(entities)
relations, entities = self.build_relation(relations, entities)
loss = 0
all_pred_relations = []
all_pred_relations = paddle.full(
shape=[batch_size, max_length * max_length, 7, 2],
fill_value=-1,
dtype=entities.dtype)
for b in range(batch_size):
if "head" not in relations[b]:
continue
head_entities = paddle.to_tensor(relations[b]["head"])
tail_entities = paddle.to_tensor(relations[b]["tail"])
relation_labels = paddle.to_tensor(relations[b]["label"],
dtype='int64')
entities_start_index = paddle.to_tensor(entities[b]["start"])
entities_labels = paddle.to_tensor(entities[b]["label"])
relation = relations[b, 1:relations[b, 0, 0] + 1]
head_entities = relation[:, 0]
tail_entities = relation[:, 1]
relation_labels = relation[:, 2]
entities_start_index = paddle.to_tensor(
entities[b, 1:entities[b, 0, 0] + 1, 0])
entities_labels = paddle.to_tensor(entities[b,
1:entities[b, 0, 2] + 1,
2])
head_index = entities_start_index[head_entities]
head_label = entities_labels[head_entities]
head_label_repr = self.entity_emb(head_label)
Expand All @@ -1400,8 +1444,13 @@ def forward(self, hidden_states, entities, relations):
logits = self.rel_classifier(heads, tails)
loss += self.loss_fct(logits, relation_labels)
pred_relations = self.get_predicted_relations(
logits, relations[b], entities[b])
all_pred_relations.append(pred_relations)
logits, relation, entities[b])
if len(pred_relations) > 0:
pred_relations = paddle.stack(pred_relations)
all_pred_relations[b, 0, :, :] = paddle.shape(
pred_relations)[0].astype(all_pred_relations.dtype)
all_pred_relations[b, 1:len(pred_relations) +
1, :, :] = pred_relations
return loss, all_pred_relations


Expand Down Expand Up @@ -1464,14 +1513,14 @@ def forward(
self,
input_ids,
bbox,
labels=None,
image=None,
attention_mask=None,
entities=None,
relations=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
entities=None,
relations=None,
labels=None,
):
outputs = self.layoutxlm(
input_ids=input_ids,
Expand All @@ -1482,23 +1531,20 @@ def forward(
position_ids=position_ids,
head_mask=head_mask,
)

seq_length = input_ids.shape[1]
sequence_output, image_output = outputs[0][:, :seq_length], outputs[
0][:, seq_length:]

sequence_output = self.dropout(sequence_output)
loss, pred_relations = self.extractor(sequence_output, entities,
relations)

hidden_states = {
f"hidden_states_{idx}": outputs[2][f"{idx}_data"]
hidden_states = [
outputs[2][f"{idx}_data"]
for idx in range(self.layoutxlm.config["num_hidden_layers"])
}
res = dict(
loss=loss,
entities=entities,
relations=relations,
pred_relations=pred_relations,
)
res.update(hidden_states)
]
hidden_states = paddle.stack(hidden_states, axis=1)

res = dict(loss=loss,
pred_relations=pred_relations,
hidden_states=hidden_states)
return res

0 comments on commit b525401

Please sign in to comment.