diff --git a/predict.py b/predict.py index ee1b1ee8..cff01c3a 100644 --- a/predict.py +++ b/predict.py @@ -13,15 +13,15 @@ import argparse def rect_to_bb(rect): - # take a bounding predicted by dlib and convert it - # to the format (x, y, w, h) as we would normally do - # with OpenCV - x = rect.left() - y = rect.top() - w = rect.right() - x - h = rect.bottom() - y - # return a tuple of (x, y, w, h) - return (x, y, w, h) + # take a bounding predicted by dlib and convert it + # to the format (x, y, w, h) as we would normally do + # with OpenCV + x = rect.left() + y = rect.top() + w = rect.right() - x + h = rect.bottom() - y + # return a tuple of (x, y, w, h) + return (x, y, w, h) def detect_face(image_paths, SAVE_DETECTED_AT, default_max_size=800,size = 300, padding = 0.25): cnn_face_detector = dlib.cnn_face_detection_model_v1('dlib_models/mmod_human_face_detector.dat') @@ -59,17 +59,17 @@ def detect_face(image_paths, SAVE_DETECTED_AT, default_max_size=800,size = 300, def predidct_age_gender_race(save_prediction_at, imgs_path = 'cropped_faces/'): img_names = [os.path.join(imgs_path, x) for x in os.listdir(imgs_path)] - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device('cpu') model_fair_7 = torchvision.models.resnet34(pretrained=True) model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18) - model_fair_7.load_state_dict(torch.load('fair_face_models/fairface_alldata_20191111.pt')) + model_fair_7.load_state_dict(torch.load('fair_face_models/res34_fair_align_multi_7_20190809.pt', map_location=torch.device('cpu'))) model_fair_7 = model_fair_7.to(device) model_fair_7.eval() model_fair_4 = torchvision.models.resnet34(pretrained=True) model_fair_4.fc = nn.Linear(model_fair_4.fc.in_features, 18) - model_fair_4.load_state_dict(torch.load('fair_face_models/fairface_alldata_4race_20191111.pt')) + model_fair_4.load_state_dict(torch.load('fair_face_models/fairface_alldata_4race_20191111.pt', map_location=torch.device('cpu'))) model_fair_4 = model_fair_4.to(device) model_fair_4.eval() @@ -206,7 +206,7 @@ def ensure_dir(directory): parser = argparse.ArgumentParser() parser.add_argument('--csv', dest='input_csv', action='store', help='csv file of image path where col name for image path is "img_path') - dlib.DLIB_USE_CUDA = True + dlib.DLIB_USE_CUDA = False print("using CUDA?: %s" % dlib.DLIB_USE_CUDA) args = parser.parse_args() SAVE_DETECTED_AT = "detected_faces"