Skip to content

Commit

Permalink
add option to remove small connected components in postprocesssing
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Aug 22, 2024
1 parent c90b5d9 commit 4c507d9
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Master
* add option to remove small connected components in postprocessing


## Release 2.4.0
Expand Down
6 changes: 5 additions & 1 deletion totalsegmentator/bin/TotalSegmentator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def main():
help="In multilabel file order classes as in v1. New v2 classes will be removed.",
default=False)

parser.add_argument("-rmb", "--remove_small_blobs", action="store_true", help="Remove small connected components (<0.2ml) from the final segmentations.",
default=False) # ~30s runtime because of the large number of classes

# "mps" is for apple silicon; the latest pytorch nightly version supports 3D Conv but not ConvTranspose3D which is
# also needed by nnU-Net. So "mps" not working for now.
# https://github.com/pytorch/pytorch/issues/77818
Expand Down Expand Up @@ -140,7 +143,8 @@ def main():
args.statistics, args.radiomics, args.crop_path, args.body_seg,
args.force_split, args.output_type, args.quiet, args.verbose, args.test, args.skip_saving,
args.device, args.license_number, not args.stats_include_incomplete,
args.no_derived_masks, args.v1_order, args.fastest, args.roi_subset_robust)
args.no_derived_masks, args.v1_order, args.fastest, args.roi_subset_robust,
"mean", args.remove_small_blobs)


if __name__ == "__main__":
Expand Down
18 changes: 15 additions & 3 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
crop_addon=[3,3,3], roi_subset=None, output_type="nifti",
statistics=False, quiet=False, verbose=False, test=0, skip_saving=False,
device="cuda", exclude_masks_at_border=True, no_derived_masks=False,
v1_order=False, stats_aggregation="mean"):
v1_order=False, stats_aggregation="mean", remove_small_blobs=False):
"""
crop: string or a nibabel image
resample: None or float (target spacing for all dimensions) or list of floats
Expand Down Expand Up @@ -527,11 +527,23 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_

if task_name == "body":
vox_vol = np.prod(img_pred.header.get_zooms())
size_thr_mm3 = 50000 / vox_vol
size_thr_mm3 = 50000
img_pred_pp = remove_small_blobs_multilabel(img_pred.get_fdata().astype(np.uint8),
class_map[task_name], ["body_extremities"],
interval=[size_thr_mm3, 1e10], debug=False, quiet=quiet)
interval=[size_thr_mm3/vox_vol, 1e10], debug=False, quiet=quiet)
img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)

# General postprocessing
if remove_small_blobs:
if not quiet: print("Removing small blobs...")
st = time.time()
vox_vol = np.prod(img_pred.header.get_zooms())
size_thr_mm3 = 200
img_pred_pp = remove_small_blobs_multilabel(img_pred.get_fdata().astype(np.uint8),
class_map[task_name], list(class_map[task_name].values()),
interval=[size_thr_mm3/vox_vol, 1e10], debug=False, quiet=quiet) # ~24s
img_pred = nib.Nifti1Image(img_pred_pp, img_pred.affine)
if not quiet: print(f" Removed in {time.time() - st:.2f}s")

if preview:
from totalsegmentator.preview import generate_preview
Expand Down
5 changes: 3 additions & 2 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
force_split=False, output_type="nifti", quiet=False, verbose=False, test=0,
skip_saving=False, device="gpu", license_number=None,
statistics_exclude_masks_at_border=True, no_derived_masks=False,
v1_order=False, fastest=False, roi_subset_robust=None, stats_aggregation="mean"):
v1_order=False, fastest=False, roi_subset_robust=None, stats_aggregation="mean",
remove_small_blobs=False):
"""
Run TotalSegmentator from within python.
Expand Down Expand Up @@ -473,7 +474,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
quiet=quiet, verbose=verbose, test=test, skip_saving=skip_saving, device=device,
exclude_masks_at_border=statistics_exclude_masks_at_border,
no_derived_masks=no_derived_masks, v1_order=v1_order,
stats_aggregation=stats_aggregation)
stats_aggregation=stats_aggregation, remove_small_blobs=remove_small_blobs)
seg = seg_img.get_fdata().astype(np.uint8)

try:
Expand Down
8 changes: 4 additions & 4 deletions totalsegmentator/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def touches_border(mask):
"""
Check if mask touches any of the borders. Then we do not calc any statistics for it because the mask
is incomplete.
Do not check last slice by one previous, because segmentation on last slice often bad.
Do not check last two slices but the previous one, because segmentation on last slices often bad.
"""
if np.any(mask[1, :, :]) or np.any(mask[-2, :, :]):
if np.any(mask[2, :, :]) or np.any(mask[-3, :, :]):
return True
if np.any(mask[:, 1, :]) or np.any(mask[:, -2, :]):
if np.any(mask[:, 2, :]) or np.any(mask[:, -3, :]):
return True
if np.any(mask[:, :, 1]) or np.any(mask[:, :, -2]):
if np.any(mask[:, :, 2]) or np.any(mask[:, :, -3]):
return True
return False

Expand Down

0 comments on commit 4c507d9

Please sign in to comment.