Skip to content

Commit

Permalink
Update debug mode for relation prompt (#3263)
Browse files Browse the repository at this point in the history
* update debug mode for relation prompt

* update

* update
  • Loading branch information
linjieccc authored Sep 14, 2022
1 parent 135e9fa commit 87613d4
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 74 deletions.
32 changes: 19 additions & 13 deletions model_zoo/uie/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ python finetune.py \
--device gpu
```

多卡启动
如果在GPU环境中使用,可以指定``gpus``参数进行多卡训练

```shell
python -u -m paddle.distributed.launch --gpus "0,1" finetune.py \
Expand Down Expand Up @@ -701,18 +701,24 @@ python evaluate.py \
输出打印示例:

```text
[2022-06-23 08:25:23,017] [ INFO] - -----------------------------
[2022-06-23 08:25:23,017] [ INFO] - Class name: 时间
[2022-06-23 08:25:23,018] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000
[2022-06-23 08:25:23,145] [ INFO] - -----------------------------
[2022-06-23 08:25:23,146] [ INFO] - Class name: 目的地
[2022-06-23 08:25:23,146] [ INFO] - Evaluation precision: 0.64286 | recall: 0.90000 | F1: 0.75000
[2022-06-23 08:25:23,272] [ INFO] - -----------------------------
[2022-06-23 08:25:23,273] [ INFO] - Class name: 费用
[2022-06-23 08:25:23,273] [ INFO] - Evaluation precision: 0.11111 | recall: 0.10000 | F1: 0.10526
[2022-06-23 08:25:23,399] [ INFO] - -----------------------------
[2022-06-23 08:25:23,399] [ INFO] - Class name: 出发地
[2022-06-23 08:25:23,400] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000
[2022-09-14 03:13:58,877] [ INFO] - -----------------------------
[2022-09-14 03:13:58,877] [ INFO] - Class Name: 疾病
[2022-09-14 03:13:58,877] [ INFO] - Evaluation Precision: 0.89744 | Recall: 0.83333 | F1: 0.86420
[2022-09-14 03:13:59,145] [ INFO] - -----------------------------
[2022-09-14 03:13:59,145] [ INFO] - Class Name: 手术治疗
[2022-09-14 03:13:59,145] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
[2022-09-14 03:13:59,439] [ INFO] - -----------------------------
[2022-09-14 03:13:59,440] [ INFO] - Class Name: 检查
[2022-09-14 03:13:59,440] [ INFO] - Evaluation Precision: 0.77778 | Recall: 0.56757 | F1: 0.65625
[2022-09-14 03:13:59,708] [ INFO] - -----------------------------
[2022-09-14 03:13:59,709] [ INFO] - Class Name: X的手术治疗
[2022-09-14 03:13:59,709] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
[2022-09-14 03:13:59,893] [ INFO] - -----------------------------
[2022-09-14 03:13:59,893] [ INFO] - Class Name: X的实验室检查
[2022-09-14 03:13:59,894] [ INFO] - Evaluation Precision: 0.71429 | Recall: 0.55556 | F1: 0.62500
[2022-09-14 03:14:00,057] [ INFO] - -----------------------------
[2022-09-14 03:14:00,058] [ INFO] - Class Name: X的影像学检查
[2022-09-14 03:14:00,058] [ INFO] - Evaluation Precision: 0.69231 | Recall: 0.45000 | F1: 0.54545
```

可配置参数说明:
Expand Down
7 changes: 0 additions & 7 deletions model_zoo/uie/data_distill/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,6 @@ python train.py \
'text': '登革热'}]}]
```

## 效果验证

| 模型 | Entity-F1 | SPO-F1 |
| :---: | :--------: | :--------: |
| UIE-Finetune | 78.57 | 56.25 |
| GPLinker-ernie-3.0-mini-zh | 68.18 | 47.06 |
| GPLinker-ernie-3.0-mini-zh + UIE数据蒸馏 | 76.38 | 50.42 |

# References

Expand Down
2 changes: 1 addition & 1 deletion model_zoo/uie/data_distill/data_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def do_data_distill():
for text in tqdm(infer_texts, desc="Predicting: ", leave=False):
infer_results.extend(uie(text))

train_synthetic_lines = synthetic2distill(texts, infer_results,
train_synthetic_lines = synthetic2distill(infer_texts, infer_results,
args.task_type)

# Concat origin and synthetic data
Expand Down
44 changes: 33 additions & 11 deletions model_zoo/uie/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from paddlenlp.utils.log import logger

from model import UIE
from utils import convert_example, reader, unify_prompt_name
from utils import convert_example, reader, unify_prompt_name, get_relation_type_dict, create_data_loader


@paddle.no_grad()
Expand Down Expand Up @@ -60,28 +60,34 @@ def do_eval():
max_seq_len=args.max_seq_len,
lazy=False)
class_dict = {}
relation_data = []
if args.debug:
for data in test_ds:
class_name = unify_prompt_name(data['prompt'])
# Only positive examples are evaluated in debug mode
if len(data['result_list']) != 0:
class_dict.setdefault(class_name, []).append(data)
if "的" not in data['prompt']:
class_dict.setdefault(class_name, []).append(data)
else:
relation_data.append((data['prompt'], data))
relation_type_dict = get_relation_type_dict(relation_data)
else:
class_dict["all_classes"] = test_ds

trans_fn = partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len)

for key in class_dict.keys():
if args.debug:
test_ds = MapDataset(class_dict[key])
else:
test_ds = class_dict[key]
test_ds = test_ds.map(
partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len))
test_batch_sampler = paddle.io.BatchSampler(dataset=test_ds,
batch_size=args.batch_size,
shuffle=False)
test_data_loader = paddle.io.DataLoader(
dataset=test_ds, batch_sampler=test_batch_sampler, return_list=True)

test_data_loader = create_data_loader(test_ds,
mode="test",
batch_size=args.batch_size,
trans_fn=trans_fn)

metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader)
Expand All @@ -90,6 +96,22 @@ def do_eval():
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
(precision, recall, f1))

if args.debug and len(relation_type_dict.keys()) != 0:
for key in relation_type_dict.keys():
test_ds = MapDataset(relation_type_dict[key])

test_data_loader = create_data_loader(test_ds,
mode="test",
batch_size=args.batch_size,
trans_fn=trans_fn)

metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader)
logger.info("-----------------------------")
logger.info("Class Name: X的%s" % key)
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
(precision, recall, f1))


if __name__ == "__main__":
# yapf: disable
Expand Down
37 changes: 13 additions & 24 deletions model_zoo/uie/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from model import UIE
from evaluate import evaluate
from utils import set_seed, convert_example, reader, MODEL_MAP
from utils import set_seed, convert_example, reader, MODEL_MAP, create_data_loader


def do_train():
Expand Down Expand Up @@ -57,28 +57,18 @@ def do_train():
max_seq_len=args.max_seq_len,
lazy=False)

train_ds = train_ds.map(
partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len))
dev_ds = dev_ds.map(
partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len))

train_batch_sampler = paddle.io.BatchSampler(dataset=train_ds,
batch_size=args.batch_size,
shuffle=True)
train_data_loader = paddle.io.DataLoader(dataset=train_ds,
batch_sampler=train_batch_sampler,
return_list=True)

dev_batch_sampler = paddle.io.BatchSampler(dataset=dev_ds,
batch_size=args.batch_size,
shuffle=False)
dev_data_loader = paddle.io.DataLoader(dataset=dev_ds,
batch_sampler=dev_batch_sampler,
return_list=True)
trans_fn = partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len)

train_data_loader = create_data_loader(train_ds,
mode="train",
batch_size=args.batch_size,
trans_fn=trans_fn)
dev_data_loader = create_data_loader(dev_ds,
mode="dev",
batch_size=args.batch_size,
trans_fn=trans_fn)

if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
state_dict = paddle.load(args.init_from_ckpt)
Expand All @@ -95,7 +85,6 @@ def do_train():

loss_list = []
global_step = 0
best_step = 0
best_f1 = 0
tic_train = time.time()
for epoch in range(1, args.num_epochs + 1):
Expand Down
112 changes: 94 additions & 18 deletions model_zoo/uie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,35 @@ def set_seed(seed):
np.random.seed(seed)


def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
"""
Create dataloader.
Args:
dataset(obj:`paddle.io.Dataset`): Dataset instance.
mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
Returns:
dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
"""
if trans_fn:
dataset = dataset.map(trans_fn)

shuffle = True if mode == 'train' else False
if mode == "train":
sampler = paddle.io.DistributedBatchSampler(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
else:
sampler = paddle.io.BatchSampler(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
dataloader = paddle.io.DataLoader(dataset,
batch_sampler=sampler,
return_list=True)
return dataloader


def convert_example(example, tokenizer, max_seq_len):
"""
example: {
Expand Down Expand Up @@ -267,6 +296,48 @@ def unify_prompt_name(prompt):
return prompt


def get_relation_type_dict(relation_data):

def compare(a, b):
a = a[::-1]
b = b[::-1]
res = ''
for i in range(min(len(a), len(b))):
if a[i] == b[i]:
res += a[i]
else:
break
if res == "":
return res
elif res[::-1][0] == "的":
return res[::-1][1:]
return ""

relation_type_dict = {}
added_list = []
for i in range(len(relation_data)):
added = False
if relation_data[i][0] not in added_list:
for j in range(i + 1, len(relation_data)):
match = compare(relation_data[i][0], relation_data[j][0])
if match != "":
match = unify_prompt_name(match)
if relation_data[i][0] not in added_list:
added_list.append(relation_data[i][0])
relation_type_dict.setdefault(match, []).append(
relation_data[i][1])
added_list.append(relation_data[j][0])
relation_type_dict.setdefault(match, []).append(
relation_data[j][1])
added = True
if not added:
added_list.append(relation_data[i][0])
suffix = relation_data[i][0].rsplit("的", 1)[1]
suffix = unify_prompt_name(suffix)
relation_type_dict[suffix] = relation_data[i][1]
return relation_type_dict


def add_entity_negative_example(examples, texts, prompts, label_set,
negative_ratio):
negative_examples = []
Expand Down Expand Up @@ -610,26 +681,31 @@ def _sep_cls_label(label, separator):
redundants1 = inverse_relation_list[i]

# 2. entity_name_set ^ subject_goldens[i]
nonentity_list = list(
set(entity_name_set) ^ set(subject_goldens[i]))
nonentity_list.sort()

redundants2 = [
nonentity + "的" + predicate_list[i][random.randrange(
len(predicate_list[i]))]
for nonentity in nonentity_list
]
redundants2 = []
if len(predicate_list[i]) != 0:
nonentity_list = list(
set(entity_name_set) ^ set(subject_goldens[i]))
nonentity_list.sort()

redundants2 = [
nonentity + "的" +
predicate_list[i][random.randrange(
len(predicate_list[i]))]
for nonentity in nonentity_list
]

# 3. entity_label_set ^ entity_prompts[i]
non_ent_label_list = list(
set(entity_label_set) ^ set(entity_prompts[i]))
non_ent_label_list.sort()

redundants3 = [
subject_goldens[i][random.randrange(
len(subject_goldens[i]))] + "的" + non_ent_label
for non_ent_label in non_ent_label_list
]
redundants3 = []
if len(subject_goldens[i]) != 0:
non_ent_label_list = list(
set(entity_label_set) ^ set(entity_prompts[i]))
non_ent_label_list.sort()

redundants3 = [
subject_goldens[i][random.randrange(
len(subject_goldens[i]))] + "的" + non_ent_label
for non_ent_label in non_ent_label_list
]

redundants_list = [redundants1, redundants2, redundants3]

Expand Down

0 comments on commit 87613d4

Please sign in to comment.