Skip to content

Commit

Permalink
Fixed validate_and_filter_scores method, and replaced _restric_task_r…
Browse files Browse the repository at this point in the history
…esults with it
  • Loading branch information
x-tabdeveloping committed Oct 21, 2024
1 parent 0bf3746 commit 607c998
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 46 deletions.
35 changes: 1 addition & 34 deletions mteb/load_results/benchmark_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,6 @@
Score = Any


def _restrict_task_results(task_result: TaskResult, task: AbsTask) -> TaskResult:
splits = task.metadata.eval_splits
if task.is_multilingual:
hf_subsets = getattr(
task, "hf_subsets", task.metadata.hf_subsets_to_langscripts.keys()
)
hf_subsets = set(hf_subsets)
else:
hf_subsets = {"default"}
new_scores = {}
seen_splits = set()
for split in task_result.scores:
if split not in splits:
continue
new_scores[split] = []
seen_subsets = set()
for _scores in task_result.scores[split]:
if _scores["hf_subset"] not in hf_subsets:
continue
new_scores[split].append(_scores)
seen_subsets.add(_scores["hf_subset"])
if seen_subsets != hf_subsets:
raise ValueError(
f"Missing subsets {hf_subsets - seen_subsets} for split {split}"
)
seen_splits.add(split)
if seen_splits != set(splits):
raise ValueError(f"Missing splits {set(splits) - seen_splits}")
new_res = {**task_result.to_dict(), "scores": new_scores}
new_res = TaskResult.from_dict(new_res)
return new_res


class ModelResult(BaseModel):
model_name: str
model_revision: str | None
Expand Down Expand Up @@ -101,7 +68,7 @@ def filter_tasks(
def select_tasks(self, tasks: list[AbsTask]) -> "ModelResult":
task_name_to_task = {task.metadata.name: task for task in tasks}
new_task_results = [
_restrict_task_results(task_res, task_name_to_task[task_res.task_name])
task_res.validate_and_filter_scores(task_name_to_task[task_res.task_name])
for task_res in self.task_results
if task_res.task_name in task_name_to_task
]
Expand Down
2 changes: 1 addition & 1 deletion mteb/load_results/load_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def load_results(
task = task_names[r.task_name]
else:
task = None
r.validate_and_filter_scores(task=task)
r = r.validate_and_filter_scores(task=task)
filtered_results.append(r)
except Exception as e:
logger.warning(
Expand Down
25 changes: 14 additions & 11 deletions mteb/load_results/task_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,10 @@ def get_score(
def __repr__(self) -> str:
return f"TaskResult(task_name={self.task_name}, scores=...)"

def validate_and_filter_scores(self, task: AbsTask | None = None) -> None:
def validate_and_filter_scores(self, task: AbsTask | None = None) -> AbsTask:
"""This ensures that the scores are correct for the given task, by removing any splits besides those specified in the task metadata.
Additionally it also ensure that all of the splits required as well as the languages are present in the scores.
Returns new TaskResult object.
Args:
task: The task to validate the scores against. E.g. if the task supplied is limited to certain splits and languages,
Expand All @@ -482,30 +483,32 @@ def validate_and_filter_scores(self, task: AbsTask | None = None) -> None:
if task is None:
task = get_task(self.task_name)
splits = task.metadata.eval_splits
hf_subsets = set(task.metadata.hf_subsets_to_langscripts)

if task.is_multilingual:
hf_subsets = getattr(
task, "hf_subsets", task.metadata.hf_subsets_to_langscripts.keys()
)
hf_subsets = set(hf_subsets)
else:
hf_subsets = {"default"}
new_scores = {}
seen_splits = set()
for split in self.scores:
for split in task_result.scores:
if split not in splits:
continue
new_scores[split] = []

seen_subsets = set()
for _scores in self.scores[split]:
for _scores in task_result.scores[split]:
if _scores["hf_subset"] not in hf_subsets:
continue
new_scores[split].append(_scores)
seen_subsets.add(_scores["hf_subset"])

if seen_subsets != hf_subsets:
raise ValueError(
f"Missing subsets {hf_subsets - seen_subsets} for split {split}"
)

seen_splits.add(split)

if seen_splits != set(splits):
raise ValueError(f"Missing splits {set(splits) - seen_splits}")

self.scores = new_scores
new_res = {**task_result.to_dict(), "scores": new_scores}
new_res = TaskResult.from_dict(new_res)
return new_res

0 comments on commit 607c998

Please sign in to comment.