From 3f3930ed40a4bfad2492aeb5a9da34656d165728 Mon Sep 17 00:00:00 2001 From: guangtai Date: Mon, 20 Nov 2023 13:18:31 -0800 Subject: [PATCH] update zero1 --- .../distributed/zero_redundancy_optimizer.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index e55f1a84039..b3ae95a4132 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -79,6 +79,8 @@ def __init__( self.pin_layout = pin_layout self.coalesce_cc = coalesce_cc + self._grad_norm = None + self.inited = False if not lazy_init: self.init_zero() @@ -104,6 +106,10 @@ def init_zero(self): self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups) self.inited = True + @property + def grad_norm(self): + return self._grad_norm + @property def sharding_groups(self): return self._sharding_groups @@ -160,12 +166,17 @@ def _shard_parameters(self): """ Shard all parameters. """ + self.device = None all_params = [] for param_group in self.param_groups: for param in param_group['params']: all_params.append(param) + if self.device is None: + self.device = param.device + else: + assert self.device == param.device, "Params should on the same device." + assert self.device.type == 'xla' - self.device = all_params[0].device xm.unlazy(all_params) sharded_params_groups = [] @@ -229,11 +240,11 @@ def _clip_grad_norm( """ max_norm = float(max_norm) norm_type = float(norm_type) - total_norm = self._calc_grad_norm(norm_type) + self._grad_norm = self._calc_grad_norm(norm_type) clip_coeff = torch.tensor( max_norm, device=self.device) / ( - total_norm + 1e-6) + self._grad_norm + 1e-6) clip_value = torch.where(clip_coeff < 1, clip_coeff, torch.tensor(1., device=self.device)) for param_group in self.base_optimizer.param_groups: @@ -283,7 +294,7 @@ def step(self, closure=None, **kwargs): shard.grad = grad_shard if self.coalesce_cc: - grad_shard = xm.reduce_scatter( + grad_shards = xm.reduce_scatter( xm.REDUCE_SUM, padded_grads, scale=1.0 / self.local_world_size, @@ -362,6 +373,7 @@ def state_dict(self): state_dict = super().state_dict() base_state = self.base_optimizer.state_dict()['state'] state_dict['base_state'] = base_state + state_dict['shape_info'] = self.get_shape_info() return state_dict def load_state_dict(self, state_dict): @@ -375,3 +387,12 @@ def load_state_dict(self, state_dict): tmp = self.base_optimizer.state_dict() tmp['state'] = base_state self.base_optimizer.load_state_dict(tmp) + + def get_shape_info(self): + shape_info = {} + idx = 0 + for param_group in self.param_groups: + for param in param_group['params']: + shape_info[idx] = param.shape + idx += 1 + return shape_info