Skip to content

Commit

Permalink
fix double linear override; spelling (#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 authored Apr 14, 2021
1 parent e6999eb commit adac058
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _init_subclass(cls, **kwargs):

if self.mem_efficient_linear:
print_rank_0(
f"Your linear layers are being patched with more memory efficient version. This will persit unless manually reset.",
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
force=True)
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = LinearFunctionForZeroStage3.apply
Expand Down Expand Up @@ -361,13 +361,6 @@ def get_model():
self._convert_to_deepspeed_param(param)
param.partition()

if mem_efficient_linear:
print_rank_0(
f"Your linear layers are being patched with more memory efficient version. This will persit unless manually turned reset.",
force=True)
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = LinearFunctionForZeroStage3.apply

def _post_init_method(self, module):
#see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
Expand Down

0 comments on commit adac058

Please sign in to comment.