-
Notifications
You must be signed in to change notification settings - Fork 0
/
summarizer_module.py
76 lines (62 loc) · 2.7 KB
/
summarizer_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
""" This module contains the summarizer functions using the bart-large-cnn and flan-T5-xl models.
These functions are used in generate_parametric_persona_blog_posts.py
"""
import torch
def bart_summarizer(model, tokenizer, article, max_blog_length, max_summary_length):
"""
Returns the bart-large-cnn model's summary of the article provided. Used in generate_blogs()
Args:
model: An instance of the pretrained BartForConditionalGeneration with the appropriate model id.
tokenizer: An instance of the pretrained BartTokenizer with the appropriate model id
article: The blog post to be summarized
max_blog_length: The maximum blog post length
max_summary_length: The maximum summary length
Returns:
The blog post summary.
"""
# For now, this is meant for ONE SINGLE ARTICLE (But works for multiple too!)
inputs = tokenizer.batch_encode_plus(
[article],
max_length=max_blog_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length= max_summary_length, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return str(summary)
def flan_summarizer(model, tokenizer, article, max_blog_length, max_summary_length): # For now, this is meant for ONE SINGLE ARTICLE (But works for multiple too!)
"""
Returns the flan-t5-xl model's summary of the article provided. Used in generate_blogs()
Args:
model: An instance of the pretrained AutoModelForSeq2SeqLM with the appropriate model id.
tokenizer: An instance of the pretrained AutoTokenizer with the appropriate model id
article: The blog post to be summarized
max_blog_length: The maximum blog post length
max_summary_length: The maximum summary length
Returns:
The blog post summary.
"""
inputs = [article]
prompt_string = "Produce an article summary of the following blog post:"
inputs = [f"{prompt_string.strip()} {i.strip()}" for i in inputs]
input_tokens = tokenizer.batch_encode_plus(
inputs,
max_length=max_blog_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to("cuda:0")
outputs = model.generate(
**input_tokens,
use_cache=True,
num_beams=5,
min_length=5,
max_new_tokens=max_summary_length,
no_repeat_ngram_size=3,
)
summary = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return str(summary[0])