Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/state_averaging' into state_aver…
Browse files Browse the repository at this point in the history
…aging
  • Loading branch information
justheuristic committed Nov 15, 2021
2 parents 09c5817 + 0ee16c8 commit cf40b94
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions hivemind/optim/experimental/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _local_tensors(self) -> Iterator[torch.Tensor]:
for param_group in self.optimizer.param_groups:
for param in param_group["params"]:
yield self.optimizer.state[param][stats]
yield from iter(self.extra_tensors)
yield from self.extra_tensors

@torch.no_grad()
def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
Expand Down Expand Up @@ -505,7 +505,7 @@ def load_state_from_peers(self, **kwargs):

metadata, flat_tensors = loaded_state
if (not isinstance(metadata.get("epoch"), int)) or metadata["epoch"] < self.local_epoch:
logger.error("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch.")
logger.warning("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch")
return

loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
Expand All @@ -517,7 +517,7 @@ def load_state_from_peers(self, **kwargs):
try:
load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
except StopIteration:
logger.error("Failed to load state from peer, received inconsistent number of optimizer statistics.")
logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
return

with torch.no_grad():
Expand Down

0 comments on commit cf40b94

Please sign in to comment.