From 35f39c34b8c5b7bbc7e6a1e99138e27693cd7903 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrique=20Mendon=C3=A7a?= Date: Sun, 5 Jan 2020 19:42:14 +0100 Subject: [PATCH 1/4] pretrained parameter in model_selection --- classification/network/models.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/classification/network/models.py b/classification/network/models.py index 06b010e..6b4fe9c 100644 --- a/classification/network/models.py +++ b/classification/network/models.py @@ -38,11 +38,11 @@ class TransferModel(nn.Module): Simple transfer learning model that takes an imagenet pretrained model with a fc layer as base model and retrains a new fc layer for num_out_classes """ - def __init__(self, modelchoice, num_out_classes=2, dropout=0.0): + def __init__(self, modelchoice, num_out_classes=2, dropout=0.0, pretrained=True): super(TransferModel, self).__init__() self.modelchoice = modelchoice if modelchoice == 'xception': - self.model = return_pytorch04_xception() + self.model = return_pytorch04_xception(pretrained) # Replace fc num_ftrs = self.model.last_linear.in_features if not dropout: @@ -116,18 +116,21 @@ def forward(self, x): def model_selection(modelname, num_out_classes, - dropout=None): + dropout=None, pretrained=True): """ :param modelname: :return: model, image size, pretraining, input_list """ if modelname == 'xception': return TransferModel(modelchoice='xception', - num_out_classes=num_out_classes), 299, \ - True, ['image'], None + num_out_classes=num_out_classes, + pretrained=pretrained), \ + 299, True, ['image'], None elif modelname == 'resnet18': - return TransferModel(modelchoice='resnet18', dropout=dropout, - num_out_classes=num_out_classes), \ + return TransferModel(modelchoice='resnet18', + dropout=dropout, + num_out_classes=num_out_classes, + pretrained=pretrained), \ 224, True, ['image'], None else: raise NotImplementedError(modelname) From ad62be149f42689db39e0ba81ff9b6cb8d481d7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrique=20Mendon=C3=A7a?= Date: Sun, 5 Jan 2020 19:51:08 +0100 Subject: [PATCH 2/4] pretrained parameter in model_selection --- classification/detect_from_video.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/classification/detect_from_video.py b/classification/detect_from_video.py index 19c31ee..ee5d3ff 100644 --- a/classification/detect_from_video.py +++ b/classification/detect_from_video.py @@ -132,12 +132,13 @@ def test_full_image_network(video_path, model_path, output_path, face_detector = dlib.get_frontal_face_detector() # Load model - model, *_ = model_selection(modelname='xception', num_out_classes=2) + pretrained = (model_path is None) + model, *_ = model_selection(modelname='xception', num_out_classes=2, pretrained=pretrained) if model_path is not None: model = torch.load(model_path) print('Model found in {}'.format(model_path)) else: - print('No model found, initializing random model.') + print('No model found, using pretrained model.') if cuda: model = model.cuda() From 650930da684fe105e69a0dd88a3c8615abbed53f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrique=20Mendon=C3=A7a?= Date: Mon, 6 Jan 2020 12:16:43 +0100 Subject: [PATCH 3/4] typo in add_argument model_path --- classification/detect_from_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/classification/detect_from_video.py b/classification/detect_from_video.py index ee5d3ff..4bd705e 100644 --- a/classification/detect_from_video.py +++ b/classification/detect_from_video.py @@ -222,7 +222,7 @@ def test_full_image_network(video_path, model_path, output_path, p = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('--video_path', '-i', type=str) - p.add_argument('--model_path', '-mi', type=str, default=None) + p.add_argument('--model_path', '-m', type=str, default=None) p.add_argument('--output_path', '-o', type=str, default='.') p.add_argument('--start_frame', type=int, default=0) From 8cca5302b5e5f2fcaddf0d41e7be030c96745504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrique=20Mendon=C3=A7a?= Date: Mon, 6 Jan 2020 13:51:36 +0100 Subject: [PATCH 4/4] pretrained resnet --- classification/network/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/classification/network/models.py b/classification/network/models.py index 6b4fe9c..823c76b 100644 --- a/classification/network/models.py +++ b/classification/network/models.py @@ -55,9 +55,9 @@ def __init__(self, modelchoice, num_out_classes=2, dropout=0.0, pretrained=True) ) elif modelchoice == 'resnet50' or modelchoice == 'resnet18': if modelchoice == 'resnet50': - self.model = torchvision.models.resnet50(pretrained=True) + self.model = torchvision.models.resnet50(pretrained=pretrained) if modelchoice == 'resnet18': - self.model = torchvision.models.resnet18(pretrained=True) + self.model = torchvision.models.resnet18(pretrained=pretrained) # Replace fc num_ftrs = self.model.fc.in_features if not dropout: