From f5613c0a572383aaeed0d1bc47a3b18fa2b5d676 Mon Sep 17 00:00:00 2001 From: nemonameless Date: Tue, 31 Jan 2023 14:14:09 +0000 Subject: [PATCH 1/5] add ppyoloe distill with soft loss and feature loss --- 0run_x_to_l.sh | 18 + 1small_m_to_s.sh | 18 + .../distill/ppyoloe_plus_crn_l_80e_coco.yml | 19 + .../distill/ppyoloe_plus_crn_m_80e_coco.yml | 19 + .../distill/ppyoloe_plus_crn_s_80e_coco.yml | 19 + .../distill/ppyoloe_plus_distill_l_to_m.yml | 26 + .../distill/ppyoloe_plus_distill_m_to_s.yml | 26 + .../distill/ppyoloe_plus_distill_x_to_l.yml | 26 + ppdet/modeling/architectures/ppyoloe.py | 9 +- ppdet/modeling/assigners/atss_assigner.py | 2 +- .../assigners/task_aligned_assigner.py | 2 +- ppdet/modeling/heads/ppyoloe_head.py | 57 +- ppdet/slim/__init__.py | 4 + ppdet/slim/distill_ppyoloe.py | 695 ++++++++++++++++++ 14 files changed, 921 insertions(+), 19 deletions(-) create mode 100644 0run_x_to_l.sh create mode 100644 1small_m_to_s.sh create mode 100644 configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco.yml create mode 100644 configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco.yml create mode 100644 configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco.yml create mode 100644 configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml create mode 100644 configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml create mode 100644 configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml create mode 100644 ppdet/slim/distill_ppyoloe.py diff --git a/0run_x_to_l.sh b/0run_x_to_l.sh new file mode 100644 index 00000000000..de52857a76a --- /dev/null +++ b/0run_x_to_l.sh @@ -0,0 +1,18 @@ +export FLAGS_allocator_strategy=auto_growth +model_type=ppyoloe/distill +job_name=ppyoloe_plus_crn_l_80e_coco +job_name_tea=ppyoloe_plus_distill_x_to_l + +config=configs/${model_type}/${job_name}.yml +slim_config=configs/${model_type}/${job_name_tea}.yml +log_dir=log_dir/${job_name} +weights=output/${job_name_tea}/model_final.pdparams + +# 1. training +#CUDA_VISIBLE_DEVICES=3 python3.7 tools/train.py -c ${config} --slim_config ${slim_config} #--eval --amp +python3.7 -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --slim_config ${slim_config} --eval +# -r output/ppyoloe_plus_distill_x_to_l/14 --amp + +# 2. eval +#CUDA_VISIBLE_DEVICES=0 python3.7 tools/eval.py -c ${config} -o weights=https://paddledet.bj.bcebos.com/models/${job_name}.pdparams +#CUDA_VISIBLE_DEVICES=2 python3.7 tools/eval.py -c ${config} -o weights=${weights} diff --git a/1small_m_to_s.sh b/1small_m_to_s.sh new file mode 100644 index 00000000000..932f8143e1e --- /dev/null +++ b/1small_m_to_s.sh @@ -0,0 +1,18 @@ +export FLAGS_allocator_strategy=auto_growth +model_type=ppyoloe/distill +job_name=ppyoloe_plus_crn_s_80e_coco +job_name_tea=ppyoloe_plus_distill_m_to_s + +config=configs/${model_type}/${job_name}.yml +slim_config=configs/${model_type}/${job_name_tea}.yml +log_dir=log_dir/${job_name} +weights=output/${job_name_tea}/model_final.pdparams + +# 1. training +#CUDA_VISIBLE_DEVICES=3 python3.7 tools/train.py -c ${config} --slim_config ${slim_config} #--eval --amp +python3.7 -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --slim_config ${slim_config} --eval +#-r output/ppyoloe_plus_distill_m_to_s/14 # --amp + +# 2. eval +#CUDA_VISIBLE_DEVICES=0 python3.7 tools/eval.py -c ${config} -o weights=https://paddledet.bj.bcebos.com/models/${job_name}.pdparams +#CUDA_VISIBLE_DEVICES=2 python3.7 tools/eval.py -c ${config} -o weights=${weights} diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco.yml new file mode 100644 index 00000000000..ffb4af2e23e --- /dev/null +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco.yml @@ -0,0 +1,19 @@ +_BASE_: [ + '../ppyoloe_plus_crn_l_80e_coco.yml', +] +architecture: PPYOLOE +PPYOLOE: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + for_distill: True + + +log_iter: 100 +snapshot_epoch: 5 +weights: output/ppyoloe_plus_crn_l_80e_coco/model_final + +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams +depth_mult: 1.0 +width_mult: 1.0 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco.yml new file mode 100644 index 00000000000..63e95a706ad --- /dev/null +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco.yml @@ -0,0 +1,19 @@ +_BASE_: [ + '../ppyoloe_plus_crn_m_80e_coco.yml', +] +architecture: PPYOLOE +PPYOLOE: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + for_distill: True + + +log_iter: 100 +snapshot_epoch: 5 +weights: output/ppyoloe_plus_crn_m_80e_coco/model_final + +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_m_obj365_pretrained.pdparams +depth_mult: 0.67 +width_mult: 0.75 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco.yml new file mode 100644 index 00000000000..0c13205d0de --- /dev/null +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco.yml @@ -0,0 +1,19 @@ +_BASE_: [ + '../ppyoloe_plus_crn_s_80e_coco.yml', +] +architecture: PPYOLOE +PPYOLOE: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + for_distill: True + + +log_iter: 100 +snapshot_epoch: 5 +weights: output/ppyoloe_plus_crn_s_80e_coco/model_final + +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams +depth_mult: 0.33 +width_mult: 0.50 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml b/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml new file mode 100644 index 00000000000..98be3cb2a51 --- /dev/null +++ b/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml @@ -0,0 +1,26 @@ +# teacher config +_BASE_: [ + '../ppyoloe_plus_crn_l_80e_coco.yml', +] +depth_mult: 1.0 +width_mult: 1.0 + +architecture: PPYOLOE +PPYOLOE: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + for_distill: True + +pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams +find_unused_parameters: True + + +slim: Distill +slim_method: PPYOLOEDistill +distill_loss: DistillPPYOLOELoss + +DistillPPYOLOELoss: # L -> M + teacher_width_mult: 1.0 + student_width_mult: 0.75 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml b/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml new file mode 100644 index 00000000000..3e54fc90ab0 --- /dev/null +++ b/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml @@ -0,0 +1,26 @@ +# teacher config +_BASE_: [ + '../ppyoloe_plus_crn_l_80e_coco.yml', +] +depth_mult: 0.67 +width_mult: 0.75 + +architecture: PPYOLOE +PPYOLOE: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + for_distill: True + +pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_m_80e_coco.pdparams +find_unused_parameters: True + + +slim: Distill +slim_method: PPYOLOEDistill +distill_loss: DistillPPYOLOELoss + +DistillPPYOLOELoss: # M -> S + teacher_width_mult: 0.75 + student_width_mult: 0.50 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml b/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml new file mode 100644 index 00000000000..6ac9809a597 --- /dev/null +++ b/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml @@ -0,0 +1,26 @@ +# teacher config +_BASE_: [ + '../ppyoloe_plus_crn_x_80e_coco.yml', +] +depth_mult: 1.33 +width_mult: 1.25 + +architecture: PPYOLOE +PPYOLOE: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + for_distill: True + +pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams +find_unused_parameters: True + + +slim: Distill +slim_method: PPYOLOEDistill +distill_loss: DistillPPYOLOELoss + +DistillPPYOLOELoss: # X -> L + teacher_width_mult: 1.25 + student_width_mult: 1.0 diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py index 7ff7c254da9..a48d18c0acd 100644 --- a/ppdet/modeling/architectures/ppyoloe.py +++ b/ppdet/modeling/architectures/ppyoloe.py @@ -36,6 +36,7 @@ def __init__(self, neck='CustomCSPPAN', yolo_head='PPYOLOEHead', post_process='BBoxPostProcess', + for_distill=False, for_mot=False): """ PPYOLOE network, see https://arxiv.org/abs/2203.16250 @@ -54,6 +55,7 @@ def __init__(self, self.yolo_head = yolo_head self.post_process = post_process self.for_mot = for_mot + self.for_distill = for_distill @classmethod def from_config(cls, cfg, *args, **kwargs): @@ -80,7 +82,12 @@ def _forward(self): if self.training: yolo_losses = self.yolo_head(neck_feats, self.inputs) - return yolo_losses + + if self.for_distill: + self.yolo_head.distill_pairs['emb_feats'] = neck_feats + return {'det_losses': yolo_losses, 'emb_feats': neck_feats} + else: + return yolo_losses else: yolo_head_outs = self.yolo_head(neck_feats) if self.post_process is not None: diff --git a/ppdet/modeling/assigners/atss_assigner.py b/ppdet/modeling/assigners/atss_assigner.py index ec7d23448fc..a1e753c9434 100644 --- a/ppdet/modeling/assigners/atss_assigner.py +++ b/ppdet/modeling/assigners/atss_assigner.py @@ -221,4 +221,4 @@ def forward(self, paddle.zeros_like(gather_scores)) assigned_scores *= gather_scores.unsqueeze(-1) - return assigned_labels, assigned_bboxes, assigned_scores + return assigned_labels, assigned_bboxes, assigned_scores, mask_positive diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py index 23af79439ae..5a756fa67da 100644 --- a/ppdet/modeling/assigners/task_aligned_assigner.py +++ b/ppdet/modeling/assigners/task_aligned_assigner.py @@ -190,4 +190,4 @@ def forward(self, alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1) assigned_scores = assigned_scores * alignment_metrics - return assigned_labels, assigned_bboxes, assigned_scores + return assigned_labels, assigned_bboxes, assigned_scores, mask_positive diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index d29e9ac73ac..93855387869 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -134,6 +134,7 @@ def __init__(self, self.proj_conv = nn.Conv2D(self.reg_channels, 1, 1, bias_attr=False) self.proj_conv.skip_quant = True self._init_weights() + self.distill_pairs = {} @classmethod def from_config(cls, cfg, input_shape): @@ -321,10 +322,14 @@ def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels, loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos, self.reg_range[0]) * bbox_weight loss_dfl = loss_dfl.sum() / assigned_scores_sum + self.distill_pairs['pred_bboxes_pos'] = pred_bboxes_pos + self.distill_pairs['pred_dist_pos'] = pred_dist_pos + self.distill_pairs['bbox_weight'] = bbox_weight else: loss_l1 = paddle.zeros([1]) loss_iou = paddle.zeros([1]) loss_dfl = pred_dist.sum() * 0. + self.distill_pairs['null_loss'] = pred_dist.sum() * 0. return loss_l1, loss_iou, loss_dfl def get_loss(self, head_outs, gt_meta, aux_pred=None): @@ -343,7 +348,7 @@ def get_loss(self, head_outs, gt_meta, aux_pred=None): pad_gt_mask = gt_meta['pad_gt_mask'] # label assignment if gt_meta['epoch_id'] < self.static_assigner_epoch: - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ self.static_assigner( anchors, num_anchors_list, @@ -356,7 +361,7 @@ def get_loss(self, head_outs, gt_meta, aux_pred=None): else: if self.sm_use: # only used in smalldet of PPYOLOE-SOD model - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ self.assigner( pred_scores.detach(), pred_bboxes.detach() * stride_tensor, @@ -368,18 +373,28 @@ def get_loss(self, head_outs, gt_meta, aux_pred=None): bg_index=self.num_classes) else: if aux_pred is None: - assigned_labels, assigned_bboxes, assigned_scores = \ - self.assigner( - pred_scores.detach(), - pred_bboxes.detach() * stride_tensor, - anchor_points, - num_anchors_list, - gt_labels, - gt_bboxes, - pad_gt_mask, - bg_index=self.num_classes) + if not hasattr(self, "assigned_labels"): + assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ + self.assigner( + pred_scores.detach(), + pred_bboxes.detach() * stride_tensor, + anchor_points, + num_anchors_list, + gt_labels, + gt_bboxes, + pad_gt_mask, + bg_index=self.num_classes) + self.assigned_labels = assigned_labels + self.assigned_bboxes = assigned_bboxes + self.assigned_scores = assigned_scores + self.mask_positive = mask_positive + else: + assigned_labels = self.assigned_labels + assigned_bboxes = self.assigned_bboxes + assigned_scores = self.assigned_scores + mask_positive = self.mask_positive else: - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ self.assigner( pred_scores_aux.detach(), pred_bboxes_aux.detach() * stride_tensor, @@ -395,12 +410,14 @@ def get_loss(self, head_outs, gt_meta, aux_pred=None): assign_out_dict = self.get_loss_from_assign( pred_scores, pred_distri, pred_bboxes, anchor_points_s, - assigned_labels, assigned_bboxes, assigned_scores, alpha_l) + assigned_labels, assigned_bboxes, assigned_scores, mask_positive, + alpha_l) if aux_pred is not None: assign_out_dict_aux = self.get_loss_from_assign( aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s, - assigned_labels, assigned_bboxes, assigned_scores, alpha_l) + assigned_labels, assigned_bboxes, assigned_scores, + mask_positive, alpha_l) loss = {} for key in assign_out_dict.keys(): loss[key] = assign_out_dict[key] + assign_out_dict_aux[key] @@ -411,7 +428,7 @@ def get_loss(self, head_outs, gt_meta, aux_pred=None): def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes, anchor_points_s, assigned_labels, assigned_bboxes, - assigned_scores, alpha_l): + assigned_scores, mask_positive, alpha_l): # cls loss if self.use_varifocal_loss: one_hot_label = F.one_hot(assigned_labels, @@ -428,6 +445,14 @@ def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes, assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.) loss_cls /= assigned_scores_sum + self.distill_pairs['pred_cls_scores'] = pred_scores + self.distill_pairs['pos_num'] = assigned_scores_sum + self.distill_pairs['assigned_scores'] = assigned_scores + self.distill_pairs['mask_positive'] = mask_positive + one_hot_label = F.one_hot(assigned_labels, + self.num_classes + 1)[..., :-1] + self.distill_pairs['target_labels'] = one_hot_label + loss_l1, loss_iou, loss_dfl = \ self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s, assigned_labels, assigned_bboxes, assigned_scores, diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index 17ffb030ec1..11d4d29840d 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -20,6 +20,7 @@ from .prune import * from .quant import * from .distill import * +from .distill_ppyoloe import * from .unstructured_prune import * from .ofa import * @@ -45,6 +46,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'): elif "slim_method" in slim_load_cfg and slim_load_cfg[ 'slim_method'] == "CWD": model = CWDDistillModel(cfg, slim_cfg) + elif "slim_method" in slim_load_cfg and slim_load_cfg[ + 'slim_method'] == "PPYOLOEDistill": + model = PPYOLOEDistillModel(cfg, slim_cfg) else: model = DistillModel(cfg, slim_cfg) cfg['model'] = model diff --git a/ppdet/slim/distill_ppyoloe.py b/ppdet/slim/distill_ppyoloe.py new file mode 100644 index 00000000000..315a9a5700b --- /dev/null +++ b/ppdet/slim/distill_ppyoloe.py @@ -0,0 +1,695 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register, create, load_config +from ppdet.utils.checkpoint import load_pretrain_weight +from .distill import parameter_init +from ppdet.modeling.losses.iou_loss import GIoULoss +from ppdet.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class PPYOLOEDistillModel(nn.Layer): + def __init__(self, cfg, slim_cfg): + super(PPYOLOEDistillModel, self).__init__() + self.student_model = create(cfg.architecture) + logger.debug('Load student model pretrain_weights:{}'.format( + cfg.pretrain_weights)) + load_pretrain_weight(self.student_model, cfg.pretrain_weights) + + slim_cfg = load_config(slim_cfg) + self.teacher_model = create(slim_cfg.architecture) + self.distill_loss = create(slim_cfg.distill_loss) + logger.debug('Load teacher model pretrain_weights:{}'.format( + slim_cfg.pretrain_weights)) + load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights) + + for param in self.teacher_model.parameters(): + param.trainable = False + + def parameters(self): + return self.student_model.parameters() + + def forward(self, inputs, alpha=0.125): + if self.training: + with paddle.no_grad(): + teacher_out = self.teacher_model(inputs) + + if hasattr(self.teacher_model.yolo_head, "assigned_labels"): + self.student_model.yolo_head.assigned_labels, self.student_model.yolo_head.assigned_bboxes, self.student_model.yolo_head.assigned_scores, self.student_model.yolo_head.mask_positive = \ + self.teacher_model.yolo_head.assigned_labels, self.teacher_model.yolo_head.assigned_bboxes, self.teacher_model.yolo_head.assigned_scores, self.teacher_model.yolo_head.mask_positive + delattr(self.teacher_model.yolo_head, "assigned_labels") + delattr(self.teacher_model.yolo_head, "assigned_bboxes") + delattr(self.teacher_model.yolo_head, "assigned_scores") + delattr(self.teacher_model.yolo_head, "mask_positive") + + student_out = self.student_model(inputs) + + # head loss concerned + soft_loss, feat_loss, distill_loss_dict = self.distill_loss( + self.teacher_model, self.student_model) + stu_loss = student_out['det_losses'] + stu_det_total_loss = stu_loss['loss'] + + # conbined distill + stu_loss[ + 'loss'] = soft_loss + alpha * feat_loss + alpha * stu_det_total_loss + stu_loss['soft_loss'] = soft_loss + stu_loss['feat_loss'] = feat_loss + return stu_loss + else: + return self.student_model(inputs) + + +@register +class DistillPPYOLOELoss(nn.Layer): + def __init__( + self, + teacher_width_mult=1.0, # default as L + student_width_mult=0.75, # default as M + neck_out_channels=[768, 384, 192], # default as L + loss_weight={ + 'class': 0.5, + 'iou': 1.25, + 'dfl': 0.25, + }, + kd_neck=True, + kd_type='fgd'): + super(DistillPPYOLOELoss, self).__init__() + self.loss_bbox = GIoULoss() + self.bbox_loss_weight = loss_weight['iou'] + self.dfl_loss_weight = loss_weight['dfl'] + self.qfl_loss_weight = loss_weight['class'] + + self.kd_neck = kd_neck + self.kd_type = kd_type + if self.kd_neck: + # Knowledge Distillation for Detectors in necks + distill_loss_module_list = [] + self.t_channel_list = [ + int(c * teacher_width_mult) for c in neck_out_channels + ] + self.s_channel_list = [ + int(c * student_width_mult) for c in neck_out_channels + ] + for i in range(len(neck_out_channels)): + if self.kd_type == 'fgd': + distill_loss_module = FGDLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i]) + elif self.kd_type == 'pkd': + distill_loss_module = PKDLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i], + resize_stu=False) + elif self.kd_type == 'mgd': + distill_loss_module = MGDSSIMLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i]) + else: + raise ValueError + distill_loss_module_list.append(distill_loss_module) + + self.distill_loss_module_list = nn.LayerList( + distill_loss_module_list) + + def bbox_loss(self, s_bbox, t_bbox, weight_targets=None): + # [x,y,w,h] + if weight_targets is not None: + loss_bbox = paddle.sum( + self.loss_bbox(s_bbox, t_bbox) * weight_targets) + avg_factor = weight_targets.sum() + loss_bbox = loss_bbox / avg_factor + else: + loss_bbox = paddle.mean(self.loss_bbox(s_bbox, t_bbox)) + return loss_bbox + + def quality_focal_loss(self, pred_logits, soft_target_logits, beta=2.0, \ + use_sigmoid=True, label_weights=None, num_total_pos=None, pos_mask=None): + if use_sigmoid: + func = F.binary_cross_entropy_with_logits + soft_target = F.sigmoid(soft_target_logits) + pred_sigmoid = F.sigmoid(pred_logits) + preds = pred_logits + else: + func = F.binary_cross_entropy + soft_target = soft_target_logits + pred_sigmoid = pred_logits + preds = pred_sigmoid + + scale_factor = pred_sigmoid - soft_target + loss = func( + preds, soft_target, reduction='none') * scale_factor.abs().pow(beta) + loss = loss + if pos_mask is not None: + loss *= pos_mask + + loss = loss.sum(1) + if label_weights is not None: + loss = loss * label_weights + if num_total_pos is not None: + loss = loss.sum() / num_total_pos + else: + loss = loss.mean() + return loss + + def distribution_focal_loss(self, pred_corners, target_corners, + weight_targets): + target_corners_label = paddle.nn.functional.softmax( + target_corners, axis=-1) + loss_dfl = paddle.nn.functional.cross_entropy( + pred_corners, + target_corners_label, + soft_label=True, + reduction='none') + loss_dfl = loss_dfl.sum(1) + if weight_targets is not None: + loss_dfl = loss_dfl * (weight_targets.expand([-1, 4]).reshape([-1])) + loss_dfl = loss_dfl.sum(-1) / weight_targets.sum() + else: + loss_dfl = loss_dfl.mean(-1) + loss_dfl = loss_dfl / 4.0 # 4 direction + return loss_dfl + + def forward(self, teacher_model, student_model): + teacher_distill_pairs = teacher_model.yolo_head.distill_pairs + student_distill_pairs = student_model.yolo_head.distill_pairs + distill_bbox_loss, distill_dfl_loss, distill_cls_loss = [], [], [] + distill_bbox_loss.append( + self.bbox_loss(student_distill_pairs['pred_bboxes_pos'], + teacher_distill_pairs['pred_bboxes_pos'].detach(), + weight_targets=student_distill_pairs['bbox_weight'] + ) if 'pred_bboxes_pos' in student_distill_pairs and \ + 'pred_bboxes_pos' in teacher_distill_pairs and \ + 'bbox_weight' in student_distill_pairs + else student_distill_pairs['null_loss'] + ) + distill_dfl_loss.append(self.distribution_focal_loss( + student_distill_pairs['pred_dist_pos'].reshape((-1, student_distill_pairs['pred_dist_pos'].shape[-1])), + teacher_distill_pairs['pred_dist_pos'].detach().reshape((-1, teacher_distill_pairs['pred_dist_pos'].shape[-1])), \ + weight_targets=student_distill_pairs['bbox_weight'] + ) if 'pred_dist_pos' in student_distill_pairs and \ + 'pred_dist_pos' in teacher_distill_pairs and \ + 'bbox_weight' in student_distill_pairs + else student_distill_pairs['null_loss'] + ) + distill_cls_loss.append( + self.quality_focal_loss( + student_distill_pairs['pred_cls_scores'].reshape(( + -1, student_distill_pairs['pred_cls_scores'].shape[-1])), + teacher_distill_pairs['pred_cls_scores'].detach().reshape(( + -1, teacher_distill_pairs['pred_cls_scores'].shape[-1])), + num_total_pos=student_distill_pairs['pos_num'], + use_sigmoid=False)) + distill_bbox_loss = paddle.add_n(distill_bbox_loss) + distill_cls_loss = paddle.add_n(distill_cls_loss) + distill_dfl_loss = paddle.add_n(distill_dfl_loss) + + if self.kd_neck: + # Knowledge Distillation for Detectors in necks + distill_neck_global_loss = [] + inputs = student_model.inputs + teacher_fpn_feats = teacher_distill_pairs['emb_feats'] + student_fpn_feats = student_distill_pairs['emb_feats'] + assert 'gt_bbox' in inputs + for i, distill_loss_module in enumerate( + self.distill_loss_module_list): + distill_neck_global_loss.append( + distill_loss_module(student_fpn_feats[i], teacher_fpn_feats[ + i], inputs)) + distill_neck_global_loss = paddle.add_n(distill_neck_global_loss) + else: + distill_neck_global_loss = paddle.to_tensor([0]) + + soft_loss = ( + distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * + self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight) + student_model.yolo_head.distill_pairs.clear() + teacher_model.yolo_head.distill_pairs.clear() + return soft_loss, \ + distill_neck_global_loss, \ + {'dfl_loss': distill_dfl_loss, 'qfl_loss': distill_cls_loss, 'bbox_loss': distill_bbox_loss} + + +@register +class FGDLoss(nn.Layer): + """ + Focal and Global Knowledge Distillation for Detectors + The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py + + Args: + student_channels (int): The number of channels in the student's FPN feature map. Default to 256. + teacher_channels (int): The number of channels in the teacher's FPN feature map. Default to 256. + normalize (bool): Whether to normalize the feature maps. + temp (float, optional): The temperature coefficient. Defaults to 0.5. + alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 + beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 + gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 + lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 + """ + + def __init__( + self, + student_channels=256, + teacher_channels=256, + normalize=True, + temp=0.5, + alpha_fgd=0.00001, # 0.001 + beta_fgd=0.000005, # 0.0005 + gamma_fgd=0.00001, # 0.001 + lambda_fgd=0.00000005): # 0.000005 + super(FGDLoss, self).__init__() + self.temp = temp + self.alpha_fgd = alpha_fgd + self.beta_fgd = beta_fgd + self.gamma_fgd = gamma_fgd + self.lambda_fgd = lambda_fgd + self.normalize = normalize + kaiming_init = parameter_init("kaiming") + zeros_init = parameter_init("constant", 0.0) + + if student_channels != teacher_channels: + self.align = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=kaiming_init) + student_channels = teacher_channels + else: + self.align = None + + self.conv_mask_s = nn.Conv2D( + student_channels, 1, kernel_size=1, weight_attr=kaiming_init) + self.conv_mask_t = nn.Conv2D( + teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) + + self.stu_conv_block = nn.Sequential( + nn.Conv2D( + student_channels, + student_channels // 2, + kernel_size=1, + weight_attr=zeros_init), + nn.LayerNorm([student_channels // 2, 1, 1]), + nn.ReLU(), + nn.Conv2D( + student_channels // 2, + student_channels, + kernel_size=1, + weight_attr=zeros_init)) + self.tea_conv_block = nn.Sequential( + nn.Conv2D( + teacher_channels, + teacher_channels // 2, + kernel_size=1, + weight_attr=zeros_init), + nn.LayerNorm([teacher_channels // 2, 1, 1]), + nn.ReLU(), + nn.Conv2D( + teacher_channels // 2, + teacher_channels, + kernel_size=1, + weight_attr=zeros_init)) + + def norm(self, feat): + # Normalize the feature maps to have zero mean and unit variances. + assert len(feat.shape) == 4 + N, C, H, W = feat.shape + feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) + mean = feat.mean(axis=-1, keepdim=True) + std = feat.std(axis=-1, keepdim=True) + feat = (feat - mean) / (std + 1e-6) + return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) + + def spatial_channel_attention(self, x, t=0.5): + shape = paddle.shape(x) + N, C, H, W = shape + _f = paddle.abs(x) + spatial_map = paddle.reshape( + paddle.mean( + _f, axis=1, keepdim=True) / t, [N, -1]) + spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W + spatial_att = paddle.reshape(spatial_map, [N, H, W]) + + channel_map = paddle.mean( + paddle.mean( + _f, axis=2, keepdim=False), axis=2, keepdim=False) + channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C + return [spatial_att, channel_att] + + def spatial_pool(self, x, mode="teacher"): + batch, channel, width, height = x.shape + x_copy = x + x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) + x_copy = x_copy.unsqueeze(1) + if mode.lower() == "student": + context_mask = self.conv_mask_s(x) + else: + context_mask = self.conv_mask_t(x) + + context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) + context_mask = F.softmax(context_mask, axis=2) + context_mask = context_mask.unsqueeze(-1) + context = paddle.matmul(x_copy, context_mask) + context = paddle.reshape(context, [batch, channel, 1, 1]) + return context + + def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, + tea_spatial_att): + def _func(a, b): + return paddle.sum(paddle.abs(a - b)) / len(a) + + mask_loss = _func(stu_channel_att, tea_channel_att) + _func( + stu_spatial_att, tea_spatial_att) + return mask_loss + + def feature_loss(self, stu_feature, tea_feature, mask_fg, mask_bg, + tea_channel_att, tea_spatial_att): + mask_fg = mask_fg.unsqueeze(axis=1) + mask_bg = mask_bg.unsqueeze(axis=1) + tea_channel_att = tea_channel_att.unsqueeze(axis=-1).unsqueeze(axis=-1) + tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) + + fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) + fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) + fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_fg)) + bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_bg)) + + fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) + fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) + fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_fg)) + bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_bg)) + + fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(mask_fg) + bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(mask_bg) + return fg_loss, bg_loss + + def relation_loss(self, stu_feature, tea_feature): + context_s = self.spatial_pool(stu_feature, "student") + context_t = self.spatial_pool(tea_feature, "teacher") + out_s = stu_feature + self.stu_conv_block(context_s) + out_t = tea_feature + self.tea_conv_block(context_t) + rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) + return rela_loss + + def mask_value(self, mask, xl, xr, yl, yr, value): + mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) + return mask + + def forward(self, stu_feature, tea_feature, inputs): + assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ + f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' + assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys( + ), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs." + gt_bboxes = inputs['gt_bbox'] + ins_shape = [ + inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) + ] + if self.align is not None: + stu_feature = self.align(stu_feature) + if self.normalize: + stu_feature, tea_feature = self.norm(stu_feature), self.norm( + tea_feature) + + tea_spatial_att, tea_channel_att = self.spatial_channel_attention( + tea_feature, self.temp) + stu_spatial_att, stu_channel_att = self.spatial_channel_attention( + stu_feature, self.temp) + + mask_fg = paddle.zeros(tea_spatial_att.shape) + mask_bg = paddle.ones_like(tea_spatial_att) + one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) + zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) + wmin, wmax, hmin, hmax = [], [], [], [] + + N, _, H, W = stu_feature.shape + if gt_bboxes.shape[1] != 0: + for i in range(N): + tmp_box = paddle.ones_like(gt_bboxes[i]) + tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W + tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W + tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H + tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H + + zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") + ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") + wmin.append( + paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum( + zero)) + wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) + hmin.append( + paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum( + zero)) + hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) + + area_recip = 1.0 / ( + hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( + wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) + + for j in range(len(gt_bboxes[i])): + if gt_bboxes[i][j].sum() > 0: + mask_fg[i] = self.mask_value( + mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j], + wmax[i][j] + 1, area_recip[0][j]) + + mask_bg[i] = paddle.where(mask_fg[i] > zero_tmp, zero_tmp, + one_tmp) + + if paddle.sum(mask_bg[i]): + mask_bg[i] /= paddle.sum(mask_bg[i]) + + fg_loss, bg_loss = self.feature_loss( + stu_feature, tea_feature, mask_fg, mask_bg, tea_channel_att, + tea_spatial_att) + mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, + stu_spatial_att, tea_spatial_att) + rela_loss = self.relation_loss(stu_feature, tea_feature) + loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + else: + rela_loss = self.relation_loss(stu_feature, tea_feature) + loss = self.lambda_fgd * rela_loss + return loss + + +@register +class PKDLoss(nn.Layer): + """ + PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient. + + Args: + loss_weight (float): Weight of loss. Defaults to 1.0. + resize_stu (bool): If True, we'll down/up sample the features of the + student model to the spatial size of those of the teacher model if + their spatial sizes are different. And vice versa. Defaults to + True. + """ + + def __init__(self, + student_channels=256, + teacher_channels=256, + normalize=True, + loss_weight=1.0, + resize_stu=True): + super(PKDLoss, self).__init__() + self.normalize = normalize + self.loss_weight = loss_weight + self.resize_stu = resize_stu + kaiming_init = parameter_init("kaiming") + if student_channels != teacher_channels: + self.align = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=kaiming_init) + else: + self.align = None + + def norm(self, feat): + # Normalize the feature maps to have zero mean and unit variances. + assert len(feat.shape) == 4 + N, C, H, W = feat.shape + feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) + mean = feat.mean(axis=-1, keepdim=True) + std = feat.std(axis=-1, keepdim=True) + feat = (feat - mean) / (std + 1e-6) + return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) + + def forward(self, stu_feature, tea_feature, inputs): + if self.align is not None: + stu_feature = self.align(stu_feature) + + loss = 0. + size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:] + if size_s[0] != size_t[0]: + if self.resize_stu: + stu_feature = F.interpolate( + stu_feature, size_t, mode='bilinear') + else: + tea_feature = F.interpolate( + tea_feature, size_s, mode='bilinear') + assert stu_feature.shape == tea_feature.shape + + if self.normalize: + norm_stu_feature = self.norm(stu_feature) + norm_tea_feature = self.norm(tea_feature) + + # First conduct feature normalization and then calculate the + # MSE loss. Methematically, it is equivalent to firstly calculate + # the Pearson Correlation Coefficient (r) between two feature + # vectors, and then use 1-r as the new feature imitation loss. + loss += F.mse_loss(norm_stu_feature, norm_tea_feature) / 2 + return loss * self.loss_weight + + +@register +class MGDSSIMLoss(nn.Layer): + def __init__(self, + student_channels=256, + teacher_channels=256, + normalize=True, + ssim=True, + loss_weight=1.0, + max_alpha=1.0, + min_alpha=0.2): + super(MGDSSIMLoss, self).__init__() + self.normalize = normalize + self.loss_weight = loss_weight + self.max_alpha = max_alpha + self.min_alpha = min_alpha + + self.mse_loss = nn.MSELoss(reduction='sum') + self.ssim_loss = SSIM(11) + + kaiming_init = parameter_init("kaiming") + if student_channels != teacher_channels: + self.align_layer = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=kaiming_init, + bias_attr=False) + else: + self.align_layer = None + + self.generations = nn.Sequential( + nn.Conv2D( + teacher_channels, teacher_channels, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2D( + teacher_channels, teacher_channels, kernel_size=3, padding=1)) + + def norm(self, feat): + # Normalize the feature maps to have zero mean and unit variances. + assert len(feat.shape) == 4 + N, C, H, W = feat.shape + feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) + mean = feat.mean(axis=-1, keepdim=True) + std = feat.std(axis=-1, keepdim=True) + feat = (feat - mean) / (std + 1e-6) + return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) + + def forward(self, stu_feature, tea_feature, input): + N = stu_feature.shape[0] + masked_fea = self.align_layer(stu_feature) + stu_feature = self.generations(masked_fea) + + if self.normalize: + stu_feature = self.norm(stu_feature) + tea_feature = self.norm(tea_feature) + + if self.ssim is False: + dis_loss = self.mse_loss(stu_feature, tea_feature) / N + else: + ssim_loss = self.ssim_loss(stu_feature, tea_feature) + dis_loss = paddle.clip((1 - ssim_loss) / 2, 0, 1) + return dis_loss * self.loss_weight + + +class SSIM(nn.Layer): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = self.create_window(window_size, self.channel) + + def gaussian(self, window_size, sigma): + gauss = paddle.to_tensor([ + math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + def create_window(self, window_size, channel): + _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0) + window = _2D_window.expand([channel, 1, window_size, window_size]) + return window + + def _ssim(self, img1, img2, window, window_size, channel, + size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=window_size // 2, + groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=window_size // 2, + groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=window_size // 2, + groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + 1e-12 + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean([1, 2, 3]) + + def forward(self, img1, img2): + channel = img1.shape[1] + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = self.create_window(self.window_size, channel) + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, + self.size_average) From df52442d89a3116ff1c644956595b78ace0a3fe0 Mon Sep 17 00:00:00 2001 From: nemonameless Date: Thu, 2 Feb 2023 03:54:17 +0000 Subject: [PATCH 2/5] refine distill codes --- 0run_x_to_l.sh | 18 - 1small_m_to_s.sh | 18 - .../distill/ppyoloe_plus_distill_l_to_m.yml | 10 +- .../distill/ppyoloe_plus_distill_m_to_s.yml | 10 +- .../distill/ppyoloe_plus_distill_x_to_l.yml | 10 +- .../gfl_r101vd_fpn_coco_distill_cwd.yml | 4 +- .../retinanet_resnet101_coco_distill_cwd.yml | 5 +- ppdet/slim/__init__.py | 11 +- ppdet/slim/distill.py | 809 ---------------- ppdet/slim/distill_loss.py | 871 ++++++++++++++++++ ppdet/slim/distill_model.py | 357 +++++++ ppdet/slim/distill_ppyoloe.py | 695 -------------- 12 files changed, 1263 insertions(+), 1555 deletions(-) delete mode 100644 0run_x_to_l.sh delete mode 100644 1small_m_to_s.sh delete mode 100644 ppdet/slim/distill.py create mode 100644 ppdet/slim/distill_loss.py create mode 100644 ppdet/slim/distill_model.py delete mode 100644 ppdet/slim/distill_ppyoloe.py diff --git a/0run_x_to_l.sh b/0run_x_to_l.sh deleted file mode 100644 index de52857a76a..00000000000 --- a/0run_x_to_l.sh +++ /dev/null @@ -1,18 +0,0 @@ -export FLAGS_allocator_strategy=auto_growth -model_type=ppyoloe/distill -job_name=ppyoloe_plus_crn_l_80e_coco -job_name_tea=ppyoloe_plus_distill_x_to_l - -config=configs/${model_type}/${job_name}.yml -slim_config=configs/${model_type}/${job_name_tea}.yml -log_dir=log_dir/${job_name} -weights=output/${job_name_tea}/model_final.pdparams - -# 1. training -#CUDA_VISIBLE_DEVICES=3 python3.7 tools/train.py -c ${config} --slim_config ${slim_config} #--eval --amp -python3.7 -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --slim_config ${slim_config} --eval -# -r output/ppyoloe_plus_distill_x_to_l/14 --amp - -# 2. eval -#CUDA_VISIBLE_DEVICES=0 python3.7 tools/eval.py -c ${config} -o weights=https://paddledet.bj.bcebos.com/models/${job_name}.pdparams -#CUDA_VISIBLE_DEVICES=2 python3.7 tools/eval.py -c ${config} -o weights=${weights} diff --git a/1small_m_to_s.sh b/1small_m_to_s.sh deleted file mode 100644 index 932f8143e1e..00000000000 --- a/1small_m_to_s.sh +++ /dev/null @@ -1,18 +0,0 @@ -export FLAGS_allocator_strategy=auto_growth -model_type=ppyoloe/distill -job_name=ppyoloe_plus_crn_s_80e_coco -job_name_tea=ppyoloe_plus_distill_m_to_s - -config=configs/${model_type}/${job_name}.yml -slim_config=configs/${model_type}/${job_name_tea}.yml -log_dir=log_dir/${job_name} -weights=output/${job_name_tea}/model_final.pdparams - -# 1. training -#CUDA_VISIBLE_DEVICES=3 python3.7 tools/train.py -c ${config} --slim_config ${slim_config} #--eval --amp -python3.7 -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --slim_config ${slim_config} --eval -#-r output/ppyoloe_plus_distill_m_to_s/14 # --amp - -# 2. eval -#CUDA_VISIBLE_DEVICES=0 python3.7 tools/eval.py -c ${config} -o weights=https://paddledet.bj.bcebos.com/models/${job_name}.pdparams -#CUDA_VISIBLE_DEVICES=2 python3.7 tools/eval.py -c ${config} -o weights=${weights} diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml b/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml index 98be3cb2a51..74299c7dd5d 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml @@ -22,5 +22,11 @@ slim_method: PPYOLOEDistill distill_loss: DistillPPYOLOELoss DistillPPYOLOELoss: # L -> M - teacher_width_mult: 1.0 - student_width_mult: 0.75 + loss_weight: {'logits': 4.0, 'feat': 1.0} + logits_distill: True + logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} + feat_distill: True + feat_distiller: 'cwd' + teacher_width_mult: 1.0 # L + student_width_mult: 0.75 # M + neck_out_channels: [768, 384, 192] # The actual channel will multiply width_mult diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml b/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml index 3e54fc90ab0..cb80edfe050 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml @@ -22,5 +22,11 @@ slim_method: PPYOLOEDistill distill_loss: DistillPPYOLOELoss DistillPPYOLOELoss: # M -> S - teacher_width_mult: 0.75 - student_width_mult: 0.50 + loss_weight: {'logits': 4.0, 'feat': 1.0} + logits_distill: True + logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} + feat_distill: True + feat_distiller: 'cwd' + teacher_width_mult: 0.75 # M + student_width_mult: 0.5 # S + neck_out_channels: [768, 384, 192] # The actual channel will multiply width_mult diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml b/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml index 6ac9809a597..2e0c44ddf78 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml @@ -22,5 +22,11 @@ slim_method: PPYOLOEDistill distill_loss: DistillPPYOLOELoss DistillPPYOLOELoss: # X -> L - teacher_width_mult: 1.25 - student_width_mult: 1.0 + loss_weight: {'logits': 4.0, 'feat': 1.0} + logits_distill: True + logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} + feat_distill: True + feat_distiller: 'cwd' + teacher_width_mult: 1.25 # X + student_width_mult: 1.0 # L + neck_out_channels: [768, 384, 192] # The actual channel will multiply width_mult diff --git a/configs/slim/distill/gfl_r101vd_fpn_coco_distill_cwd.yml b/configs/slim/distill/gfl_r101vd_fpn_coco_distill_cwd.yml index e27646cfdc1..3af5ac17f2c 100644 --- a/configs/slim/distill/gfl_r101vd_fpn_coco_distill_cwd.yml +++ b/configs/slim/distill/gfl_r101vd_fpn_coco_distill_cwd.yml @@ -6,10 +6,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_ slim: Distill slim_method: CWD -distill_loss: ChannelWiseDivergence +distill_loss: CWDFeatureLoss distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0'] -ChannelWiseDivergence: +CWDFeatureLoss: student_channels: 80 teacher_channels: 80 tau: 1.0 diff --git a/configs/slim/distill/retinanet_resnet101_coco_distill_cwd.yml b/configs/slim/distill/retinanet_resnet101_coco_distill_cwd.yml index 4073b3cb674..7087b85d040 100644 --- a/configs/slim/distill/retinanet_resnet101_coco_distill_cwd.yml +++ b/configs/slim/distill/retinanet_resnet101_coco_distill_cwd.yml @@ -7,12 +7,11 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_c slim: Distill slim_method: CWD -distill_loss: ChannelWiseDivergence +distill_loss: CWDFeatureLoss distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0'] -ChannelWiseDivergence: +CWDFeatureLoss: student_channels: 80 teacher_channels: 80 - name: cwdloss tau: 1.0 weight: 5.0 diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index 11d4d29840d..7d75082b2b3 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -12,17 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import distill_loss +from . import distill_model +from . import ofa from . import prune from . import quant -from . import distill from . import unstructured_prune +from .distill_loss import * +from .distill_model import * +from .ofa import * from .prune import * from .quant import * -from .distill import * -from .distill_ppyoloe import * from .unstructured_prune import * -from .ofa import * import yaml from ppdet.core.workspace import load_config @@ -50,6 +52,7 @@ def build_slim_model(cfg, slim_cfg, mode='train'): 'slim_method'] == "PPYOLOEDistill": model = PPYOLOEDistillModel(cfg, slim_cfg) else: + # common distillation model model = DistillModel(cfg, slim_cfg) cfg['model'] = model cfg['slim_type'] = cfg.slim diff --git a/ppdet/slim/distill.py b/ppdet/slim/distill.py deleted file mode 100644 index e3e0e764313..00000000000 --- a/ppdet/slim/distill.py +++ /dev/null @@ -1,809 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F -from paddle import ParamAttr - -from ppdet.core.workspace import register, create, load_config -from ppdet.modeling import ops -from ppdet.utils.checkpoint import load_pretrain_weight -from ppdet.utils.logger import setup_logger - -logger = setup_logger(__name__) - - -class DistillModel(nn.Layer): - def __init__(self, cfg, slim_cfg): - super(DistillModel, self).__init__() - - self.student_model = create(cfg.architecture) - logger.debug('Load student model pretrain_weights:{}'.format( - cfg.pretrain_weights)) - load_pretrain_weight(self.student_model, cfg.pretrain_weights) - - slim_cfg = load_config(slim_cfg) - - self.teacher_model = create(slim_cfg.architecture) - self.distill_loss = create(slim_cfg.distill_loss) - logger.debug('Load teacher model pretrain_weights:{}'.format( - slim_cfg.pretrain_weights)) - load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights) - - for param in self.teacher_model.parameters(): - param.trainable = False - - def parameters(self): - return self.student_model.parameters() - - def forward(self, inputs): - if self.training: - teacher_loss = self.teacher_model(inputs) - student_loss = self.student_model(inputs) - loss = self.distill_loss(self.teacher_model, self.student_model) - student_loss['distill_loss'] = loss - student_loss['teacher_loss'] = teacher_loss['loss'] - student_loss['loss'] += student_loss['distill_loss'] - return student_loss - else: - return self.student_model(inputs) - - -class FGDDistillModel(nn.Layer): - """ - Build FGD distill model. - Args: - cfg: The student config. - slim_cfg: The teacher and distill config. - """ - - def __init__(self, cfg, slim_cfg): - super(FGDDistillModel, self).__init__() - - self.is_inherit = True - # build student model before load slim config - self.student_model = create(cfg.architecture) - self.arch = cfg.architecture - stu_pretrain = cfg['pretrain_weights'] - slim_cfg = load_config(slim_cfg) - self.teacher_cfg = slim_cfg - self.loss_cfg = slim_cfg - tea_pretrain = cfg['pretrain_weights'] - - self.teacher_model = create(self.teacher_cfg.architecture) - self.teacher_model.eval() - - for param in self.teacher_model.parameters(): - param.trainable = False - - if 'pretrain_weights' in cfg and stu_pretrain: - if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: - load_pretrain_weight(self.student_model, - self.teacher_cfg.pretrain_weights) - logger.debug( - "Inheriting! loading teacher weights to student model!") - - load_pretrain_weight(self.student_model, stu_pretrain) - - if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: - load_pretrain_weight(self.teacher_model, - self.teacher_cfg.pretrain_weights) - - self.fgd_loss_dic = self.build_loss( - self.loss_cfg.distill_loss, - name_list=self.loss_cfg['distill_loss_name']) - - def build_loss(self, - cfg, - name_list=[ - 'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1', - 'neck_f_0' - ]): - loss_func = dict() - for idx, k in enumerate(name_list): - loss_func[k] = create(cfg) - return loss_func - - def forward(self, inputs): - if self.training: - s_body_feats = self.student_model.backbone(inputs) - s_neck_feats = self.student_model.neck(s_body_feats) - - with paddle.no_grad(): - t_body_feats = self.teacher_model.backbone(inputs) - t_neck_feats = self.teacher_model.neck(t_body_feats) - - loss_dict = {} - for idx, k in enumerate(self.fgd_loss_dic): - loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx], - t_neck_feats[idx], inputs) - if self.arch == "RetinaNet": - loss = self.student_model.head(s_neck_feats, inputs) - elif self.arch == "PicoDet": - head_outs = self.student_model.head( - s_neck_feats, self.student_model.export_post_process) - loss_gfl = self.student_model.head.get_loss(head_outs, inputs) - total_loss = paddle.add_n(list(loss_gfl.values())) - loss = {} - loss.update(loss_gfl) - loss.update({'loss': total_loss}) - else: - raise ValueError(f"Unsupported model {self.arch}") - for k in loss_dict: - loss['loss'] += loss_dict[k] - loss[k] = loss_dict[k] - return loss - else: - body_feats = self.student_model.backbone(inputs) - neck_feats = self.student_model.neck(body_feats) - head_outs = self.student_model.head(neck_feats) - if self.arch == "RetinaNet": - bbox, bbox_num = self.student_model.head.post_process( - head_outs, inputs['im_shape'], inputs['scale_factor']) - return {'bbox': bbox, 'bbox_num': bbox_num} - elif self.arch == "PicoDet": - head_outs = self.student_model.head( - neck_feats, self.student_model.export_post_process) - scale_factor = inputs['scale_factor'] - bboxes, bbox_num = self.student_model.head.post_process( - head_outs, - scale_factor, - export_nms=self.student_model.export_nms) - return {'bbox': bboxes, 'bbox_num': bbox_num} - else: - raise ValueError(f"Unsupported model {self.arch}") - - -class CWDDistillModel(nn.Layer): - """ - Build CWD distill model. - Args: - cfg: The student config. - slim_cfg: The teacher and distill config. - """ - - def __init__(self, cfg, slim_cfg): - super(CWDDistillModel, self).__init__() - - self.is_inherit = False - # build student model before load slim config - self.student_model = create(cfg.architecture) - self.arch = cfg.architecture - if self.arch not in ['GFL', 'RetinaNet']: - raise ValueError( - f"The arch can only be one of ['GFL', 'RetinaNet'], but received {self.arch}" - ) - - stu_pretrain = cfg['pretrain_weights'] - slim_cfg = load_config(slim_cfg) - self.teacher_cfg = slim_cfg - self.loss_cfg = slim_cfg - tea_pretrain = cfg['pretrain_weights'] - - self.teacher_model = create(self.teacher_cfg.architecture) - self.teacher_model.eval() - - for param in self.teacher_model.parameters(): - param.trainable = False - if 'pretrain_weights' in cfg and stu_pretrain: - if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: - load_pretrain_weight(self.student_model, - self.teacher_cfg.pretrain_weights) - logger.debug( - "Inheriting! loading teacher weights to student model!") - - load_pretrain_weight(self.student_model, stu_pretrain) - - if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: - load_pretrain_weight(self.teacher_model, - self.teacher_cfg.pretrain_weights) - - self.loss_dic = self.build_loss( - self.loss_cfg.distill_loss, - name_list=self.loss_cfg['distill_loss_name']) - - def build_loss(self, - cfg, - name_list=[ - 'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1', - 'neck_f_0' - ]): - loss_func = dict() - for idx, k in enumerate(name_list): - loss_func[k] = create(cfg) - return loss_func - - def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs): - loss = self.student_model.head(stu_fea_list, inputs) - distill_loss = {} - # cwd kd loss - for idx, k in enumerate(self.loss_dic): - distill_loss[k] = self.loss_dic[k](stu_fea_list[idx], - tea_fea_list[idx]) - - loss['loss'] += distill_loss[k] - loss[k] = distill_loss[k] - return loss - - def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs): - loss = {} - head_outs = self.student_model.head(stu_fea_list) - loss_gfl = self.student_model.head.get_loss(head_outs, inputs) - loss.update(loss_gfl) - total_loss = paddle.add_n(list(loss.values())) - loss.update({'loss': total_loss}) - # cwd kd loss - feat_loss = {} - loss_dict = {} - - s_cls_feat, t_cls_feat = [], [] - for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list): - conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f) - cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat) - t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f) - t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat) - s_cls_feat.append(cls_score) - t_cls_feat.append(t_cls_score) - - for idx, k in enumerate(self.loss_dic): - loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx]) - feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx], - tea_fea_list[idx]) - - for k in feat_loss: - loss['loss'] += feat_loss[k] - loss[k] = feat_loss[k] - - for k in loss_dict: - loss['loss'] += loss_dict[k] - loss[k] = loss_dict[k] - return loss - - def forward(self, inputs): - if self.training: - s_body_feats = self.student_model.backbone(inputs) - s_neck_feats = self.student_model.neck(s_body_feats) - - with paddle.no_grad(): - t_body_feats = self.teacher_model.backbone(inputs) - t_neck_feats = self.teacher_model.neck(t_body_feats) - - if self.arch == "RetinaNet": - loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats, - inputs) - elif self.arch == "GFL": - loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs) - else: - raise ValueError(f"unsupported arch {self.arch}") - return loss - else: - body_feats = self.student_model.backbone(inputs) - neck_feats = self.student_model.neck(body_feats) - head_outs = self.student_model.head(neck_feats) - if self.arch == "RetinaNet": - bbox, bbox_num = self.student_model.head.post_process( - head_outs, inputs['im_shape'], inputs['scale_factor']) - return {'bbox': bbox, 'bbox_num': bbox_num} - elif self.arch == "GFL": - bbox_pred, bbox_num = head_outs - output = {'bbox': bbox_pred, 'bbox_num': bbox_num} - return output - else: - raise ValueError(f"unsupported arch {self.arch}") - - -@register -class ChannelWiseDivergence(nn.Layer): - def __init__(self, student_channels, teacher_channels, tau=1.0, weight=1.0): - super(ChannelWiseDivergence, self).__init__() - self.tau = tau - self.loss_weight = weight - - if student_channels != teacher_channels: - self.align = nn.Conv2D( - student_channels, - teacher_channels, - kernel_size=1, - stride=1, - padding=0) - else: - self.align = None - - def distill_softmax(self, x, t): - _, _, w, h = paddle.shape(x) - x = paddle.reshape(x, [-1, w * h]) - x /= t - return F.softmax(x, axis=1) - - def forward(self, preds_s, preds_t): - assert preds_s.shape[-2:] == preds_t.shape[ - -2:], 'the output dim of teacher and student differ' - N, C, W, H = preds_s.shape - eps = 1e-5 - if self.align is not None: - preds_s = self.align(preds_s) - - softmax_pred_s = self.distill_softmax(preds_s, self.tau) - softmax_pred_t = self.distill_softmax(preds_t, self.tau) - - loss = paddle.sum(-softmax_pred_t * paddle.log(eps + softmax_pred_s) + - softmax_pred_t * paddle.log(eps + softmax_pred_t)) - return self.loss_weight * loss / (C * N) - - -@register -class DistillYOLOv3Loss(nn.Layer): - def __init__(self, weight=1000): - super(DistillYOLOv3Loss, self).__init__() - self.weight = weight - - def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj): - loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx)) - loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty)) - loss_w = paddle.abs(sw - tw) - loss_h = paddle.abs(sh - th) - loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h]) - weighted_loss = paddle.mean(loss * F.sigmoid(tobj)) - return weighted_loss - - def obj_weighted_cls(self, scls, tcls, tobj): - loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls)) - weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj))) - return weighted_loss - - def obj_loss(self, sobj, tobj): - obj_mask = paddle.cast(tobj > 0., dtype="float32") - obj_mask.stop_gradient = True - loss = paddle.mean( - ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask)) - return loss - - def forward(self, teacher_model, student_model): - teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs - student_distill_pairs = student_model.yolo_head.loss.distill_pairs - distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], [] - for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs): - distill_reg_loss.append( - self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[ - 3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4])) - distill_cls_loss.append( - self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4])) - distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4])) - distill_reg_loss = paddle.add_n(distill_reg_loss) - distill_cls_loss = paddle.add_n(distill_cls_loss) - distill_obj_loss = paddle.add_n(distill_obj_loss) - loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss - ) * self.weight - return loss - - -def parameter_init(mode="kaiming", value=0.): - if mode == "kaiming": - weight_attr = paddle.nn.initializer.KaimingUniform() - elif mode == "constant": - weight_attr = paddle.nn.initializer.Constant(value=value) - else: - weight_attr = paddle.nn.initializer.KaimingUniform() - - weight_init = ParamAttr(initializer=weight_attr) - return weight_init - - -@register -class FGDFeatureLoss(nn.Layer): - """ - The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py - Paddle version of `Focal and Global Knowledge Distillation for Detectors` - - Args: - student_channels(int): The number of channels in the student's FPN feature map. Default to 256. - teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256. - temp (float, optional): The temperature coefficient. Defaults to 0.5. - alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 - beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 - gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 - lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 - """ - - def __init__(self, - student_channels=256, - teacher_channels=256, - temp=0.5, - alpha_fgd=0.001, - beta_fgd=0.0005, - gamma_fgd=0.001, - lambda_fgd=0.000005): - super(FGDFeatureLoss, self).__init__() - self.temp = temp - self.alpha_fgd = alpha_fgd - self.beta_fgd = beta_fgd - self.gamma_fgd = gamma_fgd - self.lambda_fgd = lambda_fgd - - kaiming_init = parameter_init("kaiming") - zeros_init = parameter_init("constant", 0.0) - - if student_channels != teacher_channels: - self.align = nn.Conv2D( - student_channels, - teacher_channels, - kernel_size=1, - stride=1, - padding=0, - weight_attr=kaiming_init) - student_channels = teacher_channels - else: - self.align = None - - self.conv_mask_s = nn.Conv2D( - student_channels, 1, kernel_size=1, weight_attr=kaiming_init) - self.conv_mask_t = nn.Conv2D( - teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) - - self.stu_conv_block = nn.Sequential( - nn.Conv2D( - student_channels, - student_channels // 2, - kernel_size=1, - weight_attr=zeros_init), - nn.LayerNorm([student_channels // 2, 1, 1]), - nn.ReLU(), - nn.Conv2D( - student_channels // 2, - student_channels, - kernel_size=1, - weight_attr=zeros_init)) - self.tea_conv_block = nn.Sequential( - nn.Conv2D( - teacher_channels, - teacher_channels // 2, - kernel_size=1, - weight_attr=zeros_init), - nn.LayerNorm([teacher_channels // 2, 1, 1]), - nn.ReLU(), - nn.Conv2D( - teacher_channels // 2, - teacher_channels, - kernel_size=1, - weight_attr=zeros_init)) - - def spatial_channel_attention(self, x, t=0.5): - shape = paddle.shape(x) - N, C, H, W = shape - - _f = paddle.abs(x) - spatial_map = paddle.reshape( - paddle.mean( - _f, axis=1, keepdim=True) / t, [N, -1]) - spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W - spatial_att = paddle.reshape(spatial_map, [N, H, W]) - - channel_map = paddle.mean( - paddle.mean( - _f, axis=2, keepdim=False), axis=2, keepdim=False) - channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C - return [spatial_att, channel_att] - - def spatial_pool(self, x, mode="teacher"): - batch, channel, width, height = x.shape - x_copy = x - x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) - x_copy = x_copy.unsqueeze(1) - if mode.lower() == "student": - context_mask = self.conv_mask_s(x) - else: - context_mask = self.conv_mask_t(x) - - context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) - context_mask = F.softmax(context_mask, axis=2) - context_mask = context_mask.unsqueeze(-1) - context = paddle.matmul(x_copy, context_mask) - context = paddle.reshape(context, [batch, channel, 1, 1]) - - return context - - def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, - tea_spatial_att): - def _func(a, b): - return paddle.sum(paddle.abs(a - b)) / len(a) - - mask_loss = _func(stu_channel_att, tea_channel_att) + _func( - stu_spatial_att, tea_spatial_att) - - return mask_loss - - def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg, - tea_channel_att, tea_spatial_att): - - Mask_fg = Mask_fg.unsqueeze(axis=1) - Mask_bg = Mask_bg.unsqueeze(axis=1) - - tea_channel_att = tea_channel_att.unsqueeze(axis=-1) - tea_channel_att = tea_channel_att.unsqueeze(axis=-1) - - tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) - - fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) - fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) - fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg)) - bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg)) - - fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) - fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) - fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg)) - bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg)) - - fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg) - bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg) - - return fg_loss, bg_loss - - def relation_loss(self, stu_feature, tea_feature): - context_s = self.spatial_pool(stu_feature, "student") - context_t = self.spatial_pool(tea_feature, "teacher") - - out_s = stu_feature + self.stu_conv_block(context_s) - out_t = tea_feature + self.tea_conv_block(context_t) - - rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) - - return rela_loss - - def mask_value(self, mask, xl, xr, yl, yr, value): - mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) - return mask - - def forward(self, stu_feature, tea_feature, inputs): - """Forward function. - Args: - stu_feature(Tensor): Bs*C*H*W, student's feature map - tea_feature(Tensor): Bs*C*H*W, teacher's feature map - inputs: The inputs with gt bbox and input shape info. - """ - assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ - f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' - assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys( - ), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs." - gt_bboxes = inputs['gt_bbox'] - ins_shape = [ - inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) - ] - - index_gt = [] - for i in range(len(gt_bboxes)): - if gt_bboxes[i].size > 2: - index_gt.append(i) - # only distill feature with labeled GTbox - if len(index_gt) != len(gt_bboxes): - index_gt_t = paddle.to_tensor(index_gt) - preds_S = paddle.index_select(preds_S, index_gt_t) - preds_T = paddle.index_select(preds_T, index_gt_t) - - ins_shape = [ins_shape[c] for c in index_gt] - gt_bboxes = [gt_bboxes[c] for c in index_gt] - assert len(gt_bboxes) == preds_T.shape[ - 0], f"The number of selected GT box [{len(gt_bboxes)}] should be same with first dim of input tensor [{preds_T.shape[0]}]." - - if self.align is not None: - stu_feature = self.align(stu_feature) - - N, C, H, W = stu_feature.shape - - tea_spatial_att, tea_channel_att = self.spatial_channel_attention( - tea_feature, self.temp) - stu_spatial_att, stu_channel_att = self.spatial_channel_attention( - stu_feature, self.temp) - - Mask_fg = paddle.zeros(tea_spatial_att.shape) - Mask_bg = paddle.ones_like(tea_spatial_att) - one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) - zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) - Mask_fg.stop_gradient = True - Mask_bg.stop_gradient = True - one_tmp.stop_gradient = True - zero_tmp.stop_gradient = True - - wmin, wmax, hmin, hmax, area = [], [], [], [], [] - - for i in range(N): - tmp_box = paddle.ones_like(gt_bboxes[i]) - tmp_box.stop_gradient = True - tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W - tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W - tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H - tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H - - zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") - ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") - zero.stop_gradient = True - ones.stop_gradient = True - - wmin.append( - paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) - wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) - hmin.append( - paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) - hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) - - area_recip = 1.0 / ( - hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( - wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) - - for j in range(len(gt_bboxes[i])): - Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j], - hmax[i][j] + 1, wmin[i][j], - wmax[i][j] + 1, area_recip[0][j]) - - Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp) - - if paddle.sum(Mask_bg[i]): - Mask_bg[i] /= paddle.sum(Mask_bg[i]) - - fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg, - Mask_bg, tea_channel_att, - tea_spatial_att) - mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, - stu_spatial_att, tea_spatial_att) - rela_loss = self.relation_loss(stu_feature, tea_feature) - - loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ - + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss - - return loss - - -class LDDistillModel(nn.Layer): - def __init__(self, cfg, slim_cfg): - super(LDDistillModel, self).__init__() - self.student_model = create(cfg.architecture) - logger.debug('Load student model pretrain_weights:{}'.format( - cfg.pretrain_weights)) - load_pretrain_weight(self.student_model, cfg.pretrain_weights) - - slim_cfg = load_config(slim_cfg) #rewrite student cfg - self.teacher_model = create(slim_cfg.architecture) - logger.debug('Load teacher model pretrain_weights:{}'.format( - slim_cfg.pretrain_weights)) - load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights) - - for param in self.teacher_model.parameters(): - param.trainable = False - - def parameters(self): - return self.student_model.parameters() - - def forward(self, inputs): - if self.training: - - with paddle.no_grad(): - t_body_feats = self.teacher_model.backbone(inputs) - t_neck_feats = self.teacher_model.neck(t_body_feats) - t_head_outs = self.teacher_model.head(t_neck_feats) - - #student_loss = self.student_model(inputs) - s_body_feats = self.student_model.backbone(inputs) - s_neck_feats = self.student_model.neck(s_body_feats) - s_head_outs = self.student_model.head(s_neck_feats) - - soft_label_list = t_head_outs[0] - soft_targets_list = t_head_outs[1] - student_loss = self.student_model.head.get_loss( - s_head_outs, inputs, soft_label_list, soft_targets_list) - total_loss = paddle.add_n(list(student_loss.values())) - student_loss['loss'] = total_loss - return student_loss - else: - return self.student_model(inputs) - - -@register -class KnowledgeDistillationKLDivLoss(nn.Layer): - """Loss function for knowledge distilling using KL divergence. - - Args: - reduction (str): Options are `'none'`, `'mean'` and `'sum'`. - loss_weight (float): Loss weight of current loss. - T (int): Temperature for distillation. - """ - - def __init__(self, reduction='mean', loss_weight=1.0, T=10): - super(KnowledgeDistillationKLDivLoss, self).__init__() - assert reduction in ('none', 'mean', 'sum') - assert T >= 1 - self.reduction = reduction - self.loss_weight = loss_weight - self.T = T - - def knowledge_distillation_kl_div_loss(self, - pred, - soft_label, - T, - detach_target=True): - r"""Loss function for knowledge distilling using KL divergence. - - Args: - pred (Tensor): Predicted logits with shape (N, n + 1). - soft_label (Tensor): Target logits with shape (N, N + 1). - T (int): Temperature for distillation. - detach_target (bool): Remove soft_label from automatic differentiation - - Returns: - torch.Tensor: Loss tensor with shape (N,). - """ - - assert pred.shape == soft_label.shape - target = F.softmax(soft_label / T, axis=1) - if detach_target: - target = target.detach() - - kd_loss = F.kl_div( - F.log_softmax( - pred / T, axis=1), target, reduction='none').mean(1) * (T * T) - - return kd_loss - - def forward(self, - pred, - soft_label, - weight=None, - avg_factor=None, - reduction_override=None): - """Forward function. - - Args: - pred (Tensor): Predicted logits with shape (N, n + 1). - soft_label (Tensor): Target logits with shape (N, N + 1). - weight (Tensor, optional): The weight of loss for each - prediction. Defaults to None. - avg_factor (int, optional): Average factor that is used to average - the loss. Defaults to None. - reduction_override (str, optional): The reduction method used to - override the original reduction method of the loss. - Defaults to None. - """ - assert reduction_override in (None, 'none', 'mean', 'sum') - - reduction = (reduction_override - if reduction_override else self.reduction) - - loss_kd_out = self.knowledge_distillation_kl_div_loss( - pred, soft_label, T=self.T) - - if weight is not None: - loss_kd_out = weight * loss_kd_out - - if avg_factor is None: - if reduction == 'none': - loss = loss_kd_out - elif reduction == 'mean': - loss = loss_kd_out.mean() - elif reduction == 'sum': - loss = loss_kd_out.sum() - else: - # if reduction is mean, then average the loss by avg_factor - if reduction == 'mean': - loss = loss_kd_out.sum() / avg_factor - # if reduction is 'none', then do nothing, otherwise raise an error - elif reduction != 'none': - raise ValueError( - 'avg_factor can not be used with reduction="sum"') - - loss_kd = self.loss_weight * loss - - return loss_kd diff --git a/ppdet/slim/distill_loss.py b/ppdet/slim/distill_loss.py new file mode 100644 index 00000000000..f91add8d72f --- /dev/null +++ b/ppdet/slim/distill_loss.py @@ -0,0 +1,871 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr + +from ppdet.core.workspace import register, create +from ppdet.modeling import ops +from ppdet.modeling.losses.iou_loss import GIoULoss +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = [ + 'DistillYOLOv3Loss', + 'KnowledgeDistillationKLDivLoss', + 'DistillPPYOLOELoss', + 'FGDFeatureLoss', + 'CWDFeatureLoss', + 'PKDFeatureLoss', + 'MGDFeatureLoss', +] + + +def parameter_init(mode="kaiming", value=0.): + if mode == "kaiming": + weight_attr = paddle.nn.initializer.KaimingUniform() + elif mode == "constant": + weight_attr = paddle.nn.initializer.Constant(value=value) + else: + weight_attr = paddle.nn.initializer.KaimingUniform() + + weight_init = ParamAttr(initializer=weight_attr) + return weight_init + + +def feature_norm(feat): + # Normalize the feature maps to have zero mean and unit variances. + assert len(feat.shape) == 4 + N, C, H, W = feat.shape + feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) + mean = feat.mean(axis=-1, keepdim=True) + std = feat.std(axis=-1, keepdim=True) + feat = (feat - mean) / (std + 1e-6) + return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) + + +@register +class DistillYOLOv3Loss(nn.Layer): + def __init__(self, weight=1000): + super(DistillYOLOv3Loss, self).__init__() + self.loss_weight = weight + + def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj): + loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx)) + loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty)) + loss_w = paddle.abs(sw - tw) + loss_h = paddle.abs(sh - th) + loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h]) + weighted_loss = paddle.mean(loss * F.sigmoid(tobj)) + return weighted_loss + + def obj_weighted_cls(self, scls, tcls, tobj): + loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls)) + weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj))) + return weighted_loss + + def obj_loss(self, sobj, tobj): + obj_mask = paddle.cast(tobj > 0., dtype="float32") + obj_mask.stop_gradient = True + loss = paddle.mean( + ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask)) + return loss + + def forward(self, teacher_model, student_model): + teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs + student_distill_pairs = student_model.yolo_head.loss.distill_pairs + distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], [] + for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs): + distill_reg_loss.append( + self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[ + 3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4])) + distill_cls_loss.append( + self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4])) + distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4])) + distill_reg_loss = paddle.add_n(distill_reg_loss) + distill_cls_loss = paddle.add_n(distill_cls_loss) + distill_obj_loss = paddle.add_n(distill_obj_loss) + loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss + ) * self.loss_weight + return loss + + +@register +class KnowledgeDistillationKLDivLoss(nn.Layer): + """Loss function for knowledge distilling using KL divergence. + + Args: + reduction (str): Options are `'none'`, `'mean'` and `'sum'`. + loss_weight (float): Loss weight of current loss. + T (int): Temperature for distillation. + """ + + def __init__(self, reduction='mean', loss_weight=1.0, T=10): + super(KnowledgeDistillationKLDivLoss, self).__init__() + assert reduction in ('none', 'mean', 'sum') + assert T >= 1 + self.reduction = reduction + self.loss_weight = loss_weight + self.T = T + + def knowledge_distillation_kl_div_loss(self, + pred, + soft_label, + T, + detach_target=True): + r"""Loss function for knowledge distilling using KL divergence. + + Args: + pred (Tensor): Predicted logits with shape (N, n + 1). + soft_label (Tensor): Target logits with shape (N, N + 1). + T (int): Temperature for distillation. + detach_target (bool): Remove soft_label from automatic differentiation + + Returns: + torch.Tensor: Loss tensor with shape (N,). + """ + + assert pred.shape == soft_label.shape + target = F.softmax(soft_label / T, axis=1) + if detach_target: + target = target.detach() + + kd_loss = F.kl_div( + F.log_softmax( + pred / T, axis=1), target, reduction='none').mean(1) * (T * T) + + return kd_loss + + def forward(self, + pred, + soft_label, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (Tensor): Predicted logits with shape (N, n + 1). + soft_label (Tensor): Target logits with shape (N, N + 1). + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + + reduction = (reduction_override + if reduction_override else self.reduction) + + loss_kd_out = self.knowledge_distillation_kl_div_loss( + pred, soft_label, T=self.T) + + if weight is not None: + loss_kd_out = weight * loss_kd_out + + if avg_factor is None: + if reduction == 'none': + loss = loss_kd_out + elif reduction == 'mean': + loss = loss_kd_out.mean() + elif reduction == 'sum': + loss = loss_kd_out.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss_kd_out.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError( + 'avg_factor can not be used with reduction="sum"') + + loss_kd = self.loss_weight * loss + return loss_kd + + +@register +class DistillPPYOLOELoss(nn.Layer): + def __init__( + self, + loss_weight={'logits': 4.0, + 'feat': 1.0}, + logits_distill=True, + logits_loss_weight={'class': 1.0, + 'iou': 2.5, + 'dfl': 0.5}, + feat_distill=True, + feat_distiller='cwd', + teacher_width_mult=1.0, # L + student_width_mult=0.75, # M + neck_out_channels=[768, 384, 192]): + super(DistillPPYOLOELoss, self).__init__() + self.loss_weight_logits = loss_weight['logits'] + self.loss_weight_feat = loss_weight['feat'] + self.logits_distill = logits_distill + self.feat_distill = feat_distill + + if logits_distill and self.loss_weight_logits > 0: + self.bbox_loss_weight = logits_loss_weight['iou'] + self.dfl_loss_weight = logits_loss_weight['dfl'] + self.qfl_loss_weight = logits_loss_weight['class'] + self.loss_bbox = GIoULoss() + + if feat_distill and self.loss_weight_feat > 0: + assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd'] + self.distill_feat_loss_modules = [] + self.t_channel_list = [ + int(c * teacher_width_mult) for c in neck_out_channels + ] + self.s_channel_list = [ + int(c * student_width_mult) for c in neck_out_channels + ] + for i in range(len(neck_out_channels)): + if feat_distiller == 'cwd': + feat_loss_module = CWDFeatureLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i], + normalize=True) + elif feat_distiller == 'fgd': + feat_loss_module = FGDFeatureLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i], + normalize=True, + alpha_fgd=0.00001, + beta_fgd=0.000005, + gamma_fgd=0.00001, + lambda_fgd=0.00000005) + elif feat_distiller == 'pkd': + feat_loss_module = PKDFeatureLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i], + normalize=True, + resize_stu=False) + elif feat_distiller == 'mgd': + feat_loss_module = MGDFeatureLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i], + normalize=True, + loss_func='ssim') + else: + raise ValueError + self.distill_feat_loss_modules.append(feat_loss_module) + + def bbox_loss(self, s_bbox, t_bbox, weight_targets=None): + # [x,y,w,h] + if weight_targets is not None: + loss_bbox = paddle.sum( + self.loss_bbox(s_bbox, t_bbox) * weight_targets) + avg_factor = weight_targets.sum() + loss_bbox = loss_bbox / avg_factor + else: + loss_bbox = paddle.mean(self.loss_bbox(s_bbox, t_bbox)) + return loss_bbox + + def quality_focal_loss(self, pred_logits, soft_target_logits, beta=2.0, \ + use_sigmoid=True, label_weights=None, num_total_pos=None, pos_mask=None): + if use_sigmoid: + func = F.binary_cross_entropy_with_logits + soft_target = F.sigmoid(soft_target_logits) + pred_sigmoid = F.sigmoid(pred_logits) + preds = pred_logits + else: + func = F.binary_cross_entropy + soft_target = soft_target_logits + pred_sigmoid = pred_logits + preds = pred_sigmoid + + scale_factor = pred_sigmoid - soft_target + loss = func( + preds, soft_target, reduction='none') * scale_factor.abs().pow(beta) + loss = loss + if pos_mask is not None: + loss *= pos_mask + + loss = loss.sum(1) + if label_weights is not None: + loss = loss * label_weights + if num_total_pos is not None: + loss = loss.sum() / num_total_pos + else: + loss = loss.mean() + return loss + + def distribution_focal_loss(self, pred_corners, target_corners, + weight_targets): + target_corners_label = paddle.nn.functional.softmax( + target_corners, axis=-1) + loss_dfl = paddle.nn.functional.cross_entropy( + pred_corners, + target_corners_label, + soft_label=True, + reduction='none') + loss_dfl = loss_dfl.sum(1) + if weight_targets is not None: + loss_dfl = loss_dfl * (weight_targets.expand([-1, 4]).reshape([-1])) + loss_dfl = loss_dfl.sum(-1) / weight_targets.sum() + else: + loss_dfl = loss_dfl.mean(-1) + loss_dfl = loss_dfl / 4.0 # 4 direction + return loss_dfl + + def forward(self, teacher_model, student_model): + if self.logits_distill and self.loss_weight_logits > 0: + teacher_distill_pairs = teacher_model.yolo_head.distill_pairs + student_distill_pairs = student_model.yolo_head.distill_pairs + distill_bbox_loss, distill_dfl_loss, distill_cls_loss = [], [], [] + distill_bbox_loss.append( + self.bbox_loss(student_distill_pairs['pred_bboxes_pos'], + teacher_distill_pairs['pred_bboxes_pos'].detach(), + weight_targets=student_distill_pairs['bbox_weight'] + ) if 'pred_bboxes_pos' in student_distill_pairs and \ + 'pred_bboxes_pos' in teacher_distill_pairs and \ + 'bbox_weight' in student_distill_pairs + else student_distill_pairs['null_loss'] + ) + distill_dfl_loss.append(self.distribution_focal_loss( + student_distill_pairs['pred_dist_pos'].reshape((-1, student_distill_pairs['pred_dist_pos'].shape[-1])), + teacher_distill_pairs['pred_dist_pos'].detach().reshape((-1, teacher_distill_pairs['pred_dist_pos'].shape[-1])), \ + weight_targets=student_distill_pairs['bbox_weight'] + ) if 'pred_dist_pos' in student_distill_pairs and \ + 'pred_dist_pos' in teacher_distill_pairs and \ + 'bbox_weight' in student_distill_pairs + else student_distill_pairs['null_loss'] + ) + distill_cls_loss.append( + self.quality_focal_loss( + student_distill_pairs['pred_cls_scores'].reshape( + (-1, student_distill_pairs['pred_cls_scores'].shape[-1] + )), + teacher_distill_pairs['pred_cls_scores'].detach().reshape( + (-1, teacher_distill_pairs['pred_cls_scores'].shape[-1] + )), + num_total_pos=student_distill_pairs['pos_num'], + use_sigmoid=False)) + distill_bbox_loss = paddle.add_n(distill_bbox_loss) + distill_cls_loss = paddle.add_n(distill_cls_loss) + distill_dfl_loss = paddle.add_n(distill_dfl_loss) + + logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight + else: + logits_loss = paddle.to_tensor([0]) + + if self.feat_distill and self.loss_weight_feat > 0: + feat_loss_list = [] + inputs = student_model.inputs + teacher_fpn_feats = teacher_distill_pairs['emb_feats'] + student_fpn_feats = student_distill_pairs['emb_feats'] + assert 'gt_bbox' in inputs + for i, loss_module in enumerate(self.distill_feat_loss_modules): + feat_loss_list.append( + loss_module(student_fpn_feats[i], teacher_fpn_feats[i], + inputs)) + feat_loss = paddle.add_n(feat_loss_list) + else: + feat_loss = paddle.to_tensor([0]) + + student_model.yolo_head.distill_pairs.clear() + teacher_model.yolo_head.distill_pairs.clear() + return logits_loss * self.loss_weight_logits, feat_loss * self.loss_weight_feat + + +@register +class CWDFeatureLoss(nn.Layer): + def __init__(self, + student_channels, + teacher_channels, + normalize=False, + tau=1.0, + weight=1.0): + super(CWDFeatureLoss, self).__init__() + self.normalize = normalize + self.tau = tau + self.loss_weight = weight + + if student_channels != teacher_channels: + self.align = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0) + else: + self.align = None + + def distill_softmax(self, x, tau): + _, _, w, h = paddle.shape(x) + x = paddle.reshape(x, [-1, w * h]) + x /= tau + return F.softmax(x, axis=1) + + def forward(self, preds_s, preds_t, inputs): + assert preds_s.shape[-2:] == preds_t.shape[-2:] + N, C, H, W = preds_s.shape + eps = 1e-5 + if self.align is not None: + preds_s = self.align(preds_s) + + if self.normalize: + preds_s = feature_norm(preds_s) + preds_t = feature_norm(preds_t) + + softmax_pred_s = self.distill_softmax(preds_s, self.tau) + softmax_pred_t = self.distill_softmax(preds_t, self.tau) + + loss = paddle.sum(-softmax_pred_t * paddle.log(eps + softmax_pred_s) + + softmax_pred_t * paddle.log(eps + softmax_pred_t)) + return self.loss_weight * loss / (C * N) + + +@register +class FGDFeatureLoss(nn.Layer): + """ + Focal and Global Knowledge Distillation for Detectors + The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py + + Args: + student_channels (int): The number of channels in the student's FPN feature map. Default to 256. + teacher_channels (int): The number of channels in the teacher's FPN feature map. Default to 256. + normalize (bool): Whether to normalize the feature maps. + temp (float, optional): The temperature coefficient. Defaults to 0.5. + alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 + beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 + gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 + lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 + """ + + def __init__(self, + student_channels, + teacher_channels, + normalize=False, + loss_weight=1.0, + temp=0.5, + alpha_fgd=0.001, + beta_fgd=0.0005, + gamma_fgd=0.001, + lambda_fgd=0.000005): + super(FGDFeatureLoss, self).__init__() + self.normalize = normalize + self.loss_weight = loss_weight + self.temp = temp + self.alpha_fgd = alpha_fgd + self.beta_fgd = beta_fgd + self.gamma_fgd = gamma_fgd + self.lambda_fgd = lambda_fgd + kaiming_init = parameter_init("kaiming") + zeros_init = parameter_init("constant", 0.0) + + if student_channels != teacher_channels: + self.align = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=kaiming_init) + student_channels = teacher_channels + else: + self.align = None + + self.conv_mask_s = nn.Conv2D( + student_channels, 1, kernel_size=1, weight_attr=kaiming_init) + self.conv_mask_t = nn.Conv2D( + teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) + + self.stu_conv_block = nn.Sequential( + nn.Conv2D( + student_channels, + student_channels // 2, + kernel_size=1, + weight_attr=zeros_init), + nn.LayerNorm([student_channels // 2, 1, 1]), + nn.ReLU(), + nn.Conv2D( + student_channels // 2, + student_channels, + kernel_size=1, + weight_attr=zeros_init)) + self.tea_conv_block = nn.Sequential( + nn.Conv2D( + teacher_channels, + teacher_channels // 2, + kernel_size=1, + weight_attr=zeros_init), + nn.LayerNorm([teacher_channels // 2, 1, 1]), + nn.ReLU(), + nn.Conv2D( + teacher_channels // 2, + teacher_channels, + kernel_size=1, + weight_attr=zeros_init)) + + def spatial_channel_attention(self, x, t=0.5): + shape = paddle.shape(x) + N, C, H, W = shape + _f = paddle.abs(x) + spatial_map = paddle.reshape( + paddle.mean( + _f, axis=1, keepdim=True) / t, [N, -1]) + spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W + spatial_att = paddle.reshape(spatial_map, [N, H, W]) + + channel_map = paddle.mean( + paddle.mean( + _f, axis=2, keepdim=False), axis=2, keepdim=False) + channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C + return [spatial_att, channel_att] + + def spatial_pool(self, x, mode="teacher"): + batch, channel, width, height = x.shape + x_copy = x + x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) + x_copy = x_copy.unsqueeze(1) + if mode.lower() == "student": + context_mask = self.conv_mask_s(x) + else: + context_mask = self.conv_mask_t(x) + + context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) + context_mask = F.softmax(context_mask, axis=2) + context_mask = context_mask.unsqueeze(-1) + context = paddle.matmul(x_copy, context_mask) + context = paddle.reshape(context, [batch, channel, 1, 1]) + return context + + def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, + tea_spatial_att): + def _func(a, b): + return paddle.sum(paddle.abs(a - b)) / len(a) + + mask_loss = _func(stu_channel_att, tea_channel_att) + _func( + stu_spatial_att, tea_spatial_att) + return mask_loss + + def feature_loss(self, stu_feature, tea_feature, mask_fg, mask_bg, + tea_channel_att, tea_spatial_att): + mask_fg = mask_fg.unsqueeze(axis=1) + mask_bg = mask_bg.unsqueeze(axis=1) + tea_channel_att = tea_channel_att.unsqueeze(axis=-1).unsqueeze(axis=-1) + tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) + + fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) + fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) + fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_fg)) + bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_bg)) + + fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) + fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) + fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_fg)) + bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_bg)) + + fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(mask_fg) + bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(mask_bg) + return fg_loss, bg_loss + + def relation_loss(self, stu_feature, tea_feature): + context_s = self.spatial_pool(stu_feature, "student") + context_t = self.spatial_pool(tea_feature, "teacher") + out_s = stu_feature + self.stu_conv_block(context_s) + out_t = tea_feature + self.tea_conv_block(context_t) + rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) + return rela_loss + + def mask_value(self, mask, xl, xr, yl, yr, value): + mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) + return mask + + def forward(self, stu_feature, tea_feature, inputs): + assert stu_feature.shape[-2:] == stu_feature.shape[-2:] + assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys() + gt_bboxes = inputs['gt_bbox'] + ins_shape = [ + inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) + ] + index_gt = [] + for i in range(len(gt_bboxes)): + if gt_bboxes[i].size > 2: + index_gt.append(i) + # only distill feature with labeled GTbox + if len(index_gt) != len(gt_bboxes): + index_gt_t = paddle.to_tensor(index_gt) + stu_feature = paddle.index_select(stu_feature, index_gt_t) + tea_feature = paddle.index_select(tea_feature, index_gt_t) + + ins_shape = [ins_shape[c] for c in index_gt] + gt_bboxes = [gt_bboxes[c] for c in index_gt] + assert len(gt_bboxes) == tea_feature.shape[0] + + if self.align is not None: + stu_feature = self.align(stu_feature) + + if self.normalize: + stu_feature = feature_norm(stu_feature) + tea_feature = feature_norm(tea_feature) + + tea_spatial_att, tea_channel_att = self.spatial_channel_attention( + tea_feature, self.temp) + stu_spatial_att, stu_channel_att = self.spatial_channel_attention( + stu_feature, self.temp) + + mask_fg = paddle.zeros(tea_spatial_att.shape) + mask_bg = paddle.ones_like(tea_spatial_att) + one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) + zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) + mask_fg.stop_gradient = True + mask_bg.stop_gradient = True + one_tmp.stop_gradient = True + zero_tmp.stop_gradient = True + + wmin, wmax, hmin, hmax = [], [], [], [] + + if gt_bboxes.shape[1] == 0: + loss = self.relation_loss(stu_feature, tea_feature) + return self.lambda_fgd * loss + + N, _, H, W = stu_feature.shape + for i in range(N): + tmp_box = paddle.ones_like(gt_bboxes[i]) + tmp_box.stop_gradient = True + tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W + tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W + tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H + tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H + + zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") + ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") + zero.stop_gradient = True + ones.stop_gradient = True + wmin.append( + paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) + wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) + hmin.append( + paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) + hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) + + area_recip = 1.0 / ( + hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( + wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) + + for j in range(len(gt_bboxes[i])): + if gt_bboxes[i][j].sum() > 0: + mask_fg[i] = self.mask_value( + mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j], + wmax[i][j] + 1, area_recip[0][j]) + + mask_bg[i] = paddle.where(mask_fg[i] > zero_tmp, zero_tmp, one_tmp) + + if paddle.sum(mask_bg[i]): + mask_bg[i] /= paddle.sum(mask_bg[i]) + + fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, mask_fg, + mask_bg, tea_channel_att, + tea_spatial_att) + mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, + stu_spatial_att, tea_spatial_att) + rela_loss = self.relation_loss(stu_feature, tea_feature) + loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + return loss * self.loss_weight + + +@register +class PKDFeatureLoss(nn.Layer): + """ + PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient. + + Args: + loss_weight (float): Weight of loss. Defaults to 1.0. + resize_stu (bool): If True, we'll down/up sample the features of the + student model to the spatial size of those of the teacher model if + their spatial sizes are different. And vice versa. Defaults to + True. + """ + + def __init__(self, + student_channels=256, + teacher_channels=256, + normalize=True, + loss_weight=1.0, + resize_stu=True): + super(PKDFeatureLoss, self).__init__() + self.normalize = normalize + self.loss_weight = loss_weight + self.resize_stu = resize_stu + + kaiming_init = parameter_init("kaiming") + if student_channels != teacher_channels: + self.align = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=kaiming_init) + else: + self.align = None + + def forward(self, stu_feature, tea_feature, inputs): + if self.align is not None: + stu_feature = self.align(stu_feature) + + loss = 0. + size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:] + if size_s[0] != size_t[0]: + if self.resize_stu: + stu_feature = F.interpolate( + stu_feature, size_t, mode='bilinear') + else: + tea_feature = F.interpolate( + tea_feature, size_s, mode='bilinear') + assert stu_feature.shape == tea_feature.shape + + if self.normalize: + norm_stu_feature = feature_norm(stu_feature) + norm_tea_feature = feature_norm(tea_feature) + + # First conduct feature normalization and then calculate the + # MSE loss. Methematically, it is equivalent to firstly calculate + # the Pearson Correlation Coefficient (r) between two feature + # vectors, and then use 1-r as the new feature imitation loss. + loss += F.mse_loss(norm_stu_feature, norm_tea_feature) / 2 + return loss * self.loss_weight + + +@register +class MGDFeatureLoss(nn.Layer): + def __init__(self, + student_channels=256, + teacher_channels=256, + normalize=True, + loss_weight=1.0, + loss_func='mse'): + super(MGDFeatureLoss, self).__init__() + self.normalize = normalize + self.loss_weight = loss_weight + assert loss_func in ['mse', 'ssim'] + self.loss_func = loss_func + self.mse_loss = nn.MSELoss(reduction='sum') + self.ssim_loss = SSIM(11) + + kaiming_init = parameter_init("kaiming") + if student_channels != teacher_channels: + self.align = nn.Conv2D( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=kaiming_init, + bias_attr=False) + else: + self.align = None + + self.generation = nn.Sequential( + nn.Conv2D( + teacher_channels, teacher_channels, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2D( + teacher_channels, teacher_channels, kernel_size=3, padding=1)) + + def forward(self, stu_feature, tea_feature, inputs): + N = stu_feature.shape[0] + if self.align is not None: + stu_feature = self.align(stu_feature) + stu_feature = self.generation(stu_feature) + + if self.normalize: + stu_feature = feature_norm(stu_feature) + tea_feature = feature_norm(tea_feature) + + if self.loss_func == 'mse': + loss = self.mse_loss(stu_feature, tea_feature) / N + elif self.loss_func == 'ssim': + ssim_loss = self.ssim_loss(stu_feature, tea_feature) + loss = paddle.clip((1 - ssim_loss) / 2, 0, 1) + else: + raise ValueError + return loss * self.loss_weight + + +class SSIM(nn.Layer): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = self.create_window(window_size, self.channel) + + def gaussian(self, window_size, sigma): + gauss = paddle.to_tensor([ + math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + def create_window(self, window_size, channel): + _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0) + window = _2D_window.expand([channel, 1, window_size, window_size]) + return window + + def _ssim(self, img1, img2, window, window_size, channel, + size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=window_size // 2, + groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=window_size // 2, + groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=window_size // 2, + groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + 1e-12 + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean([1, 2, 3]) + + def forward(self, img1, img2): + channel = img1.shape[1] + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = self.create_window(self.window_size, channel) + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, + self.size_average) diff --git a/ppdet/slim/distill_model.py b/ppdet/slim/distill_model.py new file mode 100644 index 00000000000..f6658acbbb9 --- /dev/null +++ b/ppdet/slim/distill_model.py @@ -0,0 +1,357 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from IPython import embed + +from ppdet.core.workspace import register, create, load_config +from ppdet.utils.checkpoint import load_pretrain_weight +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = [ + 'DistillModel', + 'FGDDistillModel', + 'CWDDistillModel', + 'LDDistillModel', + 'PPYOLOEDistillModel', +] + + +@register +class DistillModel(nn.Layer): + """ + Build common distill model. + Args: + cfg: The student config. + slim_cfg: The teacher and distill config. + """ + + def __init__(self, cfg, slim_cfg): + super(DistillModel, self).__init__() + self.arch = cfg.architecture + + self.stu_cfg = cfg + self.student_model = create(self.stu_cfg.architecture) + if 'pretrain_weights' in self.stu_cfg and self.stu_cfg.pretrain_weights: + stu_pretrain = self.stu_cfg.pretrain_weights + else: + stu_pretrain = None + + slim_cfg = load_config(slim_cfg) + self.tea_cfg = slim_cfg + self.teacher_model = create(self.tea_cfg.architecture) + if 'pretrain_weights' in self.tea_cfg and self.tea_cfg.pretrain_weights: + tea_pretrain = self.tea_cfg.pretrain_weights + else: + tea_pretrain = None + self.distill_cfg = slim_cfg + + # load pretrain weights + self.is_inherit = False + if stu_pretrain: + if self.is_inherit and tea_pretrain: + load_pretrain_weight(self.student_model, tea_pretrain) + logger.debug( + "Inheriting! loading teacher weights to student model!") + load_pretrain_weight(self.student_model, stu_pretrain) + logger.info("Student model has loaded pretrain weights!") + if tea_pretrain: + load_pretrain_weight(self.teacher_model, tea_pretrain) + logger.info("Teacher model has loaded pretrain weights!") + + self.teacher_model.eval() + for param in self.teacher_model.parameters(): + param.trainable = False + + self.distill_loss = self.build_loss(self.distill_cfg) + + def build_loss(self, distill_cfg): + if 'distill_loss' in distill_cfg and distill_cfg.distill_loss: + return create(distill_cfg.distill_loss) + else: + return None + + def parameters(self): + return self.student_model.parameters() + + def forward(self, inputs): + if self.training: + student_loss = self.student_model(inputs) + with paddle.no_grad(): + teacher_loss = self.teacher_model(inputs) + + loss = self.distill_loss(self.teacher_model, self.student_model) + student_loss['distill_loss'] = loss + student_loss['teacher_loss'] = teacher_loss['loss'] + student_loss['loss'] += student_loss['distill_loss'] + return student_loss + else: + return self.student_model(inputs) + + +@register +class FGDDistillModel(DistillModel): + """ + Build FGD distill model. + Args: + cfg: The student config. + slim_cfg: The teacher and distill config. + """ + + def __init__(self, cfg, slim_cfg): + super(FGDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg) + assert self.arch in ['RetinaNet', 'PicoDet' + ], 'Unsupported arch: {}'.format(self.arch) + self.is_inherit = True + + def build_loss(self, distill_cfg): + assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name + assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss + loss_func = dict() + name_list = distill_cfg.distill_loss_name + for name in name_list: + loss_func[name] = create(distill_cfg.distill_loss) + return loss_func + + def forward(self, inputs): + if self.training: + s_body_feats = self.student_model.backbone(inputs) + s_neck_feats = self.student_model.neck(s_body_feats) + with paddle.no_grad(): + t_body_feats = self.teacher_model.backbone(inputs) + t_neck_feats = self.teacher_model.neck(t_body_feats) + + loss_dict = {} + for idx, k in enumerate(self.distill_loss): + loss_dict[k] = self.distill_loss[k](s_neck_feats[idx], + t_neck_feats[idx], inputs) + if self.arch == "RetinaNet": + loss = self.student_model.head(s_neck_feats, inputs) + elif self.arch == "PicoDet": + head_outs = self.student_model.head( + s_neck_feats, self.student_model.export_post_process) + loss_gfl = self.student_model.head.get_loss(head_outs, inputs) + total_loss = paddle.add_n(list(loss_gfl.values())) + loss = {} + loss.update(loss_gfl) + loss.update({'loss': total_loss}) + else: + raise ValueError(f"Unsupported model {self.arch}") + + for k in loss_dict: + loss['loss'] += loss_dict[k] + loss[k] = loss_dict[k] + return loss + else: + body_feats = self.student_model.backbone(inputs) + neck_feats = self.student_model.neck(body_feats) + head_outs = self.student_model.head(neck_feats) + if self.arch == "RetinaNet": + bbox, bbox_num = self.student_model.head.post_process( + head_outs, inputs['im_shape'], inputs['scale_factor']) + return {'bbox': bbox, 'bbox_num': bbox_num} + elif self.arch == "PicoDet": + head_outs = self.student_model.head( + neck_feats, self.student_model.export_post_process) + scale_factor = inputs['scale_factor'] + bboxes, bbox_num = self.student_model.head.post_process( + head_outs, + scale_factor, + export_nms=self.student_model.export_nms) + return {'bbox': bboxes, 'bbox_num': bbox_num} + else: + raise ValueError(f"Unsupported model {self.arch}") + + +@register +class CWDDistillModel(DistillModel): + """ + Build CWD distill model. + Args: + cfg: The student config. + slim_cfg: The teacher and distill config. + """ + + def __init__(self, cfg, slim_cfg): + super(CWDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg) + assert self.arch in ['GFL', 'RetinaNet'], 'Unsupported arch: {}'.format( + self.arch) + + def build_loss(self, distill_cfg): + assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name + assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss + loss_func = dict() + name_list = distill_cfg.distill_loss_name + for name in name_list: + loss_func[name] = create(distill_cfg.distill_loss) + return loss_func + + def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs): + loss = self.student_model.head(stu_fea_list, inputs) + distill_loss = {} + for idx, k in enumerate(self.loss_dic): + distill_loss[k] = self.loss_dic[k](stu_fea_list[idx], + tea_fea_list[idx]) + + loss['loss'] += distill_loss[k] + loss[k] = distill_loss[k] + return loss + + def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs): + loss = {} + head_outs = self.student_model.head(stu_fea_list) + loss_gfl = self.student_model.head.get_loss(head_outs, inputs) + loss.update(loss_gfl) + total_loss = paddle.add_n(list(loss.values())) + loss.update({'loss': total_loss}) + + feat_loss = {} + loss_dict = {} + s_cls_feat, t_cls_feat = [], [] + for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list): + conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f) + cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat) + t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f) + t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat) + s_cls_feat.append(cls_score) + t_cls_feat.append(t_cls_score) + + for idx, k in enumerate(self.loss_dic): + loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx]) + feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx], + tea_fea_list[idx]) + + for k in feat_loss: + loss['loss'] += feat_loss[k] + loss[k] = feat_loss[k] + + for k in loss_dict: + loss['loss'] += loss_dict[k] + loss[k] = loss_dict[k] + return loss + + def forward(self, inputs): + if self.training: + s_body_feats = self.student_model.backbone(inputs) + s_neck_feats = self.student_model.neck(s_body_feats) + with paddle.no_grad(): + t_body_feats = self.teacher_model.backbone(inputs) + t_neck_feats = self.teacher_model.neck(t_body_feats) + + if self.arch == "RetinaNet": + loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats, + inputs) + elif self.arch == "GFL": + loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs) + else: + raise ValueError(f"unsupported arch {self.arch}") + return loss + else: + body_feats = self.student_model.backbone(inputs) + neck_feats = self.student_model.neck(body_feats) + head_outs = self.student_model.head(neck_feats) + if self.arch == "RetinaNet": + bbox, bbox_num = self.student_model.head.post_process( + head_outs, inputs['im_shape'], inputs['scale_factor']) + return {'bbox': bbox, 'bbox_num': bbox_num} + elif self.arch == "GFL": + bbox_pred, bbox_num = head_outs + output = {'bbox': bbox_pred, 'bbox_num': bbox_num} + return output + else: + raise ValueError(f"unsupported arch {self.arch}") + + +@register +class LDDistillModel(DistillModel): + """ + Build LD distill model. + Args: + cfg: The student config. + slim_cfg: The teacher and distill config. + """ + + def __init__(self, cfg, slim_cfg): + super(LDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg) + assert self.arch in ['GFL'], 'Unsupported arch: {}'.format(self.arch) + + def forward(self, inputs): + if self.training: + s_body_feats = self.student_model.backbone(inputs) + s_neck_feats = self.student_model.neck(s_body_feats) + s_head_outs = self.student_model.head(s_neck_feats) + with paddle.no_grad(): + t_body_feats = self.teacher_model.backbone(inputs) + t_neck_feats = self.teacher_model.neck(t_body_feats) + t_head_outs = self.teacher_model.head(t_neck_feats) + + soft_label_list = t_head_outs[0] + soft_targets_list = t_head_outs[1] + student_loss = self.student_model.head.get_loss( + s_head_outs, inputs, soft_label_list, soft_targets_list) + total_loss = paddle.add_n(list(student_loss.values())) + student_loss['loss'] = total_loss + return student_loss + else: + return self.student_model(inputs) + + +@register +class PPYOLOEDistillModel(DistillModel): + """ + Build PPYOLOE distill model, only used in PPYOLOE + Args: + cfg: The student config. + slim_cfg: The teacher and distill config. + """ + + def __init__(self, cfg, slim_cfg): + super(PPYOLOEDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg) + assert self.arch in ['PPYOLOE'], 'Unsupported arch: {}'.format( + self.arch) + + def forward(self, inputs, alpha=0.125): + if self.training: + if hasattr(self.teacher_model.yolo_head, "assigned_labels"): + self.student_model.yolo_head.assigned_labels, self.student_model.yolo_head.assigned_bboxes, self.student_model.yolo_head.assigned_scores, self.student_model.yolo_head.mask_positive = \ + self.teacher_model.yolo_head.assigned_labels, self.teacher_model.yolo_head.assigned_bboxes, self.teacher_model.yolo_head.assigned_scores, self.teacher_model.yolo_head.mask_positive + delattr(self.teacher_model.yolo_head, "assigned_labels") + delattr(self.teacher_model.yolo_head, "assigned_bboxes") + delattr(self.teacher_model.yolo_head, "assigned_scores") + delattr(self.teacher_model.yolo_head, "mask_positive") + student_out = self.student_model(inputs) + with paddle.no_grad(): + teacher_out = self.teacher_model(inputs) + + logits_loss, feat_loss = self.distill_loss(self.teacher_model, + self.student_model) + student_loss = student_out['det_losses'] + det_total_loss = student_loss['loss'] + + total_loss = alpha * (det_total_loss + logits_loss + feat_loss) + student_loss['loss'] = total_loss + student_loss['det_loss'] = det_total_loss + student_loss['logits_loss'] = logits_loss + student_loss['feat_loss'] = feat_loss + return student_loss + else: + return self.student_model(inputs) diff --git a/ppdet/slim/distill_ppyoloe.py b/ppdet/slim/distill_ppyoloe.py deleted file mode 100644 index 315a9a5700b..00000000000 --- a/ppdet/slim/distill_ppyoloe.py +++ /dev/null @@ -1,695 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math -import numpy as np - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F -from ppdet.core.workspace import register, create, load_config -from ppdet.utils.checkpoint import load_pretrain_weight -from .distill import parameter_init -from ppdet.modeling.losses.iou_loss import GIoULoss -from ppdet.utils.logger import setup_logger - -logger = setup_logger(__name__) - - -class PPYOLOEDistillModel(nn.Layer): - def __init__(self, cfg, slim_cfg): - super(PPYOLOEDistillModel, self).__init__() - self.student_model = create(cfg.architecture) - logger.debug('Load student model pretrain_weights:{}'.format( - cfg.pretrain_weights)) - load_pretrain_weight(self.student_model, cfg.pretrain_weights) - - slim_cfg = load_config(slim_cfg) - self.teacher_model = create(slim_cfg.architecture) - self.distill_loss = create(slim_cfg.distill_loss) - logger.debug('Load teacher model pretrain_weights:{}'.format( - slim_cfg.pretrain_weights)) - load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights) - - for param in self.teacher_model.parameters(): - param.trainable = False - - def parameters(self): - return self.student_model.parameters() - - def forward(self, inputs, alpha=0.125): - if self.training: - with paddle.no_grad(): - teacher_out = self.teacher_model(inputs) - - if hasattr(self.teacher_model.yolo_head, "assigned_labels"): - self.student_model.yolo_head.assigned_labels, self.student_model.yolo_head.assigned_bboxes, self.student_model.yolo_head.assigned_scores, self.student_model.yolo_head.mask_positive = \ - self.teacher_model.yolo_head.assigned_labels, self.teacher_model.yolo_head.assigned_bboxes, self.teacher_model.yolo_head.assigned_scores, self.teacher_model.yolo_head.mask_positive - delattr(self.teacher_model.yolo_head, "assigned_labels") - delattr(self.teacher_model.yolo_head, "assigned_bboxes") - delattr(self.teacher_model.yolo_head, "assigned_scores") - delattr(self.teacher_model.yolo_head, "mask_positive") - - student_out = self.student_model(inputs) - - # head loss concerned - soft_loss, feat_loss, distill_loss_dict = self.distill_loss( - self.teacher_model, self.student_model) - stu_loss = student_out['det_losses'] - stu_det_total_loss = stu_loss['loss'] - - # conbined distill - stu_loss[ - 'loss'] = soft_loss + alpha * feat_loss + alpha * stu_det_total_loss - stu_loss['soft_loss'] = soft_loss - stu_loss['feat_loss'] = feat_loss - return stu_loss - else: - return self.student_model(inputs) - - -@register -class DistillPPYOLOELoss(nn.Layer): - def __init__( - self, - teacher_width_mult=1.0, # default as L - student_width_mult=0.75, # default as M - neck_out_channels=[768, 384, 192], # default as L - loss_weight={ - 'class': 0.5, - 'iou': 1.25, - 'dfl': 0.25, - }, - kd_neck=True, - kd_type='fgd'): - super(DistillPPYOLOELoss, self).__init__() - self.loss_bbox = GIoULoss() - self.bbox_loss_weight = loss_weight['iou'] - self.dfl_loss_weight = loss_weight['dfl'] - self.qfl_loss_weight = loss_weight['class'] - - self.kd_neck = kd_neck - self.kd_type = kd_type - if self.kd_neck: - # Knowledge Distillation for Detectors in necks - distill_loss_module_list = [] - self.t_channel_list = [ - int(c * teacher_width_mult) for c in neck_out_channels - ] - self.s_channel_list = [ - int(c * student_width_mult) for c in neck_out_channels - ] - for i in range(len(neck_out_channels)): - if self.kd_type == 'fgd': - distill_loss_module = FGDLoss( - student_channels=self.s_channel_list[i], - teacher_channels=self.t_channel_list[i]) - elif self.kd_type == 'pkd': - distill_loss_module = PKDLoss( - student_channels=self.s_channel_list[i], - teacher_channels=self.t_channel_list[i], - resize_stu=False) - elif self.kd_type == 'mgd': - distill_loss_module = MGDSSIMLoss( - student_channels=self.s_channel_list[i], - teacher_channels=self.t_channel_list[i]) - else: - raise ValueError - distill_loss_module_list.append(distill_loss_module) - - self.distill_loss_module_list = nn.LayerList( - distill_loss_module_list) - - def bbox_loss(self, s_bbox, t_bbox, weight_targets=None): - # [x,y,w,h] - if weight_targets is not None: - loss_bbox = paddle.sum( - self.loss_bbox(s_bbox, t_bbox) * weight_targets) - avg_factor = weight_targets.sum() - loss_bbox = loss_bbox / avg_factor - else: - loss_bbox = paddle.mean(self.loss_bbox(s_bbox, t_bbox)) - return loss_bbox - - def quality_focal_loss(self, pred_logits, soft_target_logits, beta=2.0, \ - use_sigmoid=True, label_weights=None, num_total_pos=None, pos_mask=None): - if use_sigmoid: - func = F.binary_cross_entropy_with_logits - soft_target = F.sigmoid(soft_target_logits) - pred_sigmoid = F.sigmoid(pred_logits) - preds = pred_logits - else: - func = F.binary_cross_entropy - soft_target = soft_target_logits - pred_sigmoid = pred_logits - preds = pred_sigmoid - - scale_factor = pred_sigmoid - soft_target - loss = func( - preds, soft_target, reduction='none') * scale_factor.abs().pow(beta) - loss = loss - if pos_mask is not None: - loss *= pos_mask - - loss = loss.sum(1) - if label_weights is not None: - loss = loss * label_weights - if num_total_pos is not None: - loss = loss.sum() / num_total_pos - else: - loss = loss.mean() - return loss - - def distribution_focal_loss(self, pred_corners, target_corners, - weight_targets): - target_corners_label = paddle.nn.functional.softmax( - target_corners, axis=-1) - loss_dfl = paddle.nn.functional.cross_entropy( - pred_corners, - target_corners_label, - soft_label=True, - reduction='none') - loss_dfl = loss_dfl.sum(1) - if weight_targets is not None: - loss_dfl = loss_dfl * (weight_targets.expand([-1, 4]).reshape([-1])) - loss_dfl = loss_dfl.sum(-1) / weight_targets.sum() - else: - loss_dfl = loss_dfl.mean(-1) - loss_dfl = loss_dfl / 4.0 # 4 direction - return loss_dfl - - def forward(self, teacher_model, student_model): - teacher_distill_pairs = teacher_model.yolo_head.distill_pairs - student_distill_pairs = student_model.yolo_head.distill_pairs - distill_bbox_loss, distill_dfl_loss, distill_cls_loss = [], [], [] - distill_bbox_loss.append( - self.bbox_loss(student_distill_pairs['pred_bboxes_pos'], - teacher_distill_pairs['pred_bboxes_pos'].detach(), - weight_targets=student_distill_pairs['bbox_weight'] - ) if 'pred_bboxes_pos' in student_distill_pairs and \ - 'pred_bboxes_pos' in teacher_distill_pairs and \ - 'bbox_weight' in student_distill_pairs - else student_distill_pairs['null_loss'] - ) - distill_dfl_loss.append(self.distribution_focal_loss( - student_distill_pairs['pred_dist_pos'].reshape((-1, student_distill_pairs['pred_dist_pos'].shape[-1])), - teacher_distill_pairs['pred_dist_pos'].detach().reshape((-1, teacher_distill_pairs['pred_dist_pos'].shape[-1])), \ - weight_targets=student_distill_pairs['bbox_weight'] - ) if 'pred_dist_pos' in student_distill_pairs and \ - 'pred_dist_pos' in teacher_distill_pairs and \ - 'bbox_weight' in student_distill_pairs - else student_distill_pairs['null_loss'] - ) - distill_cls_loss.append( - self.quality_focal_loss( - student_distill_pairs['pred_cls_scores'].reshape(( - -1, student_distill_pairs['pred_cls_scores'].shape[-1])), - teacher_distill_pairs['pred_cls_scores'].detach().reshape(( - -1, teacher_distill_pairs['pred_cls_scores'].shape[-1])), - num_total_pos=student_distill_pairs['pos_num'], - use_sigmoid=False)) - distill_bbox_loss = paddle.add_n(distill_bbox_loss) - distill_cls_loss = paddle.add_n(distill_cls_loss) - distill_dfl_loss = paddle.add_n(distill_dfl_loss) - - if self.kd_neck: - # Knowledge Distillation for Detectors in necks - distill_neck_global_loss = [] - inputs = student_model.inputs - teacher_fpn_feats = teacher_distill_pairs['emb_feats'] - student_fpn_feats = student_distill_pairs['emb_feats'] - assert 'gt_bbox' in inputs - for i, distill_loss_module in enumerate( - self.distill_loss_module_list): - distill_neck_global_loss.append( - distill_loss_module(student_fpn_feats[i], teacher_fpn_feats[ - i], inputs)) - distill_neck_global_loss = paddle.add_n(distill_neck_global_loss) - else: - distill_neck_global_loss = paddle.to_tensor([0]) - - soft_loss = ( - distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * - self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight) - student_model.yolo_head.distill_pairs.clear() - teacher_model.yolo_head.distill_pairs.clear() - return soft_loss, \ - distill_neck_global_loss, \ - {'dfl_loss': distill_dfl_loss, 'qfl_loss': distill_cls_loss, 'bbox_loss': distill_bbox_loss} - - -@register -class FGDLoss(nn.Layer): - """ - Focal and Global Knowledge Distillation for Detectors - The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py - - Args: - student_channels (int): The number of channels in the student's FPN feature map. Default to 256. - teacher_channels (int): The number of channels in the teacher's FPN feature map. Default to 256. - normalize (bool): Whether to normalize the feature maps. - temp (float, optional): The temperature coefficient. Defaults to 0.5. - alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 - beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 - gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 - lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 - """ - - def __init__( - self, - student_channels=256, - teacher_channels=256, - normalize=True, - temp=0.5, - alpha_fgd=0.00001, # 0.001 - beta_fgd=0.000005, # 0.0005 - gamma_fgd=0.00001, # 0.001 - lambda_fgd=0.00000005): # 0.000005 - super(FGDLoss, self).__init__() - self.temp = temp - self.alpha_fgd = alpha_fgd - self.beta_fgd = beta_fgd - self.gamma_fgd = gamma_fgd - self.lambda_fgd = lambda_fgd - self.normalize = normalize - kaiming_init = parameter_init("kaiming") - zeros_init = parameter_init("constant", 0.0) - - if student_channels != teacher_channels: - self.align = nn.Conv2D( - student_channels, - teacher_channels, - kernel_size=1, - stride=1, - padding=0, - weight_attr=kaiming_init) - student_channels = teacher_channels - else: - self.align = None - - self.conv_mask_s = nn.Conv2D( - student_channels, 1, kernel_size=1, weight_attr=kaiming_init) - self.conv_mask_t = nn.Conv2D( - teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) - - self.stu_conv_block = nn.Sequential( - nn.Conv2D( - student_channels, - student_channels // 2, - kernel_size=1, - weight_attr=zeros_init), - nn.LayerNorm([student_channels // 2, 1, 1]), - nn.ReLU(), - nn.Conv2D( - student_channels // 2, - student_channels, - kernel_size=1, - weight_attr=zeros_init)) - self.tea_conv_block = nn.Sequential( - nn.Conv2D( - teacher_channels, - teacher_channels // 2, - kernel_size=1, - weight_attr=zeros_init), - nn.LayerNorm([teacher_channels // 2, 1, 1]), - nn.ReLU(), - nn.Conv2D( - teacher_channels // 2, - teacher_channels, - kernel_size=1, - weight_attr=zeros_init)) - - def norm(self, feat): - # Normalize the feature maps to have zero mean and unit variances. - assert len(feat.shape) == 4 - N, C, H, W = feat.shape - feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) - mean = feat.mean(axis=-1, keepdim=True) - std = feat.std(axis=-1, keepdim=True) - feat = (feat - mean) / (std + 1e-6) - return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) - - def spatial_channel_attention(self, x, t=0.5): - shape = paddle.shape(x) - N, C, H, W = shape - _f = paddle.abs(x) - spatial_map = paddle.reshape( - paddle.mean( - _f, axis=1, keepdim=True) / t, [N, -1]) - spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W - spatial_att = paddle.reshape(spatial_map, [N, H, W]) - - channel_map = paddle.mean( - paddle.mean( - _f, axis=2, keepdim=False), axis=2, keepdim=False) - channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C - return [spatial_att, channel_att] - - def spatial_pool(self, x, mode="teacher"): - batch, channel, width, height = x.shape - x_copy = x - x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) - x_copy = x_copy.unsqueeze(1) - if mode.lower() == "student": - context_mask = self.conv_mask_s(x) - else: - context_mask = self.conv_mask_t(x) - - context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) - context_mask = F.softmax(context_mask, axis=2) - context_mask = context_mask.unsqueeze(-1) - context = paddle.matmul(x_copy, context_mask) - context = paddle.reshape(context, [batch, channel, 1, 1]) - return context - - def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, - tea_spatial_att): - def _func(a, b): - return paddle.sum(paddle.abs(a - b)) / len(a) - - mask_loss = _func(stu_channel_att, tea_channel_att) + _func( - stu_spatial_att, tea_spatial_att) - return mask_loss - - def feature_loss(self, stu_feature, tea_feature, mask_fg, mask_bg, - tea_channel_att, tea_spatial_att): - mask_fg = mask_fg.unsqueeze(axis=1) - mask_bg = mask_bg.unsqueeze(axis=1) - tea_channel_att = tea_channel_att.unsqueeze(axis=-1).unsqueeze(axis=-1) - tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) - - fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) - fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) - fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_fg)) - bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_bg)) - - fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) - fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) - fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_fg)) - bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_bg)) - - fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(mask_fg) - bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(mask_bg) - return fg_loss, bg_loss - - def relation_loss(self, stu_feature, tea_feature): - context_s = self.spatial_pool(stu_feature, "student") - context_t = self.spatial_pool(tea_feature, "teacher") - out_s = stu_feature + self.stu_conv_block(context_s) - out_t = tea_feature + self.tea_conv_block(context_t) - rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) - return rela_loss - - def mask_value(self, mask, xl, xr, yl, yr, value): - mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) - return mask - - def forward(self, stu_feature, tea_feature, inputs): - assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ - f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' - assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys( - ), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs." - gt_bboxes = inputs['gt_bbox'] - ins_shape = [ - inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) - ] - if self.align is not None: - stu_feature = self.align(stu_feature) - if self.normalize: - stu_feature, tea_feature = self.norm(stu_feature), self.norm( - tea_feature) - - tea_spatial_att, tea_channel_att = self.spatial_channel_attention( - tea_feature, self.temp) - stu_spatial_att, stu_channel_att = self.spatial_channel_attention( - stu_feature, self.temp) - - mask_fg = paddle.zeros(tea_spatial_att.shape) - mask_bg = paddle.ones_like(tea_spatial_att) - one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) - zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) - wmin, wmax, hmin, hmax = [], [], [], [] - - N, _, H, W = stu_feature.shape - if gt_bboxes.shape[1] != 0: - for i in range(N): - tmp_box = paddle.ones_like(gt_bboxes[i]) - tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W - tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W - tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H - tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H - - zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") - ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") - wmin.append( - paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum( - zero)) - wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) - hmin.append( - paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum( - zero)) - hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) - - area_recip = 1.0 / ( - hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( - wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) - - for j in range(len(gt_bboxes[i])): - if gt_bboxes[i][j].sum() > 0: - mask_fg[i] = self.mask_value( - mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j], - wmax[i][j] + 1, area_recip[0][j]) - - mask_bg[i] = paddle.where(mask_fg[i] > zero_tmp, zero_tmp, - one_tmp) - - if paddle.sum(mask_bg[i]): - mask_bg[i] /= paddle.sum(mask_bg[i]) - - fg_loss, bg_loss = self.feature_loss( - stu_feature, tea_feature, mask_fg, mask_bg, tea_channel_att, - tea_spatial_att) - mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, - stu_spatial_att, tea_spatial_att) - rela_loss = self.relation_loss(stu_feature, tea_feature) - loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ - + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss - else: - rela_loss = self.relation_loss(stu_feature, tea_feature) - loss = self.lambda_fgd * rela_loss - return loss - - -@register -class PKDLoss(nn.Layer): - """ - PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient. - - Args: - loss_weight (float): Weight of loss. Defaults to 1.0. - resize_stu (bool): If True, we'll down/up sample the features of the - student model to the spatial size of those of the teacher model if - their spatial sizes are different. And vice versa. Defaults to - True. - """ - - def __init__(self, - student_channels=256, - teacher_channels=256, - normalize=True, - loss_weight=1.0, - resize_stu=True): - super(PKDLoss, self).__init__() - self.normalize = normalize - self.loss_weight = loss_weight - self.resize_stu = resize_stu - kaiming_init = parameter_init("kaiming") - if student_channels != teacher_channels: - self.align = nn.Conv2D( - student_channels, - teacher_channels, - kernel_size=1, - stride=1, - padding=0, - weight_attr=kaiming_init) - else: - self.align = None - - def norm(self, feat): - # Normalize the feature maps to have zero mean and unit variances. - assert len(feat.shape) == 4 - N, C, H, W = feat.shape - feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) - mean = feat.mean(axis=-1, keepdim=True) - std = feat.std(axis=-1, keepdim=True) - feat = (feat - mean) / (std + 1e-6) - return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) - - def forward(self, stu_feature, tea_feature, inputs): - if self.align is not None: - stu_feature = self.align(stu_feature) - - loss = 0. - size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:] - if size_s[0] != size_t[0]: - if self.resize_stu: - stu_feature = F.interpolate( - stu_feature, size_t, mode='bilinear') - else: - tea_feature = F.interpolate( - tea_feature, size_s, mode='bilinear') - assert stu_feature.shape == tea_feature.shape - - if self.normalize: - norm_stu_feature = self.norm(stu_feature) - norm_tea_feature = self.norm(tea_feature) - - # First conduct feature normalization and then calculate the - # MSE loss. Methematically, it is equivalent to firstly calculate - # the Pearson Correlation Coefficient (r) between two feature - # vectors, and then use 1-r as the new feature imitation loss. - loss += F.mse_loss(norm_stu_feature, norm_tea_feature) / 2 - return loss * self.loss_weight - - -@register -class MGDSSIMLoss(nn.Layer): - def __init__(self, - student_channels=256, - teacher_channels=256, - normalize=True, - ssim=True, - loss_weight=1.0, - max_alpha=1.0, - min_alpha=0.2): - super(MGDSSIMLoss, self).__init__() - self.normalize = normalize - self.loss_weight = loss_weight - self.max_alpha = max_alpha - self.min_alpha = min_alpha - - self.mse_loss = nn.MSELoss(reduction='sum') - self.ssim_loss = SSIM(11) - - kaiming_init = parameter_init("kaiming") - if student_channels != teacher_channels: - self.align_layer = nn.Conv2D( - student_channels, - teacher_channels, - kernel_size=1, - stride=1, - padding=0, - weight_attr=kaiming_init, - bias_attr=False) - else: - self.align_layer = None - - self.generations = nn.Sequential( - nn.Conv2D( - teacher_channels, teacher_channels, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv2D( - teacher_channels, teacher_channels, kernel_size=3, padding=1)) - - def norm(self, feat): - # Normalize the feature maps to have zero mean and unit variances. - assert len(feat.shape) == 4 - N, C, H, W = feat.shape - feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1]) - mean = feat.mean(axis=-1, keepdim=True) - std = feat.std(axis=-1, keepdim=True) - feat = (feat - mean) / (std + 1e-6) - return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3]) - - def forward(self, stu_feature, tea_feature, input): - N = stu_feature.shape[0] - masked_fea = self.align_layer(stu_feature) - stu_feature = self.generations(masked_fea) - - if self.normalize: - stu_feature = self.norm(stu_feature) - tea_feature = self.norm(tea_feature) - - if self.ssim is False: - dis_loss = self.mse_loss(stu_feature, tea_feature) / N - else: - ssim_loss = self.ssim_loss(stu_feature, tea_feature) - dis_loss = paddle.clip((1 - ssim_loss) / 2, 0, 1) - return dis_loss * self.loss_weight - - -class SSIM(nn.Layer): - def __init__(self, window_size=11, size_average=True): - super(SSIM, self).__init__() - self.window_size = window_size - self.size_average = size_average - self.channel = 1 - self.window = self.create_window(window_size, self.channel) - - def gaussian(self, window_size, sigma): - gauss = paddle.to_tensor([ - math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) - for x in range(window_size) - ]) - return gauss / gauss.sum() - - def create_window(self, window_size, channel): - _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0) - window = _2D_window.expand([channel, 1, window_size, window_size]) - return window - - def _ssim(self, img1, img2, window, window_size, channel, - size_average=True): - mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) - mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) - mu1_sq = mu1.pow(2) - mu2_sq = mu2.pow(2) - mu1_mu2 = mu1 * mu2 - - sigma1_sq = F.conv2d( - img1 * img1, window, padding=window_size // 2, - groups=channel) - mu1_sq - sigma2_sq = F.conv2d( - img2 * img2, window, padding=window_size // 2, - groups=channel) - mu2_sq - sigma12 = F.conv2d( - img1 * img2, window, padding=window_size // 2, - groups=channel) - mu1_mu2 - - C1 = 0.01**2 - C2 = 0.03**2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( - 1e-12 + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) - - if size_average: - return ssim_map.mean() - else: - return ssim_map.mean([1, 2, 3]) - - def forward(self, img1, img2): - channel = img1.shape[1] - if channel == self.channel and self.window.dtype == img1.dtype: - window = self.window - else: - window = self.create_window(self.window_size, channel) - self.window = window - self.channel = channel - - return self._ssim(img1, img2, window, self.window_size, channel, - self.size_average) From 34b4870bacc2eaf9c445dd233931b129f351cc26 Mon Sep 17 00:00:00 2001 From: nemonameless Date: Thu, 2 Feb 2023 09:01:08 +0000 Subject: [PATCH 3/5] fix configs and docs --- configs/ppyoloe/distill/README.md | 40 ++++ ...> ppyoloe_plus_crn_l_80e_coco_distill.yml} | 22 +- ...> ppyoloe_plus_crn_m_80e_coco_distill.yml} | 22 +- ...> ppyoloe_plus_crn_s_80e_coco_distill.yml} | 22 +- configs/slim/distill/README.md | 36 ++++ .../distill/ppyoloe_plus_distill_l_to_m.yml | 10 +- .../distill/ppyoloe_plus_distill_m_to_s.yml | 10 +- .../distill/ppyoloe_plus_distill_x_to_l.yml | 10 +- ppdet/modeling/architectures/ppyoloe.py | 21 +- ppdet/modeling/architectures/yolo.py | 2 +- ppdet/modeling/heads/pico_head.py | 4 +- ppdet/modeling/heads/ppyoloe_contrast_head.py | 6 +- ppdet/modeling/heads/ppyoloe_head.py | 33 +-- ppdet/modeling/heads/ppyoloe_r_head.py | 4 +- ppdet/modeling/heads/tood_head.py | 4 +- ppdet/slim/distill_loss.py | 189 ++++++++++-------- ppdet/slim/distill_model.py | 6 +- 17 files changed, 308 insertions(+), 133 deletions(-) create mode 100644 configs/ppyoloe/distill/README.md rename configs/ppyoloe/distill/{ppyoloe_plus_crn_l_80e_coco.yml => ppyoloe_plus_crn_l_80e_coco_distill.yml} (54%) rename configs/ppyoloe/distill/{ppyoloe_plus_crn_m_80e_coco.yml => ppyoloe_plus_crn_m_80e_coco_distill.yml} (54%) rename configs/ppyoloe/distill/{ppyoloe_plus_crn_s_80e_coco.yml => ppyoloe_plus_crn_s_80e_coco_distill.yml} (54%) rename configs/{ppyoloe => slim}/distill/ppyoloe_plus_distill_l_to_m.yml (71%) rename configs/{ppyoloe => slim}/distill/ppyoloe_plus_distill_m_to_s.yml (71%) rename configs/{ppyoloe => slim}/distill/ppyoloe_plus_distill_x_to_l.yml (71%) diff --git a/configs/ppyoloe/distill/README.md b/configs/ppyoloe/distill/README.md new file mode 100644 index 00000000000..7453cdc2ded --- /dev/null +++ b/configs/ppyoloe/distill/README.md @@ -0,0 +1,40 @@ +# PPYOLOE+ Distillation(PPYOLOE+ 蒸馏) + +PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。 + + +## 模型库 + + + +## 快速开始 + +### 训练 +```shell +# 单卡 +python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml +# 多卡 +python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml +``` + +- `-c`: 指定模型配置文件,也是student配置文件。 +- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 + +### 评估 +```shell +python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams +``` + +- `-c`: 指定模型配置文件,也是student配置文件。 +- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 +- `-o weights`: 指定压缩算法训好的模型路径。 + +### 测试 +```shell +python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg +``` + +- `-c`: 指定模型配置文件。 +- `--slim_config`: 指定压缩策略配置文件。 +- `-o weights`: 指定压缩算法训好的模型路径。 +- `--infer_img`: 指定测试图像路径。 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml similarity index 54% rename from configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco.yml rename to configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml index ffb4af2e23e..edd2e744a53 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml @@ -1,13 +1,33 @@ _BASE_: [ '../ppyoloe_plus_crn_l_80e_coco.yml', ] +for_distill: True architecture: PPYOLOE PPYOLOE: backbone: CSPResNet neck: CustomCSPPAN yolo_head: PPYOLOEHead post_process: ~ - for_distill: True + #for_distill: True + +epoch: 80 +LearningRate: + base_lr: 0.001 + schedulers: + - !CosineDecay + max_epochs: 96 + - !LinearWarmup + start_factor: 0. + epochs: 5 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + clip_grad_by_norm: 30. log_iter: 100 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml similarity index 54% rename from configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco.yml rename to configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml index 63e95a706ad..ff03c24ee76 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml @@ -1,13 +1,33 @@ _BASE_: [ '../ppyoloe_plus_crn_m_80e_coco.yml', ] +for_distill: True architecture: PPYOLOE PPYOLOE: backbone: CSPResNet neck: CustomCSPPAN yolo_head: PPYOLOEHead post_process: ~ - for_distill: True + #for_distill: True + +epoch: 80 +LearningRate: + base_lr: 0.001 + schedulers: + - !CosineDecay + max_epochs: 96 + - !LinearWarmup + start_factor: 0. + epochs: 5 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + clip_grad_by_norm: 30. log_iter: 100 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco_distill.yml similarity index 54% rename from configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco.yml rename to configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco_distill.yml index 0c13205d0de..6aec0e41fa7 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco_distill.yml @@ -1,13 +1,33 @@ _BASE_: [ '../ppyoloe_plus_crn_s_80e_coco.yml', ] +for_distill: True architecture: PPYOLOE PPYOLOE: backbone: CSPResNet neck: CustomCSPPAN yolo_head: PPYOLOEHead post_process: ~ - for_distill: True + #for_distill: True + +epoch: 80 +LearningRate: + base_lr: 0.001 + schedulers: + - !CosineDecay + max_epochs: 96 + - !LinearWarmup + start_factor: 0. + epochs: 5 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + clip_grad_by_norm: 30. log_iter: 100 diff --git a/configs/slim/distill/README.md b/configs/slim/distill/README.md index 159556a8069..d3f16135adf 100644 --- a/configs/slim/distill/README.md +++ b/configs/slim/distill/README.md @@ -38,6 +38,42 @@ CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https:// |gfl_r50_fpn_1x| student | 41.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | |gfl_r50_fpn_2x + CWD| student | 44.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) | +## PPYOLOE+模型蒸馏 + + + +## 快速开始 + +### 训练 +```shell +# 单卡 +python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml +# 多卡 +python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml +``` + +- `-c`: 指定模型配置文件,也是student配置文件。 +- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 + +### 评估 +```shell +python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams +``` + +- `-c`: 指定模型配置文件,也是student配置文件。 +- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 +- `-o weights`: 指定压缩算法训好的模型路径。 + +### 测试 +```shell +python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg +``` + +- `-c`: 指定模型配置文件。 +- `--slim_config`: 指定压缩策略配置文件。 +- `-o weights`: 指定压缩算法训好的模型路径。 +- `--infer_img`: 指定测试图像路径。 + ## Citations ``` diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml b/configs/slim/distill/ppyoloe_plus_distill_l_to_m.yml similarity index 71% rename from configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml rename to configs/slim/distill/ppyoloe_plus_distill_l_to_m.yml index 74299c7dd5d..46e0346d476 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_distill_l_to_m.yml +++ b/configs/slim/distill/ppyoloe_plus_distill_l_to_m.yml @@ -1,6 +1,6 @@ -# teacher config +# teacher and slim config _BASE_: [ - '../ppyoloe_plus_crn_l_80e_coco.yml', + '../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml', ] depth_mult: 1.0 width_mult: 1.0 @@ -15,6 +15,7 @@ PPYOLOE: pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams find_unused_parameters: True +for_distill: True slim: Distill @@ -26,7 +27,8 @@ DistillPPYOLOELoss: # L -> M logits_distill: True logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} feat_distill: True - feat_distiller: 'cwd' + feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] + feat_distill_place: 'neck_feats' teacher_width_mult: 1.0 # L student_width_mult: 0.75 # M - neck_out_channels: [768, 384, 192] # The actual channel will multiply width_mult + feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml b/configs/slim/distill/ppyoloe_plus_distill_m_to_s.yml similarity index 71% rename from configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml rename to configs/slim/distill/ppyoloe_plus_distill_m_to_s.yml index cb80edfe050..46747b19477 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_distill_m_to_s.yml +++ b/configs/slim/distill/ppyoloe_plus_distill_m_to_s.yml @@ -1,6 +1,6 @@ -# teacher config +# teacher and slim config _BASE_: [ - '../ppyoloe_plus_crn_l_80e_coco.yml', + '../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml', ] depth_mult: 0.67 width_mult: 0.75 @@ -15,6 +15,7 @@ PPYOLOE: pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_m_80e_coco.pdparams find_unused_parameters: True +for_distill: True slim: Distill @@ -26,7 +27,8 @@ DistillPPYOLOELoss: # M -> S logits_distill: True logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} feat_distill: True - feat_distiller: 'cwd' + feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] + feat_distill_place: 'neck_feats' teacher_width_mult: 0.75 # M student_width_mult: 0.5 # S - neck_out_channels: [768, 384, 192] # The actual channel will multiply width_mult + feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult diff --git a/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml b/configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml similarity index 71% rename from configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml rename to configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml index 2e0c44ddf78..01512aa984a 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_distill_x_to_l.yml +++ b/configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml @@ -1,6 +1,6 @@ -# teacher config +# teacher and slim config _BASE_: [ - '../ppyoloe_plus_crn_x_80e_coco.yml', + '../../ppyoloe/ppyoloe_plus_crn_x_80e_coco.yml', ] depth_mult: 1.33 width_mult: 1.25 @@ -15,6 +15,7 @@ PPYOLOE: pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams find_unused_parameters: True +for_distill: True slim: Distill @@ -26,7 +27,8 @@ DistillPPYOLOELoss: # X -> L logits_distill: True logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} feat_distill: True - feat_distiller: 'cwd' + feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] + feat_distill_place: 'neck_feats' teacher_width_mult: 1.25 # X student_width_mult: 1.0 # L - neck_out_channels: [768, 384, 192] # The actual channel will multiply width_mult + feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py index a48d18c0acd..f7f32a0dab6 100644 --- a/ppdet/modeling/architectures/ppyoloe.py +++ b/ppdet/modeling/architectures/ppyoloe.py @@ -20,15 +20,17 @@ import copy from ppdet.core.workspace import register, create from .meta_arch import BaseArch +from IPython import embed __all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead'] -# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture -# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py +# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head +# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py when not use distillation or aux head @register class PPYOLOE(BaseArch): __category__ = 'architecture' + __shared__ = ['for_distill'] __inject__ = ['post_process'] def __init__(self, @@ -37,6 +39,7 @@ def __init__(self, yolo_head='PPYOLOEHead', post_process='BBoxPostProcess', for_distill=False, + feat_distill_place='neck_feats', for_mot=False): """ PPYOLOE network, see https://arxiv.org/abs/2203.16250 @@ -56,6 +59,9 @@ def __init__(self, self.post_process = post_process self.for_mot = for_mot self.for_distill = for_distill + self.feat_distill_place = feat_distill_place + if for_distill: + assert feat_distill_place in ['backbone_feats', 'neck_feats'] @classmethod def from_config(cls, cfg, *args, **kwargs): @@ -84,10 +90,13 @@ def _forward(self): yolo_losses = self.yolo_head(neck_feats, self.inputs) if self.for_distill: - self.yolo_head.distill_pairs['emb_feats'] = neck_feats - return {'det_losses': yolo_losses, 'emb_feats': neck_feats} - else: - return yolo_losses + if self.feat_distill_place == 'backbone_feats': + self.yolo_head.distill_pairs['backbone_feats'] = body_feats + elif self.feat_distill_place == 'neck_feats': + self.yolo_head.distill_pairs['neck_feats'] = neck_feats + else: + raise ValueError + return yolo_losses else: yolo_head_outs = self.yolo_head(neck_feats) if self.post_process is not None: diff --git a/ppdet/modeling/architectures/yolo.py b/ppdet/modeling/architectures/yolo.py index 78c7654913c..237fcfd42de 100644 --- a/ppdet/modeling/architectures/yolo.py +++ b/ppdet/modeling/architectures/yolo.py @@ -22,7 +22,7 @@ __all__ = ['YOLOv3'] # YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3 -# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py +# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py, especially when use distillation or aux head @register diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index adcd05fc6b2..e5232239910 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -651,7 +651,7 @@ def get_loss(self, head_outs, gt_meta): # label assignment if gt_meta['epoch_id'] < self.static_assigner_epoch: - assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( + assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner( anchors, num_anchors_list, gt_labels, @@ -662,7 +662,7 @@ def get_loss(self, head_outs, gt_meta): pred_bboxes=pred_bboxes.detach() * stride_tensor_list) else: - assigned_labels, assigned_bboxes, assigned_scores = self.assigner( + assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner( pred_scores.detach(), pred_bboxes.detach() * stride_tensor_list, centers, diff --git a/ppdet/modeling/heads/ppyoloe_contrast_head.py b/ppdet/modeling/heads/ppyoloe_contrast_head.py index 190c519cfe5..4f80ea9c71f 100644 --- a/ppdet/modeling/heads/ppyoloe_contrast_head.py +++ b/ppdet/modeling/heads/ppyoloe_contrast_head.py @@ -136,7 +136,7 @@ def get_loss(self, head_outs, gt_meta): pad_gt_mask = gt_meta['pad_gt_mask'] # label assignment if gt_meta['epoch_id'] < self.static_assigner_epoch: - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, _ = \ self.static_assigner( anchors, num_anchors_list, @@ -148,7 +148,7 @@ def get_loss(self, head_outs, gt_meta): alpha_l = 0.25 else: if self.sm_use: - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, _ = \ self.assigner( pred_scores.detach(), pred_bboxes.detach() * stride_tensor, @@ -159,7 +159,7 @@ def get_loss(self, head_outs, gt_meta): pad_gt_mask, bg_index=self.num_classes) else: - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, _ = \ self.assigner( pred_scores.detach(), pred_bboxes.detach() * stride_tensor, diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 93855387869..83001082ba6 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -53,7 +53,7 @@ def forward(self, feat, avg_feat): class PPYOLOEHead(nn.Layer): __shared__ = [ 'num_classes', 'eval_size', 'trt', 'exclude_nms', - 'exclude_post_process', 'use_shared_conv' + 'exclude_post_process', 'use_shared_conv', 'for_distill' ] __inject__ = ['static_assigner', 'assigner', 'nms'] @@ -81,7 +81,8 @@ def __init__(self, attn_conv='convbn', exclude_nms=False, exclude_post_process=False, - use_shared_conv=True): + use_shared_conv=True, + for_distill=False): super(PPYOLOEHead, self).__init__() assert len(in_channels) > 0, "len(in_channels) should > 0" self.in_channels = in_channels @@ -110,6 +111,7 @@ def __init__(self, self.exclude_nms = exclude_nms self.exclude_post_process = exclude_post_process self.use_shared_conv = use_shared_conv + self.for_distill = for_distill # stem self.stem_cls = nn.LayerList() @@ -134,7 +136,9 @@ def __init__(self, self.proj_conv = nn.Conv2D(self.reg_channels, 1, 1, bias_attr=False) self.proj_conv.skip_quant = True self._init_weights() - self.distill_pairs = {} + + if self.for_distill: + self.distill_pairs = {} @classmethod def from_config(cls, cfg, input_shape): @@ -322,14 +326,14 @@ def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels, loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos, self.reg_range[0]) * bbox_weight loss_dfl = loss_dfl.sum() / assigned_scores_sum - self.distill_pairs['pred_bboxes_pos'] = pred_bboxes_pos - self.distill_pairs['pred_dist_pos'] = pred_dist_pos - self.distill_pairs['bbox_weight'] = bbox_weight + if self.for_distill: + self.distill_pairs['pred_bboxes_pos'] = pred_bboxes_pos + self.distill_pairs['pred_dist_pos'] = pred_dist_pos + self.distill_pairs['bbox_weight'] = bbox_weight else: loss_l1 = paddle.zeros([1]) loss_iou = paddle.zeros([1]) loss_dfl = pred_dist.sum() * 0. - self.distill_pairs['null_loss'] = pred_dist.sum() * 0. return loss_l1, loss_iou, loss_dfl def get_loss(self, head_outs, gt_meta, aux_pred=None): @@ -445,13 +449,14 @@ def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes, assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.) loss_cls /= assigned_scores_sum - self.distill_pairs['pred_cls_scores'] = pred_scores - self.distill_pairs['pos_num'] = assigned_scores_sum - self.distill_pairs['assigned_scores'] = assigned_scores - self.distill_pairs['mask_positive'] = mask_positive - one_hot_label = F.one_hot(assigned_labels, - self.num_classes + 1)[..., :-1] - self.distill_pairs['target_labels'] = one_hot_label + if self.for_distill: + self.distill_pairs['pred_cls_scores'] = pred_scores + self.distill_pairs['pos_num'] = assigned_scores_sum + self.distill_pairs['assigned_scores'] = assigned_scores + self.distill_pairs['mask_positive'] = mask_positive + one_hot_label = F.one_hot(assigned_labels, + self.num_classes + 1)[..., :-1] + self.distill_pairs['target_labels'] = one_hot_label loss_l1, loss_iou, loss_dfl = \ self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s, diff --git a/ppdet/modeling/heads/ppyoloe_r_head.py b/ppdet/modeling/heads/ppyoloe_r_head.py index aaf21063204..9adbffa09c6 100644 --- a/ppdet/modeling/heads/ppyoloe_r_head.py +++ b/ppdet/modeling/heads/ppyoloe_r_head.py @@ -258,7 +258,7 @@ def get_loss(self, head_outs, gt_meta): pad_gt_mask = gt_meta['pad_gt_mask'] # label assignment if gt_meta['epoch_id'] < self.static_assigner_epoch: - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, _ = \ self.static_assigner( anchor_points, stride_tensor, @@ -271,7 +271,7 @@ def get_loss(self, head_outs, gt_meta): pred_bboxes.detach() ) else: - assigned_labels, assigned_bboxes, assigned_scores = \ + assigned_labels, assigned_bboxes, assigned_scores, _ = \ self.assigner( pred_scores.detach(), pred_bboxes.detach(), diff --git a/ppdet/modeling/heads/tood_head.py b/ppdet/modeling/heads/tood_head.py index f463ef2397b..81b2edd7b72 100644 --- a/ppdet/modeling/heads/tood_head.py +++ b/ppdet/modeling/heads/tood_head.py @@ -293,7 +293,7 @@ def get_loss(self, head_outs, gt_meta): pad_gt_mask = gt_meta['pad_gt_mask'] # label assignment if gt_meta['epoch_id'] < self.static_assigner_epoch: - assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( + assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner( anchors, num_anchors_list, gt_labels, @@ -302,7 +302,7 @@ def get_loss(self, head_outs, gt_meta): bg_index=self.num_classes) alpha_l = 0.25 else: - assigned_labels, assigned_bboxes, assigned_scores = self.assigner( + assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner( pred_scores.detach(), pred_bboxes.detach() * stride_tensor, bbox_center(anchors), diff --git a/ppdet/slim/distill_loss.py b/ppdet/slim/distill_loss.py index f91add8d72f..a0539277fbe 100644 --- a/ppdet/slim/distill_loss.py +++ b/ppdet/slim/distill_loss.py @@ -140,11 +140,7 @@ def knowledge_distillation_kl_div_loss(self, soft_label (Tensor): Target logits with shape (N, N + 1). T (int): Temperature for distillation. detach_target (bool): Remove soft_label from automatic differentiation - - Returns: - torch.Tensor: Loss tensor with shape (N,). """ - assert pred.shape == soft_label.shape target = F.softmax(soft_label / T, axis=1) if detach_target: @@ -217,10 +213,11 @@ def __init__( 'iou': 2.5, 'dfl': 0.5}, feat_distill=True, - feat_distiller='cwd', + feat_distiller='fgd', + feat_distill_place='neck_feats', teacher_width_mult=1.0, # L student_width_mult=0.75, # M - neck_out_channels=[768, 384, 192]): + feat_out_channels=[768, 384, 192]): super(DistillPPYOLOELoss, self).__init__() self.loss_weight_logits = loss_weight['logits'] self.loss_weight_feat = loss_weight['feat'] @@ -234,15 +231,17 @@ def __init__( self.loss_bbox = GIoULoss() if feat_distill and self.loss_weight_feat > 0: - assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd'] - self.distill_feat_loss_modules = [] + assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] + assert feat_distill_place in ['backbone_feats', 'neck_feats'] + self.feat_distill_place = feat_distill_place self.t_channel_list = [ - int(c * teacher_width_mult) for c in neck_out_channels + int(c * teacher_width_mult) for c in feat_out_channels ] self.s_channel_list = [ - int(c * student_width_mult) for c in neck_out_channels + int(c * student_width_mult) for c in feat_out_channels ] - for i in range(len(neck_out_channels)): + self.distill_feat_loss_modules = [] + for i in range(len(feat_out_channels)): if feat_distiller == 'cwd': feat_loss_module = CWDFeatureLoss( student_channels=self.s_channel_list[i], @@ -262,30 +261,28 @@ def __init__( student_channels=self.s_channel_list[i], teacher_channels=self.t_channel_list[i], normalize=True, - resize_stu=False) + resize_stu=True) elif feat_distiller == 'mgd': feat_loss_module = MGDFeatureLoss( student_channels=self.s_channel_list[i], teacher_channels=self.t_channel_list[i], normalize=True, loss_func='ssim') + elif feat_distiller == 'mimic': + feat_loss_module = MimicFeatureLoss( + student_channels=self.s_channel_list[i], + teacher_channels=self.t_channel_list[i], + normalize=True) else: raise ValueError self.distill_feat_loss_modules.append(feat_loss_module) - def bbox_loss(self, s_bbox, t_bbox, weight_targets=None): - # [x,y,w,h] - if weight_targets is not None: - loss_bbox = paddle.sum( - self.loss_bbox(s_bbox, t_bbox) * weight_targets) - avg_factor = weight_targets.sum() - loss_bbox = loss_bbox / avg_factor - else: - loss_bbox = paddle.mean(self.loss_bbox(s_bbox, t_bbox)) - return loss_bbox - - def quality_focal_loss(self, pred_logits, soft_target_logits, beta=2.0, \ - use_sigmoid=True, label_weights=None, num_total_pos=None, pos_mask=None): + def quality_focal_loss(self, + pred_logits, + soft_target_logits, + beta=2.0, + use_sigmoid=False, + num_total_pos=None): if use_sigmoid: func = F.binary_cross_entropy_with_logits soft_target = F.sigmoid(soft_target_logits) @@ -300,42 +297,60 @@ def quality_focal_loss(self, pred_logits, soft_target_logits, beta=2.0, \ scale_factor = pred_sigmoid - soft_target loss = func( preds, soft_target, reduction='none') * scale_factor.abs().pow(beta) - loss = loss - if pos_mask is not None: - loss *= pos_mask - loss = loss.sum(1) - if label_weights is not None: - loss = loss * label_weights + if num_total_pos is not None: loss = loss.sum() / num_total_pos else: loss = loss.mean() return loss - def distribution_focal_loss(self, pred_corners, target_corners, - weight_targets): - target_corners_label = paddle.nn.functional.softmax( - target_corners, axis=-1) - loss_dfl = paddle.nn.functional.cross_entropy( + def bbox_loss(self, s_bbox, t_bbox, weight_targets=None): + # [x,y,w,h] + if weight_targets is not None: + loss = paddle.sum(self.loss_bbox(s_bbox, t_bbox) * weight_targets) + avg_factor = weight_targets.sum() + loss = loss / avg_factor + else: + loss = paddle.mean(self.loss_bbox(s_bbox, t_bbox)) + return loss + + def distribution_focal_loss(self, + pred_corners, + target_corners, + weight_targets=None): + target_corners_label = F.softmax(target_corners, axis=-1) + loss_dfl = F.cross_entropy( pred_corners, target_corners_label, soft_label=True, reduction='none') loss_dfl = loss_dfl.sum(1) + if weight_targets is not None: loss_dfl = loss_dfl * (weight_targets.expand([-1, 4]).reshape([-1])) loss_dfl = loss_dfl.sum(-1) / weight_targets.sum() else: loss_dfl = loss_dfl.mean(-1) - loss_dfl = loss_dfl / 4.0 # 4 direction - return loss_dfl + return loss_dfl / 4.0 # 4 direction def forward(self, teacher_model, student_model): + teacher_distill_pairs = teacher_model.yolo_head.distill_pairs + student_distill_pairs = student_model.yolo_head.distill_pairs if self.logits_distill and self.loss_weight_logits > 0: - teacher_distill_pairs = teacher_model.yolo_head.distill_pairs - student_distill_pairs = student_model.yolo_head.distill_pairs distill_bbox_loss, distill_dfl_loss, distill_cls_loss = [], [], [] + + distill_cls_loss.append( + self.quality_focal_loss( + student_distill_pairs['pred_cls_scores'].reshape( + (-1, student_distill_pairs['pred_cls_scores'].shape[-1] + )), + teacher_distill_pairs['pred_cls_scores'].detach().reshape( + (-1, teacher_distill_pairs['pred_cls_scores'].shape[-1] + )), + num_total_pos=student_distill_pairs['pos_num'], + use_sigmoid=False)) + distill_bbox_loss.append( self.bbox_loss(student_distill_pairs['pred_bboxes_pos'], teacher_distill_pairs['pred_bboxes_pos'].detach(), @@ -343,48 +358,40 @@ def forward(self, teacher_model, student_model): ) if 'pred_bboxes_pos' in student_distill_pairs and \ 'pred_bboxes_pos' in teacher_distill_pairs and \ 'bbox_weight' in student_distill_pairs - else student_distill_pairs['null_loss'] - ) - distill_dfl_loss.append(self.distribution_focal_loss( + else paddle.zeros([1])) + + distill_dfl_loss.append( + self.distribution_focal_loss( student_distill_pairs['pred_dist_pos'].reshape((-1, student_distill_pairs['pred_dist_pos'].shape[-1])), teacher_distill_pairs['pred_dist_pos'].detach().reshape((-1, teacher_distill_pairs['pred_dist_pos'].shape[-1])), \ weight_targets=student_distill_pairs['bbox_weight'] ) if 'pred_dist_pos' in student_distill_pairs and \ 'pred_dist_pos' in teacher_distill_pairs and \ 'bbox_weight' in student_distill_pairs - else student_distill_pairs['null_loss'] - ) - distill_cls_loss.append( - self.quality_focal_loss( - student_distill_pairs['pred_cls_scores'].reshape( - (-1, student_distill_pairs['pred_cls_scores'].shape[-1] - )), - teacher_distill_pairs['pred_cls_scores'].detach().reshape( - (-1, teacher_distill_pairs['pred_cls_scores'].shape[-1] - )), - num_total_pos=student_distill_pairs['pos_num'], - use_sigmoid=False)) - distill_bbox_loss = paddle.add_n(distill_bbox_loss) + else paddle.zeros([1])) + distill_cls_loss = paddle.add_n(distill_cls_loss) + distill_bbox_loss = paddle.add_n(distill_bbox_loss) distill_dfl_loss = paddle.add_n(distill_dfl_loss) logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight else: - logits_loss = paddle.to_tensor([0]) + logits_loss = paddle.zeros([1]) if self.feat_distill and self.loss_weight_feat > 0: feat_loss_list = [] inputs = student_model.inputs - teacher_fpn_feats = teacher_distill_pairs['emb_feats'] - student_fpn_feats = student_distill_pairs['emb_feats'] assert 'gt_bbox' in inputs + assert self.feat_distill_place in student_distill_pairs + assert self.feat_distill_place in teacher_distill_pairs + stu_feats = student_distill_pairs[self.feat_distill_place] + tea_feats = teacher_distill_pairs[self.feat_distill_place] for i, loss_module in enumerate(self.distill_feat_loss_modules): feat_loss_list.append( - loss_module(student_fpn_feats[i], teacher_fpn_feats[i], - inputs)) + loss_module(stu_feats[i], tea_feats[i], inputs)) feat_loss = paddle.add_n(feat_loss_list) else: - feat_loss = paddle.to_tensor([0]) + feat_loss = paddle.zeros([1]) student_model.yolo_head.distill_pairs.clear() teacher_model.yolo_head.distill_pairs.clear() @@ -714,15 +721,44 @@ def __init__(self, self.loss_weight = loss_weight self.resize_stu = resize_stu - kaiming_init = parameter_init("kaiming") + def forward(self, stu_feature, tea_feature, inputs): + size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:] + if size_s[0] != size_t[0]: + if self.resize_stu: + stu_feature = F.interpolate( + stu_feature, size_t, mode='bilinear') + else: + tea_feature = F.interpolate( + tea_feature, size_s, mode='bilinear') + assert stu_feature.shape == tea_feature.shape + + if self.normalize: + stu_feature = feature_norm(stu_feature) + tea_feature = feature_norm(tea_feature) + + loss = F.mse_loss(stu_feature, tea_feature) / 2 + return loss * self.loss_weight + + +@register +class MimicFeatureLoss(nn.Layer): + def __init__(self, + student_channels=256, + teacher_channels=256, + normalize=True, + loss_weight=1.0): + super(MimicFeatureLoss, self).__init__() + self.normalize = normalize + self.loss_weight = loss_weight + self.mse_loss = nn.MSELoss() + if student_channels != teacher_channels: self.align = nn.Conv2D( student_channels, teacher_channels, kernel_size=1, stride=1, - padding=0, - weight_attr=kaiming_init) + padding=0) else: self.align = None @@ -730,26 +766,11 @@ def forward(self, stu_feature, tea_feature, inputs): if self.align is not None: stu_feature = self.align(stu_feature) - loss = 0. - size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:] - if size_s[0] != size_t[0]: - if self.resize_stu: - stu_feature = F.interpolate( - stu_feature, size_t, mode='bilinear') - else: - tea_feature = F.interpolate( - tea_feature, size_s, mode='bilinear') - assert stu_feature.shape == tea_feature.shape - if self.normalize: - norm_stu_feature = feature_norm(stu_feature) - norm_tea_feature = feature_norm(tea_feature) - - # First conduct feature normalization and then calculate the - # MSE loss. Methematically, it is equivalent to firstly calculate - # the Pearson Correlation Coefficient (r) between two feature - # vectors, and then use 1-r as the new feature imitation loss. - loss += F.mse_loss(norm_stu_feature, norm_tea_feature) / 2 + stu_feature = feature_norm(stu_feature) + tea_feature = feature_norm(tea_feature) + + loss = self.mse_loss(stu_feature, tea_feature) return loss * self.loss_weight diff --git a/ppdet/slim/distill_model.py b/ppdet/slim/distill_model.py index f6658acbbb9..85d7be4233d 100644 --- a/ppdet/slim/distill_model.py +++ b/ppdet/slim/distill_model.py @@ -338,15 +338,13 @@ def forward(self, inputs, alpha=0.125): delattr(self.teacher_model.yolo_head, "assigned_bboxes") delattr(self.teacher_model.yolo_head, "assigned_scores") delattr(self.teacher_model.yolo_head, "mask_positive") - student_out = self.student_model(inputs) + student_loss = self.student_model(inputs) with paddle.no_grad(): - teacher_out = self.teacher_model(inputs) + teacher_loss = self.teacher_model(inputs) logits_loss, feat_loss = self.distill_loss(self.teacher_model, self.student_model) - student_loss = student_out['det_losses'] det_total_loss = student_loss['loss'] - total_loss = alpha * (det_total_loss + logits_loss + feat_loss) student_loss['loss'] = total_loss student_loss['det_loss'] = det_total_loss From 6413df71cb95ce69cea9641eeafca02a96b8e499 Mon Sep 17 00:00:00 2001 From: nemonameless Date: Thu, 2 Feb 2023 11:44:05 +0000 Subject: [PATCH 4/5] clean codes --- .../ppyoloe_plus_crn_l_80e_coco_distill.yml | 20 ------------------- .../ppyoloe_plus_crn_m_80e_coco_distill.yml | 20 ------------------- .../ppyoloe_plus_crn_s_80e_coco_distill.yml | 20 ------------------- ppdet/modeling/architectures/ppyoloe.py | 1 - ppdet/slim/distill_model.py | 1 - 5 files changed, 62 deletions(-) diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml index edd2e744a53..a75e5857bad 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml @@ -8,26 +8,6 @@ PPYOLOE: neck: CustomCSPPAN yolo_head: PPYOLOEHead post_process: ~ - #for_distill: True - -epoch: 80 -LearningRate: - base_lr: 0.001 - schedulers: - - !CosineDecay - max_epochs: 96 - - !LinearWarmup - start_factor: 0. - epochs: 5 - -OptimizerBuilder: - optimizer: - momentum: 0.9 - type: Momentum - regularizer: - factor: 0.0005 - type: L2 - clip_grad_by_norm: 30. log_iter: 100 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml index ff03c24ee76..5838110fe3a 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml @@ -8,26 +8,6 @@ PPYOLOE: neck: CustomCSPPAN yolo_head: PPYOLOEHead post_process: ~ - #for_distill: True - -epoch: 80 -LearningRate: - base_lr: 0.001 - schedulers: - - !CosineDecay - max_epochs: 96 - - !LinearWarmup - start_factor: 0. - epochs: 5 - -OptimizerBuilder: - optimizer: - momentum: 0.9 - type: Momentum - regularizer: - factor: 0.0005 - type: L2 - clip_grad_by_norm: 30. log_iter: 100 diff --git a/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco_distill.yml b/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco_distill.yml index 6aec0e41fa7..45d281378e5 100644 --- a/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco_distill.yml +++ b/configs/ppyoloe/distill/ppyoloe_plus_crn_s_80e_coco_distill.yml @@ -8,26 +8,6 @@ PPYOLOE: neck: CustomCSPPAN yolo_head: PPYOLOEHead post_process: ~ - #for_distill: True - -epoch: 80 -LearningRate: - base_lr: 0.001 - schedulers: - - !CosineDecay - max_epochs: 96 - - !LinearWarmup - start_factor: 0. - epochs: 5 - -OptimizerBuilder: - optimizer: - momentum: 0.9 - type: Momentum - regularizer: - factor: 0.0005 - type: L2 - clip_grad_by_norm: 30. log_iter: 100 diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py index f7f32a0dab6..8d9edeb2cb2 100644 --- a/ppdet/modeling/architectures/ppyoloe.py +++ b/ppdet/modeling/architectures/ppyoloe.py @@ -20,7 +20,6 @@ import copy from ppdet.core.workspace import register, create from .meta_arch import BaseArch -from IPython import embed __all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead'] # PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head diff --git a/ppdet/slim/distill_model.py b/ppdet/slim/distill_model.py index 85d7be4233d..6ca085c437a 100644 --- a/ppdet/slim/distill_model.py +++ b/ppdet/slim/distill_model.py @@ -20,7 +20,6 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr -from IPython import embed from ppdet.core.workspace import register, create, load_config from ppdet.utils.checkpoint import load_pretrain_weight From 967cf9b069dc826ac327f52a28b5f01235b3fd9c Mon Sep 17 00:00:00 2001 From: nemonameless Date: Thu, 2 Feb 2023 13:16:10 +0000 Subject: [PATCH 5/5] merge cam, fix export --- ppdet/modeling/architectures/ppyoloe.py | 24 ++++++++++++++++++------ ppdet/modeling/heads/ppyoloe_head.py | 8 +++++--- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py index 8d9edeb2cb2..e646c30dbe3 100644 --- a/ppdet/modeling/architectures/ppyoloe.py +++ b/ppdet/modeling/architectures/ppyoloe.py @@ -97,15 +97,21 @@ def _forward(self): raise ValueError return yolo_losses else: + cam_data = {} # record bbox scores and index before nms yolo_head_outs = self.yolo_head(neck_feats) + cam_data['scores'] = yolo_head_outs[0] + if self.post_process is not None: - bbox, bbox_num = self.post_process( + bbox, bbox_num, before_nms_indexes = self.post_process( yolo_head_outs, self.yolo_head.mask_anchors, self.inputs['im_shape'], self.inputs['scale_factor']) + cam_data['before_nms_indexes'] = before_nms_indexes else: - bbox, bbox_num = self.yolo_head.post_process( + bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process( yolo_head_outs, self.inputs['scale_factor']) - output = {'bbox': bbox, 'bbox_num': bbox_num} + # data for cam + cam_data['before_nms_indexes'] = before_nms_indexes + output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data} return output @@ -195,15 +201,21 @@ def _forward(self): aux_pred=[aux_cls_scores, aux_bbox_preds]) return loss else: + cam_data = {} # record bbox scores and index before nms yolo_head_outs = self.yolo_head(neck_feats) + cam_data['scores'] = yolo_head_outs[0] + if self.post_process is not None: - bbox, bbox_num = self.post_process( + bbox, bbox_num, before_nms_indexes = self.post_process( yolo_head_outs, self.yolo_head.mask_anchors, self.inputs['im_shape'], self.inputs['scale_factor']) + cam_data['before_nms_indexes'] = before_nms_indexes else: - bbox, bbox_num = self.yolo_head.post_process( + bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process( yolo_head_outs, self.inputs['scale_factor']) - output = {'bbox': bbox, 'bbox_num': bbox_num} + # data for cam + cam_data['before_nms_indexes'] = before_nms_indexes + output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data} return output diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 6c8c6fe8a78..38d4d541545 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -480,7 +480,8 @@ def post_process(self, head_outs, scale_factor): pred_bboxes *= stride_tensor if self.exclude_post_process: return paddle.concat( - [pred_bboxes, pred_scores.transpose([0, 2, 1])], axis=-1), None + [pred_bboxes, pred_scores.transpose([0, 2, 1])], + axis=-1), None, None else: # scale bbox to origin scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1) @@ -490,9 +491,10 @@ def post_process(self, head_outs, scale_factor): pred_bboxes /= scale_factor if self.exclude_nms: # `exclude_nms=True` just use in benchmark - return pred_bboxes, pred_scores + return pred_bboxes, pred_scores, None else: - bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, pred_scores) + bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, + pred_scores) return bbox_pred, bbox_num, before_nms_indexes