Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allocating CPU memory directly on CPU without transfering them from GPU #360

Merged
merged 2 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
if __installed_ops__['sparse-attn']:
from . import sparse_attention
if __installed_ops__['cpu-adam']:
from . import adam
from . import adam
8 changes: 4 additions & 4 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def __init__(self,
# a partition of the fp32 master weights that will be updated by this process
self.single_partition_of_fp32_groups.append(
self.parallel_partitioned_fp16_groups[i]
[partition_id].clone().float().detach().to(self.device))
[partition_id].to(self.device).clone().float().detach())

# modify optimizer of have flat master weight
self.single_partition_of_fp32_groups[
Expand Down Expand Up @@ -330,7 +330,7 @@ def __init__(self,
self.local_overflow = False
self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = torch.zeros(
largest_param_numel).half().pin_memory()
largest_param_numel, device=self.device).half().pin_memory()
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
largest_param_numel,
device=torch.cuda.current_device()).half()
Expand Down Expand Up @@ -787,7 +787,7 @@ def async_accumulate_grad_in_cpu(self, param):
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype).pin_memory()
dtype=param.dtype, device=self.device).pin_memory()

self.accumulated_grads_in_cpu[param_id].add_(dest_buffer)

Expand All @@ -803,7 +803,7 @@ def async_accumulate_grad_in_cpu_via_gpu(self, param):
if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = torch.zeros(
param.numel(),
dtype=param.dtype).pin_memory()
dtype=param.dtype, device=self.device).pin_memory()

if self.micro_step_id > 0:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
Expand Down