Skip to content

Commit

Permalink
add bbox support
Browse files Browse the repository at this point in the history
  • Loading branch information
bingogome committed Aug 8, 2023
1 parent 9e75332 commit b4b10f0
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 142 deletions.
80 changes: 76 additions & 4 deletions Dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,46 @@
"slicer.util.selectModule(\"SammBase\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "642430aa-9f67-4af0-940f-540047fe92a3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"vtkPoints (0x62a4f80)\n",
" Debug: Off\n",
" Modified Time: 1776089\n",
" Reference Count: 1\n",
" Registered Events: (none)\n",
" Data: 0x102c7160\n",
" Data Array Name: Points\n",
" Number Of Points: 4\n",
" Bounds: \n",
" Xmin,Xmax: (-32.0076, 95.4545)\n",
" Ymin,Ymax: (-34.8485, 63.2576)\n",
" Zmin,Zmax: (0, 0)\n",
"\n",
"\n",
"0.0\n"
]
}
],
"source": [
"import vtk\n",
"\n",
"logic = SammBaseLogic()\n",
"logic._parameterNode.SetNodeReferenceID(\"sammPrompt2DBox\", \"vtkMRMLMarkupsPlaneNode1\")\n",
"plane = logic._parameterNode.GetNodeReference(\"sammPrompt2DBox\")\n",
"points = vtk.vtkPoints() \n",
"plane.GetPlaneCornerPoints(points)\n",
"print(points)\n",
"print(points.GetPoint(0)[2])\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
Expand Down Expand Up @@ -643,17 +683,49 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "92be4fc6-f838-47d0-ae4e-47c2598eb353",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"data": {
"text/plain": [
"array([1, 2, 3, 4])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.array([1,2,3,4]).reshape([1,4])[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"id": "4e83cab3-759b-4bf4-a2ff-2f0134331451",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"print([]==None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "512d3635-580d-443a-b011-f51e4bc4f952",
"metadata": {},
"outputs": [],
"source": []
}
Expand Down
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

[Laboratory of Biomechanical and Image Guided Surgical Systems](https://bigss.lcsr.jhu.edu/), [Johns Hopkins University](https://www.jhu.edu/)

[Yihao Liu](https://yihao.one/), [Jeremy Zhang](https://jeremyzz830.github.io/), Zhangcong She

## Known issues
#13 Resolved with a not-so-elegant way. Tested work on MRHead, DZ-MR, MRBrainTumor2 and CTChest data from Slicer. Please report a bug if your data is not working. Note for the first few seconds when you start "Mask Sync", the server is not so stable, wait a few seconds then slide it up and down, the mask then will be updated. Note now you can only work on the RED view. Will update later to support all 3 views.

Expand Down
8 changes: 4 additions & 4 deletions samm-python-terminal/sam_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def looping(self):
self.cleanup()
except zmq.error.Again:
continue
except Exception as e:
logging.error(traceback.format_exc())
time.sleep(self.interv)
continue
# except Exception as e:
# logging.error(traceback.format_exc())
# time.sleep(self.interv)
# continue

time.sleep(self.interv)

Expand Down
9 changes: 8 additions & 1 deletion samm-python-terminal/utl_sam_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def getDecodedData(msgbyte):

'''
n : int
view : char
bbox : int 32, 4
positivePrompts, int 32, n * 2
negativePrompts, int 32
'''
Expand All @@ -118,12 +120,14 @@ def getEncodedData(self):

n = self.msg["n"]
view = SammViewMapper[self.msg["view"]]
bbox2D = self.msg["bbox2D"]
positivePoints = self.msg["positivePrompts"]
negativePoints = self.msg["negativePrompts"]

msg = b''
msg += np.array([n], dtype='int32').tobytes()
msg += np.array([view], dtype='int32').tobytes()
msg += bbox2D.astype("int32").tobytes()

if positivePoints is not None and positivePoints.shape[0] > 0:
msg += np.array([positivePoints.shape[0]], dtype='int32').tobytes()
Expand All @@ -150,6 +154,9 @@ def getDecodedData(msgbyte):
msg["view"] = SammViewMapper["DICT"][np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]]
pt += 4

msg["bbox2D"] = np.frombuffer(msgbyte[pt:pt+4*4], dtype="int32").reshape([1,4])[0]
pt += 4*4

positivePromptNum = np.frombuffer(msgbyte[pt:pt+4], dtype="int32").reshape([1])[0]
pt += 4

Expand All @@ -164,7 +171,7 @@ def getDecodedData(msgbyte):
pt += 4

if negativePromptNum > 0:
negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([positivePromptNum, 2])
negativePrompt = np.frombuffer(msgbyte[pt:pt+4*2*negativePromptNum], dtype="int32").reshape([negativePromptNum, 2])
pt += 4 * 2 * negativePromptNum
msg["negativePrompts"] = negativePrompt
else:
Expand Down
43 changes: 30 additions & 13 deletions samm-python-terminal/utl_sam_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,21 @@ def CalculateEmbeddings(msg):

print("[SAMM INFO] Embeddings Cached.")

def helperPredict(dataNode, msg, points, labels, bbox2d):
dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda")
seg, _, _ = dataNode.samPredictor[msg["view"]].predict(
point_coords = points,
point_labels = labels,
box = bbox2d,
multimask_output = False,)
seg = seg[0]
return seg

def sammProcessingCallBack_INFERENCE(msg):
dataNode = SammParameterNode()
positivePoints = msg["positivePrompts"]
negativePoints = msg["negativePrompts"]
bbox2d = msg["bbox2D"]

points = []
labels = []
Expand All @@ -176,20 +187,26 @@ def sammProcessingCallBack_INFERENCE(msg):
labels.append(0)

seg = None
if len(points) > 0:
if msg["view"] == "R":
tempsize = [dataNode.imageSize[1], dataNode.imageSize[2]]
if msg["view"] == "G":
tempsize = [dataNode.imageSize[0], dataNode.imageSize[2]]
if msg["view"] == "Y":
tempsize = [dataNode.imageSize[0], dataNode.imageSize[1]]
if len(points) > 0 and (bbox2d[0]!=-404):

points = np.array(points)
point_labels = np.array(labels)
bbox2d = np.array(bbox2d)
seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)

elif len(points) == 0 and (bbox2d[0]!=-404):

points = None
point_labels = None
bbox2d = np.array(bbox2d)
seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)

elif len(points) > 0 and (bbox2d[0]==-404):

dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to("cuda")
seg, _, _ = dataNode.samPredictor[msg["view"]].predict(
point_coords = np.array(points),
point_labels = np.array(labels),
multimask_output = False,)
seg = seg[0]
points = np.array(points)
point_labels = np.array(labels)
bbox2d = None
seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)

else:
if msg["view"] == "R":
Expand Down
Loading

0 comments on commit b4b10f0

Please sign in to comment.