diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 0acc675985ca..0f2741251bb6 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -191,7 +191,7 @@ def _init_subclass(cls, **kwargs): if self.mem_efficient_linear: print_rank_0( - f"Your linear layers are being patched with more memory efficient version. This will persit unless manually reset.", + "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.", force=True) self.linear_bk = torch.nn.functional.linear torch.nn.functional.linear = LinearFunctionForZeroStage3.apply @@ -361,13 +361,6 @@ def get_model(): self._convert_to_deepspeed_param(param) param.partition() - if mem_efficient_linear: - print_rank_0( - f"Your linear layers are being patched with more memory efficient version. This will persit unless manually turned reset.", - force=True) - self.linear_bk = torch.nn.functional.linear - torch.nn.functional.linear = LinearFunctionForZeroStage3.apply - def _post_init_method(self, module): #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False) print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)