Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FedSGLD Exp #520

Merged
merged 6 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions federatedscope/contrib/trainer/local_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def copy_params(src):

def prox_term(cur, last):
loss = .0
for name, tensor in last.items():
loss += 0.5 * torch.sum((cur[name] - tensor)**2)
for name, w in cur.named_parameters():
loss += 0.5 * torch.sum((w - last[name])**2)
return loss


Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(self, model, data, device, **kwargs):
self.config = kwargs['config']
self.optim_config = self.config.train.optimizer
self.local_entropy_config = self.config.trainer.local_entropy
self._thermal = self.local_entropy_config.gamma

def train(self):
# Criterion & Optimizer
Expand All @@ -72,21 +73,21 @@ def train(self):
def run_epoch(self, optimizer, criterion, current_global_model, mu):
running_loss = 0.0
num_samples = 0
thermal = self.local_entropy_config.gamma
# for inputs, targets in self.trainloader:
for inputs, targets in self.data['train']:
inputs = inputs.to(self.device)
targets = targets.to(self.device)

# Descent Step
optimizer.zero_grad()
outputs = self.model(inputs)
ce_loss = criterion(outputs, targets)
loss = ce_loss + thermal * prox_term(self.model.state_dict(),
current_global_model)
loss = ce_loss + self._thermal * prox_term(self.model,
current_global_model)
loss.backward()
optimizer.step()

# add noise for langevine dynamics
# add noise for langevin dynamics
add_noise(
self.model,
math.sqrt(self.optim_config.lr) *
Expand All @@ -100,7 +101,7 @@ def run_epoch(self, optimizer, criterion, current_global_model, mu):
running_loss += targets.shape[0] * ce_loss.item()

num_samples += targets.shape[0]
thermal *= 1.001
self._thermal *= self.local_entropy_config.inc_factor

return num_samples, running_loss

Expand Down
11 changes: 11 additions & 0 deletions federatedscope/core/aggregators/fedopt_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def __init__(self, config, model, device='cpu'):
super(FedOptAggregator, self).__init__(model, device, config)
self.optimizer = get_optimizer(model=self.model,
**config.fedopt.optimizer)
if config.fedopt.annealing:
self._annealing = True
# TODO: generic scheduler construction
self.scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer,
step_size=config.fedopt.annealing_step_size,
gamma=config.fedopt.annealing_gamma)
else:
self._annealing = False

def aggregate(self, agg_info):
"""
Expand All @@ -29,5 +38,7 @@ def aggregate(self, agg_info):
if key in new_model.keys():
p.grad = grads[key]
self.optimizer.step()
if self._annealing:
self.scheduler.step()

return self.model.state_dict()
3 changes: 3 additions & 0 deletions federatedscope/core/configs/cfg_fl_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def extend_fl_algo_cfg(cfg):
'SGD', description="optimizer type for FedOPT")
cfg.fedopt.optimizer.lr = Argument(
0.01, description="learning rate for FedOPT optimizer")
cfg.fedopt.annealing = False
cfg.fedopt.annealing_step_size = 2000
cfg.fedopt.annealing_gamma = 0.5

# ---------------------------------------------------------------------- #
# fedprox related options, a general fl algorithm
Expand Down
5 changes: 3 additions & 2 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def extend_training_cfg(cfg):
cfg.trainer.sam.eta = .0

cfg.trainer.local_entropy = CN()
cfg.trainer.local_entropy.gamma = 1e-4
cfg.trainer.local_entropy.eps = 1e-3
cfg.trainer.local_entropy.gamma = 0.03
cfg.trainer.local_entropy.inc_factor = 1.0
cfg.trainer.local_entropy.eps = 1e-4
cfg.trainer.local_entropy.alpha = 0.75

# atc (TODO: merge later)
Expand Down
10 changes: 0 additions & 10 deletions scripts/fedsam_exp_scripts/hpo_for_fedentsgd.sh

This file was deleted.

18 changes: 0 additions & 18 deletions scripts/fedsam_exp_scripts/run_fedentsgd_on_cifar10.sh

This file was deleted.

10 changes: 0 additions & 10 deletions scripts/fedsam_exp_scripts/run_on_cifar10.sh

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ train:
eval:
freq: 100
metrics: ['acc', 'correct']
best_res_update_round_wise_key: test_loss
best_res_update_round_wise_key: test_acc
count_flops: False
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ train:
eval:
freq: 100
metrics: ['acc', 'correct']
best_res_update_round_wise_key: test_loss
best_res_update_round_wise_key: test_acc
count_flops: False
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use_gpu: True
outdir: exp
device: 0
early_stop:
patience: 0
Expand All @@ -12,9 +13,10 @@ federate:
fedopt:
use: True
optimizer:
lr: 0.0001
lr: 1.0
weight_decay: 0.0
momentum: 0.0
annealing: True
data:
root: data/
type: 'CIFAR10@torchvision'
Expand All @@ -36,8 +38,9 @@ criterion:
trainer:
type: local_entropy_trainer
local_entropy:
gamma: 0.0001
eps: 0.001
gamma: 0.03
inc_factor: 1.0001
eps: 0.0001
alpha: 0.75
train:
batch_or_epoch: 'epoch'
Expand All @@ -48,5 +51,14 @@ train:
eval:
freq: 100
metrics: ['acc', 'correct']
best_res_update_round_wise_key: test_loss
best_res_update_round_wise_key: test_acc
count_flops: False
hpo:
scheduler: bo_gp
num_workers: 0
ss: 'scripts/wide_valley_exp_scripts/search_space_for_fedentsgd.yaml'
sha:
budgets: [10000, 10000]
iter: 400
metric: server_global_eval.test_acc
working_folder: bo_gp_fedentsgd
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ train:
eval:
freq: 100
metrics: ['acc', 'correct']
best_res_update_round_wise_key: test_loss
best_res_update_round_wise_key: test_acc
count_flops: False
9 changes: 9 additions & 0 deletions scripts/wide_valley_exp_scripts/hpo_for_fedentsgd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
set -e

alpha=$1
device=$2

echo $alpha
echo $device

CUDA_VISIBLE_DEVICES="${device}" python federatedscope/hpo.py --cfg scripts/wide_valley_exp_scripts/fedentsgd_on_cifar10.yaml hpo.working_folder bo_gp_fedentsgd_${device} outdir bo_gp_fedentsgd_${device} >/dev/null 2>/dev/null
23 changes: 23 additions & 0 deletions scripts/wide_valley_exp_scripts/run_fedentsgd_on_cifar10.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
set -e

lda_alpha=$1
cudaid=$2
gamma=$3
lr=$4
eps=$5
alpha=$6
annealing=$7


echo $lda_alpha
echo $cudaid
echo $gamma
echo $lr
echo $eps
echo $alpha
echo $annealing

for (( i=0; i<5; i++ ))
do
CUDA_VISIBLE_DEVICES="${cudaid}" python federatedscope/main.py --cfg scripts/wide_valley_exp_scripts/fedentsgd_on_cifar10.yaml seed $i data.splitter_args "[{'alpha': ${lda_alpha}}]" trainer.local_entropy.gamma $gamma fedopt.optimizer.lr 1.0 fedopt.annealing $annealing trainer.local_entropy.eps $eps trainer.local_entropy.alpha $alpha train.optimizer.lr $lr expname fedentsgd_${lda_alpha}_${gamma}_${eps}_${annealing}_${i}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need & ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need as they would run on the same device

done
9 changes: 9 additions & 0 deletions scripts/wide_valley_exp_scripts/run_on_cifar10.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
set -e

algo=$1
alpha=$2

for (( i=0; i<5; i++ ))
do
CUDA_VISIBLE_DEVICES="${i}" python federatedscope/main.py --cfg scripts/wide_valley_exp_scripts/${algo}_on_cifar10.yaml seed $i data.splitter_args "[{'alpha': ${alpha}}]" expname ${algo}_${alpha}_${i} &
done
24 changes: 24 additions & 0 deletions scripts/wide_valley_exp_scripts/search_space_for_fedentsgd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
trainer.local_entropy.gamma:
type: float
lower: 0.0
upper: 0.1
trainer.local_entropy.inc_factor:
type: cate
choices: [0.0, 1.0001, 1.001]
trainer.local_entropy.eps:
type: float
lower: 1e-5
upper: 1e-2
log: True
trainer.local_entropy.alpha:
type: float
lower: 0.75
upper: 1.0
fedopt.optimizer.lr:
type: float
lower: 0.01
upper: 10.0
log: True
fedopt.annealing:
type: cate
choices: [False, True]