From b4b10f0a9d0d3e9946af914aef655e2b60a29c20 Mon Sep 17 00:00:00 2001
From: Yihao Liu <yliu333@jhu.edu>
Date: Tue, 8 Aug 2023 19:25:17 -0400
Subject: [PATCH] add bbox support

---
 Dev.ipynb                                    |  80 ++++-
 README.md                                    |   2 -
 samm-python-terminal/sam_server.py           |   8 +-
 samm-python-terminal/utl_sam_msg.py          |   9 +-
 samm-python-terminal/utl_sam_server.py       |  43 ++-
 samm/SammBase/Resources/UI/SammBase.ui       | 357 +++++++++++++------
 samm/SammBase/SammBaseLib/LogicSamm.py       |  32 +-
 samm/SammBase/SammBaseLib/UtilConnections.py |   2 +-
 samm/SammBase/SammBaseLib/UtilMsgFactory.py  |   9 +-
 samm/SammBase/SammBaseLib/WidgetSammBase.py  |  19 +-
 10 files changed, 419 insertions(+), 142 deletions(-)

diff --git a/Dev.ipynb b/Dev.ipynb
index 751f5f6..c5da79c 100644
--- a/Dev.ipynb
+++ b/Dev.ipynb
@@ -26,6 +26,46 @@
     "slicer.util.selectModule(\"SammBase\")"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "642430aa-9f67-4af0-940f-540047fe92a3",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "vtkPoints (0x62a4f80)\n",
+      "  Debug: Off\n",
+      "  Modified Time: 1776089\n",
+      "  Reference Count: 1\n",
+      "  Registered Events: (none)\n",
+      "  Data: 0x102c7160\n",
+      "  Data Array Name: Points\n",
+      "  Number Of Points: 4\n",
+      "  Bounds: \n",
+      "    Xmin,Xmax: (-32.0076, 95.4545)\n",
+      "    Ymin,Ymax: (-34.8485, 63.2576)\n",
+      "    Zmin,Zmax: (0, 0)\n",
+      "\n",
+      "\n",
+      "0.0\n"
+     ]
+    }
+   ],
+   "source": [
+    "import vtk\n",
+    "\n",
+    "logic = SammBaseLogic()\n",
+    "logic._parameterNode.SetNodeReferenceID(\"sammPrompt2DBox\", \"vtkMRMLMarkupsPlaneNode1\")\n",
+    "plane = logic._parameterNode.GetNodeReference(\"sammPrompt2DBox\")\n",
+    "points = vtk.vtkPoints() \n",
+    "plane.GetPlaneCornerPoints(points)\n",
+    "print(points)\n",
+    "print(points.GetPoint(0)[2])\n"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 3,
@@ -643,17 +683,49 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "id": "92be4fc6-f838-47d0-ae4e-47c2598eb353",
    "metadata": {},
-   "outputs": [],
-   "source": []
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array([1, 2, 3, 4])"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "np.array([1,2,3,4]).reshape([1,4])[0]"
+   ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 20,
    "id": "4e83cab3-759b-4bf4-a2ff-2f0134331451",
    "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "False\n"
+     ]
+    }
+   ],
+   "source": [
+    "print([]==None)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "512d3635-580d-443a-b011-f51e4bc4f952",
+   "metadata": {},
    "outputs": [],
    "source": []
   }
diff --git a/README.md b/README.md
index 2ecd12e..8a2561c 100644
--- a/README.md
+++ b/README.md
@@ -4,8 +4,6 @@
 
 [Laboratory of Biomechanical and Image Guided Surgical Systems](https://bigss.lcsr.jhu.edu/), [Johns Hopkins University](https://www.jhu.edu/)
 
-[Yihao Liu](https://yihao.one/), [Jeremy Zhang](https://jeremyzz830.github.io/), Zhangcong She
-
 ## Known issues
 #13 Resolved with a not-so-elegant way. Tested work on MRHead, DZ-MR, MRBrainTumor2 and CTChest data from Slicer. Please report a bug if your data is not working. Note for the first few seconds when you start "Mask Sync", the server is not so stable, wait a few seconds then slide it up and down, the mask then will be updated. Note now you can only work on the RED view. Will update later to support all 3 views.
 
diff --git a/samm-python-terminal/sam_server.py b/samm-python-terminal/sam_server.py
index 2ff4df4..1ee64a7 100644
--- a/samm-python-terminal/sam_server.py
+++ b/samm-python-terminal/sam_server.py
@@ -68,10 +68,10 @@ def looping(self):
                 self.cleanup()
             except zmq.error.Again:
                 continue
-            except Exception as e:
-                logging.error(traceback.format_exc())
-                time.sleep(self.interv)
-                continue
+            # except Exception as e:
+            #     logging.error(traceback.format_exc())
+            #     time.sleep(self.interv)
+            #     continue
 
             time.sleep(self.interv)
 
diff --git a/samm-python-terminal/utl_sam_msg.py b/samm-python-terminal/utl_sam_msg.py
index 504b0e9..30234af 100644
--- a/samm-python-terminal/utl_sam_msg.py
+++ b/samm-python-terminal/utl_sam_msg.py
@@ -109,6 +109,8 @@ def getDecodedData(msgbyte):
 
 '''
 n : int
+view : char
+bbox : int 32, 4
 positivePrompts, int 32, n * 2
 negativePrompts, int 32
 '''
@@ -118,12 +120,14 @@ def getEncodedData(self):
         
         n = self.msg["n"]
         view = SammViewMapper[self.msg["view"]]
+        bbox2D = self.msg["bbox2D"]
         positivePoints = self.msg["positivePrompts"]
         negativePoints = self.msg["negativePrompts"]
         
         msg = b''
         msg += np.array([n], dtype='int32').tobytes()
         msg += np.array([view], dtype='int32').tobytes()
+        msg += bbox2D.astype("int32").tobytes()
 
         if positivePoints is not None and positivePoints.shape[0] > 0:
             msg += np.array([positivePoints.shape[0]], dtype='int32').tobytes()
@@ -150,6 +154,9 @@ def getDecodedData(msgbyte):
         msg["view"] = SammViewMapper["DICT"][np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]]
         pt += 4
 
+        msg["bbox2D"] = np.frombuffer(msgbyte[pt:pt+4*4], dtype="int32").reshape([1,4])[0]
+        pt += 4*4
+
         positivePromptNum = np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]
         pt += 4
 
@@ -164,7 +171,7 @@ def getDecodedData(msgbyte):
         pt += 4
 
         if negativePromptNum > 0:
-            negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([positivePromptNum, 2])
+            negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([negativePromptNum, 2])
             pt += 4 * 2 * negativePromptNum
             msg["negativePrompts"] = negativePrompt
         else:
diff --git a/samm-python-terminal/utl_sam_server.py b/samm-python-terminal/utl_sam_server.py
index ae72200..d8ee7ef 100644
--- a/samm-python-terminal/utl_sam_server.py
+++ b/samm-python-terminal/utl_sam_server.py
@@ -158,10 +158,21 @@ def CalculateEmbeddings(msg):
 
     print("[SAMM INFO] Embeddings Cached.")
 
+def helperPredict(dataNode, msg, points, labels, bbox2d):
+    dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda")
+    seg, _, _ = dataNode.samPredictor[msg["view"]].predict(
+        point_coords = points,
+        point_labels = labels,
+        box = bbox2d,
+        multimask_output = False,)
+    seg = seg[0]
+    return seg
+
 def sammProcessingCallBack_INFERENCE(msg):
     dataNode = SammParameterNode()
     positivePoints = msg["positivePrompts"]
     negativePoints = msg["negativePrompts"]
+    bbox2d = msg["bbox2D"]
 
     points = []
     labels = []
@@ -176,20 +187,26 @@ def sammProcessingCallBack_INFERENCE(msg):
             labels.append(0)
 
     seg = None
-    if len(points) > 0:
-        if msg["view"] == "R":
-            tempsize = [dataNode.imageSize[1], dataNode.imageSize[2]]
-        if msg["view"] == "G":
-            tempsize = [dataNode.imageSize[0], dataNode.imageSize[2]]
-        if msg["view"] == "Y":
-            tempsize = [dataNode.imageSize[0], dataNode.imageSize[1]]
+    if len(points) > 0 and (bbox2d[0]!=-404):
+
+        points = np.array(points)
+        point_labels = np.array(labels)
+        bbox2d = np.array(bbox2d)
+        seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)
+    
+    elif len(points) == 0 and (bbox2d[0]!=-404):
+        
+        points = None
+        point_labels = None
+        bbox2d = np.array(bbox2d)
+        seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)
+
+    elif len(points) > 0 and (bbox2d[0]==-404):
         
-        dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda")
-        seg, _, _ = dataNode.samPredictor[msg["view"]].predict(
-            point_coords = np.array(points),
-            point_labels = np.array(labels),
-            multimask_output = False,)
-        seg = seg[0]
+        points = np.array(points)
+        point_labels = np.array(labels)
+        bbox2d = None
+        seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)
         
     else:
         if msg["view"] == "R":
diff --git a/samm/SammBase/Resources/UI/SammBase.ui b/samm/SammBase/Resources/UI/SammBase.ui
index 9fc52f4..81214c4 100644
--- a/samm/SammBase/Resources/UI/SammBase.ui
+++ b/samm/SammBase/Resources/UI/SammBase.ui
@@ -7,10 +7,34 @@
     <x>0</x>
     <y>0</y>
     <width>416</width>
-    <height>854</height>
+    <height>1055</height>
    </rect>
   </property>
   <layout class="QGridLayout" name="gridLayout">
+   <item row="1" column="0">
+    <widget class="QTabWidget" name="tabWidget_3">
+     <property name="currentIndex">
+      <number>0</number>
+     </property>
+     <widget class="QWidget" name="tab_4">
+      <attribute name="title">
+       <string>SAM and Variants</string>
+      </attribute>
+      <layout class="QGridLayout" name="gridLayout_5">
+       <item row="1" column="0">
+        <widget class="QComboBox" name="comboModel"/>
+       </item>
+       <item row="0" column="0">
+        <widget class="QLabel" name="label_11">
+         <property name="text">
+          <string>Model Selection</string>
+         </property>
+        </widget>
+       </item>
+      </layout>
+     </widget>
+    </widget>
+   </item>
    <item row="5" column="0">
     <spacer name="verticalSpacer">
      <property name="orientation">
@@ -138,17 +162,190 @@
          </attribute>
         </widget>
        </item>
-       <item row="1" column="1">
-        <widget class="QRadioButton" name="radioWorkOnGreen">
-         <property name="enabled">
-          <bool>true</bool>
+       <item row="4" column="1">
+        <widget class="QPushButton" name="pushUnfreezeSlice">
+         <property name="text">
+          <string>Unfreeze Slice</string>
          </property>
+        </widget>
+       </item>
+       <item row="5" column="0" colspan="2">
+        <widget class="QTabWidget" name="tabWidget_4">
+         <property name="currentIndex">
+          <number>0</number>
+         </property>
+         <widget class="QWidget" name="tab_5">
+          <attribute name="title">
+           <string>Points</string>
+          </attribute>
+          <layout class="QGridLayout" name="gridLayout_6">
+           <item row="0" column="0">
+            <widget class="QLabel" name="label_3">
+             <property name="minimumSize">
+              <size>
+               <width>0</width>
+               <height>30</height>
+              </size>
+             </property>
+             <property name="text">
+              <string>Add</string>
+             </property>
+            </widget>
+           </item>
+           <item row="1" column="0">
+            <widget class="qSlicerSimpleMarkupsWidget" name="markupsAdd">
+             <property name="enterPlaceModeOnNodeChange">
+              <bool>false</bool>
+             </property>
+             <property name="nodeColor">
+              <color>
+               <red>0</red>
+               <green>255</green>
+               <blue>0</blue>
+              </color>
+             </property>
+             <property name="defaultNodeColor">
+              <color>
+               <red>0</red>
+               <green>255</green>
+               <blue>0</blue>
+              </color>
+             </property>
+            </widget>
+           </item>
+           <item row="2" column="0">
+            <widget class="QLabel" name="label_4">
+             <property name="minimumSize">
+              <size>
+               <width>0</width>
+               <height>30</height>
+              </size>
+             </property>
+             <property name="text">
+              <string>Remove</string>
+             </property>
+            </widget>
+           </item>
+           <item row="3" column="0">
+            <widget class="qSlicerSimpleMarkupsWidget" name="markupsRemove">
+             <property name="enterPlaceModeOnNodeChange">
+              <bool>false</bool>
+             </property>
+             <property name="defaultNodeColor">
+              <color>
+               <red>255</red>
+               <green>0</green>
+               <blue>0</blue>
+              </color>
+             </property>
+            </widget>
+           </item>
+          </layout>
+         </widget>
+         <widget class="QWidget" name="tab_6">
+          <attribute name="title">
+           <string>2D Box</string>
+          </attribute>
+          <layout class="QGridLayout" name="gridLayout_7">
+           <item row="0" column="0">
+            <widget class="QLabel" name="label_12">
+             <property name="text">
+              <string>Bounding Box</string>
+             </property>
+            </widget>
+           </item>
+           <item row="2" column="0">
+            <spacer name="verticalSpacer_3">
+             <property name="orientation">
+              <enum>Qt::Vertical</enum>
+             </property>
+             <property name="sizeHint" stdset="0">
+              <size>
+               <width>20</width>
+               <height>40</height>
+              </size>
+             </property>
+            </spacer>
+           </item>
+           <item row="1" column="0">
+            <widget class="qMRMLNodeComboBox" name="markups2DBox">
+             <property name="enabled">
+              <bool>true</bool>
+             </property>
+             <property name="nodeTypes">
+              <stringlist notr="true">
+               <string>vtkMRMLMarkupsPlaneNode</string>
+              </stringlist>
+             </property>
+             <property name="hideChildNodeTypes">
+              <stringlist notr="true"/>
+             </property>
+             <property name="interactionNodeSingletonTag">
+              <string notr="true"/>
+             </property>
+            </widget>
+           </item>
+           <item row="1" column="1">
+            <widget class="QPushButton" name="pushMarkups2DBox">
+             <property name="text">
+              <string>Add a BBox</string>
+             </property>
+            </widget>
+           </item>
+          </layout>
+         </widget>
+         <widget class="QWidget" name="tab_7">
+          <attribute name="title">
+           <string>3D Box</string>
+          </attribute>
+          <layout class="QGridLayout" name="gridLayout_8">
+           <item row="0" column="0">
+            <widget class="QLabel" name="label_13">
+             <property name="text">
+              <string>Bounding Box</string>
+             </property>
+            </widget>
+           </item>
+           <item row="2" column="0">
+            <spacer name="verticalSpacer_4">
+             <property name="orientation">
+              <enum>Qt::Vertical</enum>
+             </property>
+             <property name="sizeHint" stdset="0">
+              <size>
+               <width>20</width>
+               <height>40</height>
+              </size>
+             </property>
+            </spacer>
+           </item>
+           <item row="1" column="0">
+            <widget class="qMRMLNodeComboBox" name="markups3DBox">
+             <property name="enabled">
+              <bool>false</bool>
+             </property>
+             <property name="nodeTypes">
+              <stringlist notr="true">
+               <string>vtkMRMLMarkupsROINode</string>
+              </stringlist>
+             </property>
+             <property name="hideChildNodeTypes">
+              <stringlist notr="true"/>
+             </property>
+             <property name="interactionNodeSingletonTag">
+              <string notr="true"/>
+             </property>
+            </widget>
+           </item>
+          </layout>
+         </widget>
+        </widget>
+       </item>
+       <item row="3" column="1">
+        <widget class="QPushButton" name="pushStopMaskSync">
          <property name="text">
-          <string>Green</string>
+          <string>Stop Mask Sync</string>
          </property>
-         <attribute name="buttonGroup">
-          <string notr="true">buttonGroupWorkOn</string>
-         </attribute>
         </widget>
        </item>
        <item row="2" column="0">
@@ -164,20 +361,6 @@
          </attribute>
         </widget>
        </item>
-       <item row="3" column="0">
-        <widget class="QPushButton" name="pushStartMaskSync">
-         <property name="text">
-          <string>Start Mask Sync</string>
-         </property>
-        </widget>
-       </item>
-       <item row="3" column="1">
-        <widget class="QPushButton" name="pushStopMaskSync">
-         <property name="text">
-          <string>Stop Mask Sync</string>
-         </property>
-        </widget>
-       </item>
        <item row="4" column="0">
         <widget class="QPushButton" name="pushFreezeSlice">
          <property name="text">
@@ -185,71 +368,23 @@
          </property>
         </widget>
        </item>
-       <item row="4" column="1">
-        <widget class="QPushButton" name="pushUnfreezeSlice">
-         <property name="text">
-          <string>Unfreeze Slice</string>
-         </property>
-        </widget>
-       </item>
-       <item row="5" column="0">
-        <widget class="QLabel" name="label_3">
-         <property name="minimumSize">
-          <size>
-           <width>0</width>
-           <height>30</height>
-          </size>
+       <item row="1" column="1">
+        <widget class="QRadioButton" name="radioWorkOnGreen">
+         <property name="enabled">
+          <bool>true</bool>
          </property>
          <property name="text">
-          <string>Prompt - Add</string>
-         </property>
-        </widget>
-       </item>
-       <item row="6" column="0" colspan="2">
-        <widget class="qSlicerSimpleMarkupsWidget" name="markupsAdd">
-         <property name="enterPlaceModeOnNodeChange">
-          <bool>false</bool>
-         </property>
-         <property name="nodeColor">
-          <color>
-           <red>0</red>
-           <green>255</green>
-           <blue>0</blue>
-          </color>
-         </property>
-         <property name="defaultNodeColor">
-          <color>
-           <red>0</red>
-           <green>255</green>
-           <blue>0</blue>
-          </color>
+          <string>Green</string>
          </property>
+         <attribute name="buttonGroup">
+          <string notr="true">buttonGroupWorkOn</string>
+         </attribute>
         </widget>
        </item>
-       <item row="7" column="0">
-        <widget class="QLabel" name="label_4">
-         <property name="minimumSize">
-          <size>
-           <width>0</width>
-           <height>30</height>
-          </size>
-         </property>
+       <item row="3" column="0">
+        <widget class="QPushButton" name="pushStartMaskSync">
          <property name="text">
-          <string>Prompt - Remove</string>
-         </property>
-        </widget>
-       </item>
-       <item row="8" column="0" colspan="2">
-        <widget class="qSlicerSimpleMarkupsWidget" name="markupsRemove">
-         <property name="enterPlaceModeOnNodeChange">
-          <bool>false</bool>
-         </property>
-         <property name="defaultNodeColor">
-          <color>
-           <red>255</red>
-           <green>0</green>
-           <blue>0</blue>
-          </color>
+          <string>Start Mask Sync</string>
          </property>
         </widget>
        </item>
@@ -391,30 +526,6 @@
      </widget>
     </widget>
    </item>
-   <item row="1" column="0">
-    <widget class="QTabWidget" name="tabWidget_3">
-     <property name="currentIndex">
-      <number>0</number>
-     </property>
-     <widget class="QWidget" name="tab_4">
-      <attribute name="title">
-       <string>SAM and Variants</string>
-      </attribute>
-      <layout class="QGridLayout" name="gridLayout_5">
-       <item row="1" column="0">
-        <widget class="QComboBox" name="comboModel"/>
-       </item>
-       <item row="0" column="0">
-        <widget class="QLabel" name="label_11">
-         <property name="text">
-          <string>Model Selection</string>
-         </property>
-        </widget>
-       </item>
-      </layout>
-     </widget>
-    </widget>
-   </item>
   </layout>
  </widget>
  <customwidgets>
@@ -512,9 +623,41 @@
     </hint>
    </hints>
   </connection>
+  <connection>
+   <sender>SammBase</sender>
+   <signal>mrmlSceneChanged(vtkMRMLScene*)</signal>
+   <receiver>markups2DBox</receiver>
+   <slot>setMRMLScene(vtkMRMLScene*)</slot>
+   <hints>
+    <hint type="sourcelabel">
+     <x>207</x>
+     <y>527</y>
+    </hint>
+    <hint type="destinationlabel">
+     <x>207</x>
+     <y>679</y>
+    </hint>
+   </hints>
+  </connection>
+  <connection>
+   <sender>SammBase</sender>
+   <signal>mrmlSceneChanged(vtkMRMLScene*)</signal>
+   <receiver>markups3DBox</receiver>
+   <slot>setMRMLScene(vtkMRMLScene*)</slot>
+   <hints>
+    <hint type="sourcelabel">
+     <x>207</x>
+     <y>527</y>
+    </hint>
+    <hint type="destinationlabel">
+     <x>207</x>
+     <y>700</y>
+    </hint>
+   </hints>
+  </connection>
  </connections>
  <buttongroups>
-  <buttongroup name="buttonGroupWorkOn"/>
   <buttongroup name="buttonGroupDataType"/>
+  <buttongroup name="buttonGroupWorkOn"/>
  </buttongroups>
 </ui>
diff --git a/samm/SammBase/SammBaseLib/LogicSamm.py b/samm/SammBase/SammBaseLib/LogicSamm.py
index b60c1df..b87287b 100644
--- a/samm/SammBase/SammBaseLib/LogicSamm.py
+++ b/samm/SammBase/SammBaseLib/LogicSamm.py
@@ -205,6 +205,7 @@ def processInitPromptSync(self):
         # Init
         self._prompt_add    = self._parameterNode.GetNodeReference("sammPromptAdd")
         self._prompt_remove = self._parameterNode.GetNodeReference("sammPromptRemove")
+        self._prompt_2dbox = self._parameterNode.GetNodeReference("sammPrompt2DBox")
 
         if self._parameterNode.GetParameter("sammViewOptions") == "RED":
             self._slider = \
@@ -250,7 +251,9 @@ def utilGetCurrentSliceIndex(self):
             imshape = (self._parameterNode._volMetaData[0][0], self._parameterNode._volMetaData[0][1])
         return curslc, view, imshape
     
-    def utilGetPositionOnSlicer(self, temp, view):
+    def utilGetPositionOnSlice(self, ras, view):
+        temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1])
+        temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder]
         if view == "RED":
             return [round(temp[1]),round(temp[2])]
         if view == "GREEN":
@@ -268,7 +271,7 @@ def processStartPromptSync(self):
         if self._flag_prompt_sync:
 
             curslc, view, imshape = self.utilGetCurrentSliceIndex()
-            curslc = int(curslc)
+            curslc = round(curslc)
 
             mask = None
 
@@ -278,21 +281,34 @@ def processStartPromptSync(self):
                 for i in range(numControlPoints):
                     ras = vtk.vtkVector3d(0.0, 0.0, 0.0)
                     self._prompt_add.GetNthControlPointPosition(i,ras)
-                    temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1])
-                    temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder]
-                    prompt_add_point.append(self.utilGetPositionOnSlicer(temp, view))
+                    prompt_add_point.append(self.utilGetPositionOnSlice(ras, view))
 
                 numControlPoints = self._prompt_remove.GetNumberOfControlPoints()
                 for i in range(numControlPoints):
                     ras = vtk.vtkVector3d(0.0, 0.0, 0.0)
                     self._prompt_remove.GetNthControlPointPosition(i,ras)
-                    temp = self._volumeRasToIjk.MultiplyPoint([ras[0],ras[1],ras[2],1])
-                    temp = np.array([temp[0], temp[1], temp[2]])[self._parameterNode.RGYNpArrOrder]
-                    prompt_remove_point.append(self.utilGetPositionOnSlicer(temp, view))
+                    prompt_remove_point.append(self.utilGetPositionOnSlice(ras, view))
+
+                plane = self._parameterNode.GetNodeReference("sammPrompt2DBox")
+
+                if plane:
+                    points = vtk.vtkPoints() 
+                    plane.GetPlaneCornerPoints(points)
+                    ras = [points.GetPoint(0)[0],points.GetPoint(0)[1],points.GetPoint(0)[2]]
+                    bboxmin = self.utilGetPositionOnSlice(ras, view)
+                    ras = [points.GetPoint(2)[0],points.GetPoint(2)[1],points.GetPoint(2)[2]]
+                    bboxmax = self.utilGetPositionOnSlice(ras, view)
+                    if bboxmin[0] > bboxmax[0]:
+                        bboxmin_ = [bboxmin[0], bboxmin[1]]
+                        bboxmin = [bboxmax[0], bboxmax[1]]
+                        bboxmax = [bboxmin_[0], bboxmin_[1]]
+                else:
+                    bboxmin, bboxmax = [-404,-404], [-404,-404]
 
                 mask = self._connections.pushRequest(SammMsgType.INFERENCE, {
                     "n" : curslc,
                     "view" : self._parameterNode.GetParameter("sammViewOptions")[0],
+                    "bbox2D" : np.array([bboxmin[0], bboxmin[1], bboxmax[0], bboxmax[1]]),
                     "positivePrompts" : np.array(prompt_add_point),
                     "negativePrompts" : np.array(prompt_remove_point)
                 })
diff --git a/samm/SammBase/SammBaseLib/UtilConnections.py b/samm/SammBase/SammBaseLib/UtilConnections.py
index fcd8f14..a8c31ce 100644
--- a/samm/SammBase/SammBaseLib/UtilConnections.py
+++ b/samm/SammBase/SammBaseLib/UtilConnections.py
@@ -34,7 +34,7 @@ def pushRequest(self, requireType, MSG):
         msgByte = SammMsgSolverMapper[requireType](MSG).getEncodedData()
         sock = self.context.socket(zmq.REQ)
         # if no receiption, try extending the wait time. The first setup time takes longer
-        sock.setsockopt(zmq.RCVTIMEO, 20000) 
+        sock.setsockopt(zmq.RCVTIMEO, 10000) 
         sock.connect("tcp://%s:%s" % (self.ip, self.portControl))
         sock.send_multipart([commandByte, msgByte])
         feedback = sock.recv()
diff --git a/samm/SammBase/SammBaseLib/UtilMsgFactory.py b/samm/SammBase/SammBaseLib/UtilMsgFactory.py
index 504b0e9..30234af 100644
--- a/samm/SammBase/SammBaseLib/UtilMsgFactory.py
+++ b/samm/SammBase/SammBaseLib/UtilMsgFactory.py
@@ -109,6 +109,8 @@ def getDecodedData(msgbyte):
 
 '''
 n : int
+view : char
+bbox : int 32, 4
 positivePrompts, int 32, n * 2
 negativePrompts, int 32
 '''
@@ -118,12 +120,14 @@ def getEncodedData(self):
         
         n = self.msg["n"]
         view = SammViewMapper[self.msg["view"]]
+        bbox2D = self.msg["bbox2D"]
         positivePoints = self.msg["positivePrompts"]
         negativePoints = self.msg["negativePrompts"]
         
         msg = b''
         msg += np.array([n], dtype='int32').tobytes()
         msg += np.array([view], dtype='int32').tobytes()
+        msg += bbox2D.astype("int32").tobytes()
 
         if positivePoints is not None and positivePoints.shape[0] > 0:
             msg += np.array([positivePoints.shape[0]], dtype='int32').tobytes()
@@ -150,6 +154,9 @@ def getDecodedData(msgbyte):
         msg["view"] = SammViewMapper["DICT"][np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]]
         pt += 4
 
+        msg["bbox2D"] = np.frombuffer(msgbyte[pt:pt+4*4], dtype="int32").reshape([1,4])[0]
+        pt += 4*4
+
         positivePromptNum = np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]
         pt += 4
 
@@ -164,7 +171,7 @@ def getDecodedData(msgbyte):
         pt += 4
 
         if negativePromptNum > 0:
-            negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([positivePromptNum, 2])
+            negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([negativePromptNum, 2])
             pt += 4 * 2 * negativePromptNum
             msg["negativePrompts"] = negativePrompt
         else:
diff --git a/samm/SammBase/SammBaseLib/WidgetSammBase.py b/samm/SammBase/SammBaseLib/WidgetSammBase.py
index cb72f86..65db2e2 100644
--- a/samm/SammBase/SammBaseLib/WidgetSammBase.py
+++ b/samm/SammBase/SammBaseLib/WidgetSammBase.py
@@ -67,6 +67,8 @@ def setup(self):
 
         self.ui.markupsAdd.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI)
         self.ui.markupsRemove.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI)
+        self.ui.markups2DBox.connect("markupsNodeChanged()", self.updateParameterNodeFromGUI)
+        self.ui.pushMarkups2DBox.connect("clicked(bool)", self.onPushMarkups2DBox)
         self.ui.markupsAdd.markupsPlaceWidget().setPlaceModePersistency(True)
         self.ui.markupsRemove.markupsPlaceWidget().setPlaceModePersistency(True)
 
@@ -92,6 +94,7 @@ def updateGUIFromParameterNode(self, caller=None, event=None):
         self.ui.comboVolumeNode.setCurrentNode(self._parameterNode.GetNodeReference("sammInputVolume"))
         self.ui.markupsAdd.setCurrentNode(self._parameterNode.GetNodeReference("sammPromptAdd"))
         self.ui.markupsRemove.setCurrentNode(self._parameterNode.GetNodeReference("sammPromptRemove"))
+        self.ui.markups2DBox.setCurrentNode(self._parameterNode.GetNodeReference("sammPrompt2DBox"))
 
         self.ui.comboSegmentationNode.setCurrentNode(self._parameterNode.GetNodeReference("sammSegmentation"))
 
@@ -128,6 +131,8 @@ def updateParameterNodeFromGUI(self, caller=None, event=None):
         self._parameterNode.SetNodeReferenceID("sammInputVolume", self.ui.comboVolumeNode.currentNodeID)
         self._parameterNode.SetNodeReferenceID("sammPromptAdd", self.ui.markupsAdd.currentNode().GetID())
         self._parameterNode.SetNodeReferenceID("sammPromptRemove", self.ui.markupsRemove.currentNode().GetID())
+        if self.ui.markups2DBox.currentNode():
+            self._parameterNode.SetNodeReferenceID("sammPrompt2DBox", self.ui.markups2DBox.currentNode().GetID())
         self._parameterNode._workspace = os.path.dirname(os.path.abspath(self.ui.pathWorkSpace.currentPath.strip()))
         self._parameterNode.SetNodeReferenceID("sammSegmentation", self.ui.comboSegmentationNode.currentNodeID)
         self._parameterNode.GetNodeReference("sammSegmentation").SetReferenceImageGeometryParameterFromVolumeNode(
@@ -203,4 +208,16 @@ def onPushModuleSeg(self):
         slicer.util.selectModule("Segmentations")
 
     def onPushModuleSegEditor(self):
-        slicer.util.selectModule("SegmentEditor")
\ No newline at end of file
+        slicer.util.selectModule("SegmentEditor")
+
+    def onPushMarkups2DBox(self):
+        planeNode = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLMarkupsPlaneNode').GetID()
+        selectionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLSelectionNodeSingleton")
+        selectionNode.SetReferenceActivePlaceNodeID(planeNode)
+        interactionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLInteractionNodeSingleton")
+        placeModePersistence = 0
+        interactionNode.SetPlaceModePersistence(placeModePersistence)
+        # mode 1 is Place, can also be accessed via slicer.vtkMRMLInteractionNode().Place
+        interactionNode.SetCurrentInteractionMode(1)
+
+        self._parameterNode.SetNodeReferenceID("sammPrompt2DBox", planeNode)
\ No newline at end of file