From 841986aa0081bdeaf785d1ed4c48dd108fa69a78 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 8 Feb 2021 20:34:29 +0530 Subject: [PATCH] Remove redundant `num_features` arg from Classification model (#88) * remove num_features arg from Classification model * add annotations to args Co-authored-by: Jirka Borovec --- flash/vision/classification/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 69a3fd8c85..5528cfc5d6 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -38,10 +38,9 @@ class ImageClassifier(ClassificationTask): def __init__( self, - num_classes, - backbone="resnet18", - num_features: int = None, - pretrained=True, + num_classes: int, + backbone: str = "resnet18", + pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()),