diff --git a/paddlenlp/transformers/layoutxlm/modeling.py b/paddlenlp/transformers/layoutxlm/modeling.py index 3483ef575428..707e725a18c3 100644 --- a/paddlenlp/transformers/layoutxlm/modeling.py +++ b/paddlenlp/transformers/layoutxlm/modeling.py @@ -1310,36 +1310,75 @@ def __init__(self, hidden_size=768, hidden_dropout_prob=0.1): self.loss_fct = CrossEntropyLoss() def build_relation(self, relations, entities): - batch_size = len(relations) - new_relations = [] + batch_size, max_seq_len = paddle.shape(entities)[:2] + new_relations = paddle.full( + shape=[batch_size, max_seq_len * max_seq_len, 3], + fill_value=-1, + dtype=relations.dtype) for b in range(batch_size): - if len(entities[b]["start"]) <= 2: - entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]} - all_possible_relations = set([ - (i, j) for i in range(len(entities[b]["label"])) - for j in range(len(entities[b]["label"])) - if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2 - ]) + if entities[b, 0, 0] <= 2: + entitie_new = paddle.full(shape=[512, 3], + fill_value=-1, + dtype=entities.dtype) + entitie_new[0, :] = 2 + entitie_new[1:3, 0] = 0 # start + entitie_new[1:3, 1] = 1 # end + entitie_new[1:3, 2] = 0 # label + entities[b] = entitie_new + entitie_label = entities[b, 1:entities[b, 0, 2] + 1, 2] + all_possible_relations1 = paddle.arange(0, + entities[b, 0, 2], + dtype=entities.dtype) + all_possible_relations1 = all_possible_relations1[entitie_label == + 1] + all_possible_relations2 = paddle.arange(0, + entities[b, 0, 2], + dtype=entities.dtype) + all_possible_relations2 = all_possible_relations2[entitie_label == + 2] + + all_possible_relations = paddle.stack(paddle.meshgrid( + all_possible_relations1, all_possible_relations2), + axis=2).reshape([-1, 2]) if len(all_possible_relations) == 0: - all_possible_relations = {(0, 1)} - if "head" in relations[b]: - positive_relations = set( - list(zip(relations[b]["head"], relations[b]["tail"]))) + all_possible_relations = paddle.full_like( + all_possible_relations, fill_value=-1, dtype=entities.dtype) + all_possible_relations[0, 0] = 0 + all_possible_relations[0, 1] = 1 + + relation_head = relations[b, 1:relations[b, 0, 0] + 1, 0] + relation_tail = relations[b, 1:relations[b, 0, 1] + 1, 1] + positive_relations = paddle.stack([relation_head, relation_tail], + axis=1) + + all_possible_relations_repeat = all_possible_relations.unsqueeze( + axis=1).tile([1, len(positive_relations), 1]) + positive_relations_repeat = positive_relations.unsqueeze( + axis=0).tile([len(all_possible_relations), 1, 1]) + mask = paddle.all( + all_possible_relations_repeat == positive_relations_repeat, + axis=2) + negative_mask = paddle.any(mask, axis=1) == False + negative_relations = all_possible_relations[negative_mask] + + positive_mask = paddle.any(mask, axis=0) == True + positive_relations = positive_relations[positive_mask] + if negative_mask.sum() > 0: + reordered_relations = paddle.concat( + [positive_relations, negative_relations]) else: - positive_relations = set() - negative_relations = all_possible_relations - positive_relations - positive_relations = set( - [i for i in positive_relations if i in all_possible_relations]) - reordered_relations = list(positive_relations) + list( - negative_relations) - relation_per_doc = { - "head": [i[0] for i in reordered_relations], - "tail": [i[1] for i in reordered_relations], - "label": [1] * len(positive_relations) + [0] * - (len(reordered_relations) - len(positive_relations)) - } - assert len(relation_per_doc["head"]) != 0 - new_relations.append(relation_per_doc) + reordered_relations = positive_relations + + relation_per_doc_label = paddle.zeros( + [len(reordered_relations), 1], dtype=reordered_relations.dtype) + relation_per_doc_label[:len(positive_relations)] = 1 + relation_per_doc = paddle.concat( + [reordered_relations, relation_per_doc_label], axis=1) + assert len(relation_per_doc[:, 0]) != 0 + new_relations[b, 0] = paddle.shape(relation_per_doc)[0].astype( + new_relations.dtype) + new_relations[b, 1:len(relation_per_doc) + 1] = relation_per_doc + # new_relations.append(relation_per_doc) return new_relations, entities def get_predicted_relations(self, logits, relations, entities): @@ -1347,34 +1386,39 @@ def get_predicted_relations(self, logits, relations, entities): for i, pred_label in enumerate(logits.argmax(-1)): if pred_label != 1: continue - rel = {} - rel["head_id"] = relations["head"][i] - rel["head"] = (entities["start"][rel["head_id"]], - entities["end"][rel["head_id"]]) - rel["head_type"] = entities["label"][rel["head_id"]] - - rel["tail_id"] = relations["tail"][i] - rel["tail"] = (entities["start"][rel["tail_id"]], - entities["end"][rel["tail_id"]]) - rel["tail_type"] = entities["label"][rel["tail_id"]] - rel["type"] = 1 + rel = paddle.full(shape=[7, 2], + fill_value=-1, + dtype=relations.dtype) + rel[0, 0] = relations[:, 0][i] + rel[1, 0] = entities[:, 0][relations[:, 0][i] + 1] + rel[1, 1] = entities[:, 1][relations[:, 0][i] + 1] + rel[2, 0] = entities[:, 2][relations[:, 0][i] + 1] + rel[3, 0] = relations[:, 1][i] + rel[4, 0] = entities[:, 0][relations[:, 1][i] + 1] + rel[4, 1] = entities[:, 1][relations[:, 1][i] + 1] + rel[5, 0] = entities[:, 2][relations[:, 1][i] + 1] + rel[6, 0] = 1 pred_relations.append(rel) return pred_relations def forward(self, hidden_states, entities, relations): - batch_size, max_n_words, context_dim = hidden_states.shape + batch_size, max_length, _ = paddle.shape(entities) relations, entities = self.build_relation(relations, entities) loss = 0 - all_pred_relations = [] + all_pred_relations = paddle.full( + shape=[batch_size, max_length * max_length, 7, 2], + fill_value=-1, + dtype=entities.dtype) for b in range(batch_size): - if "head" not in relations[b]: - continue - head_entities = paddle.to_tensor(relations[b]["head"]) - tail_entities = paddle.to_tensor(relations[b]["tail"]) - relation_labels = paddle.to_tensor(relations[b]["label"], - dtype='int64') - entities_start_index = paddle.to_tensor(entities[b]["start"]) - entities_labels = paddle.to_tensor(entities[b]["label"]) + relation = relations[b, 1:relations[b, 0, 0] + 1] + head_entities = relation[:, 0] + tail_entities = relation[:, 1] + relation_labels = relation[:, 2] + entities_start_index = paddle.to_tensor( + entities[b, 1:entities[b, 0, 0] + 1, 0]) + entities_labels = paddle.to_tensor(entities[b, + 1:entities[b, 0, 2] + 1, + 2]) head_index = entities_start_index[head_entities] head_label = entities_labels[head_entities] head_label_repr = self.entity_emb(head_label) @@ -1400,8 +1444,13 @@ def forward(self, hidden_states, entities, relations): logits = self.rel_classifier(heads, tails) loss += self.loss_fct(logits, relation_labels) pred_relations = self.get_predicted_relations( - logits, relations[b], entities[b]) - all_pred_relations.append(pred_relations) + logits, relation, entities[b]) + if len(pred_relations) > 0: + pred_relations = paddle.stack(pred_relations) + all_pred_relations[b, 0, :, :] = paddle.shape( + pred_relations)[0].astype(all_pred_relations.dtype) + all_pred_relations[b, 1:len(pred_relations) + + 1, :, :] = pred_relations return loss, all_pred_relations @@ -1464,14 +1513,14 @@ def forward( self, input_ids, bbox, - labels=None, image=None, attention_mask=None, + entities=None, + relations=None, token_type_ids=None, position_ids=None, head_mask=None, - entities=None, - relations=None, + labels=None, ): outputs = self.layoutxlm( input_ids=input_ids, @@ -1482,23 +1531,20 @@ def forward( position_ids=position_ids, head_mask=head_mask, ) - seq_length = input_ids.shape[1] sequence_output, image_output = outputs[0][:, :seq_length], outputs[ 0][:, seq_length:] + sequence_output = self.dropout(sequence_output) loss, pred_relations = self.extractor(sequence_output, entities, relations) - - hidden_states = { - f"hidden_states_{idx}": outputs[2][f"{idx}_data"] + hidden_states = [ + outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config["num_hidden_layers"]) - } - res = dict( - loss=loss, - entities=entities, - relations=relations, - pred_relations=pred_relations, - ) - res.update(hidden_states) + ] + hidden_states = paddle.stack(hidden_states, axis=1) + + res = dict(loss=loss, + pred_relations=pred_relations, + hidden_states=hidden_states) return res