Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MedHelm: Implement medcalc bench scenario, metrics and specs #3207

Open
wants to merge 14 commits into
base: med-helm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions helm-frontend/project_metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"title": "Lite",
"description": "Lightweight, broad evaluation of the capabilities of language models using in-context learning",
"id": "lite",
"releases": ["v1.10.0", "v1.9.0", "v1.8.0", "v1.7.0", "v1.6.0", "v1.5.0", "v1.4.0", "v1.3.0", "v1.2.0", "v1.1.0", "v1.0.0"]
"releases": ["v1.11.0", "v1.10.0", "v1.9.0", "v1.8.0", "v1.7.0", "v1.6.0", "v1.5.0", "v1.4.0", "v1.3.0", "v1.2.0", "v1.1.0", "v1.0.0"]
},
{
"title": "Classic",
Expand All @@ -27,7 +27,7 @@
"title": "MMLU",
"description": "Massive Multitask Language Understanding (MMLU) evaluations using standardized prompts",
"id": "mmlu",
"releases": ["v1.10.0", "v1.9.0", "v1.8.0", "v1.7.0", "v1.6.0", "v1.5.0", "v1.4.0", "v1.3.0", "v1.2.0", "v1.1.0", "v1.0.0"]
"releases": ["v1.11.0", "v1.10.0", "v1.9.0", "v1.8.0", "v1.7.0", "v1.6.0", "v1.5.0", "v1.4.0", "v1.3.0", "v1.2.0", "v1.1.0", "v1.0.0"]
},
{
"title": "VHELM",
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ nav:
- get_helm_rank.md
- benchmark.md
- huggingface_models.md
- Multimodality:
- Papers:
- heim.md
- vhelm.md
- Reference:
Expand Down
93 changes: 93 additions & 0 deletions src/helm/benchmark/metrics/chain_of_thought_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
from typing import List, Optional

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.metrics.metric import Metric
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat


def extract_answer(output_text: str) -> Optional[str]:
"""
Extracts the answer from the output text using two exact regex patterns.
Returns None if no valid answer is found.

Args:
output_text (str): The text from which to extract the answer.

Returns:
Optional[str]: The extracted answer (A-J) if found, otherwise None.
"""
# First regex: Matches "answer is (A-J)" with optional parentheses
match = re.search(r"answer is \(?([A-J])\)?", output_text)
if match:
return match.group(1)

# Second regex: Matches "[answer: (A-J)]" with optional leading characters like "."
match = re.search(r"\.*\[aA\]nswer:\s*\(?([A-J])\)?", output_text)
if match:
return match.group(1)

# If neither regex matches, return None
return None


class ChainOfThoughtMetric(Metric):
"""
This metric focuses on structured reasoning and the accuracy of extracted answers.
It compares model outputs against correct answers provided in a multiple-choice
format and returns a score indicating the correctness of the generated response.
"""

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
"""
Evaluate the generated output for chain-of-thought reasoning accuracy.

The method extracts the model's output, determines the correct answer
from the provided references, and compares the two to compute a binary score.

Args:
adapter_spec (AdapterSpec): Specification of the adapter used for the evaluation.
request_state (RequestState): The state of the current request, including
the input instance, output results, and references.
metric_service (MetricService): A service used to compute metrics if needed.
eval_cache_path (str): Path to the evaluation cache for storing or retrieving data.

Returns:
List[Stat]: A list containing a single `Stat` object with the correctness
score (1 for correct, 0 for incorrect) under the metric
name "chain_of_thought_correct".
"""
# Assert that completions exist if the result is not None
assert (
request_state.result is not None and request_state.result.completions
), "Request state result must have completions."

# Set output_text if the assertion passes
output_text = request_state.result.completions[0].text

# Extract the answer using the updated logic
extracted_answer = extract_answer(output_text)

# Find the correct answer from references by translating index to letter
correct_answer = None
for index, option in enumerate(request_state.instance.references):
if option.is_correct:
correct_answer = chr(65 + index) # Translate index (0 -> A, 1 -> B, etc.)
break

# Raise an exception if no correct answer is found
if correct_answer is None:
raise ValueError(f"No correct answer found for instance ID {request_state.instance.id}")

# Compare extracted answer with the correct answer and compute the score
score = 1 if extracted_answer == correct_answer else 0
return [Stat(MetricName("chain_of_thought_correctness")).add(score)]
181 changes: 181 additions & 0 deletions src/helm/benchmark/metrics/medcalc_bench_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import re
from datetime import datetime
from typing import List

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.metrics.metric import Metric
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat
from helm.common.hierarchical_logger import hlog


class MedCalcBenchMetric(Metric):
def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
"""Metric for MedCalc-Bench dataset.

Original implementation:
https://github.com/ncbi-nlp/MedCalc-Bench/blob/048ba77dbe332e9190935e4a30965bff444b940e/evaluation/evaluate.py#L11
"""
assert request_state.instance.extra_data, (
"Could not find `extra_data` in the request state. "
"Both `lower_limit` and `upper_limit` are required for this metric."
)

assert len(request_state.result.completions) == 1, (
f"Found a total of {len(request_state.result.completions)} completions. "
"Only one was expected"
)

final_answer = (
request_state.result.completions[0]
.text.strip()
.lower()
.split("calculated value:")[-1]
.strip()
)

correctness = 0
if final_answer:
try:
correctness = self.medcalc_bench_metric_calculation(
answer=final_answer,
ground_truth=request_state.instance.extra_data["ground_truth"],
calid=int(request_state.instance.extra_data["calculator_id"]),
upper_limit=request_state.instance.extra_data["upper_limit"],
lower_limit=request_state.instance.extra_data["lower_limit"],
)
except ValueError as e:
hlog(
(
"Failed to calculate the correctess of the output for MedCalc-Bench instance "
f'with id {request_state.instance.extra_data["id"]}: {e}'
)
)

return [Stat(MetricName("medcalc_bench_metric")).add(correctness)]

def medcalc_bench_metric_calculation(
self,
answer: str,
ground_truth: str,
calid: int,
upper_limit: str,
lower_limit: str,
) -> int:
"""Calculate the metric for MedCalc-Bench dataset.

This method is basically a copy of the original implementation of this metric:
https://github.com/ncbi-nlp/MedCalc-Bench/blob/048ba77dbe332e9190935e4a30965bff444b940e/evaluation/evaluate.py#L11

Credits to the original authors: https://github.com/ncbi-nlp/MedCalc-Bench.
"""
if calid in [13, 68]:
# Output Type: date

if datetime.strptime(answer, "%m/%d/%Y").strftime(
"%-m/%-d/%Y"
) == datetime.strptime(ground_truth, "%m/%d/%Y").strftime("%-m/%-d/%Y"):
correctness = 1
else:
correctness = 0
elif calid in [69]:
# Output Type: integer (A, B)
match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?",
ground_truth,
)
ground_truth = f"({match.group(1)}, {match.group(3)})"
match = re.search(
r"\(?[\"\']?(\d+)\s*(weeks?)?[\"\']?,?\s*[\"\']?(\d+)\s*(days?)?[\"\']?\s*\)?",
answer,
)
if match:
weeks = match.group(1)
days = match.group(3)
answer = f"({weeks}, {days})"
if eval(answer) == eval(ground_truth):
correctness = 1
else:
correctness = 0
else:
correctness = 0
elif calid in [
4,
15,
16,
17,
18,
20,
21,
25,
27,
28,
29,
32,
33,
36,
43,
45,
48,
51,
69,
]:
# Output Type: integer A
answer = round(int(answer))
if answer == int(ground_truth):
correctness = 1
else:
correctness = 0
elif calid in [
2,
3,
5,
6,
7,
8,
9,
10,
11,
19,
22,
23,
24,
26,
30,
31,
38,
39,
40,
44,
46,
49,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
]:
# Output Type: decimal
answer = float(answer)
if answer >= float(lower_limit) and answer <= float(upper_limit):
correctness = 1
else:
correctness = 0
else:
raise ValueError(f"Unknown calculator ID: {calid}")
return correctness
4 changes: 4 additions & 0 deletions src/helm/benchmark/presentation/run_entries_speech.conf
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ entries: [
{description: "vocal_sound:model=audiolm", priority: 1}
{description: "audiocaps:model=audiolm", priority: 1}
{description: "voxceleb2:model=audiolm", priority: 1}
{description: "air_bench_chat:subject=speech,model=audiolm", priority: 1}
{description: "air_bench_chat:subject=sound,model=audiolm", priority: 1}
{description: "air_bench_chat:subject=music,model=audiolm", priority: 1}
{description: "air_bench_chat:subject=mix,model=audiolm", priority: 1}

####################################################################################################################
# Fairness
Expand Down
20 changes: 20 additions & 0 deletions src/helm/benchmark/run_specs/audio_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,23 @@ def get_casual_conversations2_run_spec(subject: str) -> RunSpec:
metric_specs=metric_specs,
groups=[run_spec_name],
)


@run_spec_function("air_bench_chat")
def get_air_bench_chat_run_spec(subject: str) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.audio_language.air_bench_chat_scenario." "AirBenchChatScenario",
args={"subject": subject},
)
adapter_spec = _get_generation_adapter_spec(
max_tokens=50,
)
metric_specs: List[MetricSpec] = _get_open_ended_generation_metric_specs()
run_spec_name: str = "air_bench_chat"
return RunSpec(
name=f"{run_spec_name}:subject={subject}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
)
31 changes: 31 additions & 0 deletions src/helm/benchmark/run_specs/enem_challenge_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from helm.benchmark.adaptation.adapters.adapter_factory import ADAPT_MULTIPLE_CHOICE_JOINT
from helm.benchmark.adaptation.common_adapter_specs import get_multiple_choice_adapter_spec
from helm.benchmark.metrics.common_metric_specs import get_exact_match_metric_specs
from helm.benchmark.run_spec import RunSpec, run_spec_function
from helm.benchmark.scenarios.scenario import ScenarioSpec


@run_spec_function("enem_challenge")
def get_enem_spec() -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.enem_challenge_scenario.ENEMChallengeScenario", args={}
)

adapter_spec = get_multiple_choice_adapter_spec(
method=ADAPT_MULTIPLE_CHOICE_JOINT,
instructions="Dê uma resposta selecionando uma letra entre as opções fornecidas. "
"Se as opções forem A, B, C, D e E, "
"sua resposta deve consistir em uma única letra que corresponde a resposta correta.\n"
"Exemplo: Qual é a capital da França?\nA. Londres\nB. Paris\nC. Roma\nD. Berlim\nE. Sydney\n"
"Resposta: B",
input_noun="Pergunta",
output_noun="Resposta",
)

return RunSpec(
name="enem_challenge",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
groups=["enem_challenge"],
)
Loading
Loading