-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add new code about FedRep and Simple_tuning (#564)
* add new code * updating * updating * re-formmat code * conflict resolve * updating * updating * Update test_femnist_simple_tuning.py * Update test_femnist_simple_tuning.py format change - (rerun-unitest) * Update __init__.py fix the import problem * Update __init__.py * re-formatted code --------- Co-authored-by: yuexiang.xyx <[email protected]> Co-authored-by: Osier-Yi <[email protected]> Co-authored-by: Daoyuan Chen <[email protected]>
- Loading branch information
1 parent
4e74f2a
commit dfe6393
Showing
8 changed files
with
382 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import copy | ||
import torch | ||
import logging | ||
|
||
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer | ||
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer | ||
|
||
from typing import Type | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def wrap_FedRepTrainer( | ||
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: | ||
# ---------------------------------------------------------------------- # | ||
# FedRep method: | ||
# https://arxiv.org/abs/2102.07078 | ||
# First training linear classifier and then feature extractor | ||
# Linear classifier: local_param; feature extractor: global_param | ||
# ---------------------------------------------------------------------- # | ||
init_FedRep_ctx(base_trainer) | ||
|
||
base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_fedrep, | ||
trigger="on_fit_start", | ||
insert_pos=-1) | ||
|
||
base_trainer.register_hook_in_train(new_hook=hook_on_epoch_start_fedrep, | ||
trigger="on_epoch_start", | ||
insert_pos=-1) | ||
|
||
return base_trainer | ||
|
||
|
||
def init_FedRep_ctx(base_trainer): | ||
|
||
ctx = base_trainer.ctx | ||
cfg = base_trainer.cfg | ||
|
||
ctx.epoch_feature = cfg.personalization.epoch_feature | ||
ctx.epoch_linear = cfg.personalization.epoch_linear | ||
|
||
ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear | ||
|
||
ctx.epoch_number = 0 | ||
|
||
ctx.lr_feature = cfg.personalization.lr_feature | ||
ctx.lr_linear = cfg.personalization.lr_linear | ||
ctx.weight_decay = cfg.personalization.weight_decay | ||
|
||
ctx.local_param = cfg.personalization.local_param | ||
|
||
ctx.local_update_param = [] | ||
ctx.global_update_param = [] | ||
|
||
for name, param in ctx.model.named_parameters(): | ||
if name.split(".")[0] in ctx.local_param: | ||
ctx.local_update_param.append(param) | ||
else: | ||
ctx.global_update_param.append(param) | ||
|
||
|
||
def hook_on_fit_start_fedrep(ctx): | ||
|
||
ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear | ||
ctx.epoch_number = 0 | ||
|
||
ctx.optimizer_for_feature = torch.optim.SGD(ctx.global_update_param, | ||
lr=ctx.lr_feature, | ||
momentum=0, | ||
weight_decay=ctx.weight_decay) | ||
ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param, | ||
lr=ctx.lr_linear, | ||
momentum=0, | ||
weight_decay=ctx.weight_decay) | ||
|
||
for name, param in ctx.model.named_parameters(): | ||
|
||
if name.split(".")[0] in ctx.local_param: | ||
param.requires_grad = True | ||
else: | ||
param.requires_grad = False | ||
|
||
ctx.optimizer = ctx.optimizer_for_linear | ||
|
||
|
||
def hook_on_epoch_start_fedrep(ctx): | ||
|
||
ctx.epoch_number += 1 | ||
|
||
if ctx.epoch_number == ctx.epoch_linear + 1: | ||
|
||
for name, param in ctx.model.named_parameters(): | ||
|
||
if name.split(".")[0] in ctx.local_param: | ||
param.requires_grad = False | ||
else: | ||
param.requires_grad = True | ||
|
||
ctx.optimizer = ctx.optimizer_for_feature |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import copy | ||
import torch | ||
import logging | ||
import math | ||
|
||
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer | ||
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer | ||
|
||
from typing import Type | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def wrap_Simple_tuning_Trainer( | ||
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: | ||
# ---------------------------------------------------------------------- # | ||
# Simple_tuning method: | ||
# https://arxiv.org/abs/2302.01677 | ||
# Only tuning the linear classifier and freeze the feature extractor | ||
# the key is to reinitialize the linear classifier | ||
# ---------------------------------------------------------------------- # | ||
init_Simple_tuning_ctx(base_trainer) | ||
|
||
base_trainer.register_hook_in_ft(new_hook=hook_on_fit_start_simple_tuning, | ||
trigger="on_fit_start", | ||
insert_pos=-1) | ||
|
||
return base_trainer | ||
|
||
|
||
def init_Simple_tuning_ctx(base_trainer): | ||
|
||
ctx = base_trainer.ctx | ||
cfg = base_trainer.cfg | ||
|
||
ctx.epoch_linear = cfg.finetune.epoch_linear | ||
|
||
ctx.num_train_epoch = ctx.epoch_linear | ||
|
||
ctx.epoch_number = 0 | ||
|
||
ctx.lr_linear = cfg.finetune.lr_linear | ||
ctx.weight_decay = cfg.finetune.weight_decay | ||
|
||
ctx.local_param = cfg.finetune.local_param | ||
|
||
ctx.local_update_param = [] | ||
|
||
for name, param in ctx.model.named_parameters(): | ||
if name.split(".")[0] in ctx.local_param: | ||
ctx.local_update_param.append(param) | ||
|
||
|
||
def hook_on_fit_start_simple_tuning(ctx): | ||
|
||
ctx.num_train_epoch = ctx.epoch_linear | ||
ctx.epoch_number = 0 | ||
|
||
ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param, | ||
lr=ctx.lr_linear, | ||
momentum=0, | ||
weight_decay=ctx.weight_decay) | ||
|
||
for name, param in ctx.model.named_parameters(): | ||
if name.split(".")[0] in ctx.local_param: | ||
if name.split(".")[1] == 'weight': | ||
stdv = 1. / math.sqrt(param.size(-1)) | ||
param.data.uniform_(-stdv, stdv) | ||
else: | ||
param.data.uniform_(-stdv, stdv) | ||
param.requires_grad = True | ||
else: | ||
param.requires_grad = False | ||
|
||
ctx.optimizer = ctx.optimizer_for_linear |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import unittest | ||
|
||
from federatedscope.core.auxiliaries.data_builder import get_data | ||
from federatedscope.core.auxiliaries.utils import setup_seed | ||
from federatedscope.core.auxiliaries.logging import update_logger | ||
from federatedscope.core.configs.config import global_cfg | ||
from federatedscope.core.auxiliaries.runner_builder import get_runner | ||
from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls | ||
|
||
SAMPLE_CLIENT_NUM = 5 | ||
|
||
|
||
class FedRep_Testing(unittest.TestCase): | ||
def setUp(self): | ||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | ||
|
||
def set_config_Fedrep_femnist(self, cfg): | ||
backup_cfg = cfg.clone() | ||
|
||
import torch | ||
cfg.use_gpu = torch.cuda.is_available() | ||
cfg.eval.freq = 10 | ||
cfg.eval.metrics = ['acc', 'loss_regular'] | ||
|
||
cfg.federate.mode = 'standalone' | ||
cfg.train.local_update_steps = 5 | ||
cfg.federate.total_round_num = 20 | ||
cfg.federate.sample_client_num = SAMPLE_CLIENT_NUM | ||
|
||
cfg.data.root = 'test_data/' | ||
cfg.data.type = 'femnist' | ||
cfg.data.splits = [0.6, 0.2, 0.2] | ||
cfg.data.batch_size = 10 | ||
cfg.data.subsample = 0.05 | ||
cfg.data.transform = [['ToTensor'], | ||
[ | ||
'Normalize', { | ||
'mean': [0.9637], | ||
'std': [0.1592] | ||
} | ||
]] | ||
|
||
cfg.model.type = 'convnet2' | ||
cfg.model.hidden = 2048 | ||
cfg.model.out_channels = 62 | ||
|
||
cfg.train.optimizer.lr = 0.001 | ||
cfg.train.optimizer.weight_decay = 0.0 | ||
cfg.grad.grad_clip = 5.0 | ||
|
||
cfg.criterion.type = 'CrossEntropyLoss' | ||
cfg.trainer.type = 'cvtrainer' | ||
cfg.seed = 123 | ||
cfg.personalization.local_param = ['fc2'] | ||
cfg.personalization.local_update_steps = 2 | ||
cfg.personalization.regular_weight = 0.1 | ||
cfg.personalization.epoch_feature = 2 | ||
cfg.personalization.epoch_linear = 1 | ||
cfg.personalization.lr_feature = 0.1 | ||
cfg.personalization.lr_linear = 0.1 | ||
|
||
return backup_cfg | ||
|
||
def test_femnist_standalone(self): | ||
init_cfg = global_cfg.clone() | ||
backup_cfg = self.set_config_Fedrep_femnist(init_cfg) | ||
setup_seed(init_cfg.seed) | ||
update_logger(init_cfg, True) | ||
|
||
data, modified_cfg = get_data(init_cfg.clone()) | ||
init_cfg.merge_from_other_cfg(modified_cfg) | ||
self.assertIsNotNone(data) | ||
self.assertEqual(init_cfg.federate.sample_client_num, | ||
SAMPLE_CLIENT_NUM) | ||
|
||
Fed_runner = get_runner(data=data, | ||
server_class=get_server_cls(init_cfg), | ||
client_class=get_client_cls(init_cfg), | ||
config=init_cfg.clone()) | ||
self.assertIsNotNone(Fed_runner) | ||
test_best_results = Fed_runner.run() | ||
print(test_best_results) | ||
init_cfg.merge_from_other_cfg(backup_cfg) | ||
self.assertLess( | ||
test_best_results["client_summarized_weighted_avg"]['test_loss'], | ||
1200) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.