Skip to content

Commit

Permalink
feat: support cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong0618 committed Sep 22, 2021
1 parent 5c4a63c commit ab30fcf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 10 additions & 10 deletions model_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,13 @@ def loadImages(folder):
if __name__ == "__main__":
model = res_skip()
model.load_state_dict(torch.load('erika.pth'))

model.cuda()
is_cuda = torch.cuda.is_available()
if is_cuda:
model.cuda()
else:
model.cpu()
model.eval()


filelists = loadImages(sys.argv[1])

with torch.no_grad():
Expand All @@ -255,8 +257,11 @@ def loadImages(folder):
# manually construct a batch. You can change it based on your usecases.
patch = np.ones((1,1,rows,cols),dtype="float32")
patch[0,0,0:src.shape[0],0:src.shape[1]] = src

tensor = torch.from_numpy(patch).cuda()

if is_cuda:
tensor = torch.from_numpy(patch).cuda()
else:
tensor = torch.from_numpy(patch).cpu()
y = model(tensor)
print(imname, torch.max(y), torch.min(y))

Expand All @@ -266,8 +271,3 @@ def loadImages(folder):

head, tail = os.path.split(imname)
cv2.imwrite(sys.argv[2]+"/"+tail.replace(".jpg",".png"),yc[0:src.shape[0],0:src.shape[1]])





2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch==1.9.1
opencv-python

0 comments on commit ab30fcf

Please sign in to comment.