Skip to content

Commit

Permalink
Optimizing eager aggregation - reduce deepcopy ops (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShixiongQi authored Jun 3, 2024
1 parent 334de7b commit 5888067
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/prerequisites.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Python 3.9.6
The target runtime environment is Linux. Development has been mainly conducted under macOS environment. This section describes how to set up a development environment in macOS (Intel chip) and Ubuntu.

The following tools and packages are needed as minimum:
- go 1.18+
- go 1.22+
- golangci-lint

After installing above packages, you could try a development setup called `fiab`, an acronym for flame-in-a-box, which is found [here](system/fiab.md).
Expand All @@ -112,7 +112,7 @@ sudo apt update

Install golang and and golangci-lint.
```bash
golang_file=go1.18.6.linux-amd64.tar.gz
golang_file=go1.22.3.linux-amd64.tar.gz
curl -LO https://go.dev/dl/$golang_file && tar -C $HOME -xzf $golang_file
echo "PATH=\"\$HOME/go/bin:\$PATH\"" >> $HOME/.bashrc
source $HOME/.bashrc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _aggregate_weights(self, tag: str) -> None:
return

total = 0
base_weights = deepcopy(self.weights)

for msg, metadata in channel.recv_fifo(channel.ends()):
end, _ = metadata
Expand All @@ -60,18 +61,18 @@ def _aggregate_weights(self, tag: str) -> None:

# optimizer conducts optimization (in this case, aggregation)
global_weights = self.optimizer.do(
deepcopy(self.weights), self.cache, total=total
base_weights, self.cache, total=total
)
if global_weights is None:
logger.debug("failed model aggregation")
time.sleep(1)
return

# save global weights before updating it
self.prev_weights = self.weights
# save global weights before updating it
self.prev_weights = self.weights

# set global weights
self.weights = global_weights
# set global weights
self.weights = global_weights

logger.debug(f"received {len(self.cache)} trainer updates in cache")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _aggregate_weights(self, tag: str) -> None:
return

total = 0
base_weights = deepcopy(self.weights)

for msg, metadata in channel.recv_fifo(channel.ends()):
end, timestamp = metadata
Expand Down Expand Up @@ -72,7 +73,7 @@ def _aggregate_weights(self, tag: str) -> None:

# optimizer conducts optimization (in this case, aggregation)
global_weights = self.optimizer.do(
deepcopy(self.weights),
base_weights,
self.cache,
total=total,
num_trainers=len(channel.ends()),
Expand All @@ -82,8 +83,8 @@ def _aggregate_weights(self, tag: str) -> None:
time.sleep(1)
return

# set global weights
self.weights = global_weights
# set global weights
self.weights = global_weights

# update model with global weights
self._update_model()

0 comments on commit 5888067

Please sign in to comment.