Skip to content

Commit

Permalink
OpAgg eval tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhosking committed Feb 13, 2024
1 parent 7a4ecf2 commit e464ea9
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 26 deletions.
78 changes: 69 additions & 9 deletions torchseq/eval/recipes/opagg/hero_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torchseq.eval.recipes import EvalRecipe
from torchseq.utils.rouge import get_jackknife_rouge

from nltk.tokenize import sent_tokenize, word_tokenize


class Recipe(EvalRecipe):
name: str = "opagg.twostage_post"
Expand All @@ -16,7 +18,8 @@ def run(
self,
predicted_summaries: Optional[list[str]] = None,
prev_model: Literal["vitc", "vitc-base", "mnli", "mnli-base"] = "vitc",
variant: Literal["oneshot", "sentencewise", "extractive"] = "extractive",
variant: Literal["oneshot", "piecewise", "extractive"] = "extractive",
llm_name: str = "llama7b",
silent: bool = False,
) -> dict[str, Any]:
result = {}
Expand All @@ -33,24 +36,32 @@ def run(
# TODO: label each output with its origin so that we don't have to do this
if variant == "extractive":
predicted_summaries = extractive_summaries["extractive_summaries"]
output_name = f"extractive_{self.split_str}"
else:
# Load the LLM outputs
with jsonlines.open(
os.path.join(self.model_path, "eval", f"llm_outputs_{variant}_{self.split_str}.jsonl")
os.path.join(self.model_path, "eval", f"llm_outputs_{variant}_{self.split_str}_{llm_name}.jsonl")
) as reader:
llm_outputs = list(reader)

output_name = f"{variant}_{self.split_str}_{llm_name}"

if variant == "oneshot":
predicted_summaries = [resp["response"] for resp in llm_outputs]
elif variant == "sentencewise":
predicted_summaries = [self.cleanup_llm_output(resp["response"]) for resp in llm_outputs]

elif variant == "piecewise":
i = 0
predicted_summaries = []
for clusters in extractive_summaries["evidence"]:
curr_summ = []
for cluster in clusters:
curr_summ.append(llm_outputs[i]["response"].strip())
sent = self.cleanup_llm_output(llm_outputs[i]["response"])
curr_summ.append(sent)
i += 1
predicted_summaries.append(" ".join(curr_summ))

with open(os.path.join(self.model_path, "eval", f"hero_{output_name}.txt"), "w") as f:
f.writelines([summ + "\n" for summ in predicted_summaries])
else:
# Allow this recipe to be used for external systems (ie baselines)
if not silent:
Expand All @@ -69,11 +80,14 @@ def run(

# Score the summaries

result["word_count"] = np.mean([len(word_tokenize(summ)) for summ in predicted_summaries])
result["sent_count"] = np.mean([len(sent_tokenize(summ)) for summ in predicted_summaries])

# Rouge
result["rouge"] = get_jackknife_rouge(predicted_summaries, [row["summaries"] for row in eval_data])

# SummaC
from summac.model_summac import SummaCConv
# from summac.model_summac import SummaCConv

# model_conv = SummaCConv(
# models=["vitc"],
Expand All @@ -85,6 +99,11 @@ def run(
# agg="mean",
# )

# for imager in model_conv.imagers:
# imager.cache_folder = os.path.expanduser("~/.summac_cache/")
# os.makedirs(imager.cache_folder, exist_ok=True)
# imager.load_cache()

# print("Evaling SC_ins")
# docs = [" ".join([" ".join(sent for sent in rev["sentences"]) for rev in ent["reviews"]]) for ent in eval_data]
# res = model_conv.score(docs, predicted_summaries)
Expand All @@ -95,16 +114,40 @@ def run(
# res = model_conv.score(docs, predicted_summaries)
# result["sc_refs"] = np.mean(res["scores"]) * 100

# if variant == 'piecewise':
# if not silent:
# print("Evaling attribution")
# docs = [" ".join(cluster) for clusters in extractive_summaries["evidence"] for cluster in clusters]
# preds = [pred['response'] for pred in llm_outputs]
# assert len(docs) == len(preds)
# res = model_conv.score(docs, preds)
# result["sc_attr"] = np.mean(res["scores"]) * 100
# elif variant == 'oneshot':
# if not silent:
# print("Evaling attribution")
# docs = [" ".join([sent for cluster in clusters for sent in cluster]) for clusters in extractive_summaries["evidence"]]
# preds = predicted_summaries
# assert len(docs) == len(preds)
# res = model_conv.score(docs, preds)
# result["sc_attr"] = np.mean(res["scores"]) * 100
# else:
# result['sc_attr'] = None

# for imager in model_conv.imagers:
# imager.save_cache()

# # Prevalence
if not silent:
print("Evaling prevalence")
from torchseq.metric_hooks.prevalence_metric import PrevalenceMetric

review_limit = 500
review_limit = 200

prevmet = PrevalenceMetric(model_name=prev_model)
prevmet = PrevalenceMetric(
model_name=prev_model, cache_name=("space-" if "space" in dataset_eval else "amasum-") + self.split_str
)
adjusted_prevalence, (prevs, reds, trivs, gens), _ = prevmet.get_prevalence(
[[" ".join(rev["sentences"]) for rev in row["reviews"][-review_limit:]] for row in eval_data],
[[" ".join(rev["sentences"]) for rev in row["reviews"][:review_limit]] for row in eval_data],
predicted_summaries,
pbar=not silent,
product_names=[row["entity_name"] for row in eval_data],
Expand All @@ -119,3 +162,20 @@ def run(
)

return result

def cleanup_llm_output(self, text):
# Cleanup whitespace
sents = sent_tokenize(text.replace("\n", " ").strip())
# Strip "helpful" LLM padding
sents = [
sent.strip()
for sent in sents
if sent.lower()[:4] != "sure" and sent.lower()[:9] != "here is a" and sent.lower()[:8] != "here are"
]
sents = [
sent[1:-1].strip() if sent[0] == '"' and sent[-1] == '"' else sent for sent in sents
] # inefficiently strip quotes
sents = [
sent[1:-1].strip() if sent[0] == "'" and sent[-1] == "'" else sent for sent in sents
] # inefficiently strip quotes
return " ".join(sents).strip()
28 changes: 20 additions & 8 deletions torchseq/eval/recipes/opagg/hero_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from torchseq.utils.model_loader import model_from_path
from torchseq.metric_hooks.opsumm_cluster_aug import OpSummClusterAugMetricHook

PROMPT_TEMPLATE_PIECEWISE = """Here is a list of sentences taken from reviews of a single hotel:
from nltk.tokenize import word_tokenize
import numpy as np

PROMPT_TEMPLATE_SENTENCEWISE = """Here is a list of sentences taken from reviews of a single {:}:
{:}
In no more than 15 words, write a single short sentence using very simple language that includes the main point:
"""

PROMPT_TEMPLATE_ONESHOT = """Here is a list of sentences taken from reviews of a single hotel:
PROMPT_TEMPLATE_ONESHOT = """Here is a list of sentences taken from reviews of a single {:}:
{:}
Expand All @@ -22,6 +25,8 @@
class Recipe(EvalRecipe):
name: str = "opagg.twostage_pre"

cluster_limit = 24

def run(self):
result = {}

Expand All @@ -41,30 +46,37 @@ def run(self):
clusters_per_entity = summaries["evidence"]
entity_ids = summaries["entity_ids"]

prompts_flat_piecewise = [
{"entity_id": ent_id, "prompt": PROMPT_TEMPLATE_PIECEWISE.format("\n".join(cluster))}
product_type = "hotel" if "space" in self.model_path else "product"

prompts_flat_sentencewise = [
{
"entity_id": ent_id,
"prompt": PROMPT_TEMPLATE_SENTENCEWISE.format(product_type, "\n".join(cluster[: self.cluster_limit])),
}
for clusters, ent_id in zip(clusters_per_entity, entity_ids)
for cluster in clusters
]

result["prompts_piecewise"] = prompts_flat_piecewise
result["prompts_sentencewise"] = max(
[len(word_tokenize(prompt["prompt"])) for prompt in prompts_flat_sentencewise]
)

with jsonlines.open(
os.path.join(self.model_path, "eval", f"llm_inputs_piecewise_{self.split_str}.jsonl"), "w"
) as writer:
writer.write_all(prompts_flat_piecewise)
writer.write_all(prompts_flat_sentencewise)

prompts_flat_oneshot = [
{
"entity_id": ent_id,
"prompt": PROMPT_TEMPLATE_ONESHOT.format(
"\n".join([sent for cluster in clusters for sent in cluster])
product_type, "\n".join([sent for cluster in clusters for sent in cluster[: self.cluster_limit]])
),
}
for clusters, ent_id in zip(clusters_per_entity, entity_ids)
]

result["prompts_oneshot"] = prompts_flat_oneshot
result["prompts_oneshot"] = max([len(word_tokenize(prompt["prompt"])) for prompt in prompts_flat_oneshot])

with jsonlines.open(
os.path.join(self.model_path, "eval", f"llm_inputs_oneshot_{self.split_str}.jsonl"), "w"
Expand Down
28 changes: 20 additions & 8 deletions torchseq/metric_hooks/opsumm_cluster_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,14 +1453,17 @@ def prefilter_condition(sentence, hotel_aspect_filter=True, amazon_filter=False,

all_evidence = []
all_evidence_paths = []

with Pool(8) as pool:
for row in tqdm(
eval_data,
desc="Selecting centroids (using {:})".format(
config.eval.metrics.opsumm_cluster_aug.get("summary_centroid_method", "rouge")
),
disable=agent.silent,
all_merges = []

with Pool(4) as pool:
for row_ix, row in enumerate(
tqdm(
eval_data,
desc="Selecting centroids (using {:})".format(
config.eval.metrics.opsumm_cluster_aug.get("summary_centroid_method", "rouge")
),
disable=agent.silent,
)
):
summary_sentences = []
summary_evidence = []
Expand Down Expand Up @@ -1678,6 +1681,14 @@ def prefilter_condition(sentence, hotel_aspect_filter=True, amazon_filter=False,
summary_evidence_merged[-1].extend(summary_evidence[tgtix])
summary_evidence_paths_merged[-1].extend(summary_evidence_paths[tgtix])

# if row_ix == 0:
# print('Merges:')
# print(summary_sentences)
# print(summary_sentences_merged)
# print(ixs_to_merge)

all_merges.append(ixs_to_merge)

summary_sentences = summary_sentences_merged
summary_evidence = summary_evidence_merged
summary_evidence_paths = summary_evidence_paths_merged
Expand Down Expand Up @@ -1854,6 +1865,7 @@ def prefilter_condition(sentence, hotel_aspect_filter=True, amazon_filter=False,
"refs": gold_summs,
"entity_ids": [row["entity_id"] for row in eval_data], # include these to enable alignment downstream
"entity_names": [row["entity_name"] for row in eval_data], # include these to enable alignment downstream
"merges": all_merges,
}

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion torchseq/metric_hooks/prevalence_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
model_name: Literal["vitc", "vitc-base", "mnli", "mnli-base"] = "mnli",
threshold: float = 0.04,
use_cache: bool = True,
cache_name: str = "",
):
# Default model was originally mnli - changed to vitc
self.model = SummaCZS(
Expand All @@ -30,7 +31,7 @@ def __init__(
self.use_cache = use_cache

if self.use_cache:
self.model.imager.cache_folder = os.path.expanduser("~/.summac_cache/")
self.model.imager.cache_folder = os.path.expanduser(f"~/.summac_cache/{cache_name}")
os.makedirs(self.model.imager.cache_folder, exist_ok=True)
self.model.imager.load_cache()

Expand Down

0 comments on commit e464ea9

Please sign in to comment.