Skip to content

Commit

Permalink
[Docs] Add files for KD algo DFND
Browse files Browse the repository at this point in the history
  • Loading branch information
Cbtor committed Sep 21, 2023
1 parent 90c7af1 commit 9e1b1f5
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 4 deletions.
31 changes: 31 additions & 0 deletions configs/distill/mmcls/dfnd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Learning Student Networks in the Wild (DFND)

> [Learning Student Networks in the Wild](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf)
<!-- [ALGORITHM] -->

## Abstract

Data-free learning for student networks is a new paradigm for solving users’ anxiety caused by the privacy problem of using original training data. Since the architectures of modern convolutional neural networks (CNNs) are compact and sophisticated, the alternative images or meta-data generated from the teacher network are often broken. Thus, the student network cannot achieve the comparable performance to that of the pre-trained teacher network especially on the large-scale image dataset. Different to previous works, we present to maximally utilize the massive available unlabeled data in the wild. Specifically, we first thoroughly analyze the output differences between teacher and student network on the original data and develop a data collection method. Then, a noisy knowledge distillation algorithm is proposed for achieving the performance of the student network. In practice, an adaptation matrix is learned with the student network for correcting the label noise produced by the teacher network on the collected unlabeled images. The effectiveness of our DFND (DataFree Noisy Distillation) method is then verified on several benchmarks to demonstrate its superiority over state-of-theart data-free distillation methods. Experiments on various datasets demonstrate that the student networks learned by the proposed method can achieve comparable performance with those using the original dataset.

<img width="910" alt="pipeline" src="./dfnd.PNG">

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 94.78 | 95.34 | 94.82 | [config](./dfnd_logits_resnet34_resnet18_8xb32_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) | [model](https://drive.google.com/file/d/1_MekfTkCsEl68meWPqtdNZIxdJO2R2Eb/view?usp=drive_link) |

## Citation

```latex
@inproceedings{chen2021learning,
title={Learning student networks in the wild},
author={Chen, Hanting and Guo, Tianyu and Xu, Chang and Li, Wenshuo and Xu, Chunjing and Xu, Chao and Wang, Yunhe},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={6428--6437},
year={2021}
}
```
Binary file added configs/distill/mmcls/dfnd/dfnd.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
_base_ = [
'mmcls::_base_/default_runtime.py'
]

# optimizer
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))
# learning policy
param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[320, 640], gamma=0.1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=800, val_interval=1)
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=128)


train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=32),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]

train_dataloader = dict(
batch_size=256,
num_workers=5,
dataset=dict(
type='ImageNet',
data_root='/cache/data/imagenet/',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)


test_pipeline = [
dict(type='PackClsInputs'),
]


val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type='CIFAR10',
data_prefix='/cache/data/cifar',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator


teacher_ckpt = '/cache/models/resnet_model.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
type='DFNDDistill',
calculate_student_loss=False,
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
val_data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# convert image from BGR to RGB
bgr_to_rgb=False),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='DFNDLoss', tau=4, loss_weight=1,
num_classes=10, batch_select=0.5)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(type='mmrazor.DFNDValLoop')
5 changes: 3 additions & 2 deletions mmrazor/engine/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim_greedy_search_loop import AutoSlimGreedySearchLoop
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
from .distill_val_loop import (DFNDValLoop, SelfDistillValLoop,
SingleTeacherDistillValLoop)
from .evolution_search_loop import EvolutionSearchLoop
from .iteprune_val_loop import ItePruneValLoop
from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop,
Expand All @@ -15,5 +16,5 @@
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop',
'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop',
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop'
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop', 'DFNDValLoop'
]
35 changes: 35 additions & 0 deletions mmrazor/engine/runner/distill_val_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,38 @@ def run(self):

self.runner.call_hook('after_val_epoch', metrics=student_metrics)
self.runner.call_hook('after_val')


@LOOPS.register_module()
class DFNDValLoop(SingleTeacherDistillValLoop):
"""Validation loop for DFND. DFND requires different dataset for training
and validation.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""

def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
super().__init__(runner, dataloader, evaluator, fp16)
if self.runner.distributed:
assert hasattr(self.runner.model.module, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.module.val_data_preprocessor
self.teacher = self.runner.model.module.teacher
self.teacher.data_preprocessor = data_preprocessor

else:
assert hasattr(self.runner.model, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.val_data_preprocessor
self.teacher = self.runner.model.teacher
self.teacher.data_preprocessor = data_preprocessor
3 changes: 2 additions & 1 deletion mmrazor/models/algorithms/distill/configurable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .datafree_distillation import (DAFLDataFreeDistillation,
DataFreeDistillation)
from .dfnd_distill import DFNDDistill
from .fpn_teacher_distill import FpnTeacherDistill
from .overhaul_feature_distillation import OverhaulFeatureDistillation
from .self_distill import SelfDistill
Expand All @@ -9,5 +10,5 @@
__all__ = [
'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
'DataFreeDistillation', 'DAFLDataFreeDistillation',
'OverhaulFeatureDistillation'
'OverhaulFeatureDistillation', 'DFNDDistill'
]
Loading

0 comments on commit 9e1b1f5

Please sign in to comment.