Skip to content

Commit

Permalink
test gpu
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jul 4, 2023
1 parent 9a28579 commit c774e92
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
34 changes: 19 additions & 15 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MetricReduction,
convert_data_type,
convert_to_tensor,
convert_to_numpy,
ensure_tuple_rep,
look_up_option,
optional_import,
Expand Down Expand Up @@ -165,38 +166,41 @@ def get_mask_edges(
seg_pred = seg_pred == label_idx
if seg_gt.dtype not in (bool, torch.bool):
seg_gt = seg_gt == label_idx

if crop:
if not (seg_pred | seg_gt).any():
pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt)
or_vol = seg_pred | seg_gt
if not or_vol.any():
pred, gt = np.zeros(seg_pred.shape, dtype=bool), np.zeros(seg_gt.shape, dtype=bool)
return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore
channel_first = [seg_pred[None], seg_gt[None], or_vol[None]]
if spacing is None: # cpu only erosion
seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool)
else: # pytorch subvoxel, maybe on gpu, but croppad boolean values on GPU is not supported
seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, dtype=torch.float16)
cropper = CropForegroundD(
["pred", "gt"], source_key="src", margin=1, allow_smaller=True, start_coord_key=None, end_coord_key=None
)
mask = seg_pred | seg_gt
cropped = cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]}) # type: ignore
seg_pred = cropped["pred"][0]
seg_gt = cropped["gt"][0]

if spacing is None:
# Do binary erosion and use XOR to get edges
seg_pred = convert_data_type(seg_pred, np.ndarray)[0]
seg_gt = convert_data_type(seg_gt, np.ndarray)[0]
cropped = cropper({"pred": seg_pred, "gt": seg_gt, "src": or_vol}) # type: ignore
seg_pred, seg_gt = cropped["pred"][0], cropped["gt"][0]

if spacing is None: # Do binary erosion and use XOR to get edges
seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool)
edges_pred = binary_erosion(seg_pred) ^ seg_pred
edges_gt = binary_erosion(seg_gt) ^ seg_gt
return edges_pred, edges_gt
code_to_area_table, k = get_code_to_measure_table(spacing)
code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device)
spatial_dims = len(spacing)
conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d
code_pred, code_gt = conv(torch.stack([seg_pred[None], seg_gt[None]], dim=0).float(), k.float()) # type: ignore
vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float()
code_pred, code_gt = conv(vol, k.to(vol)) # type: ignore
# edges
all_ones = len(code_to_area_table) - 1
edges_pred = (code_pred != 0) & (code_pred != all_ones)
edges_gt = (code_gt != 0) & (code_gt != all_ones)
# areas of edges
areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape)
areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape)
return edges_pred.array[0], edges_gt.array[0], areas_pred.array[0], areas_gt.array[0] # type: ignore
ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0])
return convert_to_numpy(ret, wrap_sequence=False)


def get_surface_distance(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ def test_compute_surface_dice_subvoxel(self):
)
assert_allclose(res, 0.5, type_test=False)

mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100), torch.zeros(1, 1, 100, 100, 100)
d = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100, device=d), torch.zeros(1, 1, 100, 100, 100, device=d)
mask_gt[0, 0, 0:50, :, :] = 1
mask_pred[0, 0, 0:51, :, :] = 1
res = compute_surface_dice(
Expand Down

0 comments on commit c774e92

Please sign in to comment.