From fc41eb4ce311334a43ccbfc2d4332ca784ac4a61 Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:06:45 -0800 Subject: [PATCH] threshold on gpu (#348) Co-authored-by: Benjamin Morris --- cyto_dl/models/im2im/utils/instance_seg.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cyto_dl/models/im2im/utils/instance_seg.py b/cyto_dl/models/im2im/utils/instance_seg.py index c3b8e6f32..9223321ff 100644 --- a/cyto_dl/models/im2im/utils/instance_seg.py +++ b/cyto_dl/models/im2im/utils/instance_seg.py @@ -504,12 +504,11 @@ def cluster_object(self, semantic, skel, embedding): return out def __call__(self, image): - image = image.detach().cpu().half().numpy() + image = image.detach().half() + naive_labeling, _ = label((image[1] > self.semantic_threshold).cpu()) + skel = image[0].cpu().numpy() + embedding = image[2 : 2 + self.dim].cpu().numpy() - skel = image[0] - naive_labeling, _ = label(image[1] > self.semantic_threshold) - - embedding = image[2 : 2 + self.dim] regions = enumerate(find_objects(naive_labeling), start=1) highest_cell_idx = 0