Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer learning using pre trained objective detection model FCOS: Fully Convolutional One-Stage Object Detection architecture #5932

Closed
santhoshnumberone opened this issue May 3, 2022 · 8 comments

Comments

@santhoshnumberone
Copy link

santhoshnumberone commented May 3, 2022

🚀 The feature

FCOS architecture
FCOSs

I used requires_grad to freeze all the layers in the network and decided to train only the classification which used to predict 90 classes to train only one class

Used the code below to replace the entire classification block and pick 77th label only, replaced entire weight and bais classification with 77th class.

# load an object detection model pre-trained on COCO
    model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True)
    
    selected_head_classification_head_cls_logits_weight = 0
    selected_head_classification_head_cls_logits_bias = 0

    # List out all the name of the parameters whose gradient can be altered for further training
    for name, param in model.named_parameters():

        # If requires gradient parameters
        if param.requires_grad:

            if name == "head.classification_head.cls_logits.bias" or name == "head.classification_head.cls_logits.weight":
                layer_para_name = "weight" if name.split('.')[-1]=='weight' else "bias"
                print("\nReplacing",name,"layer containing 90 class score",layer_para_name,", with",selected_label,"th class score",layer_para_name)          
                print("####################################")
                print("Original layer size(0:Background + 90 classes): ",param.data.size())
                # Reshaping bias
                if name.split('.')[-1] == 'bias':
                    selected_head_classification_head_cls_logits_bias = param.data[selected_label:selected_label+1]
                    param.data = torch.cat([param.data[:1], selected_head_classification_head_cls_logits_bias])
                # Reshaping weight
                if name.split('.')[-1] == 'weight':
                    selected_head_classification_head_cls_logits_weight = torch.tensor(param.data[selected_label][:].reshape([1, 256,3,3]))
                    param.data = torch.cat([param.data[:1], selected_head_classification_head_cls_logits_weight])
                print("Alteres layer size(0:Background +",selected_label,"th class): ",param.data.size())
                print("####################################")
                print("Finished enabling requires gradient for",name,"layer......")
                # Make the layer trainable
                param.requires_grad = True

            else:
                # Make the layer non-trainable
                param.requires_grad = False

I got this output

Replacing head.classification_head.cls_logits.weight layer containing 90 class score weight , with 77 th class score weight
####################################
Original layer size(0:Background + 90 classes):  torch.Size([91, 256, 3, 3])
Alteres layer size(0:Background + 77 th class):  torch.Size([2, 256, 3, 3])
####################################
Finished enabling requires gradient for head.classification_head.cls_logits.weight layer......

Replacing head.classification_head.cls_logits.bias layer containing 90 class score bias , with 77 th class score bias
####################################
Original layer size(0:Background + 90 classes):  torch.Size([91])
Alteres layer size(0:Background + 77 th class):  torch.Size([2])
####################################
Finished enabling requires gradient for head.classification_head.cls_logits.bias layer......
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:26: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
FCOS(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, eps=1e-05)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=1e-05)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=1e-05)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=1e-05)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(512, eps=1e-05)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=1e-05)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=1e-05)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=1e-05)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=1e-05)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=1e-05)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=1e-05)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=1e-05)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=1e-05)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=1e-05)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(1024, eps=1e-05)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=1e-05)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=1e-05)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=1e-05)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=1e-05)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=1e-05)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=1e-05)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=1e-05)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=1e-05)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=1e-05)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=1e-05)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=1e-05)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=1e-05)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=1e-05)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(2048, eps=1e-05)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=1e-05)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=1e-05)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=1e-05)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=1e-05)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=1e-05)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (layer_blocks): ModuleList(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (extra_blocks): LastLevelP6P7(
        (p6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (p7): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
    )
  )
  (anchor_generator): AnchorGenerator()
  (head): FCOSHead(
    (classification_head): FCOSClassificationHead(
      (conv): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm(32, 256, eps=1e-05, affine=True)
        (2): ReLU()
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): GroupNorm(32, 256, eps=1e-05, affine=True)
        (5): ReLU()
        (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): GroupNorm(32, 256, eps=1e-05, affine=True)
        (8): ReLU()
        (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (10): GroupNorm(32, 256, eps=1e-05, affine=True)
        (11): ReLU()
      )
      (cls_logits): Conv2d(256, 91, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (regression_head): FCOSRegressionHead(
      (conv): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm(32, 256, eps=1e-05, affine=True)
        (2): ReLU()
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): GroupNorm(32, 256, eps=1e-05, affine=True)
        (5): ReLU()
        (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): GroupNorm(32, 256, eps=1e-05, affine=True)
        (8): ReLU()
        (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (10): GroupNorm(32, 256, eps=1e-05, affine=True)
        (11): ReLU()
      )
      (bbox_reg): Conv2d(256, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bbox_ctrness): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
)

Motivation, pitch

Following this script An Instance segmentation model for PennFudan Dataset and Building your own object detector — PyTorch vs TensorFlow and how to even get started?

There's no Predictor class for FCOS: Fully Convolutional One-Stage Object Detection architecture pytorch model /torchvision/models/detection/fcos.py like for faster-RCNN Predictor or MaskRCNN Predictor

When I started training the model, with the script below

num_epochs = 10
for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch,print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

I got this error

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:490: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-17-05e881bbc3b2> in <module>()
      2 for epoch in range(num_epochs):
      3     # train for one epoch, printing every 10 iterations
----> 4     train_one_epoch(model, optimizer, data_loader, device, epoch,print_freq=10)
      5     # update the learning rate
      6     lr_scheduler.step()

6 frames
/content/drive/MyDrive/PytorchObjectDetector/engine.py in train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq)
     30         print("######################",targets)
     31 
---> 32         loss_dict = model(images, targets)
     33 
     34         losses = sum(loss for loss in loss_dict.values())

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/fcos.py in forward(self, images, targets)
    594 
    595         # compute the fcos heads outputs using the features
--> 596         head_outputs = self.head(features)
    597 
    598         # create the set of anchors

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/fcos.py in forward(self, x)
    120 
    121     def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
--> 122         cls_logits = self.classification_head(x)
    123         bbox_regression, bbox_ctrness = self.regression_head(x)
    124         return {

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/fcos.py in forward(self, x)
    184             # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
    185             N, _, H, W = cls_logits.shape
--> 186             cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
    187             cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
    188             cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, 4)

RuntimeError: shape '[2, -1, 91, 168, 96]' is invalid for input of size 64512

What do I do to alter this and get it working?

Alternatives

I have no idea how to correct the issue

Additional context

Looks like dimension mismatch but how do I correct it? and where do I correct it?

cc @datumbox @vfdev-5 @YosuaMichael

@datumbox
Copy link
Contributor

datumbox commented May 9, 2022

@santhoshnumberone The FCOSClassificationHead receives the num_classes as parameter and it uses it to estimate the dimensions of cls_logits and manipulate the output during forward. See:

self.num_classes = num_classes

Though it's possible to modify the entire network to replace all references of num_classes with the new value, I think accuracy-wise you should get similar results to just keeping the pre-trained weights of the backbones and initializing the heads from scratch.

I'm going to close the issue, as I believe this answers your question but feel free to reopen if you face issues.

@datumbox datumbox closed this as completed May 9, 2022
@santhoshnumberone
Copy link
Author

@santhoshnumberone The FCOSClassificationHead receives the num_classes as parameter and it uses it to estimate the dimensions of cls_logits and manipulate the output during forward. See:

self.num_classes = num_classes

Though it's possible to modify the entire network to replace all references of num_classes with the new value, I think accuracy-wise you should get similar results to just keeping the pre-trained weights of the backbones and initializing the heads from scratch.

I'm going to close the issue, as I believe this answers your question but feel free to reopen if you face issues.

I have an interesting finding to disprove your thinking of accuracy remaining same no matter training head from scratch or with pre trained weight.

I did the same with Faster R-CNN ResNet-50 FPN

Replacing prediction head with label 77 Cell Phone weight which is what my custom object was detected as when inferred with the original pre trained object detection network

Replacing roi_heads.box_predictor.cls_score.weight layer containing 90 class score weight , with 77 th class score weight
####################################
Original layer size(0:Background + 90 classes):  torch.Size([91, 1024])
Alteres layer size(0:Background + 77 th class):  torch.Size([2, 1024])
####################################
Finished enabling requires gradient for roi_heads.box_predictor.cls_score.weight layer......

Replacing roi_heads.box_predictor.cls_score.bias layer containing 90 class score bias , with 77 th class score bias
####################################
Original layer size(0:Background + 90 classes):  torch.Size([91])
Alteres layer size(0:Background + 77 th class):  torch.Size([2])
####################################
Finished enabling requires gradient for roi_heads.box_predictor.cls_score.bias layer......

I trained using Adam+LambdaLR

optimizer = torch.optim.Adam(params,lr=0.005,betas=(0.9,0.999),eps=1e-08,weight_decay=0.0005,amsgrad=False)
lambda1 = lambda epoch: 0.65 ** epoch
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

Trained for 10 epochs

num_epochs = 10
for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch,print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Here's the output for first two epochs

Epoch: [0]  [188/189]  eta: 0:00:05  lr: 0.005000  loss: 0.1684 (0.1853)  loss_classifier: 0.0370 (0.0534)  loss_box_reg: 0.1215 (0.1199)  loss_objectness: 0.0060 (0.0063)  loss_rpn_box_reg: 0.0054 (0.0056)  time: 5.4183  data: 0.0438  max mem: 8036
Epoch: [0] Total time: 0:16:33 (5.2578 s / it)
Test: Total time: 0:02:22 (0.3780 s / it)

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.062
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.166
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.039
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.063
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.104
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.329
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.329
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.329

Epoch: [1]  [188/189]  eta: 0:00:05  lr: 0.003250  loss: 0.1665 (0.1679)  loss_classifier: 0.0327 (0.0352)  loss_box_reg: 0.1244 (0.1207)  loss_objectness: 0.0067 (0.0065)  loss_rpn_box_reg: 0.0047 (0.0056)  time: 4.8139  data: 0.0454  max mem: 8036
Epoch: [1] Total time: 0:16:07 (5.1194 s / it)
Test: Total time: 0:02:23 (0.3783 s / it)

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.058
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.161
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.033
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.060
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.100
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.337
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.337
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.337

Here's the training result for first two epochs when trained head from scratch

Epoch: [0]  [755/756]  eta: 0:00:00  lr: 0.005000  loss: 0.2446 (4932930780181.6865)  loss_classifier: 0.1180 (2722672846675.1851)  loss_box_reg: 0.1029 (2209564656677.2778)  loss_objectness: 0.0182 (185259575.7658)  loss_rpn_box_reg: 0.0049 (507922124.5077)  time: 0.6842  data: 0.0082  max mem: 4576
Epoch: [0] Total time: 0:09:34 (0.7604 s / it)

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000

 Epoch: [1]  [755/756]  eta: 0:00:00  lr: 0.003250  loss: 0.2392 (204026750352061792256.0000)  loss_classifier: 0.1184 (75947324116073545728.0000)  loss_box_reg: 0.0976 (128079420092685189120.0000)  loss_objectness: 0.0207 (0.0186)  loss_rpn_box_reg: 0.0061 (0.0065)  time: 0.7242  data: 0.0082  max mem: 4576
Epoch: [1] Total time: 0:09:26 (0.7489 s / it)

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000

So you can see the difference here training from pre trained weights gives better result than when trained from scratch

So request you to look into what the actual issue with this particular FCOS model implementation in pytorch

@datumbox
Copy link
Contributor

@santhoshnumberone You are using optimizer configuration that I haven't tested but the 0.0 accuracies on the from scratch might indicate an issue of LR or initialisation. Note that 2 epochs are not nearly enough to get good results. Our recent refresh of the models used 400 epochs to get near SOTA results (see #5763 for details to copy the recipe and try it out).

So far, I don't see any evidence that there is a problem with the implementation of FCOS. We have fully reproduced the result of the paper and thus I think it's more likely that the issue could be on the way you train or modify the model. I recommend trying the recipe posted above to see what kind of results you get between from scratch and pre-trained.

@santhoshnumberone
Copy link
Author

@santhoshnumberone You are using optimizer configuration that I haven't tested but the 0.0 accuracies on the from scratch might indicate an issue of LR or initialisation. Note that 2 epochs are not nearly enough to get good results. Our recent refresh of the models used 400 epochs to get near SOTA results (see #5763 for details to copy the recipe and try it out).

So far, I don't see any evidence that there is a problem with the implementation of FCOS. We have fully reproduced the result of the paper and thus I think it's more likely that the issue could be on the way you train or modify the model. I recommend trying the recipe posted above to see what kind of results you get between from scratch and pre-trained.

I agree, I don't have resources(GPU or 18Mins/epoch on google colab with GPU) to justify my claims that for this particular case of transfer learning for custom object detection(I can't disclose the data as well).

I could have achieved better mAP and lower losses while training loading head with pre trained weights rather than from scratch with far less epochs

Could you please guide me as to how to overcome getting this error while trying to train with FCOS?

/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/fcos.py in forward(self, x)
    184             # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
    185             N, _, H, W = cls_logits.shape
--> 186             cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
    187             cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
    188             cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, 4)

RuntimeError: shape '[2, -1, 91, 168, 96]' is invalid for input of size 64512

@datumbox
Copy link
Contributor

I agree, I don't have resources(GPU or 18Mins/epoch on google colab with GPU) to justify my claims that for this particular case of transfer learning for custom object detection(I can't disclose the data as well).

Understood but it's very hard to provide much help if I can't reproduce the problem. I can only tell you what I believe is not but due to the lack of info I'm forced to guess.

RuntimeError: shape '[2, -1, 91, 168, 96]' is invalid for input of size 64512

As mentioned earlier, the 91 here is the number of classes + background. You would need to modify num_classes everywhere to avoid this (set it to 2 in your case). You can try to trace all the usages of num_classes through the code and modify the network to adjust to the classes. But if I were you, I would just grab the pretrained backbone of FCOS and just train from scratch the Detection heads. This would require less iterations but still doing about 20-25 epochs would be needed.

@santhoshnumberone
Copy link
Author

I agree, I don't have resources(GPU or 18Mins/epoch on google colab with GPU) to justify my claims that for this particular case of transfer learning for custom object detection(I can't disclose the data as well).

Understood but it's very hard to provide much help if I can't reproduce the problem. I can only tell you what I believe is not but due to the lack of info I'm forced to guess.

RuntimeError: shape '[2, -1, 91, 168, 96]' is invalid for input of size 64512

As mentioned earlier, the 91 here is the number of classes + background. You would need to modify num_classes everywhere to avoid this (set it to 2 in your case). You can try to trace all the usages of num_classes through the code and modify the network to adjust to the classes. But if I were you, I would just grab the pretrained backbone of FCOS and just train from scratch the Detection heads. This would require less iterations but still doing about 20-25 epochs would be needed.

Since 91(0:background+90:Classes) is hard coded it would be required to be changed manually whenever I get the error.

Let me try it and get back with result.

I wonder how come this is not the issue with the code of Faster R-CNN ResNet-50 FPN

@k-bandi
Copy link

k-bandi commented Jun 3, 2022

Hi, I have a question regarding this. I'm trying to do transfer learning using FCOS but with a completely different set of classes compared to the pre-trained case. I couldn't find any tutorial on how to modify the layers to accommodate for the custom classes. I still want to leverage the pre-trained weights. Could you please help @datumbox ?

@k-bandi
Copy link

k-bandi commented Jun 6, 2022

@santhoshnumberone and @datumbox - I figured out how to do this. Here's a way you can modify the FCOS to custom number of classes.

def make_custom_object_detection_model_fcos(num_classes):
    model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True)  # load an object detection model pre-trained on COCO
    num_anchors = model.head.classification_head.num_anchors
    model.head.classification_head.num_classes = num_classes
    out_channels = model.head.classification_head.conv[9].out_channels
    cls_logits = torch.nn.Conv2d(out_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
    torch.nn.init.normal_(cls_logits.weight, std=0.01)
    torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))

    model.head.classification_head.cls_logits = cls_logits

    return model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants