Skip to content

Commit

Permalink
Merge pull request #900 from mlcommons/new-apis_v0.1.0-dev_metrics_su…
Browse files Browse the repository at this point in the history
…pport_separate_target-pred_files

Metrics library updated to support comma separate inputs
  • Loading branch information
sarthakpati authored Jul 12, 2024
2 parents df378db + 5832c03 commit 3e02cec
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 61 deletions.
150 changes: 119 additions & 31 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,41 @@
)


def __update_header_location_case_insensitive(
input_df: pd.DataFrame, expected_column_name: str, required: bool = True
) -> pd.DataFrame:
"""
This function checks for a column in the dataframe in a case-insensitive manner and renames it.
Args:
input_df (pd.DataFrame): The input dataframe.
expected_column_name (str): The expected column name.
required (bool, optional): Whether the column is required. Defaults to True.
Returns:
pd.DataFrame: The updated dataframe.
"""
actual_column_name = None
for col in input_df.columns:
if col.lower() == expected_column_name.lower():
actual_column_name = col
break

if required:
assert (
actual_column_name is not None
), f"Column {expected_column_name} not found in the dataframe"

return input_df.rename(columns={actual_column_name: expected_column_name})
else:
return input_df


def generate_metrics_dict(
input_csv: str, config: str, outputfile: Optional[str] = None
input_csv: str,
config: str,
outputfile: Optional[str] = None,
missing_prediction: int = -1,
) -> dict:
"""
This function generates metrics from the input csv and the config.
Expand All @@ -41,27 +74,83 @@ def generate_metrics_dict(
input_csv (str): The input CSV.
config (str): The input yaml config.
outputfile (str, optional): The output file to save the metrics. Defaults to None.
missing_prediction (int, optional): The value to use for missing predictions as penalty. Default is -1.
Returns:
dict: The metrics dictionary.
"""
input_df = pd.read_csv(input_csv)
# the case where the input is a comma-separated 2 files with targets and predictions
if "," in input_csv:
target_csv, prediction_csv = input_csv.split(",")
target_df = pd.read_csv(target_csv)
prediction_df = pd.read_csv(prediction_csv)
## start sanity checks
# if missing predictions are not to be penalized, check if the number of rows in the target and prediction files are the same
if missing_prediction == -1:
assert (
target_df.shape[0] == prediction_df.shape[0]
), "The number of rows in the target and prediction files should be the same"

# check if the number of columns in the target and prediction files are the same
assert (
target_df.shape[1] == prediction_df.shape[1]
), "The number of columns in the target and prediction files should be the same"
assert (
target_df.shape[1] == 2
), "The target and prediction files should have *exactly* 2 columns"

# find the correct header for the subjectID column
target_df = __update_header_location_case_insensitive(target_df, "SubjectID")
prediction_df = __update_header_location_case_insensitive(
prediction_df, "SubjectID"
)
# check if prediction_df has extra subjectIDs
assert (
prediction_df["SubjectID"].isin(target_df["SubjectID"]).all()
), "The `SubjectID` column in the prediction file should be a subset of the `SubjectID` column in the target file"

# individual checks for target and prediction dataframes
for df in [target_df, prediction_df]:
# check if the "subjectID" column has duplicates
assert (
df["SubjectID"].duplicated().sum() == 0
), "The `SubjectID` column should not have duplicates"

# check if SubjectID is the first column
assert (
df.columns[0] == "SubjectID"
), "The `SubjectID` column should be the first column in the target and prediction files"

# change the column name after subjectID to target and prediction
target_df = target_df.rename(columns={target_df.columns[1]: "Target"})
prediction_df = prediction_df.rename(
columns={prediction_df.columns[1]: "Prediction"}
)

# combine the two dataframes
input_df = target_df.merge(prediction_df, how="left", on="SubjectID").fillna(
missing_prediction
)

# check required headers in a case insensitive manner
headers = {}
required_columns = ["subjectid", "prediction", "target"]
for col, _ in input_df.items():
col_lower = col.lower()
else:
# the case where the input is a single file with targets and predictions
input_df = pd.read_csv(input_csv)

# check required headers in a case insensitive manner and rename them
required_columns = ["SubjectID", "Prediction", "Target"]
for column_to_check in required_columns:
if column_to_check == col_lower:
headers[column_to_check] = col
if col_lower == "mask":
headers["mask"] = col
for column in required_columns:
assert column in headers, f"The input csv should have a column named {column}"
input_df = __update_header_location_case_insensitive(
input_df, column_to_check
)

# check if the "subjectID" column has duplicates
assert (
input_df["SubjectID"].duplicated().sum() == 0
), "The `SubjectID` column should not have duplicates"

overall_stats_dict = {}
parameters = ConfigManager(config)
# ensure that the problem_type is set
problem_type = parameters.get("problem_type", None)
problem_type = (
find_problem_type_from_parameters(parameters)
Expand All @@ -70,12 +159,14 @@ def generate_metrics_dict(
)
parameters["problem_type"] = problem_type

if problem_type == "regression" or problem_type == "classification":
parameters["model"]["num_classes"] = len(parameters["model"]["class_list"])
predictions_tensor = torch.from_numpy(
input_df[headers["prediction"]].to_numpy().ravel()
if problem_type == "classification":
parameters["model"]["num_classes"] = parameters["model"].get(
"num_classes", len(parameters["model"]["class_list"])
)
labels_tensor = torch.from_numpy(input_df[headers["target"]].to_numpy().ravel())

if problem_type == "regression" or problem_type == "classification":
predictions_tensor = torch.from_numpy(input_df["Prediction"].to_numpy().ravel())
labels_tensor = torch.from_numpy(input_df["Target"].to_numpy().ravel())
overall_stats_dict = overall_stats(
predictions_tensor, labels_tensor, parameters
)
Expand All @@ -84,10 +175,10 @@ def generate_metrics_dict(
# read images and then calculate metrics
class_list = parameters["model"]["class_list"]
for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
current_subject_id = row[headers["subjectid"]]
current_subject_id = row["SubjectID"]
overall_stats_dict[current_subject_id] = {}
label_image = torchio.LabelMap(row[headers["target"]])
pred_image = torchio.LabelMap(row[headers["prediction"]])
label_image = torchio.LabelMap(row["Target"])
pred_image = torchio.LabelMap(row["Prediction"])
label_tensor = label_image.data
pred_tensor = pred_image.data
spacing = label_image.spacing
Expand Down Expand Up @@ -225,20 +316,17 @@ def __percentile_clip(
) # normalizes values to [0;1]
return output_tensor

input_df = __update_header_location_case_insensitive(input_df, "Mask", False)
for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
current_subject_id = row[headers["subjectid"]]
current_subject_id = row["SubjectID"]
overall_stats_dict[current_subject_id] = {}
target_image = __fix_2d_tensor(
torchio.ScalarImage(row[headers["target"]]).data
)
pred_image = __fix_2d_tensor(
torchio.ScalarImage(row[headers["prediction"]]).data
)
# if "mask" is not in the row, we assume that the whole image is the mask
target_image = __fix_2d_tensor(torchio.ScalarImage(row["Target"]).data)
pred_image = __fix_2d_tensor(torchio.ScalarImage(row["Prediction"]).data)
# if "Mask" is not in the row, we assume that the whole image is the mask
# always cast to byte tensor
mask = (
__fix_2d_tensor(torchio.LabelMap(row[headers["mask"]]).data)
if "mask" in row
__fix_2d_tensor(torchio.LabelMap(row["Mask"]).data)
if "Mask" in row
else torch.from_numpy(
np.ones(target_image.numpy().shape, dtype=np.uint8)
)
Expand Down
52 changes: 41 additions & 11 deletions GANDLF/entrypoints/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from GANDLF.entrypoints import append_copyright_to_help


def _generate_metrics(input_data: str, config: str, output_file: Optional[str]):
try:
generate_metrics_dict(input_data, config, output_file)
except Exception as e:
# TODO: why catch this? why not rely on normal python behavior?
sys.exit("ERROR: " + str(e))

def _generate_metrics(
input_data: str,
config: str,
output_file: Optional[str],
missing_prediction: int = -1,
):
generate_metrics_dict(input_data, config, output_file, missing_prediction)
print("Finished.")


Expand All @@ -35,7 +35,7 @@ def _generate_metrics(input_data: str, config: str, output_file: Optional[str]):
"--input-data",
"-i",
required=True,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
type=str,
help="The CSV file of input data that is used to generate the metrics; "
"should contain 3 columns: 'SubjectID,Target,Prediction'",
)
Expand All @@ -45,11 +45,30 @@ def _generate_metrics(input_data: str, config: str, output_file: Optional[str]):
type=click.Path(file_okay=True, dir_okay=False),
help="Location to save the output dictionary. If not provided, will print to stdout.",
)
@click.option(
"--missing-prediction",
"-m",
required=False,
type=int,
default=-1,
help="The value to use for missing predictions as penalty; if `-1`, this does not get added. This is only used in the case where the targets and predictions are passed independently.",
)
@click.option("--raw-input", hidden=True)
@append_copyright_to_help
def new_way(config: str, input_data: str, output_file: Optional[str], raw_input: str):
def new_way(
config: str,
input_data: str,
output_file: Optional[str],
missing_prediction: int,
raw_input: str,
):
"""Metrics calculator."""
_generate_metrics(input_data=input_data, config=config, output_file=output_file)
_generate_metrics(
input_data=input_data,
config=config,
output_file=output_file,
missing_prediction=missing_prediction,
)


@deprecated(
Expand Down Expand Up @@ -95,6 +114,14 @@ def old_way():
default=None,
help="Location to save the output dictionary. If not provided, will print to stdout.",
)
parser.add_argument(
"-m",
"--missingprediction",
metavar="",
type=int,
default=-1,
help="The value to use for missing predictions as penalty; if `-1`, this does not get added. This is only used in the case where the targets and predictions are passed independently.",
)
parser.add_argument(
"-v",
"--version",
Expand All @@ -112,7 +139,10 @@ def old_way():
assert args.inputdata is not None, "Missing required parameter: inputdata"

_generate_metrics(
input_data=args.inputdata, config=args.config, output_file=args.outputfile
input_data=args.inputdata,
config=args.config,
output_file=args.outputfile,
missing_prediction=args.missingprediction,
)


Expand Down
22 changes: 12 additions & 10 deletions testing/entrypoints/test_generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@
should_succeed=True,
new_way_lines=[
# full command
"--input-data input.csv --output-file output.json --config config.yaml",
"--input-data input.csv --output-file output.json --config config.yaml --missing-prediction 666",
# tests short arg aliases
"-i input.csv -o output.json -c config.yaml",
"-i input.csv -o output.json -c config.yaml -m 666",
# --raw-input param exists that do nothing
"-i input.csv -o output.json -c config.yaml --raw-input 123321",
"-i input.csv -o output.json -c config.yaml --raw-input 123321 -m 666",
],
old_way_lines=[
"--inputdata input.csv --outputfile output.json --config config.yaml",
"--data_path input.csv --output_path output.json --parameters_file config.yaml",
"-i input.csv -o output.json -c config.yaml",
"--inputdata input.csv --outputfile output.json --config config.yaml --missingprediction 666",
"--data_path input.csv --output_path output.json --parameters_file config.yaml --missingprediction 666",
"-i input.csv -o output.json -c config.yaml -m 666",
# --raw-input param exists that do nothing
"-i input.csv -o output.json -c config.yaml --rawinput 123321",
"-i input.csv -o output.json -c config.yaml -rawinput 123321",
"-i input.csv -o output.json -c config.yaml --rawinput 123321 -m 666",
"-i input.csv -o output.json -c config.yaml -rawinput 123321 -m 666",
],
expected_args={
"input_csv": "input.csv",
"config": "config.yaml",
"outputfile": "output.json",
"missing_prediction": 666,
},
),
CliCase(
Expand All @@ -56,6 +57,7 @@
"input_csv": "input.csv",
"config": "config.yaml",
"outputfile": None,
"missing_prediction": -1,
},
),
CliCase(
Expand All @@ -69,6 +71,7 @@
"input_csv": "input.csv",
"config": "config.yaml",
"outputfile": "output_na.json",
"missing_prediction": -1,
},
),
CliCase(
Expand All @@ -78,8 +81,7 @@
"-o output.json -c config.yaml",
"-i input.csv -o output.json",
# input, config should point to existing file, not dir
"-i path_na -o output.json -c config.yaml",
"-i tmp_dir/ -o output.json -c config.yaml",
# "-i tmp_dir/ -o output.json -c config.yaml",
"-i input.csv -o output.json -c path_na",
"-i input.csv -o output.json -c tmp_dir/",
# output if passed should not point to dir
Expand Down
Loading

0 comments on commit 3e02cec

Please sign in to comment.