forked from cvat-ai/cvat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_handler.py
66 lines (53 loc) · 2.15 KB
/
model_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
from tools.test import *
import os
from copy import copy
import jsonpickle
import numpy as np
class ModelHandler:
def __init__(self):
# Setup device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
base_dir = os.path.abspath(os.environ.get("MODEL_PATH",
"/opt/nuclio/SiamMask/experiments/siammask_sharp"))
class configPath:
config = os.path.join(base_dir, "config_davis.json")
self.config = load_config(configPath)
from custom import Custom
siammask = Custom(anchors=self.config['anchors'])
self.siammask = load_pretrain(siammask, os.path.join(base_dir, "SiamMask_DAVIS.pth"))
self.siammask.eval().to(self.device)
def encode_state(self, state):
state['net.zf'] = state['net'].zf
state.pop('net', None)
state.pop('mask', None)
for k,v in state.items():
state[k] = jsonpickle.encode(v)
return state
def decode_state(self, state):
for k,v in state.items():
state[k] = jsonpickle.decode(v)
state['net'] = copy(self.siammask)
state['net'].zf = state['net.zf']
del state['net.zf']
return state
def infer(self, image, shape, state):
image = np.array(image)
if state is None: # init tracking
xtl, ytl, xbr, ybr = shape
target_pos = np.array([(xtl + xbr) / 2, (ytl + ybr) / 2])
target_sz = np.array([xbr - xtl, ybr - ytl])
siammask = copy(self.siammask) # don't modify self.siammask
state = siamese_init(image, target_pos, target_sz, siammask,
self.config['hp'], device=self.device)
state = self.encode_state(state)
else: # track
state = self.decode_state(state)
state = siamese_track(state, image, mask_enable=True,
refine_enable=True, device=self.device)
shape = state['ploygon'].flatten().tolist()
state = self.encode_state(state)
return shape, state