Skip to content

Commit

Permalink
fix: dataset loading and standardize naming
Browse files Browse the repository at this point in the history
  • Loading branch information
Leonardo Schettini committed Dec 17, 2024
1 parent bb15f35 commit 792fb4f
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat
from helm.common.hierarchical_logger import hlog


class MedCalcBenchMetric(Metric):
Expand All @@ -33,25 +34,33 @@ def evaluate_generation(
"Only one was expected"
)

final_answer = request_state.result.completions[0].text.strip().lower().split("final answer:")[-1].strip()

try:
correctness = self.medcalc_bench_range_metric_calculation(
answer=final_answer,
ground_truth=request_state.instance.extra_data["ground_truth"],
calid=int(request_state.instance.extra_data["calculator_id"]),
upper_limit=request_state.instance.extra_data["upper_limit"],
lower_limit=request_state.instance.extra_data["lower_limit"],
)
except ValueError:
raise ValueError(
"Failed to calculate the correctess of the output for a MedCalc-Bench instance."
)
final_answer = (
request_state.result.completions[0]
.text.strip()
.lower()
.split("calculated value:")[-1]
.strip()
)

stat = Stat(MetricName("medcalc_bench_metric"))
stat.add(int(correctness))
correctness = 0
if final_answer:
try:
correctness = self.medcalc_bench_metric_calculation(
answer=final_answer,
ground_truth=request_state.instance.extra_data["ground_truth"],
calid=int(request_state.instance.extra_data["calculator_id"]),
upper_limit=request_state.instance.extra_data["upper_limit"],
lower_limit=request_state.instance.extra_data["lower_limit"],
)
except ValueError as e:
hlog(
(
"Failed to calculate the correctess of the output for MedCalc-Bench instance "
f'with id {request_state.instance.extra_data["id"]}: {e}'
)
)

return [stat]
return [Stat(MetricName("medcalc_bench_metric")).add(correctness)]

def medcalc_bench_metric_calculation(
self,
Expand Down Expand Up @@ -120,8 +129,8 @@ def medcalc_bench_metric_calculation(
69,
]:
# Output Type: integer A
answer = round(eval(answer))
if answer == eval(ground_truth):
answer = round(int(answer))
if answer == int(ground_truth):
correctness = 1
else:
correctness = 0
Expand Down Expand Up @@ -162,8 +171,8 @@ def medcalc_bench_metric_calculation(
67,
]:
# Output Type: decimal
answer = eval(answer)
if answer >= eval(lower_limit) and answer <= eval(upper_limit):
answer = float(answer)
if answer >= float(lower_limit) and answer <= float(upper_limit):
correctness = 1
else:
correctness = 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Dict, List

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.common_adapter_specs import get_generation_adapter_spec
from helm.benchmark.metrics.common_metric_specs import get_basic_metric_specs
from helm.benchmark.metrics.metric import MetricSpec
Expand All @@ -10,44 +11,47 @@
ONE_SHOT_EXAMPLES_URL = "https://raw.githubusercontent.com/ncbi-nlp/MedCalc-Bench/048ba77dbe332e9190935e4a30965bff444b940e/evaluation/one_shot_finalized_explanation.json"


@run_spec_function("med_calc_bench_zero_shot_cot")
def get_med_calc_bench_zero_shot_spec() -> RunSpec:
@run_spec_function("medcalc_bench")
def get_medcalc_bench_spec(method: str) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.med_calc_bench_scenario.MedCalcBenchScenario",
class_name="helm.benchmark.scenarios.medcalc_bench_scenario.MedCalcBenchScenario",
args={},
)

adapter_spec = get_generation_adapter_spec(
instructions=_get_zero_shot_cot_instructions(),
input_noun=None, # Set directly in the scenario.
output_noun="Calculated Value",
max_tokens=500,
)
if method == "zero_shot":
adapter_spec = get_medcalc_bench_zero_shot_adapter()
elif method == "one_shot":
adapter_spec = get_medcalc_bench_one_shot_adapter()
else:
raise ValueError(f"Invalid method for MedCalc-Bench: {method}")

metric_specs: List[MetricSpec] = [
MetricSpec(
class_name="helm.benchmark.metrics.med_calc_bench_metrics.MedCalcBenchMetric",
class_name="helm.benchmark.metrics.medcalc_bench_metrics.MedCalcBenchMetric",
args={},
)
] + get_basic_metric_specs([])
] # + get_basic_metric_specs([])

return RunSpec(
name="med_calc_bench",
name=f"medcalc_bench:method{method}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=["med_calc_bench"],
groups=["medcalc_bench"],
)


@run_spec_function("med_calc_bench_one_shot_cot")
def get_med_calc_bench_one_shot_spec() -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.med_calc_bench_scenario.MedCalcBenchScenario",
args={},
def get_medcalc_bench_zero_shot_adapter() -> AdapterSpec:
return get_generation_adapter_spec(
instructions=_get_zero_shot_cot_instructions(),
input_noun=None, # Set directly in the scenario.
output_noun="\n\nCalculated Value",
max_tokens=500,
)

adapter_spec = get_generation_adapter_spec(

def get_medcalc_bench_one_shot_adapter() -> AdapterSpec:
return get_generation_adapter_spec(
instructions=_get_one_shot_cot_instructions(
# TODO: Modify this to retrieve the question and calculator ID from the respective dataset sample.
# For more information see the docstring for the `_get_one_shot_cot_instructions` function.
Expand All @@ -63,28 +67,13 @@ def get_med_calc_bench_one_shot_spec() -> RunSpec:
calculator_id="2",
),
input_noun=None, # Set directly in the scenario.
output_noun="Calculated Value",
output_noun="\n\nCalculated Value",
max_tokens=500,
)

metric_specs: List[MetricSpec] = [
MetricSpec(
class_name="helm.benchmark.metrics.med_calc_bench_metrics.MedCalcBenchMetric",
args={},
)
] + get_basic_metric_specs([])

return RunSpec(
name="med_calc_bench",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=["med_calc_bench"],
)


def _get_zero_shot_cot_instructions() -> str:
"""Generate instructions for the MedCalcBench scenario.
"""Generate instructions for the MedCalc-Bench scenario.
This function is inspired on the system prompt definition in the original code:
https://github.com/ncbi-nlp/MedCalc-Bench/blob/048ba77dbe332e9190935e4a30965bff444b940e/evaluation/run.py#L16
Expand All @@ -101,7 +90,7 @@ def _get_zero_shot_cot_instructions() -> str:


def _get_one_shot_cot_instructions(question: str, calculator_id: str) -> str:
"""Generate instructions for the MedCalcBench scenario.
"""Generate instructions for the MedCalc-Bench scenario.
This function is inspired on the system prompt definition in the original code:
https://github.com/ncbi-nlp/MedCalc-Bench/blob/048ba77dbe332e9190935e4a30965bff444b940e/evaluation/run.py#L26
Expand All @@ -120,7 +109,7 @@ def _get_one_shot_cot_instructions(question: str, calculator_id: str) -> str:

if not examples:
raise ValueError(
"Failed to load one-shot examples for the MedCalcBench scenario."
"Failed to load one-shot examples for the MedCalc-Bench scenario."
)

example = examples.get(calculator_id, {})
Expand Down
72 changes: 0 additions & 72 deletions src/helm/benchmark/scenarios/med_calc_bench_scenario.py

This file was deleted.

96 changes: 96 additions & 0 deletions src/helm/benchmark/scenarios/medcalc_bench_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import os
from typing import Dict, List

from datasets import DatasetDict, load_dataset

from helm.benchmark.scenarios.scenario import (
CORRECT_TAG,
TEST_SPLIT,
TRAIN_SPLIT,
Input,
Instance,
Output,
Reference,
Scenario,
)
from helm.common.general import ensure_directory_exists


class MedCalcBenchScenario(Scenario):
"""
MedCalcBench scenario: Processes a medical calculation dataset with explanations.
Each record in the dataset has:
- Row Number
- Calculator ID
- Calculator Name
- Category
- Output Type
- Note ID
- Note Type
- Question
- Ground Truth Explanation
- Patient Note
- Relevant Entities
- Lower Limit
- Upper Limit
- Ground Truth Answer
The output is formatted as:
"The answer is <calculated value>. Steps: <explanation>"
"""

HUGGING_FACE_DATASET_PATH: str = "ncbi/MedCalc-Bench-v1.0"

# TODO: Add a base url
DATASET_DOWNLOAD_BASE_URL: str = ""

name = "medcalcbench"
description = "Medical calculation questions with step-by-step explanations."
tags = ["reasoning", "medicine", "calculation"]

def get_instances(self, output_path: str) -> List[Instance]:
data_path: str = os.path.join(output_path, "data")
ensure_directory_exists(data_path)
dataset: DatasetDict = load_dataset(self.HUGGING_FACE_DATASET_PATH)

splits = {TRAIN_SPLIT: "train", TEST_SPLIT: "test"}
instances: List[Instance] = []
for (
helm_split_name,
dataset_split_name,
) in splits.items(): # Iterate over the splits
split_data = dataset[dataset_split_name]

for example in split_data:
question = example["Question"]
patient_note = example["Patient Note"]

input_text = (
f"Patient Note:\n\n{patient_note}\n\nQuestion:\n\n{question}"
)

# Format the final answer with explanation
instances.append(
Instance(
input=Input(text=input_text),
references=[
Reference(
Output(text=example["Ground Truth Answer"]),
tags=[CORRECT_TAG],
)
],
split=helm_split_name,
extra_data={
"id": example["Row Number"],
"relevant_entities": example["Relevant Entities"],
"lower_limit": example["Lower Limit"],
"upper_limit": example["Upper Limit"],
"calculator_id": example["Calculator ID"],
"ground_truth": example["Ground Truth Answer"],
},
)
)

return instances

0 comments on commit 792fb4f

Please sign in to comment.