Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually calling step with DeepSpeed Adam passing in fp16 params #361

Merged
merged 1 commit into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
Expand Down Expand Up @@ -292,6 +293,9 @@ def zero_overlap_comm(self):
def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload

def deepspeed_adam(self):
return self._config.zero_config.deepspeed_adam

def zero_optimization_stage(self):
return self._config.zero_optimization_stage

Expand Down Expand Up @@ -532,7 +536,11 @@ def _configure_basic_optimizer(self, model_parameters):
)
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.zero_cpu_offload():
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
if False: #self.deepspeed_adam():
optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
else:
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)

else:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,14 +1405,14 @@ def step(self, closure=None):
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
#torch.set_num_threads(12)
timers('optimizer_step').start()
self.optimizer.step()
self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
#get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None

for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
#for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
# fp16_partitions[partition_id].data.copy_(fp32_partition.data)
timers('optimizer_step').stop()

if self.cpu_offload:
Expand Down