diff --git a/federatedscope/contrib/trainer/local_entropy.py b/federatedscope/contrib/trainer/local_entropy.py index 57a75eab8..7680d4b14 100644 --- a/federatedscope/contrib/trainer/local_entropy.py +++ b/federatedscope/contrib/trainer/local_entropy.py @@ -17,8 +17,8 @@ def copy_params(src): def prox_term(cur, last): loss = .0 - for name, tensor in last.items(): - loss += 0.5 * torch.sum((cur[name] - tensor)**2) + for name, w in cur.named_parameters(): + loss += 0.5 * torch.sum((w - last[name])**2) return loss @@ -46,6 +46,7 @@ def __init__(self, model, data, device, **kwargs): self.config = kwargs['config'] self.optim_config = self.config.train.optimizer self.local_entropy_config = self.config.trainer.local_entropy + self._thermal = self.local_entropy_config.gamma def train(self): # Criterion & Optimizer @@ -72,21 +73,21 @@ def train(self): def run_epoch(self, optimizer, criterion, current_global_model, mu): running_loss = 0.0 num_samples = 0 - thermal = self.local_entropy_config.gamma # for inputs, targets in self.trainloader: for inputs, targets in self.data['train']: inputs = inputs.to(self.device) targets = targets.to(self.device) # Descent Step + optimizer.zero_grad() outputs = self.model(inputs) ce_loss = criterion(outputs, targets) - loss = ce_loss + thermal * prox_term(self.model.state_dict(), - current_global_model) + loss = ce_loss + self._thermal * prox_term(self.model, + current_global_model) loss.backward() optimizer.step() - # add noise for langevine dynamics + # add noise for langevin dynamics add_noise( self.model, math.sqrt(self.optim_config.lr) * @@ -100,7 +101,7 @@ def run_epoch(self, optimizer, criterion, current_global_model, mu): running_loss += targets.shape[0] * ce_loss.item() num_samples += targets.shape[0] - thermal *= 1.001 + self._thermal *= self.local_entropy_config.inc_factor return num_samples, running_loss diff --git a/federatedscope/core/aggregators/fedopt_aggregator.py b/federatedscope/core/aggregators/fedopt_aggregator.py index e6e63cab5..6ca64c7e0 100644 --- a/federatedscope/core/aggregators/fedopt_aggregator.py +++ b/federatedscope/core/aggregators/fedopt_aggregator.py @@ -13,6 +13,15 @@ def __init__(self, config, model, device='cpu'): super(FedOptAggregator, self).__init__(model, device, config) self.optimizer = get_optimizer(model=self.model, **config.fedopt.optimizer) + if config.fedopt.annealing: + self._annealing = True + # TODO: generic scheduler construction + self.scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, + step_size=config.fedopt.annealing_step_size, + gamma=config.fedopt.annealing_gamma) + else: + self._annealing = False def aggregate(self, agg_info): """ @@ -29,5 +38,7 @@ def aggregate(self, agg_info): if key in new_model.keys(): p.grad = grads[key] self.optimizer.step() + if self._annealing: + self.scheduler.step() return self.model.state_dict() diff --git a/federatedscope/core/configs/cfg_fl_algo.py b/federatedscope/core/configs/cfg_fl_algo.py index e1df1d335..7a4f9e7bf 100644 --- a/federatedscope/core/configs/cfg_fl_algo.py +++ b/federatedscope/core/configs/cfg_fl_algo.py @@ -16,6 +16,9 @@ def extend_fl_algo_cfg(cfg): 'SGD', description="optimizer type for FedOPT") cfg.fedopt.optimizer.lr = Argument( 0.01, description="learning rate for FedOPT optimizer") + cfg.fedopt.annealing = False + cfg.fedopt.annealing_step_size = 2000 + cfg.fedopt.annealing_gamma = 0.5 # ---------------------------------------------------------------------- # # fedprox related options, a general fl algorithm diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index b38db67a6..c896d29d1 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -16,8 +16,9 @@ def extend_training_cfg(cfg): cfg.trainer.sam.eta = .0 cfg.trainer.local_entropy = CN() - cfg.trainer.local_entropy.gamma = 1e-4 - cfg.trainer.local_entropy.eps = 1e-3 + cfg.trainer.local_entropy.gamma = 0.03 + cfg.trainer.local_entropy.inc_factor = 1.0 + cfg.trainer.local_entropy.eps = 1e-4 cfg.trainer.local_entropy.alpha = 0.75 # atc (TODO: merge later) diff --git a/scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh b/scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh deleted file mode 100644 index 0c8dbbcec..000000000 --- a/scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh +++ /dev/null @@ -1,10 +0,0 @@ -set -e - -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 0 1e-4 1e-4 0.1 >/dev/null 2>/dev/null & -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 1 1e-4 1e-4 1.0 >/dev/null 2>/dev/null & -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 2 1e-4 1e-3 0.1 >/dev/null 2>/dev/null & -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 3 1e-4 1e-3 1.0 >/dev/null 2>/dev/null & -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 4 1e-3 1e-4 0.1 >/dev/null 2>/dev/null & -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 5 1e-3 1e-4 1.0 >/dev/null 2>/dev/null & -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 6 1e-3 1e-3 0.1 >/dev/null 2>/dev/null & -bash scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh 0.05 7 1e-3 1e-3 1.0 >/dev/null 2>/dev/null & diff --git a/scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh b/scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh deleted file mode 100644 index 7498a3097..000000000 --- a/scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh +++ /dev/null @@ -1,18 +0,0 @@ -set -e - -lda_alpha=$1 -cudaid=$2 -gamma=$3 -eps=$4 -lr=$5 - -echo $lda_alpha -echo $cudaid -echo $gamma -echo $eps -echo $lr - -for (( i=0; i<5; i++ )) -do - python federatedscope/main.py --cfg scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml seed $i device $cudaid data.splitter_args "[{'alpha': ${lda_alpha}}]" trainer.local_entropy.gamma $gamma fedopt.optimizer.lr $gamma trainer.local_entropy.eps $eps train.optimizer.lr $lr expname fedentsgd_${lda_alpha}_${gamma}_${eps}_${lr}_${i} -done diff --git a/scripts/fedsam_exp_scripts/run_on_cifar10.sh b/scripts/fedsam_exp_scripts/run_on_cifar10.sh deleted file mode 100644 index 53c2f5d37..000000000 --- a/scripts/fedsam_exp_scripts/run_on_cifar10.sh +++ /dev/null @@ -1,10 +0,0 @@ -set -e - -algo=$1 -alpha=$2 -cudaid=$3 - -for (( i=0; i<5; i++ )) -do - python federatedscope/main.py --cfg scripts/fedsam_exp_scripts/${algo}_on_cifar10.yaml seed $i device $cudaid data.splitter_args "[{'alpha': ${alpha}}]" expname ${algo}_${alpha}_${i} -done diff --git a/scripts/fedsam_exp_scripts/fedasam_on_cifar10.yaml b/scripts/wide_valley_exp_scripts/fedasam_on_cifar10.yaml similarity index 96% rename from scripts/fedsam_exp_scripts/fedasam_on_cifar10.yaml rename to scripts/wide_valley_exp_scripts/fedasam_on_cifar10.yaml index dc94392c3..792fe4a41 100644 --- a/scripts/fedsam_exp_scripts/fedasam_on_cifar10.yaml +++ b/scripts/wide_valley_exp_scripts/fedasam_on_cifar10.yaml @@ -42,5 +42,5 @@ train: eval: freq: 100 metrics: ['acc', 'correct'] - best_res_update_round_wise_key: test_loss + best_res_update_round_wise_key: test_acc count_flops: False diff --git a/scripts/fedsam_exp_scripts/fedavg_on_cifar10.yaml b/scripts/wide_valley_exp_scripts/fedavg_on_cifar10.yaml similarity index 95% rename from scripts/fedsam_exp_scripts/fedavg_on_cifar10.yaml rename to scripts/wide_valley_exp_scripts/fedavg_on_cifar10.yaml index 37c51022f..eb5705051 100644 --- a/scripts/fedsam_exp_scripts/fedavg_on_cifar10.yaml +++ b/scripts/wide_valley_exp_scripts/fedavg_on_cifar10.yaml @@ -36,5 +36,5 @@ train: eval: freq: 100 metrics: ['acc', 'correct'] - best_res_update_round_wise_key: test_loss + best_res_update_round_wise_key: test_acc count_flops: False diff --git a/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml b/scripts/wide_valley_exp_scripts/fedentsgd_on_cifar10.yaml similarity index 75% rename from scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml rename to scripts/wide_valley_exp_scripts/fedentsgd_on_cifar10.yaml index 9570bffd4..2d94ef1e8 100644 --- a/scripts/fedsam_exp_scripts/fedentsgd_on_cifar10.yaml +++ b/scripts/wide_valley_exp_scripts/fedentsgd_on_cifar10.yaml @@ -1,4 +1,5 @@ use_gpu: True +outdir: exp device: 0 early_stop: patience: 0 @@ -12,9 +13,10 @@ federate: fedopt: use: True optimizer: - lr: 0.0001 + lr: 1.0 weight_decay: 0.0 momentum: 0.0 + annealing: True data: root: data/ type: 'CIFAR10@torchvision' @@ -36,8 +38,9 @@ criterion: trainer: type: local_entropy_trainer local_entropy: - gamma: 0.0001 - eps: 0.001 + gamma: 0.03 + inc_factor: 1.0001 + eps: 0.0001 alpha: 0.75 train: batch_or_epoch: 'epoch' @@ -48,5 +51,14 @@ train: eval: freq: 100 metrics: ['acc', 'correct'] - best_res_update_round_wise_key: test_loss + best_res_update_round_wise_key: test_acc count_flops: False +hpo: + scheduler: bo_gp + num_workers: 0 + ss: 'scripts/wide_valley_exp_scripts/search_space_for_fedentsgd.yaml' + sha: + budgets: [10000, 10000] + iter: 400 + metric: server_global_eval.test_acc + working_folder: bo_gp_fedentsgd diff --git a/scripts/fedsam_exp_scripts/fedsam_on_cifar10.yaml b/scripts/wide_valley_exp_scripts/fedsam_on_cifar10.yaml similarity index 96% rename from scripts/fedsam_exp_scripts/fedsam_on_cifar10.yaml rename to scripts/wide_valley_exp_scripts/fedsam_on_cifar10.yaml index f4975ffac..d31eb7ff4 100644 --- a/scripts/fedsam_exp_scripts/fedsam_on_cifar10.yaml +++ b/scripts/wide_valley_exp_scripts/fedsam_on_cifar10.yaml @@ -42,5 +42,5 @@ train: eval: freq: 100 metrics: ['acc', 'correct'] - best_res_update_round_wise_key: test_loss + best_res_update_round_wise_key: test_acc count_flops: False diff --git a/scripts/wide_valley_exp_scripts/hpo_for_fedentsgd.sh b/scripts/wide_valley_exp_scripts/hpo_for_fedentsgd.sh new file mode 100644 index 000000000..a1f7e9293 --- /dev/null +++ b/scripts/wide_valley_exp_scripts/hpo_for_fedentsgd.sh @@ -0,0 +1,9 @@ +set -e + +alpha=$1 +device=$2 + +echo $alpha +echo $device + +CUDA_VISIBLE_DEVICES="${device}" python federatedscope/hpo.py --cfg scripts/wide_valley_exp_scripts/fedentsgd_on_cifar10.yaml hpo.working_folder bo_gp_fedentsgd_${device} outdir bo_gp_fedentsgd_${device} >/dev/null 2>/dev/null diff --git a/scripts/wide_valley_exp_scripts/run_fedentsgd_on_cifar10.sh b/scripts/wide_valley_exp_scripts/run_fedentsgd_on_cifar10.sh new file mode 100644 index 000000000..906054664 --- /dev/null +++ b/scripts/wide_valley_exp_scripts/run_fedentsgd_on_cifar10.sh @@ -0,0 +1,23 @@ +set -e + +lda_alpha=$1 +cudaid=$2 +gamma=$3 +lr=$4 +eps=$5 +alpha=$6 +annealing=$7 + + +echo $lda_alpha +echo $cudaid +echo $gamma +echo $lr +echo $eps +echo $alpha +echo $annealing + +for (( i=0; i<5; i++ )) +do + CUDA_VISIBLE_DEVICES="${cudaid}" python federatedscope/main.py --cfg scripts/wide_valley_exp_scripts/fedentsgd_on_cifar10.yaml seed $i data.splitter_args "[{'alpha': ${lda_alpha}}]" trainer.local_entropy.gamma $gamma fedopt.optimizer.lr 1.0 fedopt.annealing $annealing trainer.local_entropy.eps $eps trainer.local_entropy.alpha $alpha train.optimizer.lr $lr expname fedentsgd_${lda_alpha}_${gamma}_${eps}_${annealing}_${i} +done diff --git a/scripts/wide_valley_exp_scripts/run_on_cifar10.sh b/scripts/wide_valley_exp_scripts/run_on_cifar10.sh new file mode 100644 index 000000000..efbb8c1c7 --- /dev/null +++ b/scripts/wide_valley_exp_scripts/run_on_cifar10.sh @@ -0,0 +1,9 @@ +set -e + +algo=$1 +alpha=$2 + +for (( i=0; i<5; i++ )) +do + CUDA_VISIBLE_DEVICES="${i}" python federatedscope/main.py --cfg scripts/wide_valley_exp_scripts/${algo}_on_cifar10.yaml seed $i data.splitter_args "[{'alpha': ${alpha}}]" expname ${algo}_${alpha}_${i} & +done diff --git a/scripts/wide_valley_exp_scripts/search_space_for_fedentsgd.yaml b/scripts/wide_valley_exp_scripts/search_space_for_fedentsgd.yaml new file mode 100644 index 000000000..5b5a17849 --- /dev/null +++ b/scripts/wide_valley_exp_scripts/search_space_for_fedentsgd.yaml @@ -0,0 +1,24 @@ +trainer.local_entropy.gamma: + type: float + lower: 0.0 + upper: 0.1 +trainer.local_entropy.inc_factor: + type: cate + choices: [0.0, 1.0001, 1.001] +trainer.local_entropy.eps: + type: float + lower: 1e-5 + upper: 1e-2 + log: True +trainer.local_entropy.alpha: + type: float + lower: 0.75 + upper: 1.0 +fedopt.optimizer.lr: + type: float + lower: 0.01 + upper: 10.0 + log: True +fedopt.annealing: + type: cate + choices: [False, True]