-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_parametric_persona_blog_posts.py
445 lines (333 loc) · 18.1 KB
/
generate_parametric_persona_blog_posts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
""" This is the main module of the project. It contains the functions for generating
the blog posts from the parametric medical personas (specified persona_variables.py),
using the novel cascaded Generator-Summarizer architecture.
It also contains functions evaluating the generated blog posts using the novel Original Correctness metric (from metric.py).
NOTE: This version of the blog generation script's main purpose is the evaluation of the blog posts
using the Original Correctness metric. Please refer to the old_blog_generation_code/ directory
for versions of this module which allow users to customize blog generation parameters like
blog_length, condition, time_period, prompt_style and so on.
"""
import argparse
import matplotlib.pyplot as plt
import torch
from huggingface_hub import login
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
from transformers import T5Tokenizer, T5ForConditionalGeneration
from llama_summarizer import llama_summarizer
from mmr_summarizer import mmr_summarizer
from summarizer_module import bart_summarizer, flan_summarizer
from plausibility_metric import evaluate_blog_post_progression
from persona_variables import PERSONAS
starting_prompt = "Write a blog post"
activities = ["birthday", "gym day", "day out with pets", "football game", "beach day", "picnic in the park", "ski trip", "dog walking adventure in New York"]
months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
# HELPER FUNCTIONS
def get_fact_prompt(universal_fact_list: list[str], condition: str = "asthma", prompt_type: str = "consistent") -> str:
"""
Converts the list of universal facts to a part of the prompt which
will eventually be fed to the Llama model. Used in the generate_blogs() function
while creating the overall prompt to the system.
Args:
universal_fact_list: A list of strings unique to each persona, containing detailed facts pertaining
to their condition, medication, physical activity restrictions, allergens, and so on.
condition: The medical condition we want to generate blog posts for. Defaults to "asthma".
prompt_type: The instruction to the model, whether to be "consistent" with the universal facts or to
keep track about the "changes" to the facts over time. Defaults to "consistent".
Returns:
The "fact_prompt", a portion of the prompt, reminding the model of the universal facts regarding the persona's medical condition.
"""
if prompt_type == "consistent":
fp = "Be consistent with the following facts: "
elif prompt_type == "changes":
fp = "Write about how the following facts change over time:"
i = 1
for i, fact in enumerate(universal_fact_list, i):
fp = fp + str(i)+". "+ fact+ " "
return fp
def clean_output(prompt: str, text: str) -> str: # Removes the prompt from the generated text, so it only contains the blog!!
"""
Removes the prompt from Llama2's generated text, so it only contains the blog post.
Used in generate_blogs() to clean the LLM's output, so that it may be appended to the list
of blog posts.
Args:
prompt: The complete prompt fed to the Llama2 model prior to blog generation
text: The complete text generated by the Llama2 model, which includes the prompt
Returns:
The generated text without the prompt.
"""
reduced = text.split(maxsplit=1)[1] # Remove the Leading <s> token!
blog = reduced.removeprefix(prompt)
blog = blog.strip()
return blog
def get_summary_prompt(subset: list[str]):
"""
Returns the portion of the prompt eventually fed to Llama2,
which demarcates the summaries of the previous blog posts (subset).
Used in get_previous_n_summaries() to generate part of the prompt
Args:
subset: A list of the previous few summaries to be included in the prompt
Returns:
The "summary_prompt" a portion of the prompt indicating the past-few blog post summaries.
"""
prompt = "This is a summary of all your previous blog posts:"
i = 1
for i, summary in enumerate(subset, i):
if(summary != ""):
prompt = prompt + " " + summary
return prompt
def get_previous_n_summaries(summaries, turn, n):
"""
Returns the previous n (past-look-over parameter) summaries from the current turn.
Used in generate_blogs() to create the summary prompt.
Args:
summaries: A list of all the previous blog post summaries
turn: The index of the current turn
n: The number of previous summaries to return (essentially the past-look-over)
Returns:
The "summary_prompt" portion of the overall Llama2 prompt.
"""
# for turn in range(len(a)):
subset = []
if(turn < n):
for i in range(n - turn):
subset.append("")
subset = subset + summaries[1:turn+1]
else:
subset = summaries[turn-n+1:turn+1]
# summary_subsets.append(subset)
# print("PROMPT:")
# print(get_summary_prompt(subset))
return get_summary_prompt(subset)
# GENERATOR FUNCTIONS
def llama_generate(prompt, model, tokenizer, temperature = 0.8, max_blog_length=300):
"""
Prompts the Llama2 model to generate the blog post. Used in generate_blogs()
Args:
prompt: The complete prompt (including system_prompt, fact_prompt, summary_prompt, and task_prompt)
model: The pretrained AutoModelForCausalLM instance with the appropriate id.
tokenizer: The pretrained AutoTokenizer instance with the appropriate id
temperature (float, optional): The temperature parameter for Llama2. Defaults to 0.8.
max_blog_length (int, optional): The max number of tokens in the blog post generated. Defaults to 300.
Returns:
The text generated by the Llama2 model.
"""
DEV = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = tokenizer.encode(prompt, return_tensors="pt").to(DEV)
generate_kwargs = dict(
input_ids=inputs,
temperature=temperature,
top_p=1.0,
top_k=40,
max_new_tokens=max_blog_length,
repetition_penalty=1.3
)
outputs = model.generate(**generate_kwargs)
text = str(tokenizer.decode(outputs[0]))
return text
# SUMMARIZER FUNCTIONS
def get_summary(summarizer_type, model, tokenizer, article, max_blog_length, max_summary_length, style = "zero-shot"):
"""
An abstraction of the summarizer functions, allowing the use of any of the four summary models using a keyword argument.
Args:
summarizer_type: One of the four summarizer models ("bart"/"flant5"/"llama"/"mmr")
model: The huggingface summarizer model instance
tokenizer: The huggingface summarizer tokenizer instance
article: The blog post to be summarized
max_blog_length: The maximum blog post length
max_summary_length: The maximum summary length
style (str, optional): The prompting style ("zero-shot"/"one-shot"/"two-shot"), exclusively for the Llama2-based summarizer. Defaults to "zero-shot".
Returns:
The blog post summary
"""
if(summarizer_type == "bart"):
return bart_summarizer(model, tokenizer, article, max_blog_length, max_summary_length)
elif(summarizer_type == "flant5"):
return flan_summarizer(model, tokenizer, article, max_blog_length, max_summary_length)
elif(summarizer_type=="llama"):
return llama_summarizer(model, tokenizer, article, 0.8, 300, style)
elif(summarizer_type=="mmr"):
return mmr_summarizer(article, 30, 0.5)
# THE MAIN STAGE
def generate_blogs(generator_model, generator_tokenizer,
summarizer_type,
summarizer_model, summarizer_tokenizer,
persona_id,
max_summary_length, max_blog_length, condition = "asthma", past_look_over = 1, style = None,
prompt_type = "consistent", time_frame = "monthly"):
"""
Generates the blog posts and summaries in a cascaded fashion, using the universal facts list for a particular persona.
Args:
generator_model: An instance of the pretrained hugginface generator model (Llama2)
generator_tokenizer: An instance of the pretrained huggingface generator tokenizer (Llama2)
summarizer_type: One of the four summarizer models ("bart"/"flant5"/"llama"/"mmr")
summarizer_model: An instance of the pretrained huggingface summarizer model
summarizer_tokenizer: An instance of the pretrained huggingface summarizer tokenizer
persona_id (int): One of the 10 persona ids (ranging from 1-10, inclusive)
max_blog_length: The maximum blog post length
max_summary_length: The maximum summary length
condition (str, optional): The medical condition we want to generate blog posts for. Defaults to "asthma".
past_look_over (int, optional): The number of previous blog post summaries to include in the current turn's prompt. Defaults to 1.
style (str, optional): The prompting style ("zero-shot"/"one-shot"/"two-shot"), exclusively for the Llama2-based summarizer. Defaults to "zero-shot".
prompt_type (str, optional): The instruction to the model, whether to be "consistent" with the universal facts or to
keep track about the "changes" to the facts over time. Defaults to "consistent".
time_frame (str, optional): The time-frame over which the model generates blog posts ("daily"/"weekly"/"monthly"). Defaults to "monthly".
Returns:
A list of blog posts generated subject to the above conditions and a list of the corresponding blog post summaries.
"""
global PERSONAS
global months
# System Prompts (SP), Universal Facts List (UFL), Gold Question-Answer Lists (GQAL)
SP, UFL, GQAL = PERSONAS
system_prompt = SP[persona_id]
universal_fact_list = UFL[persona_id]
blog_posts = []
summaries = []
print(system_prompt)
sum_prompt = "This is a summary of all your previous blog posts: "
for turn in range(len(months)): #fixed
fact_prompt = get_fact_prompt(universal_fact_list=universal_fact_list, condition=condition, prompt_type=prompt_type)
summary = ""
# summary_prompt = ""
if(turn == 0):
summary_prompt = ""
summaries.append("")
else:
summary = get_summary(summarizer_type= summarizer_type,
model= summarizer_model, tokenizer= summarizer_tokenizer,
article= blog_posts[turn-1],
max_blog_length= max_blog_length, max_summary_length= max_summary_length)
summaries.append(summary)
summary_prompt = get_previous_n_summaries(summaries=summaries,
turn=turn,
n=past_look_over)
# assert turn < len(activities), "ERROR!!"
if time_frame == "daily":
time_prompt = "Now write a blog post for Day " + str(turn)
elif time_frame == "monthly":
time_prompt = "Now write a blog post for the month of " + months[turn]
elif time_frame == "weekly":
time_prompt = "Now write a blog post for Week " + str(turn)
prompt = system_prompt + fact_prompt + summary_prompt + time_prompt + " Blog Post:"
text = llama_generate(prompt=prompt,
model=generator_model,
tokenizer=generator_tokenizer,
temperature=0.8,
max_blog_length= max_blog_length)
blog_i = clean_output(prompt, text)
blog_i = text.split("Blog Post:")[-1]
blog_posts.append(blog_i)
print("_______________________________________________________________________________________________________________________________________")
print()
print()
print("SUMMARY", turn, ": ")
print(summary)
print()
print("Turn. ", turn)
if(time_frame == "monthly"):
print("Month: ", months[turn])
elif(time_frame == "daily"):
print("Day: ", turn)
elif(time_frame == "weekly"):
print("Week: ", turn)
print("PROMPT: ")
print(prompt)
print()
print("BLOG ", turn, ": ")
print(blog_i)
print()
print()
print("_______________________________________________________________________________________________________________________________________")
return blog_posts, summaries
def main():
parser = argparse.ArgumentParser()
parser.add_argument("summarizer_type", type=str, help='bart or flant5')
parser.add_argument("persona_id", type=int, help='1 to 10 (Identifier assigned to each persona!)')
# Arguments for previous versions of the code that allowed for customization
# of condition, blog length, past_look_over, prompt_type etc. (These are still
# supported in the scripts in the old_blog_generation_code directory)
# parser.add_argument("condition", type=str, help='chronic/acute condition')
# parser.add_argument("blog_length", type=int, help='max length of blog to be generated')
# parser.add_argument("summary_length", type=int, help='max length of summary to be generated')
# parser.add_argument("past_look_over", type=int, help='number of past summaries to look at!!')
# parser.add_argument("style", default = "zero-shot", type=str, help='ICL prompt style!!')
# parser.add_argument("prompt_type", type=str, help='Be Consistent/Track Changes')
# parser.add_argument("time_frame", type=str, help='Daily/Weekly/Monthly')
args = parser.parse_args()
summarizer_type = args.summarizer_type
persona_id = args.persona_id
# condition = args.condition
# print(condition)
# max_blog_length = args.blog_length
# max_summary_length = args.summary_length
# past_look_over = args.past_look_over
# style = args.style
# prompt_type = args.prompt_type
# time_frame = args.time_frame
# Loading the Generator Model
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
cache_path = "/data/shire/data/aaditd/trial/"
DEV = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator_model_name = "meta-llama/Llama-2-7b-chat-hf"
login("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")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
generator_model = AutoModelForCausalLM.from_pretrained(generator_model_name,
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
cache_dir=cache_path)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name, cache_dir=cache_path)
summarizer_model_name = ""
# Loading the summarizer model
if(summarizer_type == "bart"):
summarizer_model_name = "facebook/bart-large-cnn"
summarizer_tokenizer = BartTokenizer.from_pretrained(summarizer_model_name, cache_dir=cache_path)
summarizer_model = BartForConditionalGeneration.from_pretrained(summarizer_model_name, cache_dir=cache_path)
elif(summarizer_type == "flant5"):
summarizer_model_name = "jordiclive/flan-t5-3b-summarizer"
summarizer_tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name, cache_dir="/data/shire/data/aaditd/trial/")
kwargs = dict(device_map="auto", torch_dtype=torch.bfloat16)
target_length = 150
max_source_length = 512
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_model_name, **kwargs, cache_dir="/data/shire/data/aaditd/trial/")
elif(summarizer_type == "llama"):
summarizer_model = generator_model
summarizer_tokenizer = generator_tokenizer
elif(summarizer_type == "mmr"):
summarizer_tokenizer, summarizer_model = 12, 34
blog_posts, summaries = generate_blogs(generator_model, generator_tokenizer,
summarizer_type,
summarizer_model, summarizer_tokenizer,
persona_id,
max_summary_length = 100, max_blog_length = 300, condition="asthma", past_look_over=2, style="zero-shot",
prompt_type="changes", time_frame="monthly")
cache_path = "/data/shire/data/aaditd/trial/"
qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large", cache_dir= cache_path)
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto", cache_dir= cache_path)
QA_Scores = []
QA_Verbose = []
SP, UFL, GQAL = PERSONAS
gold_question_answers = GQAL[persona_id]
metrics_file_name = f"Metrics/persona_{persona_id}.txt"
plot_title_string = f"Asthma Monthly Persona {persona_id}"
# Evaluation of the Blog posts!!
evaluate_blog_post_progression(qa_model = qa_model, qa_tokenizer = qa_tokenizer,
passages = blog_posts,
gold_question_answers= gold_question_answers,
metrics_file=metrics_file_name,
title_string = plot_title_string)
print("GENERATOR MODEL USED: ", generator_model_name)
print("SUMMARIZER MODEL USED: ", summarizer_model_name)
print("MAX BLOG LENGTH: ", 300)
print("MAX SUMMARY LENGTH: ", 100)
print("PAST LOOK OVER: ", 2)
print("*********************************************************************************************")
print(f"SUMMARIES: {str(summaries)}")
print("*********************************************************************************************")
print("DONE GURL!!")
if __name__ == "__main__":
main()