Skip to content

Commit

Permalink
Rename to HIRO
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhosking committed Mar 13, 2024
1 parent 6a9f30e commit b1f2a2a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion torchseq/eval/recipes/opagg/hero_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def run(
i += 1
predicted_summaries.append(" ".join(curr_summ))

with open(os.path.join(self.model_path, "eval", f"hero_{output_name}.txt"), "w") as f:
with open(os.path.join(self.model_path, "eval", f"hiro_{output_name}.txt"), "w") as f:
f.writelines([summ + "\n" for summ in predicted_summaries])
else:
# Allow this recipe to be used for external systems (ie baselines)
Expand Down
16 changes: 8 additions & 8 deletions torchseq/metric_hooks/opsumm_cluster_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import logging
from truecase import get_true_case

logger = logging.getLogger("HEROMetric")
logger = logging.getLogger("HIROMetric")


# Check for equivalence of paths, allowing for wildcard values
Expand Down Expand Up @@ -104,14 +104,14 @@ def on_batch(self, batch, logits, output, memory, use_test=False):

def on_end_epoch(self, agent, use_test=False):
# Populate caches
logger.info("Populating HERO caches - this may take a while!")
logger.info("Populating HIRO caches - this may take a while!")
_, _ = OpSummClusterAugMetricHook.codes_from_cache(self.config, agent, test=False, train=False)
# _, _ = OpSummClusterAugMetricHook.codes_from_cache(self.config, agent, test=False, train=True)
logger.info("...done")

if self.config.eval.metrics.opsumm_cluster_aug.get("run_nli", False):
logger.info("Running NLI eval")
self.scores["hero_nli"], _, _ = OpSummClusterAugMetricHook.eval_nli(
self.scores["hiro_nli"], _, _ = OpSummClusterAugMetricHook.eval_nli(
self.config,
agent,
test=use_test,
Expand All @@ -130,7 +130,7 @@ def on_end_epoch(self, agent, use_test=False):
if self.config.eval.metrics.opsumm_cluster_aug.get("run_extract_summaries", False):
logger.info("Running generation using HRQ paths")
(
self.scores["hero_generation"],
self.scores["hiro_generation"],
generated_summaries,
) = OpSummClusterAugMetricHook.eval_extract_summaries_and_score(
self.config,
Expand All @@ -146,7 +146,7 @@ def on_end_epoch(self, agent, use_test=False):
if self.config.eval.metrics.opsumm_cluster_aug.get("run_selection_oracle_comparison", False):
logger.info("Running cluster vs oracle comparison...")
self.scores[
"hero_selection_vs_oracle"
"hiro_selection_vs_oracle"
] = OpSummClusterAugMetricHook.eval_compare_selected_clusters_to_oracle(
self.config,
agent.data_path,
Expand All @@ -161,7 +161,7 @@ def on_end_epoch(self, agent, use_test=False):
if self.config.eval.metrics.opsumm_cluster_aug.get("run_selection_prevalence", False):
logger.info("Running cluster prevalence eval...")
(
self.scores["hero_selection_prevalence"],
self.scores["hiro_selection_prevalence"],
_,
) = OpSummClusterAugMetricHook.eval_cluster_prevalence(
self.config,
Expand All @@ -175,7 +175,7 @@ def on_end_epoch(self, agent, use_test=False):
"run_purity_bleu", False
) or self.config.eval.metrics.opsumm_cluster_aug.get("run_purity_nli", False):
logger.info("Running cluster purity eval")
self.scores["hero_purity"] = OpSummClusterAugMetricHook.eval_cluster_purity(
self.scores["hiro_purity"] = OpSummClusterAugMetricHook.eval_cluster_purity(
self.config,
agent,
test=use_test,
Expand All @@ -187,7 +187,7 @@ def on_end_epoch(self, agent, use_test=False):

if self.config.eval.metrics.opsumm_cluster_aug.get("run_specialisation", False):
logger.info("Running specialisation eval")
self.scores["hero_specialisation"] = OpSummClusterAugMetricHook.eval_specialisation(
self.scores["hiro_specialisation"] = OpSummClusterAugMetricHook.eval_specialisation(
self.config,
agent,
test=use_test,
Expand Down
4 changes: 2 additions & 2 deletions torchseq/utils/config_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def migrate_23_to_24_encdec(cfg_dict, check_only=False):
return False if check_only else cfg_dict


def migrate_selfret_to_hero(cfg_dict, check_only=False):
def migrate_selfret_to_hiro(cfg_dict, check_only=False):
if "self_retrieval" in cfg_dict["eval"]["metrics"]:
if check_only:
return True
Expand All @@ -58,7 +58,7 @@ def migrate_selfret_to_hero(cfg_dict, check_only=False):
return False if check_only else cfg_dict


all_migrations = [migrate_optimizers_23, migrate_23_to_24_encdec, migrate_selfret_to_hero]
all_migrations = [migrate_optimizers_23, migrate_23_to_24_encdec, migrate_selfret_to_hiro]


def check_config(config):
Expand Down

0 comments on commit b1f2a2a

Please sign in to comment.