Skip to content

Commit

Permalink
lint fix; add weights phase classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Mar 12, 2024
1 parent de8cc53 commit 554df4d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
Binary file added resources/contrast_phase_classifiers.pkl
Binary file not shown.
28 changes: 14 additions & 14 deletions totalsegmentator/bin/totalseg_get_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
def pi_time_to_phase(pi_time: float) -> str:
"""
Convert the pi time to a phase and get a probability for the value.
native: 0-10
arterial_early: 10-30
arterial_late: 30-50
portal_venous: 50-100
delayed: 100+
returns: phase, probability
"""
if pi_time < 5:
Expand All @@ -43,41 +43,41 @@ def pi_time_to_phase(pi_time: float) -> str:
return "portal_venous", 0.7
else:
return "delayed", 0.7


def get_ct_contrast_phase(ct_img: nib.Nifti1Image):

organs = ["liver", "spleen", "kidney_left", "kidney_right", "pancreas", "urinary_bladder", "gallbladder",
"heart", "aorta", "inferior_vena_cava", "portal_vein_and_splenic_vein",
"iliac_vena_left", "iliac_vena_right", "iliac_artery_left", "iliac_artery_right",
"pulmonary_vein"]

seg_img, stats = totalsegmentator(ct_img, None, ml=True, fast=True, statistics=True,
roi_subset=None, quiet=False)

features = []
for organ in organs:
features.append(stats[organ]["intensity"])
# todo: adapt
# classifier_path = Path(__file__).parent / "classifier.pkl"
classifier_path = "/mnt/nvme/data/phase_classification/classifiers.pkl"

# weights from longitudinalliver dataset
classifier_path = Path(__file__).parents[2] / "resources" / "contrast_phase_classifiers.pkl"
# classifier_path = "/mnt/nvme/data/phase_classification/classifiers.pkl"
clfs = pickle.load(open(classifier_path, "rb"))

# ensemble across folds
preds = []
for fold, clf in clfs.items():
preds.append(clf.predict([features])[0])
preds = np.array(preds)
pi_time = round(float(np.mean(preds)), 2)
pi_time_std = round(float(np.std(preds)), 4)

print("Ensemble res:")
print(preds)
# print(f"mean: {pi_time} +/- {pi_time_std}")
print(f"mean: {pi_time} [{preds.min():.1f}-{preds.max():.1f}]")
phase, probability = pi_time_to_phase(pi_time)

return {"pi_time": pi_time, "phase": phase, "probability": probability}


Expand All @@ -101,10 +101,10 @@ def main():
args = parser.parse_args()

res = get_ct_contrast_phase(nib.load(args.input_file))

print("Result:")
pprint(res)

with open(args.output_file, "w") as f:
f.write(json.dumps(res, indent=4))

Expand Down

0 comments on commit 554df4d

Please sign in to comment.