Skip to content

Commit

Permalink
update zero1
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Nov 20, 2023
1 parent 35be453 commit 3f3930e
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(
self.pin_layout = pin_layout
self.coalesce_cc = coalesce_cc

self._grad_norm = None

self.inited = False
if not lazy_init:
self.init_zero()
Expand All @@ -104,6 +106,10 @@ def init_zero(self):
self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups)
self.inited = True

@property
def grad_norm(self):
return self._grad_norm

@property
def sharding_groups(self):
return self._sharding_groups
Expand Down Expand Up @@ -160,12 +166,17 @@ def _shard_parameters(self):
"""
Shard all parameters.
"""
self.device = None
all_params = []
for param_group in self.param_groups:
for param in param_group['params']:
all_params.append(param)
if self.device is None:
self.device = param.device
else:
assert self.device == param.device, "Params should on the same device."
assert self.device.type == 'xla'

self.device = all_params[0].device
xm.unlazy(all_params)

sharded_params_groups = []
Expand Down Expand Up @@ -229,11 +240,11 @@ def _clip_grad_norm(
"""
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = self._calc_grad_norm(norm_type)
self._grad_norm = self._calc_grad_norm(norm_type)

clip_coeff = torch.tensor(
max_norm, device=self.device) / (
total_norm + 1e-6)
self._grad_norm + 1e-6)
clip_value = torch.where(clip_coeff < 1, clip_coeff,
torch.tensor(1., device=self.device))
for param_group in self.base_optimizer.param_groups:
Expand Down Expand Up @@ -283,7 +294,7 @@ def step(self, closure=None, **kwargs):
shard.grad = grad_shard

if self.coalesce_cc:
grad_shard = xm.reduce_scatter(
grad_shards = xm.reduce_scatter(
xm.REDUCE_SUM,
padded_grads,
scale=1.0 / self.local_world_size,
Expand Down Expand Up @@ -362,6 +373,7 @@ def state_dict(self):
state_dict = super().state_dict()
base_state = self.base_optimizer.state_dict()['state']
state_dict['base_state'] = base_state
state_dict['shape_info'] = self.get_shape_info()
return state_dict

def load_state_dict(self, state_dict):
Expand All @@ -375,3 +387,12 @@ def load_state_dict(self, state_dict):
tmp = self.base_optimizer.state_dict()
tmp['state'] = base_state
self.base_optimizer.load_state_dict(tmp)

def get_shape_info(self):
shape_info = {}
idx = 0
for param_group in self.param_groups:
for param in param_group['params']:
shape_info[idx] = param.shape
idx += 1
return shape_info

0 comments on commit 3f3930e

Please sign in to comment.