Skip to content

Commit

Permalink
Allocating CPU memory directly on CPU without transfering them from G…
Browse files Browse the repository at this point in the history
…PU (#360)

* Allocating CPU memory directly on CPU without transfering them from GPU

* formatting fixes
  • Loading branch information
samyam authored Sep 4, 2020
1 parent 077cfd4 commit 7be128a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 15 deletions.
1 change: 1 addition & 0 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

ds_opt_adam = None


class DeepSpeedCPUAdam(torch.optim.Optimizer):

optimizer_id = 0
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def _get_state_without_padding(self, state_with_padding, padding):
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value

return lean_state

# Return base optimizer states.
Expand Down
24 changes: 13 additions & 11 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def __init__(self,

# a partition of the fp32 master weights that will be updated by this process
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i]
[partition_id].clone().float().detach().to(self.device))
self.parallel_partitioned_fp16_groups[i][partition_id].to(
self.device).clone().float().detach())

# modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
Expand Down Expand Up @@ -331,7 +331,8 @@ def __init__(self,
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(
largest_param_numel).half().pin_memory()
largest_param_numel,
device=self.device).half().pin_memory()
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
largest_param_numel,
device=torch.cuda.current_device()).half()
Expand Down Expand Up @@ -478,10 +479,10 @@ def independent_gradient_partition_epilogue(self):

if self.overlap_comm:
torch.cuda.synchronize()

if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups):

if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
Expand All @@ -500,8 +501,7 @@ def independent_gradient_partition_epilogue(self):

for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new):
accumulated_grad.add_(new_avg_grad)



self._release_ipg_buffers()

# No need to keep the gradients anymore.
Expand Down Expand Up @@ -788,7 +788,8 @@ def async_accumulate_grad_in_cpu(self, param):
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype).pin_memory()
dtype=param.dtype,
device=self.device).pin_memory()

self.accumulated_grads_in_cpu[param_id].add_(dest_buffer)

Expand All @@ -804,7 +805,8 @@ def async_accumulate_grad_in_cpu_via_gpu(self, param):
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype).pin_memory()
dtype=param.dtype,
device=self.device).pin_memory()

if self.micro_step_id > 0:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
Expand Down Expand Up @@ -871,7 +873,7 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):

src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad=None
param.grad = None

def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
Expand Down Expand Up @@ -904,7 +906,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):

def copy_grads_in_partition(self, param):
if self.cpu_offload:

if self.gradient_accumulation_steps > 1:
self.async_accumulate_grad_in_cpu_via_gpu(param)

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_adam_acuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deepspeed.ops.adam import DeepSpeedCPUAdam


def check_equal(first, second, atol=1e-2, verbose=False):
if verbose:
print(first)
Expand Down Expand Up @@ -36,11 +37,11 @@ def test_adam_opt(model_size):

for i in range(10):
rng_state = torch.get_rng_state()
param.grad=torch.randn(model_size, device=device)
param.grad = torch.randn(model_size, device=device)
torch.set_rng_state(rng_state)
param1.grad=torch.randn(model_size, device=device)
param1.grad = torch.randn(model_size, device=device)

optimizer.step()
optimizer1.step()

check_equal(param, param1, atol = 1e-2, verbose=True)
check_equal(param, param1, atol=1e-2, verbose=True)

0 comments on commit 7be128a

Please sign in to comment.