Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update debug mode for relation prompt #3263

Merged
merged 3 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions model_zoo/uie/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ python finetune.py \
--device gpu
```

多卡启动
如果在GPU环境中使用,可以指定``gpus``参数进行多卡训练

```shell
python -u -m paddle.distributed.launch --gpus "0,1" finetune.py \
Expand Down Expand Up @@ -701,18 +701,24 @@ python evaluate.py \
输出打印示例:

```text
[2022-06-23 08:25:23,017] [ INFO] - -----------------------------
[2022-06-23 08:25:23,017] [ INFO] - Class name: 时间
[2022-06-23 08:25:23,018] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000
[2022-06-23 08:25:23,145] [ INFO] - -----------------------------
[2022-06-23 08:25:23,146] [ INFO] - Class name: 目的地
[2022-06-23 08:25:23,146] [ INFO] - Evaluation precision: 0.64286 | recall: 0.90000 | F1: 0.75000
[2022-06-23 08:25:23,272] [ INFO] - -----------------------------
[2022-06-23 08:25:23,273] [ INFO] - Class name: 费用
[2022-06-23 08:25:23,273] [ INFO] - Evaluation precision: 0.11111 | recall: 0.10000 | F1: 0.10526
[2022-06-23 08:25:23,399] [ INFO] - -----------------------------
[2022-06-23 08:25:23,399] [ INFO] - Class name: 出发地
[2022-06-23 08:25:23,400] [ INFO] - Evaluation precision: 1.00000 | recall: 1.00000 | F1: 1.00000
[2022-09-14 03:13:58,877] [ INFO] - -----------------------------
[2022-09-14 03:13:58,877] [ INFO] - Class Name: 疾病
[2022-09-14 03:13:58,877] [ INFO] - Evaluation Precision: 0.89744 | Recall: 0.83333 | F1: 0.86420
[2022-09-14 03:13:59,145] [ INFO] - -----------------------------
[2022-09-14 03:13:59,145] [ INFO] - Class Name: 手术治疗
[2022-09-14 03:13:59,145] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
[2022-09-14 03:13:59,439] [ INFO] - -----------------------------
[2022-09-14 03:13:59,440] [ INFO] - Class Name: 检查
[2022-09-14 03:13:59,440] [ INFO] - Evaluation Precision: 0.77778 | Recall: 0.56757 | F1: 0.65625
[2022-09-14 03:13:59,708] [ INFO] - -----------------------------
[2022-09-14 03:13:59,709] [ INFO] - Class Name: X的手术治疗
[2022-09-14 03:13:59,709] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
[2022-09-14 03:13:59,893] [ INFO] - -----------------------------
[2022-09-14 03:13:59,893] [ INFO] - Class Name: X的实验室检查
[2022-09-14 03:13:59,894] [ INFO] - Evaluation Precision: 0.71429 | Recall: 0.55556 | F1: 0.62500
[2022-09-14 03:14:00,057] [ INFO] - -----------------------------
[2022-09-14 03:14:00,058] [ INFO] - Class Name: X的影像学检查
[2022-09-14 03:14:00,058] [ INFO] - Evaluation Precision: 0.69231 | Recall: 0.45000 | F1: 0.54545
```

可配置参数说明:
Expand Down
7 changes: 0 additions & 7 deletions model_zoo/uie/data_distill/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,6 @@ python train.py \
'text': '登革热'}]}]
```

## 效果验证

| 模型 | Entity-F1 | SPO-F1 |
| :---: | :--------: | :--------: |
| UIE-Finetune | 78.57 | 56.25 |
| GPLinker-ernie-3.0-mini-zh | 68.18 | 47.06 |
| GPLinker-ernie-3.0-mini-zh + UIE数据蒸馏 | 76.38 | 50.42 |

# References

Expand Down
2 changes: 1 addition & 1 deletion model_zoo/uie/data_distill/data_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def do_data_distill():
for text in tqdm(infer_texts, desc="Predicting: ", leave=False):
infer_results.extend(uie(text))

train_synthetic_lines = synthetic2distill(texts, infer_results,
train_synthetic_lines = synthetic2distill(infer_texts, infer_results,
args.task_type)

# Concat origin and synthetic data
Expand Down
44 changes: 33 additions & 11 deletions model_zoo/uie/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from paddlenlp.utils.log import logger

from model import UIE
from utils import convert_example, reader, unify_prompt_name
from utils import convert_example, reader, unify_prompt_name, get_relation_type_dict, create_data_loader


@paddle.no_grad()
Expand Down Expand Up @@ -60,28 +60,34 @@ def do_eval():
max_seq_len=args.max_seq_len,
lazy=False)
class_dict = {}
relation_data = []
if args.debug:
for data in test_ds:
class_name = unify_prompt_name(data['prompt'])
# Only positive examples are evaluated in debug mode
if len(data['result_list']) != 0:
class_dict.setdefault(class_name, []).append(data)
if "的" not in data['prompt']:
class_dict.setdefault(class_name, []).append(data)
else:
relation_data.append((data['prompt'], data))
relation_type_dict = get_relation_type_dict(relation_data)
else:
class_dict["all_classes"] = test_ds

trans_fn = partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len)

for key in class_dict.keys():
if args.debug:
test_ds = MapDataset(class_dict[key])
else:
test_ds = class_dict[key]
test_ds = test_ds.map(
partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len))
test_batch_sampler = paddle.io.BatchSampler(dataset=test_ds,
batch_size=args.batch_size,
shuffle=False)
test_data_loader = paddle.io.DataLoader(
dataset=test_ds, batch_sampler=test_batch_sampler, return_list=True)

test_data_loader = create_data_loader(test_ds,
mode="test",
batch_size=args.batch_size,
trans_fn=trans_fn)

metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader)
Expand All @@ -90,6 +96,22 @@ def do_eval():
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
(precision, recall, f1))

if args.debug and len(relation_type_dict.keys()) != 0:
for key in relation_type_dict.keys():
test_ds = MapDataset(relation_type_dict[key])

test_data_loader = create_data_loader(test_ds,
mode="test",
batch_size=args.batch_size,
trans_fn=trans_fn)

metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader)
logger.info("-----------------------------")
logger.info("Class Name: X的%s" % key)
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
(precision, recall, f1))


if __name__ == "__main__":
# yapf: disable
Expand Down
37 changes: 13 additions & 24 deletions model_zoo/uie/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from model import UIE
from evaluate import evaluate
from utils import set_seed, convert_example, reader, MODEL_MAP
from utils import set_seed, convert_example, reader, MODEL_MAP, create_data_loader


def do_train():
Expand Down Expand Up @@ -57,28 +57,18 @@ def do_train():
max_seq_len=args.max_seq_len,
lazy=False)

train_ds = train_ds.map(
partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len))
dev_ds = dev_ds.map(
partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len))

train_batch_sampler = paddle.io.BatchSampler(dataset=train_ds,
batch_size=args.batch_size,
shuffle=True)
train_data_loader = paddle.io.DataLoader(dataset=train_ds,
batch_sampler=train_batch_sampler,
return_list=True)

dev_batch_sampler = paddle.io.BatchSampler(dataset=dev_ds,
batch_size=args.batch_size,
shuffle=False)
dev_data_loader = paddle.io.DataLoader(dataset=dev_ds,
batch_sampler=dev_batch_sampler,
return_list=True)
trans_fn = partial(convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len)

train_data_loader = create_data_loader(train_ds,
mode="train",
batch_size=args.batch_size,
trans_fn=trans_fn)
dev_data_loader = create_data_loader(dev_ds,
mode="dev",
batch_size=args.batch_size,
trans_fn=trans_fn)

if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
state_dict = paddle.load(args.init_from_ckpt)
Expand All @@ -95,7 +85,6 @@ def do_train():

loss_list = []
global_step = 0
best_step = 0
best_f1 = 0
tic_train = time.time()
for epoch in range(1, args.num_epochs + 1):
Expand Down
112 changes: 94 additions & 18 deletions model_zoo/uie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,35 @@ def set_seed(seed):
np.random.seed(seed)


def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
"""
Create dataloader.
Args:
dataset(obj:`paddle.io.Dataset`): Dataset instance.
mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
Returns:
dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
"""
if trans_fn:
dataset = dataset.map(trans_fn)

shuffle = True if mode == 'train' else False
if mode == "train":
sampler = paddle.io.DistributedBatchSampler(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
else:
sampler = paddle.io.BatchSampler(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
dataloader = paddle.io.DataLoader(dataset,
batch_sampler=sampler,
return_list=True)
return dataloader


def convert_example(example, tokenizer, max_seq_len):
"""
example: {
Expand Down Expand Up @@ -267,6 +296,48 @@ def unify_prompt_name(prompt):
return prompt


def get_relation_type_dict(relation_data):

def compare(a, b):
a = a[::-1]
b = b[::-1]
res = ''
for i in range(min(len(a), len(b))):
if a[i] == b[i]:
res += a[i]
else:
break
if res == "":
return res
elif res[::-1][0] == "的":
return res[::-1][1:]
return ""

relation_type_dict = {}
added_list = []
for i in range(len(relation_data)):
added = False
if relation_data[i][0] not in added_list:
for j in range(i + 1, len(relation_data)):
match = compare(relation_data[i][0], relation_data[j][0])
if match != "":
match = unify_prompt_name(match)
if relation_data[i][0] not in added_list:
added_list.append(relation_data[i][0])
relation_type_dict.setdefault(match, []).append(
relation_data[i][1])
added_list.append(relation_data[j][0])
relation_type_dict.setdefault(match, []).append(
relation_data[j][1])
added = True
if not added:
added_list.append(relation_data[i][0])
suffix = relation_data[i][0].rsplit("的", 1)[1]
suffix = unify_prompt_name(suffix)
relation_type_dict[suffix] = relation_data[i][1]
return relation_type_dict


def add_entity_negative_example(examples, texts, prompts, label_set,
negative_ratio):
negative_examples = []
Expand Down Expand Up @@ -610,26 +681,31 @@ def _sep_cls_label(label, separator):
redundants1 = inverse_relation_list[i]

# 2. entity_name_set ^ subject_goldens[i]
nonentity_list = list(
set(entity_name_set) ^ set(subject_goldens[i]))
nonentity_list.sort()

redundants2 = [
nonentity + "的" + predicate_list[i][random.randrange(
len(predicate_list[i]))]
for nonentity in nonentity_list
]
redundants2 = []
if len(predicate_list[i]) != 0:
nonentity_list = list(
set(entity_name_set) ^ set(subject_goldens[i]))
nonentity_list.sort()

redundants2 = [
nonentity + "的" +
predicate_list[i][random.randrange(
len(predicate_list[i]))]
for nonentity in nonentity_list
]

# 3. entity_label_set ^ entity_prompts[i]
non_ent_label_list = list(
set(entity_label_set) ^ set(entity_prompts[i]))
non_ent_label_list.sort()

redundants3 = [
subject_goldens[i][random.randrange(
len(subject_goldens[i]))] + "的" + non_ent_label
for non_ent_label in non_ent_label_list
]
redundants3 = []
if len(subject_goldens[i]) != 0:
non_ent_label_list = list(
set(entity_label_set) ^ set(entity_prompts[i]))
non_ent_label_list.sort()

redundants3 = [
subject_goldens[i][random.randrange(
len(subject_goldens[i]))] + "的" + non_ent_label
for non_ent_label in non_ent_label_list
]

redundants_list = [redundants1, redundants2, redundants3]

Expand Down