diff --git a/model_torch.py b/model_torch.py index 47f25ab..b407f96 100644 --- a/model_torch.py +++ b/model_torch.py @@ -238,11 +238,13 @@ def loadImages(folder): if __name__ == "__main__": model = res_skip() model.load_state_dict(torch.load('erika.pth')) - - model.cuda() + is_cuda = torch.cuda.is_available() + if is_cuda: + model.cuda() + else: + model.cpu() model.eval() - filelists = loadImages(sys.argv[1]) with torch.no_grad(): @@ -255,8 +257,11 @@ def loadImages(folder): # manually construct a batch. You can change it based on your usecases. patch = np.ones((1,1,rows,cols),dtype="float32") patch[0,0,0:src.shape[0],0:src.shape[1]] = src - - tensor = torch.from_numpy(patch).cuda() + + if is_cuda: + tensor = torch.from_numpy(patch).cuda() + else: + tensor = torch.from_numpy(patch).cpu() y = model(tensor) print(imname, torch.max(y), torch.min(y)) @@ -266,8 +271,3 @@ def loadImages(folder): head, tail = os.path.split(imname) cv2.imwrite(sys.argv[2]+"/"+tail.replace(".jpg",".png"),yc[0:src.shape[0],0:src.shape[1]]) - - - - - diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5bddf1c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +torch==1.9.1 +opencv-python \ No newline at end of file