Skip to content

Commit

Permalink
Merge pull request #61 from bingogome/dev-refactor-everything
Browse files Browse the repository at this point in the history
mobile samm support
  • Loading branch information
bingogome authored Aug 8, 2023
2 parents 40da608 + 9e75332 commit f85ad57
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
3 changes: 2 additions & 1 deletion samm-python-terminal/utl_sam_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,6 @@ def getDecodedData(msgbyte):
"vit_b" : 0,
"vit_l" : 1,
"vit_h" : 2,
"DICT" : ["vit_b", "vit_l", "vit_h"]
"mobile_vit_t" : 3,
"DICT" : ["vit_b", "vit_l", "vit_h", "mobile_vit_t"]
}
31 changes: 26 additions & 5 deletions samm-python-terminal/utl_sam_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from tqdm import tqdm
import sys,os, cv2, matplotlib.pyplot as plt
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"#
from segment_anything import sam_model_registry, SamPredictor
from segment_anything import sam_model_registry as sam_model_registry_sam
from segment_anything import SamPredictor as SamPredictor_sam
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from mobile_sam import SamPredictor as SamPredictor_mobile
import torch, functools, pickle

def singleton(cls):
Expand Down Expand Up @@ -49,6 +52,8 @@ def initNetwork(self, model = "vit_b"):

if model.startswith('vit_'):
self.initNetworkSam(model)
if model.startswith('mobile_'):
self.initNetworkMobile(model)

def initNetworkSam(self, model):
dictpath = {
Expand All @@ -60,12 +65,28 @@ def initNetworkSam(self, model):
if not os.path.isfile(self.sam_checkpoint):
raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint)
model_type = model
sam = sam_model_registry[model_type](checkpoint=self.sam_checkpoint)
sam = sam_model_registry_sam[model_type](checkpoint=self.sam_checkpoint)
sam.to(device=self.device)

self.samPredictor["R"] = SamPredictor(sam)
self.samPredictor["G"] = SamPredictor(sam)
self.samPredictor["Y"] = SamPredictor(sam)
self.samPredictor["R"] = SamPredictor_sam(sam)
self.samPredictor["G"] = SamPredictor_sam(sam)
self.samPredictor["Y"] = SamPredictor_sam(sam)
print(f'[SAMM INFO] Model initialzed to: "{model}"')

def initNetworkMobile(self, model):
dictpath = {
"mobile_vit_t" : "mobile_sam.pt"
}
self.sam_checkpoint = self.workspace + "/" + dictpath[model]
if not os.path.isfile(self.sam_checkpoint):
raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint)
model_type = model[7:]
sam = sam_model_registry_mobile[model_type](checkpoint=self.sam_checkpoint)
sam.to(device=self.device)

self.samPredictor["R"] = SamPredictor_mobile(sam)
self.samPredictor["G"] = SamPredictor_mobile(sam)
self.samPredictor["Y"] = SamPredictor_mobile(sam)
print(f'[SAMM INFO] Model initialzed to: "{model}"')


Expand Down
3 changes: 2 additions & 1 deletion samm/SammBase/SammBaseLib/UtilMsgFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,6 @@ def getDecodedData(msgbyte):
"vit_b" : 0,
"vit_l" : 1,
"vit_h" : 2,
"DICT" : ["vit_b", "vit_l", "vit_h"]
"mobile_vit_t" : 3,
"DICT" : ["vit_b", "vit_l", "vit_h", "mobile_vit_t"]
}
2 changes: 1 addition & 1 deletion samm/SammBase/SammBaseLib/WidgetSammBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def setup(self):
self.ui.comboSegmentNode.connect("currentIndexChanged(int)", self.updateParameterNodeFromGUI)
self.ui.comboModel.connect("currentIndexChanged(int)", self.updateParameterNodeFromGUI)
self.ui.comboModel.connect("currentIndexChanged(int)", self.onUpdateComboModel)
comboModelItems = ['vit_b', 'vit_l', 'vit_h']
comboModelItems = ['vit_b', 'vit_l', 'vit_h', 'mobile_vit_t']
for item in comboModelItems:
self.ui.comboModel.addItem(item)

Expand Down

0 comments on commit f85ad57

Please sign in to comment.