From b4b10f0a9d0d3e9946af914aef655e2b60a29c20 Mon Sep 17 00:00:00 2001 From: Yihao Liu <yliu333@jhu.edu> Date: Tue, 8 Aug 2023 19:25:17 -0400 Subject: [PATCH] add bbox support --- Dev.ipynb | 80 ++++- README.md | 2 - samm-python-terminal/sam_server.py | 8 +- samm-python-terminal/utl_sam_msg.py | 9 +- samm-python-terminal/utl_sam_server.py | 43 ++- samm/SammBase/Resources/UI/SammBase.ui | 357 +++++++++++++------ samm/SammBase/SammBaseLib/LogicSamm.py | 32 +- samm/SammBase/SammBaseLib/UtilConnections.py | 2 +- samm/SammBase/SammBaseLib/UtilMsgFactory.py | 9 +- samm/SammBase/SammBaseLib/WidgetSammBase.py | 19 +- 10 files changed, 419 insertions(+), 142 deletions(-) diff --git a/Dev.ipynb b/Dev.ipynb index 751f5f6..c5da79c 100644 --- a/Dev.ipynb +++ b/Dev.ipynb @@ -26,6 +26,46 @@ "slicer.util.selectModule(\"SammBase\")" ] }, + { + "cell_type": "code", + "execution_count": 19, + "id": "642430aa-9f67-4af0-940f-540047fe92a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "vtkPoints (0x62a4f80)\n", + " Debug: Off\n", + " Modified Time: 1776089\n", + " Reference Count: 1\n", + " Registered Events: (none)\n", + " Data: 0x102c7160\n", + " Data Array Name: Points\n", + " Number Of Points: 4\n", + " Bounds: \n", + " Xmin,Xmax: (-32.0076, 95.4545)\n", + " Ymin,Ymax: (-34.8485, 63.2576)\n", + " Zmin,Zmax: (0, 0)\n", + "\n", + "\n", + "0.0\n" + ] + } + ], + "source": [ + "import vtk\n", + "\n", + "logic = SammBaseLogic()\n", + "logic._parameterNode.SetNodeReferenceID(\"sammPrompt2DBox\", \"vtkMRMLMarkupsPlaneNode1\")\n", + "plane = logic._parameterNode.GetNodeReference(\"sammPrompt2DBox\")\n", + "points = vtk.vtkPoints() \n", + "plane.GetPlaneCornerPoints(points)\n", + "print(points)\n", + "print(points.GetPoint(0)[2])\n" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -643,17 +683,49 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "92be4fc6-f838-47d0-ae4e-47c2598eb353", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 2, 3, 4])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "np.array([1,2,3,4]).reshape([1,4])[0]" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "4e83cab3-759b-4bf4-a2ff-2f0134331451", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "False\n" + ] + } + ], + "source": [ + "print([]==None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "512d3635-580d-443a-b011-f51e4bc4f952", + "metadata": {}, "outputs": [], "source": [] } diff --git a/README.md b/README.md index 2ecd12e..8a2561c 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,6 @@ [Laboratory of Biomechanical and Image Guided Surgical Systems](https://bigss.lcsr.jhu.edu/), [Johns Hopkins University](https://www.jhu.edu/) -[Yihao Liu](https://yihao.one/), [Jeremy Zhang](https://jeremyzz830.github.io/), Zhangcong She - ## Known issues #13 Resolved with a not-so-elegant way. Tested work on MRHead, DZ-MR, MRBrainTumor2 and CTChest data from Slicer. Please report a bug if your data is not working. Note for the first few seconds when you start "Mask Sync", the server is not so stable, wait a few seconds then slide it up and down, the mask then will be updated. Note now you can only work on the RED view. Will update later to support all 3 views. diff --git a/samm-python-terminal/sam_server.py b/samm-python-terminal/sam_server.py index 2ff4df4..1ee64a7 100644 --- a/samm-python-terminal/sam_server.py +++ b/samm-python-terminal/sam_server.py @@ -68,10 +68,10 @@ def looping(self): self.cleanup() except zmq.error.Again: continue - except Exception as e: - logging.error(traceback.format_exc()) - time.sleep(self.interv) - continue + # except Exception as e: + # logging.error(traceback.format_exc()) + # time.sleep(self.interv) + # continue time.sleep(self.interv) diff --git a/samm-python-terminal/utl_sam_msg.py b/samm-python-terminal/utl_sam_msg.py index 504b0e9..30234af 100644 --- a/samm-python-terminal/utl_sam_msg.py +++ b/samm-python-terminal/utl_sam_msg.py @@ -109,6 +109,8 @@ def getDecodedData(msgbyte): ''' n : int +view : char +bbox : int 32, 4 positivePrompts, int 32, n * 2 negativePrompts, int 32 ''' @@ -118,12 +120,14 @@ def getEncodedData(self): n = self.msg["n"] view = SammViewMapper[self.msg["view"]] + bbox2D = self.msg["bbox2D"] positivePoints = self.msg["positivePrompts"] negativePoints = self.msg["negativePrompts"] msg = b'' msg += np.array([n], dtype='int32').tobytes() msg += np.array([view], dtype='int32').tobytes() + msg += bbox2D.astype("int32").tobytes() if positivePoints is not None and positivePoints.shape[0] > 0: msg += np.array([positivePoints.shape[0]], dtype='int32').tobytes() @@ -150,6 +154,9 @@ def getDecodedData(msgbyte): msg["view"] = SammViewMapper["DICT"][np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]] pt += 4 + msg["bbox2D"] = np.frombuffer(msgbyte[pt:pt+4*4], dtype="int32").reshape([1,4])[0] + pt += 4*4 + positivePromptNum = np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0] pt += 4 @@ -164,7 +171,7 @@ def getDecodedData(msgbyte): pt += 4 if negativePromptNum > 0: - negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([positivePromptNum, 2]) + negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([negativePromptNum, 2]) pt += 4 * 2 * negativePromptNum msg["negativePrompts"] = negativePrompt else: diff --git a/samm-python-terminal/utl_sam_server.py b/samm-python-terminal/utl_sam_server.py index ae72200..d8ee7ef 100644 --- a/samm-python-terminal/utl_sam_server.py +++ b/samm-python-terminal/utl_sam_server.py @@ -158,10 +158,21 @@ def CalculateEmbeddings(msg): print("[SAMM INFO] Embeddings Cached.") +def helperPredict(dataNode, msg, points, labels, bbox2d): + dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda") + seg, _, _ = dataNode.samPredictor[msg["view"]].predict( + point_coords = points, + point_labels = labels, + box = bbox2d, + multimask_output = False,) + seg = seg[0] + return seg + def sammProcessingCallBack_INFERENCE(msg): dataNode = SammParameterNode() positivePoints = msg["positivePrompts"] negativePoints = msg["negativePrompts"] + bbox2d = msg["bbox2D"] points = [] labels = [] @@ -176,20 +187,26 @@ def sammProcessingCallBack_INFERENCE(msg): labels.append(0) seg = None - if len(points) > 0: - if msg["view"] == "R": - tempsize = [dataNode.imageSize[1], dataNode.imageSize[2]] - if msg["view"] == "G": - tempsize = [dataNode.imageSize[0], dataNode.imageSize[2]] - if msg["view"] == "Y": - tempsize = [dataNode.imageSize[0], dataNode.imageSize[1]] + if len(points) > 0 and (bbox2d[0]!=-404): + + points = np.array(points) + point_labels = np.array(labels) + bbox2d = np.array(bbox2d) + seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) + + elif len(points) == 0 and (bbox2d[0]!=-404): + + points = None + point_labels = None + bbox2d = np.array(bbox2d) + seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) + + elif len(points) > 0 and (bbox2d[0]==-404): - dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda") - seg, _, _ = dataNode.samPredictor[msg["view"]].predict( - point_coords = np.array(points), - point_labels = np.array(labels), - multimask_output = False,) - seg = seg[0] + points = np.array(points) + point_labels = np.array(labels) + bbox2d = None + seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) else: if msg["view"] == "R": diff --git a/samm/SammBase/Resources/UI/SammBase.ui b/samm/SammBase/Resources/UI/SammBase.ui index 9fc52f4..81214c4 100644 --- a/samm/SammBase/Resources/UI/SammBase.ui +++ b/samm/SammBase/Resources/UI/SammBase.ui @@ -7,10 +7,34 @@ <x>0</x> <y>0</y> <width>416</width> - <height>854</height> + <height>1055</height> </rect> </property> <layout class="QGridLayout" name="gridLayout"> + <item row="1" column="0"> + <widget class="QTabWidget" name="tabWidget_3"> + <property name="currentIndex"> + <number>0</number> + </property> + <widget class="QWidget" name="tab_4"> + <attribute name="title"> + <string>SAM and Variants</string> + </attribute> + <layout class="QGridLayout" name="gridLayout_5"> + <item row="1" column="0"> + <widget class="QComboBox" name="comboModel"/> + </item> + <item row="0" column="0"> + <widget class="QLabel" name="label_11"> + <property name="text"> + <string>Model Selection</string> + </property> + </widget> + </item> + </layout> + </widget> + </widget> + </item> <item row="5" column="0"> <spacer name="verticalSpacer"> <property name="orientation"> @@ -138,17 +162,190 @@ </attribute> </widget> </item> - <item row="1" column="1"> - <widget class="QRadioButton" name="radioWorkOnGreen"> - <property name="enabled"> - <bool>true</bool> + <item row="4" column="1"> + <widget class="QPushButton" name="pushUnfreezeSlice"> + <property name="text"> + <string>Unfreeze Slice</string> </property> + </widget> + </item> + <item row="5" column="0" colspan="2"> + <widget class="QTabWidget" name="tabWidget_4"> + <property name="currentIndex"> + <number>0</number> + </property> + <widget class="QWidget" name="tab_5"> + <attribute name="title"> + <string>Points</string> + </attribute> + <layout class="QGridLayout" name="gridLayout_6"> + <item row="0" column="0"> + <widget class="QLabel" name="label_3"> + <property name="minimumSize"> + <size> + <width>0</width> + <height>30</height> + </size> + </property> + <property name="text"> + <string>Add</string> + </property> + </widget> + </item> + <item row="1" column="0"> + <widget class="qSlicerSimpleMarkupsWidget" name="markupsAdd"> + <property name="enterPlaceModeOnNodeChange"> + <bool>false</bool> + </property> + <property name="nodeColor"> + <color> + <red>0</red> + <green>255</green> + <blue>0</blue> + </color> + </property> + <property name="defaultNodeColor"> + <color> + <red>0</red> + <green>255</green> + <blue>0</blue> + </color> + </property> + </widget> + </item> + <item row="2" column="0"> + <widget class="QLabel" name="label_4"> + <property name="minimumSize"> + <size> + <width>0</width> + <height>30</height> + </size> + </property> + <property name="text"> + <string>Remove</string> + </property> + </widget> + </item> + <item row="3" column="0"> + <widget class="qSlicerSimpleMarkupsWidget" name="markupsRemove"> + <property name="enterPlaceModeOnNodeChange"> + <bool>false</bool> + </property> + <property name="defaultNodeColor"> + <color> + <red>255</red> + <green>0</green> + <blue>0</blue> + </color> + </property> + </widget> + </item> + </layout> + </widget> + <widget class="QWidget" name="tab_6"> + <attribute name="title"> + <string>2D Box</string> + </attribute> + <layout class="QGridLayout" name="gridLayout_7"> + <item row="0" column="0"> + <widget class="QLabel" name="label_12"> + <property name="text"> + <string>Bounding Box</string> + </property> + </widget> + </item> + <item row="2" column="0"> + <spacer name="verticalSpacer_3"> + <property name="orientation"> + <enum>Qt::Vertical</enum> + </property> + <property name="sizeHint" stdset="0"> + <size> + <width>20</width> + <height>40</height> + </size> + </property> + </spacer> + </item> + <item row="1" column="0"> + <widget class="qMRMLNodeComboBox" name="markups2DBox"> + <property name="enabled"> + <bool>true</bool> + </property> + <property name="nodeTypes"> + <stringlist notr="true"> + <string>vtkMRMLMarkupsPlaneNode</string> + </stringlist> + </property> + <property name="hideChildNodeTypes"> + <stringlist notr="true"/> + </property> + <property name="interactionNodeSingletonTag"> + <string notr="true"/> + </property> + </widget> + </item> + <item row="1" column="1"> + <widget class="QPushButton" name="pushMarkups2DBox"> + <property name="text"> + <string>Add a BBox</string> + </property> + </widget> + </item> + </layout> + </widget> + <widget class="QWidget" name="tab_7"> + <attribute name="title"> + <string>3D Box</string> + </attribute> + <layout class="QGridLayout" name="gridLayout_8"> + <item row="0" column="0"> + <widget class="QLabel" name="label_13"> + <property name="text"> + <string>Bounding Box</string> + </property> + </widget> + </item> + <item row="2" column="0"> + <spacer name="verticalSpacer_4"> + <property name="orientation"> + <enum>Qt::Vertical</enum> + </property> + <property name="sizeHint" stdset="0"> + <size> + <width>20</width> + <height>40</height> + </size> + </property> + </spacer> + </item> + <item row="1" column="0"> + <widget class="qMRMLNodeComboBox" name="markups3DBox"> + <property name="enabled"> + <bool>false</bool> + </property> + <property name="nodeTypes"> + <stringlist notr="true"> + <string>vtkMRMLMarkupsROINode</string> + </stringlist> + </property> + <property name="hideChildNodeTypes"> + <stringlist notr="true"/> + </property> + <property name="interactionNodeSingletonTag"> + <string notr="true"/> + </property> + </widget> + </item> + </layout> + </widget> + </widget> + </item> + <item row="3" column="1"> + <widget class="QPushButton" name="pushStopMaskSync"> <property name="text"> - <string>Green</string> + <string>Stop Mask Sync</string> </property> - <attribute name="buttonGroup"> - <string notr="true">buttonGroupWorkOn</string> - </attribute> </widget> </item> <item row="2" column="0"> @@ -164,20 +361,6 @@ </attribute> </widget> </item> - <item row="3" column="0"> - <widget class="QPushButton" name="pushStartMaskSync"> - <property name="text"> - <string>Start Mask Sync</string> - </property> - </widget> - </item> - <item row="3" column="1"> - <widget class="QPushButton" name="pushStopMaskSync"> - <property name="text"> - <string>Stop Mask Sync</string> - </property> - </widget> - </item> <item row="4" column="0"> <widget class="QPushButton" name="pushFreezeSlice"> <property name="text"> @@ -185,71 +368,23 @@ </property> </widget> </item> - <item row="4" column="1"> - <widget class="QPushButton" name="pushUnfreezeSlice"> - <property name="text"> - <string>Unfreeze Slice</string> - </property> - </widget> - </item> - <item row="5" column="0"> - <widget class="QLabel" name="label_3"> - <property name="minimumSize"> - <size> - <width>0</width> - <height>30</height> - </size> + <item row="1" column="1"> + <widget class="QRadioButton" name="radioWorkOnGreen"> + <property name="enabled"> + <bool>true</bool> </property> <property name="text"> - <string>Prompt - Add</string> - </property> - </widget> - </item> - <item row="6" column="0" colspan="2"> - <widget class="qSlicerSimpleMarkupsWidget" name="markupsAdd"> - <property name="enterPlaceModeOnNodeChange"> - <bool>false</bool> - </property> - <property name="nodeColor"> - <color> - <red>0</red> - <green>255</green> - <blue>0</blue> - </color> - </property> - <property name="defaultNodeColor"> - <color> - <red>0</red> - <green>255</green> - <blue>0</blue> - </color> + <string>Green</string> </property> + <attribute name="buttonGroup"> + <string notr="true">buttonGroupWorkOn</string> + </attribute> </widget> </item> - <item row="7" column="0"> - <widget class="QLabel" name="label_4"> - <property name="minimumSize"> - <size> - <width>0</width> - <height>30</height> - </size> - </property> + <item row="3" column="0"> + <widget class="QPushButton" name="pushStartMaskSync"> <property name="text"> - <string>Prompt - Remove</string> - </property> - </widget> - </item> - <item row="8" column="0" colspan="2"> - <widget class="qSlicerSimpleMarkupsWidget" name="markupsRemove"> - <property name="enterPlaceModeOnNodeChange"> - <bool>false</bool> - </property> - <property name="defaultNodeColor"> - <color> - <red>255</red> - <green>0</green> - <blue>0</blue> - </color> + <string>Start Mask Sync</string> </property> </widget> </item> @@ -391,30 +526,6 @@ </widget> </widget> </item> - <item row="1" column="0"> - <widget class="QTabWidget" name="tabWidget_3"> - <property name="currentIndex"> - <number>0</number> - </property> - <widget class="QWidget" name="tab_4"> - <attribute name="title"> - <string>SAM and Variants</string> - </attribute> - <layout class="QGridLayout" name="gridLayout_5"> - <item row="1" column="0"> - <widget class="QComboBox" name="comboModel"/> - </item> - <item row="0" column="0"> - <widget class="QLabel" name="label_11"> - <property name="text"> - <string>Model Selection</string> - </property> - </widget> - </item> - </layout> - </widget> - </widget> - </item> </layout> </widget> <customwidgets> @@ -512,9 +623,41 @@ </hint> </hints> </connection> + <connection> + <sender>SammBase</sender> + <signal>mrmlSceneChanged(vtkMRMLScene*)</signal> + <receiver>markups2DBox</receiver> + <slot>setMRMLScene(vtkMRMLScene*)</slot> + <hints> + <hint type="sourcelabel"> + <x>207</x> + <y>527</y> + </hint> + <hint type="destinationlabel"> + <x>207</x> + <y>679</y> + </hint> + </hints> + </connection> + <connection> + <sender>SammBase</sender> + <signal>mrmlSceneChanged(vtkMRMLScene*)</signal> + <receiver>markups3DBox</receiver> + <slot>setMRMLScene(vtkMRMLScene*)</slot> + <hints> + <hint type="sourcelabel"> + <x>207</x> + <y>527</y> + </hint> + <hint type="destinationlabel"> + <x>207</x> + <y>700</y> + </hint> + </hints> + </connection> </connections> <buttongroups> - <buttongroup name="buttonGroupWorkOn"/> <buttongroup name="buttonGroupDataType"/> + <buttongroup name="buttonGroupWorkOn"/> </buttongroups> </ui> diff --git a/samm/SammBase/SammBaseLib/LogicSamm.py b/samm/SammBase/SammBaseLib/LogicSamm.py index b60c1df..b87287b 100644 --- a/samm/SammBase/SammBaseLib/LogicSamm.py +++ b/samm/SammBase/SammBaseLib/LogicSamm.py @@ -205,6 +205,7 @@ def processInitPromptSync(self): # Init self._prompt_add = self._parameterNode.GetNodeReference("sammPromptAdd") self._prompt_remove = self._parameterNode.GetNodeReference("sammPromptRemove") + self._prompt_2dbox = self._parameterNode.GetNodeReference("sammPrompt2DBox") if self._parameterNode.GetParameter("sammViewOptions") == "RED": self._slider = \ @@ -250,7 +251,9 @@ def utilGetCurrentSliceIndex(self): imshape = (self._parameterNode._volMetaData[0][0], self._parameterNode._volMetaData[0][1]) return curslc, view, imshape - def utilGetPositionOnSlicer(self, temp, view): + def utilGetPositionOnSlice(self, ras, view): + temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1]) + temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder] if view == "RED": return [round(temp[1]),round(temp[2])] if view == "GREEN": @@ -268,7 +271,7 @@ def processStartPromptSync(self): if self._flag_prompt_sync: curslc, view, imshape = self.utilGetCurrentSliceIndex() - curslc = int(curslc) + curslc = round(curslc) mask = None @@ -278,21 +281,34 @@ def processStartPromptSync(self): for i in range(numControlPoints): ras = vtk.vtkVector3d(0.0, 0.0, 0.0) self._prompt_add.GetNthControlPointPosition(i,ras) - temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1]) - temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder] - prompt_add_point.append(self.utilGetPositionOnSlicer(temp, view)) + prompt_add_point.append(self.utilGetPositionOnSlice(ras, view)) numControlPoints = self._prompt_remove.GetNumberOfControlPoints() for i in range(numControlPoints): ras = vtk.vtkVector3d(0.0, 0.0, 0.0) self._prompt_remove.GetNthControlPointPosition(i,ras) - temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1]) - temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder] - prompt_remove_point.append(self.utilGetPositionOnSlicer(temp, view)) + prompt_remove_point.append(self.utilGetPositionOnSlice(ras, view)) + + plane = self._parameterNode.GetNodeReference("sammPrompt2DBox") + + if plane: + points = vtk.vtkPoints() + plane.GetPlaneCornerPoints(points) + ras = [points.GetPoint(0)[0],points.GetPoint(0)[1],points.GetPoint(0)[2]] + bboxmin = self.utilGetPositionOnSlice(ras, view) + ras = [points.GetPoint(2)[0],points.GetPoint(2)[1],points.GetPoint(2)[2]] + bboxmax = self.utilGetPositionOnSlice(ras, view) + if bboxmin[0] > bboxmax[0]: + bboxmin_ = [bboxmin[0], bboxmin[1]] + bboxmin = [bboxmax[0], bboxmax[1]] + bboxmax = [bboxmin_[0], bboxmin_[1]] + else: + bboxmin, bboxmax = [-404,-404], [-404,-404] mask = self._connections.pushRequest(SammMsgType.INFERENCE, { "n" : curslc, "view" : self._parameterNode.GetParameter("sammViewOptions")[0], + "bbox2D" : np.array([bboxmin[0], bboxmin[1], bboxmax[0], bboxmax[1]]), "positivePrompts" : np.array(prompt_add_point), "negativePrompts" : np.array(prompt_remove_point) }) diff --git a/samm/SammBase/SammBaseLib/UtilConnections.py b/samm/SammBase/SammBaseLib/UtilConnections.py index fcd8f14..a8c31ce 100644 --- a/samm/SammBase/SammBaseLib/UtilConnections.py +++ b/samm/SammBase/SammBaseLib/UtilConnections.py @@ -34,7 +34,7 @@ def pushRequest(self, requireType, MSG): msgByte = SammMsgSolverMapper[requireType](MSG).getEncodedData() sock = self.context.socket(zmq.REQ) # if no receiption, try extending the wait time. The first setup time takes longer - sock.setsockopt(zmq.RCVTIMEO, 20000) + sock.setsockopt(zmq.RCVTIMEO, 10000) sock.connect("tcp://%s:%s" % (self.ip, self.portControl)) sock.send_multipart([commandByte, msgByte]) feedback = sock.recv() diff --git a/samm/SammBase/SammBaseLib/UtilMsgFactory.py b/samm/SammBase/SammBaseLib/UtilMsgFactory.py index 504b0e9..30234af 100644 --- a/samm/SammBase/SammBaseLib/UtilMsgFactory.py +++ b/samm/SammBase/SammBaseLib/UtilMsgFactory.py @@ -109,6 +109,8 @@ def getDecodedData(msgbyte): ''' n : int +view : char +bbox : int 32, 4 positivePrompts, int 32, n * 2 negativePrompts, int 32 ''' @@ -118,12 +120,14 @@ def getEncodedData(self): n = self.msg["n"] view = SammViewMapper[self.msg["view"]] + bbox2D = self.msg["bbox2D"] positivePoints = self.msg["positivePrompts"] negativePoints = self.msg["negativePrompts"] msg = b'' msg += np.array([n], dtype='int32').tobytes() msg += np.array([view], dtype='int32').tobytes() + msg += bbox2D.astype("int32").tobytes() if positivePoints is not None and positivePoints.shape[0] > 0: msg += np.array([positivePoints.shape[0]], dtype='int32').tobytes() @@ -150,6 +154,9 @@ def getDecodedData(msgbyte): msg["view"] = SammViewMapper["DICT"][np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]] pt += 4 + msg["bbox2D"] = np.frombuffer(msgbyte[pt:pt+4*4], dtype="int32").reshape([1,4])[0] + pt += 4*4 + positivePromptNum = np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0] pt += 4 @@ -164,7 +171,7 @@ def getDecodedData(msgbyte): pt += 4 if negativePromptNum > 0: - negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([positivePromptNum, 2]) + negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([negativePromptNum, 2]) pt += 4 * 2 * negativePromptNum msg["negativePrompts"] = negativePrompt else: diff --git a/samm/SammBase/SammBaseLib/WidgetSammBase.py b/samm/SammBase/SammBaseLib/WidgetSammBase.py index cb72f86..65db2e2 100644 --- a/samm/SammBase/SammBaseLib/WidgetSammBase.py +++ b/samm/SammBase/SammBaseLib/WidgetSammBase.py @@ -67,6 +67,8 @@ def setup(self): self.ui.markupsAdd.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI) self.ui.markupsRemove.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI) + self.ui.markups2DBox.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI) + self.ui.pushMarkups2DBox.connect("clicked(bool)", self.onPushMarkups2DBox) self.ui.markupsAdd.markupsPlaceWidget().setPlaceModePersistency(True) self.ui.markupsRemove.markupsPlaceWidget().setPlaceModePersistency(True) @@ -92,6 +94,7 @@ def updateGUIFromParameterNode(self, caller=None, event=None): self.ui.comboVolumeNode.setCurrentNode(self._parameterNode.GetNodeReference("sammInputVolume")) self.ui.markupsAdd.setCurrentNode(self._parameterNode.GetNodeReference("sammPromptAdd")) self.ui.markupsRemove.setCurrentNode(self._parameterNode.GetNodeReference("sammPromptRemove")) + self.ui.markups2DBox.setCurrentNode(self._parameterNode.GetNodeReference("sammPrompt2DBox")) self.ui.comboSegmentationNode.setCurrentNode(self._parameterNode.GetNodeReference("sammSegmentation")) @@ -128,6 +131,8 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): self._parameterNode.SetNodeReferenceID("sammInputVolume", self.ui.comboVolumeNode.currentNodeID) self._parameterNode.SetNodeReferenceID("sammPromptAdd", self.ui.markupsAdd.currentNode().GetID()) self._parameterNode.SetNodeReferenceID("sammPromptRemove", self.ui.markupsRemove.currentNode().GetID()) + if self.ui.markups2DBox.currentNode(): + self._parameterNode.SetNodeReferenceID("sammPrompt2DBox", self.ui.markups2DBox.currentNode().GetID()) self._parameterNode._workspace = os.path.dirname(os.path.abspath(self.ui.pathWorkSpace.currentPath.strip())) self._parameterNode.SetNodeReferenceID("sammSegmentation", self.ui.comboSegmentationNode.currentNodeID) self._parameterNode.GetNodeReference("sammSegmentation").SetReferenceImageGeometryParameterFromVolumeNode( @@ -203,4 +208,16 @@ def onPushModuleSeg(self): slicer.util.selectModule("Segmentations") def onPushModuleSegEditor(self): - slicer.util.selectModule("SegmentEditor") \ No newline at end of file + slicer.util.selectModule("SegmentEditor") + + def onPushMarkups2DBox(self): + planeNode = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLMarkupsPlaneNode').GetID() + selectionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLSelectionNodeSingleton") + selectionNode.SetReferenceActivePlaceNodeID(planeNode) + interactionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLInteractionNodeSingleton") + placeModePersistence = 0 + interactionNode.SetPlaceModePersistence(placeModePersistence) + # mode 1 is Place, can also be accessed via slicer.vtkMRMLInteractionNode().Place + interactionNode.SetCurrentInteractionMode(1) + + self._parameterNode.SetNodeReferenceID("sammPrompt2DBox", planeNode) \ No newline at end of file