forked from google/fedjax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fed_avg.py
103 lines (86 loc) · 4.03 KB
/
fed_avg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic federated averaging implementation using fedjax.experimental.
This is the basic implementation in that sense that we write out the native
Python for loops over clients by hand. While this is much more straightforward
to implement, it's not the fastest. This is a good starting point for
implementing custom algorithms since it tends to follow pseudocode closely.
Based on the paper:
Communication-Efficient Learning of Deep Networks from Decentralized Data
H. Brendan McMahan, Eider Moore, Daniel Ramage,
Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017.
https://arxiv.org/abs/1602.05629
"""
from typing import Any, Callable, Mapping, Sequence, Tuple
import fedjax
import jax
ClientId = bytes
Grads = fedjax.Params
@fedjax.dataclass
class ServerState:
"""State of server passed between rounds.
Attributes:
params: A pytree representing the server model parameters.
opt_state: A pytree representing the server optimizer state.
"""
params: fedjax.Params
opt_state: fedjax.OptState
def federated_averaging(
grad_fn: Callable[[fedjax.Params, fedjax.BatchExample, fedjax.PRNGKey],
Grads],
client_optimizer: fedjax.Optimizer,
server_optimizer: fedjax.Optimizer,
client_batch_hparams: fedjax.ShuffleRepeatBatchHParams
) -> fedjax.FederatedAlgorithm:
"""Builds the basic implementation of federated averaging."""
def init(params: fedjax.Params) -> ServerState:
opt_state = server_optimizer.init(params)
return ServerState(params, opt_state)
def apply(
server_state: ServerState,
clients: Sequence[Tuple[ClientId, fedjax.ClientDataset, fedjax.PRNGKey]]
) -> Tuple[ServerState, Mapping[ClientId, Any]]:
client_diagnostics = {}
# We use a list here for clarity, but we strongly recommend avoiding loading
# all client outputs into memory since the outputs can be quite large
# depending on the size of the model.
client_delta_params_weights = []
for client_id, client_dataset, client_rng in clients:
delta_params = client_update(server_state.params, client_dataset,
client_rng)
client_delta_params_weights.append((delta_params, len(client_dataset)))
# We record the l2 norm of client updates as an example, but it is not
# required for the algorithm.
client_diagnostics[client_id] = {
'delta_l2_norm': fedjax.tree_util.tree_l2_norm(delta_params)
}
mean_delta_params = fedjax.tree_util.tree_mean(client_delta_params_weights)
server_state = server_update(server_state, mean_delta_params)
return server_state, client_diagnostics
def client_update(server_params, client_dataset, client_rng):
params = server_params
opt_state = client_optimizer.init(params)
for batch in client_dataset.shuffle_repeat_batch(client_batch_hparams):
client_rng, use_rng = jax.random.split(client_rng)
grads = grad_fn(params, batch, use_rng)
opt_state, params = client_optimizer.apply(grads, opt_state, params)
delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
server_params, params)
return delta_params
def server_update(server_state, mean_delta_params):
opt_state, params = server_optimizer.apply(mean_delta_params,
server_state.opt_state,
server_state.params)
return ServerState(params, opt_state)
return fedjax.FederatedAlgorithm(init, apply)