Skip to content

Commit

Permalink
OpAgg eval updates
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhosking committed Jan 24, 2024
1 parent 39295c5 commit 8b9a870
Show file tree
Hide file tree
Showing 3 changed files with 408 additions and 167 deletions.
6 changes: 3 additions & 3 deletions torchseq/eval/recipes/opagg/twostage_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def run(self, predicted_summaries: Optional[list[str]] = None) -> dict[str, Any]
print("Evaling prevalence")
from torchseq.metric_hooks.prevalence_metric import PrevalenceMetric

prevmet = PrevalenceMetric()
prevs, reds, trivs = prevmet.get_prevalence(
prevmet = PrevalenceMetric(model_name="vitc")
(prevs, reds, trivs), _ = prevmet.get_prevalence(
[[" ".join(rev["sentences"]) for rev in row["reviews"]] for row in eval_data],
predicted_summaries,
pbar=False,
product_names=[product_names[row["entity_id"]] for row in eval_data],
trivial_template=trivial_template,
)
result["prevalence"] = (np.mean(prevs), np.mean(reds), np.mean(trivs))
result["prevalence"] = (np.mean(prevs) * 100, np.mean(reds) * 100, np.mean(trivs) * 100)

return result
47 changes: 36 additions & 11 deletions torchseq/metric_hooks/prevalence_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Original is a CLI, this provides a more generic OO wrapper
# See https://arxiv.org/abs/2307.14305

import os
import nltk.tokenize
from summac.model_summac import SummaCZS
from tqdm import tqdm
Expand All @@ -12,12 +13,27 @@
class PrevalenceMetric:
threshold: float = 0.04
model: SummaCZS
use_cache: bool

def __init__(self, model_name: Literal["vitc", "vitc-base", "mnli", "mnli-base"] = "mnli"):
def __init__(
self,
model_name: Literal["vitc", "vitc-base", "mnli", "mnli-base"] = "mnli",
threshold: float = 0.04,
use_cache: bool = True,
):
# Default model was originally mnli - changed to vitc
self.model = SummaCZS(
granularity="document", model_name=model_name, bins="percentile", use_con=False, device="cuda"
)
self.threshold = threshold

self.use_cache = use_cache

if self.use_cache:
self.model.imager.cache_folder = os.path.expanduser("~/.summac_cache/")
os.makedirs(self.model.imager.cache_folder, exist_ok=True)
cache_file = self.model.imager.get_cache_file()
self.model.imager.load_cache()

def get_prevalence(
self,
Expand All @@ -31,8 +47,6 @@ def get_prevalence(
trivial_default: str = "a hotel",
batch_size: int = 32,
):
threshold = 0.04

if product_names is None:
product_names = [trivial_default] * len(reviews)

Expand Down Expand Up @@ -76,12 +90,12 @@ def get_prevalence(

# Calculate which summary sentences were trivial
trivial_scores = all_scores[:trivial_offset]
trivial_mask = np.array(trivial_scores) > threshold
trivial_mask = np.array(trivial_scores) > self.threshold

# Calculate which sentences are redundant (wrt previous sentences)
if not ignore_redundancy:
redundant_scores = all_scores[trivial_offset:redundant_offset]
redundant_mask_flat = np.array(redundant_scores) > threshold
redundant_mask_flat = np.array(redundant_scores) > self.threshold
redundant_mask_list = []
k = 0
for i, sent in enumerate(sents):
Expand All @@ -98,16 +112,27 @@ def get_prevalence(

# Calculate which sentences are supported by reviews
implied_scores = all_scores[redundant_offset:]
implied_mask_flat = np.array(implied_scores) > threshold
implied_mask_flat = np.array(implied_scores) > self.threshold
implied_counts = implied_mask_flat.reshape(len(curr_reviews), len(sents)).mean(axis=0)

implied_counts = implied_counts * (
np.logical_not(trivial_mask) & (ignore_redundancy | np.logical_not(redundant_mask))
) # Ignore if trivial or redundant

# Aggregate
prevalences.append(np.mean(implied_counts))
redundancies.append(np.mean(redundant_mask))
trivials.append(np.mean(trivial_mask))

return prevalences, redundancies, trivials
prevalences.append(implied_counts)
redundancies.append(redundant_mask)
trivials.append(trivial_mask)

if self.use_cache:
self.model.imager.save_cache()

return (
[np.mean(prevs) for prevs in prevalences],
[np.mean(reds) for reds in redundancies],
[np.mean(trivs) for trivs in trivials],
), (
prevalences,
redundancies,
trivials,
)
Loading

0 comments on commit 8b9a870

Please sign in to comment.