From b4b10f0a9d0d3e9946af914aef655e2b60a29c20 Mon Sep 17 00:00:00 2001 From: Yihao Liu Date: Tue, 8 Aug 2023 19:25:17 -0400 Subject: [PATCH] add bbox support --- Dev.ipynb | 80 ++++- README.md | 2 - samm-python-terminal/sam_server.py | 8 +- samm-python-terminal/utl_sam_msg.py | 9 +- samm-python-terminal/utl_sam_server.py | 43 ++- samm/SammBase/Resources/UI/SammBase.ui | 357 +++++++++++++------ samm/SammBase/SammBaseLib/LogicSamm.py | 32 +- samm/SammBase/SammBaseLib/UtilConnections.py | 2 +- samm/SammBase/SammBaseLib/UtilMsgFactory.py | 9 +- samm/SammBase/SammBaseLib/WidgetSammBase.py | 19 +- 10 files changed, 419 insertions(+), 142 deletions(-) diff --git a/Dev.ipynb b/Dev.ipynb index 751f5f6..c5da79c 100644 --- a/Dev.ipynb +++ b/Dev.ipynb @@ -26,6 +26,46 @@ "slicer.util.selectModule(\"SammBase\")" ] }, + { + "cell_type": "code", + "execution_count": 19, + "id": "642430aa-9f67-4af0-940f-540047fe92a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "vtkPoints (0x62a4f80)\n", + " Debug: Off\n", + " Modified Time: 1776089\n", + " Reference Count: 1\n", + " Registered Events: (none)\n", + " Data: 0x102c7160\n", + " Data Array Name: Points\n", + " Number Of Points: 4\n", + " Bounds: \n", + " Xmin,Xmax: (-32.0076, 95.4545)\n", + " Ymin,Ymax: (-34.8485, 63.2576)\n", + " Zmin,Zmax: (0, 0)\n", + "\n", + "\n", + "0.0\n" + ] + } + ], + "source": [ + "import vtk\n", + "\n", + "logic = SammBaseLogic()\n", + "logic._parameterNode.SetNodeReferenceID(\"sammPrompt2DBox\", \"vtkMRMLMarkupsPlaneNode1\")\n", + "plane = logic._parameterNode.GetNodeReference(\"sammPrompt2DBox\")\n", + "points = vtk.vtkPoints() \n", + "plane.GetPlaneCornerPoints(points)\n", + "print(points)\n", + "print(points.GetPoint(0)[2])\n" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -643,17 +683,49 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "92be4fc6-f838-47d0-ae4e-47c2598eb353", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 2, 3, 4])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "np.array([1,2,3,4]).reshape([1,4])[0]" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "4e83cab3-759b-4bf4-a2ff-2f0134331451", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "False\n" + ] + } + ], + "source": [ + "print([]==None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "512d3635-580d-443a-b011-f51e4bc4f952", + "metadata": {}, "outputs": [], "source": [] } diff --git a/README.md b/README.md index 2ecd12e..8a2561c 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,6 @@ [Laboratory of Biomechanical and Image Guided Surgical Systems](https://bigss.lcsr.jhu.edu/), [Johns Hopkins University](https://www.jhu.edu/) -[Yihao Liu](https://yihao.one/), [Jeremy Zhang](https://jeremyzz830.github.io/), Zhangcong She - ## Known issues #13 Resolved with a not-so-elegant way. Tested work on MRHead, DZ-MR, MRBrainTumor2 and CTChest data from Slicer. Please report a bug if your data is not working. Note for the first few seconds when you start "Mask Sync", the server is not so stable, wait a few seconds then slide it up and down, the mask then will be updated. Note now you can only work on the RED view. Will update later to support all 3 views. diff --git a/samm-python-terminal/sam_server.py b/samm-python-terminal/sam_server.py index 2ff4df4..1ee64a7 100644 --- a/samm-python-terminal/sam_server.py +++ b/samm-python-terminal/sam_server.py @@ -68,10 +68,10 @@ def looping(self): self.cleanup() except zmq.error.Again: continue - except Exception as e: - logging.error(traceback.format_exc()) - time.sleep(self.interv) - continue + # except Exception as e: + # logging.error(traceback.format_exc()) + # time.sleep(self.interv) + # continue time.sleep(self.interv) diff --git a/samm-python-terminal/utl_sam_msg.py b/samm-python-terminal/utl_sam_msg.py index 504b0e9..30234af 100644 --- a/samm-python-terminal/utl_sam_msg.py +++ b/samm-python-terminal/utl_sam_msg.py @@ -109,6 +109,8 @@ def getDecodedData(msgbyte): ''' n : int +view : char +bbox : int 32, 4 positivePrompts, int 32, n * 2 negativePrompts, int 32 ''' @@ -118,12 +120,14 @@ def getEncodedData(self): n = self.msg["n"] view = SammViewMapper[self.msg["view"]] + bbox2D = self.msg["bbox2D"] positivePoints = self.msg["positivePrompts"] negativePoints = self.msg["negativePrompts"] msg = b'' msg += np.array([n], dtype='int32').tobytes() msg += np.array([view], dtype='int32').tobytes() + msg += bbox2D.astype("int32").tobytes() if positivePoints is not None and positivePoints.shape[0] > 0: msg += np.array([positivePoints.shape[0]], dtype='int32').tobytes() @@ -150,6 +154,9 @@ def getDecodedData(msgbyte): msg["view"] = SammViewMapper["DICT"][np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]] pt += 4 + msg["bbox2D"] = np.frombuffer(msgbyte[pt:pt+4*4], dtype="int32").reshape([1,4])[0] + pt += 4*4 + positivePromptNum = np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0] pt += 4 @@ -164,7 +171,7 @@ def getDecodedData(msgbyte): pt += 4 if negativePromptNum > 0: - negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([positivePromptNum, 2]) + negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([negativePromptNum, 2]) pt += 4 * 2 * negativePromptNum msg["negativePrompts"] = negativePrompt else: diff --git a/samm-python-terminal/utl_sam_server.py b/samm-python-terminal/utl_sam_server.py index ae72200..d8ee7ef 100644 --- a/samm-python-terminal/utl_sam_server.py +++ b/samm-python-terminal/utl_sam_server.py @@ -158,10 +158,21 @@ def CalculateEmbeddings(msg): print("[SAMM INFO] Embeddings Cached.") +def helperPredict(dataNode, msg, points, labels, bbox2d): + dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda") + seg, _, _ = dataNode.samPredictor[msg["view"]].predict( + point_coords = points, + point_labels = labels, + box = bbox2d, + multimask_output = False,) + seg = seg[0] + return seg + def sammProcessingCallBack_INFERENCE(msg): dataNode = SammParameterNode() positivePoints = msg["positivePrompts"] negativePoints = msg["negativePrompts"] + bbox2d = msg["bbox2D"] points = [] labels = [] @@ -176,20 +187,26 @@ def sammProcessingCallBack_INFERENCE(msg): labels.append(0) seg = None - if len(points) > 0: - if msg["view"] == "R": - tempsize = [dataNode.imageSize[1], dataNode.imageSize[2]] - if msg["view"] == "G": - tempsize = [dataNode.imageSize[0], dataNode.imageSize[2]] - if msg["view"] == "Y": - tempsize = [dataNode.imageSize[0], dataNode.imageSize[1]] + if len(points) > 0 and (bbox2d[0]!=-404): + + points = np.array(points) + point_labels = np.array(labels) + bbox2d = np.array(bbox2d) + seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) + + elif len(points) == 0 and (bbox2d[0]!=-404): + + points = None + point_labels = None + bbox2d = np.array(bbox2d) + seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) + + elif len(points) > 0 and (bbox2d[0]==-404): - dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda") - seg, _, _ = dataNode.samPredictor[msg["view"]].predict( - point_coords = np.array(points), - point_labels = np.array(labels), - multimask_output = False,) - seg = seg[0] + points = np.array(points) + point_labels = np.array(labels) + bbox2d = None + seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) else: if msg["view"] == "R": diff --git a/samm/SammBase/Resources/UI/SammBase.ui b/samm/SammBase/Resources/UI/SammBase.ui index 9fc52f4..81214c4 100644 --- a/samm/SammBase/Resources/UI/SammBase.ui +++ b/samm/SammBase/Resources/UI/SammBase.ui @@ -7,10 +7,34 @@ 0 0 416 - 854 + 1055 + + + + 0 + + + + SAM and Variants + + + + + + + + + Model Selection + + + + + + + @@ -138,17 +162,190 @@ - - - - true + + + + Unfreeze Slice + + + + + + 0 + + + + Points + + + + + + + 0 + 30 + + + + Add + + + + + + + false + + + + 0 + 255 + 0 + + + + + 0 + 255 + 0 + + + + + + + + + 0 + 30 + + + + Remove + + + + + + + false + + + + 255 + 0 + 0 + + + + + + + + + 2D Box + + + + + + Bounding Box + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + true + + + + vtkMRMLMarkupsPlaneNode + + + + + + + + + + + + + + Add a BBox + + + + + + + + 3D Box + + + + + + Bounding Box + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + false + + + + vtkMRMLMarkupsROINode + + + + + + + + + + + + + + + + - Green + Stop Mask Sync - - buttonGroupWorkOn - @@ -164,20 +361,6 @@ - - - - Start Mask Sync - - - - - - - Stop Mask Sync - - - @@ -185,71 +368,23 @@ - - - - Unfreeze Slice - - - - - - - - 0 - 30 - + + + + true - Prompt - Add - - - - - - - false - - - - 0 - 255 - 0 - - - - - 0 - 255 - 0 - + Green + + buttonGroupWorkOn + - - - - - 0 - 30 - - + + - Prompt - Remove - - - - - - - false - - - - 255 - 0 - 0 - + Start Mask Sync @@ -391,30 +526,6 @@ - - - - 0 - - - - SAM and Variants - - - - - - - - - Model Selection - - - - - - - @@ -512,9 +623,41 @@ + + SammBase + mrmlSceneChanged(vtkMRMLScene*) + markups2DBox + setMRMLScene(vtkMRMLScene*) + + + 207 + 527 + + + 207 + 679 + + + + + SammBase + mrmlSceneChanged(vtkMRMLScene*) + markups3DBox + setMRMLScene(vtkMRMLScene*) + + + 207 + 527 + + + 207 + 700 + + + - + diff --git a/samm/SammBase/SammBaseLib/LogicSamm.py b/samm/SammBase/SammBaseLib/LogicSamm.py index b60c1df..b87287b 100644 --- a/samm/SammBase/SammBaseLib/LogicSamm.py +++ b/samm/SammBase/SammBaseLib/LogicSamm.py @@ -205,6 +205,7 @@ def processInitPromptSync(self): # Init self._prompt_add = self._parameterNode.GetNodeReference("sammPromptAdd") self._prompt_remove = self._parameterNode.GetNodeReference("sammPromptRemove") + self._prompt_2dbox = self._parameterNode.GetNodeReference("sammPrompt2DBox") if self._parameterNode.GetParameter("sammViewOptions") == "RED": self._slider = \ @@ -250,7 +251,9 @@ def utilGetCurrentSliceIndex(self): imshape = (self._parameterNode._volMetaData[0][0], self._parameterNode._volMetaData[0][1]) return curslc, view, imshape - def utilGetPositionOnSlicer(self, temp, view): + def utilGetPositionOnSlice(self, ras, view): + temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1]) + temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder] if view == "RED": return [round(temp[1]),round(temp[2])] if view == "GREEN": @@ -268,7 +271,7 @@ def processStartPromptSync(self): if self._flag_prompt_sync: curslc, view, imshape = self.utilGetCurrentSliceIndex() - curslc = int(curslc) + curslc = round(curslc) mask = None @@ -278,21 +281,34 @@ def processStartPromptSync(self): for i in range(numControlPoints): ras = vtk.vtkVector3d(0.0, 0.0, 0.0) self._prompt_add.GetNthControlPointPosition(i,ras) - temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1]) - temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder] - prompt_add_point.append(self.utilGetPositionOnSlicer(temp, view)) + prompt_add_point.append(self.utilGetPositionOnSlice(ras, view)) numControlPoints = self._prompt_remove.GetNumberOfControlPoints() for i in range(numControlPoints): ras = vtk.vtkVector3d(0.0, 0.0, 0.0) self._prompt_remove.GetNthControlPointPosition(i,ras) - temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1]) - temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder] - prompt_remove_point.append(self.utilGetPositionOnSlicer(temp, view)) + prompt_remove_point.append(self.utilGetPositionOnSlice(ras, view)) + + plane = self._parameterNode.GetNodeReference("sammPrompt2DBox") + + if plane: + points = vtk.vtkPoints() + plane.GetPlaneCornerPoints(points) + ras = [points.GetPoint(0)[0],points.GetPoint(0)[1],points.GetPoint(0)[2]] + bboxmin = self.utilGetPositionOnSlice(ras, view) + ras = [points.GetPoint(2)[0],points.GetPoint(2)[1],points.GetPoint(2)[2]] + bboxmax = self.utilGetPositionOnSlice(ras, view) + if bboxmin[0] > bboxmax[0]: + bboxmin_ = [bboxmin[0], bboxmin[1]] + bboxmin = [bboxmax[0], bboxmax[1]] + bboxmax = [bboxmin_[0], bboxmin_[1]] + else: + bboxmin, bboxmax = [-404,-404], [-404,-404] mask = self._connections.pushRequest(SammMsgType.INFERENCE, { "n" : curslc, "view" : self._parameterNode.GetParameter("sammViewOptions")[0], + "bbox2D" : np.array([bboxmin[0], bboxmin[1], bboxmax[0], bboxmax[1]]), "positivePrompts" : np.array(prompt_add_point), "negativePrompts" : np.array(prompt_remove_point) }) diff --git a/samm/SammBase/SammBaseLib/UtilConnections.py b/samm/SammBase/SammBaseLib/UtilConnections.py index fcd8f14..a8c31ce 100644 --- a/samm/SammBase/SammBaseLib/UtilConnections.py +++ b/samm/SammBase/SammBaseLib/UtilConnections.py @@ -34,7 +34,7 @@ def pushRequest(self, requireType, MSG): msgByte = SammMsgSolverMapper[requireType](MSG).getEncodedData() sock = self.context.socket(zmq.REQ) # if no receiption, try extending the wait time. The first setup time takes longer - sock.setsockopt(zmq.RCVTIMEO, 20000) + sock.setsockopt(zmq.RCVTIMEO, 10000) sock.connect("tcp://%s:%s" % (self.ip, self.portControl)) sock.send_multipart([commandByte, msgByte]) feedback = sock.recv() diff --git a/samm/SammBase/SammBaseLib/UtilMsgFactory.py b/samm/SammBase/SammBaseLib/UtilMsgFactory.py index 504b0e9..30234af 100644 --- a/samm/SammBase/SammBaseLib/UtilMsgFactory.py +++ b/samm/SammBase/SammBaseLib/UtilMsgFactory.py @@ -109,6 +109,8 @@ def getDecodedData(msgbyte): ''' n : int +view : char +bbox : int 32, 4 positivePrompts, int 32, n * 2 negativePrompts, int 32 ''' @@ -118,12 +120,14 @@ def getEncodedData(self): n = self.msg["n"] view = SammViewMapper[self.msg["view"]] + bbox2D = self.msg["bbox2D"] positivePoints = self.msg["positivePrompts"] negativePoints = self.msg["negativePrompts"] msg = b'' msg += np.array([n], dtype='int32').tobytes() msg += np.array([view], dtype='int32').tobytes() + msg += bbox2D.astype("int32").tobytes() if positivePoints is not None and positivePoints.shape[0] > 0: msg += np.array([positivePoints.shape[0]], dtype='int32').tobytes() @@ -150,6 +154,9 @@ def getDecodedData(msgbyte): msg["view"] = SammViewMapper["DICT"][np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]] pt += 4 + msg["bbox2D"] = np.frombuffer(msgbyte[pt:pt+4*4], dtype="int32").reshape([1,4])[0] + pt += 4*4 + positivePromptNum = np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0] pt += 4 @@ -164,7 +171,7 @@ def getDecodedData(msgbyte): pt += 4 if negativePromptNum > 0: - negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([positivePromptNum, 2]) + negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([negativePromptNum, 2]) pt += 4 * 2 * negativePromptNum msg["negativePrompts"] = negativePrompt else: diff --git a/samm/SammBase/SammBaseLib/WidgetSammBase.py b/samm/SammBase/SammBaseLib/WidgetSammBase.py index cb72f86..65db2e2 100644 --- a/samm/SammBase/SammBaseLib/WidgetSammBase.py +++ b/samm/SammBase/SammBaseLib/WidgetSammBase.py @@ -67,6 +67,8 @@ def setup(self): self.ui.markupsAdd.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI) self.ui.markupsRemove.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI) + self.ui.markups2DBox.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI) + self.ui.pushMarkups2DBox.connect("clicked(bool)", self.onPushMarkups2DBox) self.ui.markupsAdd.markupsPlaceWidget().setPlaceModePersistency(True) self.ui.markupsRemove.markupsPlaceWidget().setPlaceModePersistency(True) @@ -92,6 +94,7 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.comboVolumeNode.setCurrentNode(self._parameterNode.GetNodeReference("sammInputVolume")) self.ui.markupsAdd.setCurrentNode(self._parameterNode.GetNodeReference("sammPromptAdd")) self.ui.markupsRemove.setCurrentNode(self._parameterNode.GetNodeReference("sammPromptRemove")) + self.ui.markups2DBox.setCurrentNode(self._parameterNode.GetNodeReference("sammPrompt2DBox")) self.ui.comboSegmentationNode.setCurrentNode(self._parameterNode.GetNodeReference("sammSegmentation")) @@ -128,6 +131,8 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): self._parameterNode.SetNodeReferenceID("sammInputVolume", self.ui.comboVolumeNode.currentNodeID) self._parameterNode.SetNodeReferenceID("sammPromptAdd", self.ui.markupsAdd.currentNode().GetID()) self._parameterNode.SetNodeReferenceID("sammPromptRemove", self.ui.markupsRemove.currentNode().GetID()) + if self.ui.markups2DBox.currentNode(): + self._parameterNode.SetNodeReferenceID("sammPrompt2DBox", self.ui.markups2DBox.currentNode().GetID()) self._parameterNode._workspace = os.path.dirname(os.path.abspath(self.ui.pathWorkSpace.currentPath.strip())) self._parameterNode.SetNodeReferenceID("sammSegmentation", self.ui.comboSegmentationNode.currentNodeID) self._parameterNode.GetNodeReference("sammSegmentation").SetReferenceImageGeometryParameterFromVolumeNode( @@ -203,4 +208,16 @@ def onPushModuleSeg(self): slicer.util.selectModule("Segmentations") def onPushModuleSegEditor(self): - slicer.util.selectModule("SegmentEditor") \ No newline at end of file + slicer.util.selectModule("SegmentEditor") + + def onPushMarkups2DBox(self): + planeNode = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLMarkupsPlaneNode').GetID() + selectionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLSelectionNodeSingleton") + selectionNode.SetReferenceActivePlaceNodeID(planeNode) + interactionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLInteractionNodeSingleton") + placeModePersistence = 0 + interactionNode.SetPlaceModePersistence(placeModePersistence) + # mode 1 is Place, can also be accessed via slicer.vtkMRMLInteractionNode().Place + interactionNode.SetCurrentInteractionMode(1) + + self._parameterNode.SetNodeReferenceID("sammPrompt2DBox", planeNode) \ No newline at end of file