diff --git a/classify/train.py b/classify/train.py index d55dc066d7a3..9fb7c52b545a 100644 --- a/classify/train.py +++ b/classify/train.py @@ -114,13 +114,13 @@ def train(opt, device): LOGGER.warning("WARNING: pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'") model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model reshape_classifier_output(model, nc) # update class count - for p in model.parameters(): - p.requires_grad = True # for training for m in model.modules(): if not pretrained and hasattr(m, 'reset_parameters'): m.reset_parameters() if isinstance(m, torch.nn.Dropout) and opt.dropout is not None: m.p = opt.dropout # set dropout + for p in model.parameters(): + p.requires_grad = True # for training model = model.to(device) names = trainloader.dataset.classes # class names model.names = names # attach class names