Skip to content

Commit

Permalink
add new code about FedRep and Simple_tuning (#564)
Browse files Browse the repository at this point in the history
* add new code

* updating

* updating

* re-formmat code

* conflict resolve

* updating

* updating

* Update test_femnist_simple_tuning.py

* Update test_femnist_simple_tuning.py

format change - (rerun-unitest)

* Update __init__.py

fix the import problem

* Update __init__.py

* re-formatted code

---------

Co-authored-by: yuexiang.xyx <[email protected]>
Co-authored-by: Osier-Yi <[email protected]>
Co-authored-by: Daoyuan Chen <[email protected]>
  • Loading branch information
4 people authored Apr 3, 2023
1 parent 4e74f2a commit dfe6393
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 1 deletion.
9 changes: 9 additions & 0 deletions federatedscope/core/auxiliaries/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ def get_trainer(model=None,
# copy construct style: instance a (class A) -> instance b (class B)
trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer,
base_trainer=trainer)
elif config.federate.method.lower() == "fedrep":
from federatedscope.core.trainers import wrap_FedRepTrainer
# wrap style: instance a (class A) -> instance a (class A)
trainer = wrap_FedRepTrainer(trainer)

# attacker plug-in
if 'backdoor' in config.attack.attack_method:
Expand All @@ -246,4 +250,9 @@ def get_trainer(model=None,
from federatedscope.core.trainers import wrap_fedprox_trainer
trainer = wrap_fedprox_trainer(trainer)

# different fine-tuning
if config.finetune.before_eval and config.finetune.simple_tuning:
from federatedscope.core.trainers import wrap_Simple_tuning_Trainer
trainer = wrap_Simple_tuning_Trainer(trainer)

return trainer
7 changes: 7 additions & 0 deletions federatedscope/core/configs/cfg_fl_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def extend_fl_algo_cfg(cfg):
cfg.personalization.K = 5 # the local approximation steps for pFedMe
cfg.personalization.beta = 1.0 # the average moving parameter for pFedMe

# parameters for FedRep:
cfg.personalization.lr_feature = 0.1 # learning rate: feature extractors
cfg.personalization.lr_linear = 0.1 # learning rate: linear head
cfg.personalization.epoch_feature = 1 # training epoch number
cfg.personalization.epoch_linear = 2 # training epoch number
cfg.personalization.weight_decay = 0.0

# ---------------------------------------------------------------------- #
# FedSage+ related options, gfl
# ---------------------------------------------------------------------- #
Expand Down
7 changes: 7 additions & 0 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def extend_training_cfg(cfg):
cfg.finetune.scheduler.type = ''
cfg.finetune.scheduler.warmup_ratio = 0.0

# simple-tuning
cfg.finetune.simple_tuning = False # use simple tuning, default: False
cfg.finetune.epoch_linear = 10 # training epoch number, default: 10
cfg.finetune.lr_linear = 0.005 # learning rate for training linear head
cfg.finetune.weight_decay = 0.0
cfg.finetune.local_param = [] # tuning parameters list

# ---------------------------------------------------------------------- #
# Gradient related options
# ---------------------------------------------------------------------- #
Expand Down
6 changes: 5 additions & 1 deletion federatedscope/core/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer
from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer
from federatedscope.core.trainers.trainer_FedEM import FedEMTrainer
from federatedscope.core.trainers.trainer_FedRep import wrap_FedRepTrainer
from federatedscope.core.trainers.trainer_simple_tuning import \
wrap_Simple_tuning_Trainer
from federatedscope.core.trainers.context import Context
from federatedscope.core.trainers.trainer_fedprox import wrap_fedprox_trainer
from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_trainer, \
Expand All @@ -16,5 +19,6 @@
'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer',
'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer',
'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server',
'BaseTrainer', 'GeneralTFTrainer'
'wrap_Simple_tuning_Trainer', 'wrap_FedRepTrainer', 'BaseTrainer',
'GeneralTFTrainer'
]
99 changes: 99 additions & 0 deletions federatedscope/core/trainers/trainer_FedRep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import copy
import torch
import logging

from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer

from typing import Type

logger = logging.getLogger(__name__)


def wrap_FedRepTrainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
# ---------------------------------------------------------------------- #
# FedRep method:
# https://arxiv.org/abs/2102.07078
# First training linear classifier and then feature extractor
# Linear classifier: local_param; feature extractor: global_param
# ---------------------------------------------------------------------- #
init_FedRep_ctx(base_trainer)

base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_fedrep,
trigger="on_fit_start",
insert_pos=-1)

base_trainer.register_hook_in_train(new_hook=hook_on_epoch_start_fedrep,
trigger="on_epoch_start",
insert_pos=-1)

return base_trainer


def init_FedRep_ctx(base_trainer):

ctx = base_trainer.ctx
cfg = base_trainer.cfg

ctx.epoch_feature = cfg.personalization.epoch_feature
ctx.epoch_linear = cfg.personalization.epoch_linear

ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear

ctx.epoch_number = 0

ctx.lr_feature = cfg.personalization.lr_feature
ctx.lr_linear = cfg.personalization.lr_linear
ctx.weight_decay = cfg.personalization.weight_decay

ctx.local_param = cfg.personalization.local_param

ctx.local_update_param = []
ctx.global_update_param = []

for name, param in ctx.model.named_parameters():
if name.split(".")[0] in ctx.local_param:
ctx.local_update_param.append(param)
else:
ctx.global_update_param.append(param)


def hook_on_fit_start_fedrep(ctx):

ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear
ctx.epoch_number = 0

ctx.optimizer_for_feature = torch.optim.SGD(ctx.global_update_param,
lr=ctx.lr_feature,
momentum=0,
weight_decay=ctx.weight_decay)
ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param,
lr=ctx.lr_linear,
momentum=0,
weight_decay=ctx.weight_decay)

for name, param in ctx.model.named_parameters():

if name.split(".")[0] in ctx.local_param:
param.requires_grad = True
else:
param.requires_grad = False

ctx.optimizer = ctx.optimizer_for_linear


def hook_on_epoch_start_fedrep(ctx):

ctx.epoch_number += 1

if ctx.epoch_number == ctx.epoch_linear + 1:

for name, param in ctx.model.named_parameters():

if name.split(".")[0] in ctx.local_param:
param.requires_grad = False
else:
param.requires_grad = True

ctx.optimizer = ctx.optimizer_for_feature
75 changes: 75 additions & 0 deletions federatedscope/core/trainers/trainer_simple_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import copy
import torch
import logging
import math

from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer

from typing import Type

logger = logging.getLogger(__name__)


def wrap_Simple_tuning_Trainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
# ---------------------------------------------------------------------- #
# Simple_tuning method:
# https://arxiv.org/abs/2302.01677
# Only tuning the linear classifier and freeze the feature extractor
# the key is to reinitialize the linear classifier
# ---------------------------------------------------------------------- #
init_Simple_tuning_ctx(base_trainer)

base_trainer.register_hook_in_ft(new_hook=hook_on_fit_start_simple_tuning,
trigger="on_fit_start",
insert_pos=-1)

return base_trainer


def init_Simple_tuning_ctx(base_trainer):

ctx = base_trainer.ctx
cfg = base_trainer.cfg

ctx.epoch_linear = cfg.finetune.epoch_linear

ctx.num_train_epoch = ctx.epoch_linear

ctx.epoch_number = 0

ctx.lr_linear = cfg.finetune.lr_linear
ctx.weight_decay = cfg.finetune.weight_decay

ctx.local_param = cfg.finetune.local_param

ctx.local_update_param = []

for name, param in ctx.model.named_parameters():
if name.split(".")[0] in ctx.local_param:
ctx.local_update_param.append(param)


def hook_on_fit_start_simple_tuning(ctx):

ctx.num_train_epoch = ctx.epoch_linear
ctx.epoch_number = 0

ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param,
lr=ctx.lr_linear,
momentum=0,
weight_decay=ctx.weight_decay)

for name, param in ctx.model.named_parameters():
if name.split(".")[0] in ctx.local_param:
if name.split(".")[1] == 'weight':
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
else:
param.data.uniform_(-stdv, stdv)
param.requires_grad = True
else:
param.requires_grad = False

ctx.optimizer = ctx.optimizer_for_linear
91 changes: 91 additions & 0 deletions tests/test_femnist_fedrep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from federatedscope.core.auxiliaries.data_builder import get_data
from federatedscope.core.auxiliaries.utils import setup_seed
from federatedscope.core.auxiliaries.logging import update_logger
from federatedscope.core.configs.config import global_cfg
from federatedscope.core.auxiliaries.runner_builder import get_runner
from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls

SAMPLE_CLIENT_NUM = 5


class FedRep_Testing(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))

def set_config_Fedrep_femnist(self, cfg):
backup_cfg = cfg.clone()

import torch
cfg.use_gpu = torch.cuda.is_available()
cfg.eval.freq = 10
cfg.eval.metrics = ['acc', 'loss_regular']

cfg.federate.mode = 'standalone'
cfg.train.local_update_steps = 5
cfg.federate.total_round_num = 20
cfg.federate.sample_client_num = SAMPLE_CLIENT_NUM

cfg.data.root = 'test_data/'
cfg.data.type = 'femnist'
cfg.data.splits = [0.6, 0.2, 0.2]
cfg.data.batch_size = 10
cfg.data.subsample = 0.05
cfg.data.transform = [['ToTensor'],
[
'Normalize', {
'mean': [0.9637],
'std': [0.1592]
}
]]

cfg.model.type = 'convnet2'
cfg.model.hidden = 2048
cfg.model.out_channels = 62

cfg.train.optimizer.lr = 0.001
cfg.train.optimizer.weight_decay = 0.0
cfg.grad.grad_clip = 5.0

cfg.criterion.type = 'CrossEntropyLoss'
cfg.trainer.type = 'cvtrainer'
cfg.seed = 123
cfg.personalization.local_param = ['fc2']
cfg.personalization.local_update_steps = 2
cfg.personalization.regular_weight = 0.1
cfg.personalization.epoch_feature = 2
cfg.personalization.epoch_linear = 1
cfg.personalization.lr_feature = 0.1
cfg.personalization.lr_linear = 0.1

return backup_cfg

def test_femnist_standalone(self):
init_cfg = global_cfg.clone()
backup_cfg = self.set_config_Fedrep_femnist(init_cfg)
setup_seed(init_cfg.seed)
update_logger(init_cfg, True)

data, modified_cfg = get_data(init_cfg.clone())
init_cfg.merge_from_other_cfg(modified_cfg)
self.assertIsNotNone(data)
self.assertEqual(init_cfg.federate.sample_client_num,
SAMPLE_CLIENT_NUM)

Fed_runner = get_runner(data=data,
server_class=get_server_cls(init_cfg),
client_class=get_client_cls(init_cfg),
config=init_cfg.clone())
self.assertIsNotNone(Fed_runner)
test_best_results = Fed_runner.run()
print(test_best_results)
init_cfg.merge_from_other_cfg(backup_cfg)
self.assertLess(
test_best_results["client_summarized_weighted_avg"]['test_loss'],
1200)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit dfe6393

Please sign in to comment.