-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
partition_parameters.py
executable file
·2257 lines (1804 loc) · 104 KB
/
partition_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math
import os
import types
from typing import Callable, Iterable
from enum import Enum
import functools
import itertools
from typing import List
from collections import defaultdict
import logging
import torch
from torch import Tensor
from deepspeed import comm as dist
from torch.nn import Module
from torch.nn import Parameter
from .linear import zero3_linear_wrap
from deepspeed.utils import groups
import deepspeed
from ..utils import see_memory_usage, get_only_unique_item
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.config_utils import get_config_default
from deepspeed.utils import instrument_w_nvtx, logger
from deepspeed.comm.comm import init_distributed
from deepspeed.utils.debug import (debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name,
debug_param2name_id, debug_param2name_id_shape_status)
from deepspeed.accelerator import get_accelerator
from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
from deepspeed.inference.quantization.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict
partitioned_param_data_shape = [0]
zero_init_context = 0
top_level_context = None
class NoGatherHandle:
def __init__(self, param: Parameter) -> None:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to be available")
if hasattr(param.ds_tensor, "ds_quant_scale"):
param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to(
device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape)
else:
param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(),
non_blocking=True).view(param.ds_shape)
self.__param = param
def wait(self, **kwargs) -> None:
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE
class NoGatherCoalescedHandle:
def __init__(self, params: List[Parameter]) -> None:
self.__params = params
self.__complete = False
for param in self.__params:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
if hasattr(param.ds_tensor, "ds_quant_scale"):
param.data = Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale).to(
device=get_accelerator().current_device_name(), non_blocking=True).view(param.ds_shape)
else:
param.data = param.ds_tensor.data.to(device=get_accelerator().current_device_name(),
non_blocking=True).view(param.ds_shape)
@instrument_w_nvtx
def wait(self, **kwargs) -> None:
if self.__complete:
return
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
for param in self.__params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
param.ds_status = ZeroParamStatus.AVAILABLE
self.__complete = True
def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True)
def print_rank_0(message, debug=False, force=False):
rank = dist.get_rank()
if rank == 0 and (debug or force):
print(message)
# other variations
# - print for all ranks w/o interleaving
# printflock(f"[{rank}] {message}")
# - print to log file per rank
# log_rank_file(rank, message)
def debug_rank0(msg: str) -> None:
if dist.get_rank() == 0:
logger.debug(msg)
def _init_external_params(module):
if not hasattr(module, '_external_params'):
module._external_params = {}
def external_parameters(self):
return self._external_params.items()
def all_parameters(self):
return itertools.chain(self.named_parameters(self, recurse=False), external_parameters(self))
module.ds_external_parameters = types.MethodType(external_parameters, module)
module.all_parameters = types.MethodType(all_parameters, module)
def register_external_parameter(module, parameter):
"""Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
the forward and backward passes of ``module``.
This is used when a parameter is accessed outside of its owning module's
``forward()``. DeepSpeed must know to collect it from its partitioned
state and when to release the memory.
.. note::
This is only applicable to training with ZeRO stage 3.
Args:
module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
parameter (``torch.nn.Parameter``): The parameter to register.
Raises:
RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
Examples
========
#. Register a weight that is used in another module's forward pass (line 6).
Parameter ``layer1.weight`` is used by ``layer2`` (line 11).
.. code-block:: python
:linenos:
:emphasize-lines: 6,11
class ModuleZ3(torch.nn.Module):
def __init__(self, *args):
super().__init__(self, *args)
self.layer1 = SomeLayer()
self.layer2 = OtherLayer()
deepspeed.zero.register_external_parameter(self, self.layer1.weight)
def forward(self, input):
x = self.layer1(input)
# self.layer1.weight is required by self.layer2.forward
y = self.layer2(x, self.layer1.weight)
return y
"""
if not isinstance(parameter, torch.nn.Parameter):
raise RuntimeError('Parameter is not a torch.nn.Parameter')
if not hasattr(module, '_external_params'):
_init_external_params(module)
key = id(parameter)
module._external_params[key] = parameter
def unregister_external_parameter(module, parameter):
"""Reverses the effects of :meth:`register_external_parameter`.
Args:
module (``torch.nn.Module``): The module to affect.
parameter (``torch.nn.Parameter``): The parameter to unregister.
Raises:
RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
RuntimeError: If ``parameter`` is not a registered external parameter of ``module``.
"""
if not isinstance(parameter, torch.nn.Parameter):
raise RuntimeError('Parameter is not a torch.nn.Parameter')
if not hasattr(module, '_external_params') or id(parameter) not in module._external_params:
raise RuntimeError('Parameter is not a registered external parameter of module.')
key = id(parameter)
del module._external_params[key]
class ZeroParamType(Enum):
# same as regular pytorch parameters
NORMAL = 1
# parameters are partitioned across data parallel process
PARTITIONED = 2
# the parameter is held with a unique process rank
# and is not available on all other process
REMOTE = 3
class ZeroParamStatus(Enum):
# parameters are fully present and ready for use on all processes
AVAILABLE = 1
# parameters are either partitioned or remote in some or all process
NOT_AVAILABLE = 2
# parameters are being gathered.
INFLIGHT = 3
_orig_torch_tensor = torch.tensor
_orig_torch_empty = torch.empty
_orig_torch_zeros = torch.zeros
_orig_torch_ones = torch.ones
_orig_torch_full = torch.full
_orig_torch_arange = torch.arange
_orig_torch_eye = torch.eye
_orig_torch_randn = torch.randn
def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
def wrapped_fn(*args, **kwargs) -> Tensor:
if kwargs.get("device", None) is None:
kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
tensor: Tensor = fn(*args, **kwargs)
if tensor.is_floating_point():
tensor.data = tensor.data.to(target_fp_dtype)
return tensor
return wrapped_fn
def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:
def new_tensor(cls, *args, **kwargs) -> Tensor:
device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
if not args:
args = (0, )
tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs)
if tensor.is_floating_point():
tensor = tensor.to(dtype)
return tensor
return new_tensor
# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls, include_root=True):
subclass_list = []
def recurse(cl):
for subclass in cl.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)
recurse(cls)
ret = set(subclass_list)
if include_root:
ret.add(cls)
return ret
@instrument_w_nvtx
def free_param(param: Parameter) -> None:
"""Free underlying storage of a parameter."""
assert not param.ds_active_sub_modules, param.ds_summary()
if get_accelerator().on_accelerator(param.data):
# need to make sure that we don't free the parameter while it is still
# being used for computation
if not get_accelerator().is_synchronized_device():
param.data.record_stream(get_accelerator().current_stream())
# param.data doesn't store anything meaningful in partitioned state
param.data = torch.empty(0, dtype=param.dtype, device=param.device)
param.ds_status = ZeroParamStatus.NOT_AVAILABLE
reuse_buffers = False
temp_contiguous_tensor = None
empty_buffers = {}
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
num_module_parameters = 0
num_module_elements = 0
def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None):
self.mem_efficient_linear = mem_efficient_linear
self.enabled = enabled
self._set_dtype(ds_config, dtype)
assert self.dtype in [
torch.half, torch.bfloat16, torch.float
], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"
self.wrapped_cls = set()
self.skip_init_depth = 0
self.quantized_initialization = None
if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization:
self.quantized_initialization = ds_config.weight_quantization_config.quantized_initialization
def __enter__(self):
if not self.enabled:
return
global zero_init_context
if zero_init_context == 0:
self.patch_init_and_builtins()
global top_level_context
top_level_context = self
zero_init_context += 1
def __exit__(self, exc_type, exc_value, traceback):
if not self.enabled:
return
global zero_init_context
zero_init_context -= 1
# Exiting the top level context
if zero_init_context == 0:
self.unpatch_init_and_builtins()
global top_level_context
top_level_context = None
if dist.get_rank() == 0:
billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9
num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters
logger.info(
f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B")
# Now that we cleaned up the metaclass injection, raise the exception.
if exc_type is not None:
return False
# To be implemented by inheriting classes
def _post_init_method(self, module):
pass
def _set_dtype(self, ds_config, dtype):
if ds_config is not None and dtype is None:
if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
if ds_config.bfloat16_enabled:
self.dtype = torch.bfloat16
elif ds_config.fp16_enabled:
self.dtype = torch.half
else:
self.dtype = torch.float
else:
self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported(
) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32
def patch_init_and_builtins(self):
def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
"""many models make use of child modules like Linear or Embedding which
perform their own weight initialization in their __init__ methods,
but will then have more weight initialization in a parent module's __init__
method that modifies weights of child modules, which is typically done
using the Module.apply method.
since the Init context manager partitions child modules immediately after
they are initialized, without modifying apply we would entirely skip
any initialization done by parent modules.
to get around this issue, we wrap the function passed to Module.apply
so that the applied function is applied to child modules correctly.
"""
def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable:
if hasattr(fn_to_apply, "wrapped"):
return fn_to_apply
@functools.wraps(fn_to_apply)
def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None:
"""gathers parameters before calling apply function. afterwards
parameters are broadcasted to ensure consistency across all ranks
then re-partitioned.
takes the following steps:
1. allgathers parameters for the current module being worked on
2. calls the original function
3. broadcasts root rank's parameters to the other ranks
4. re-partitions the parameters
"""
# TODO Delay error checking for dangling partitioned parameters to post module init
# raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
# f"were zero params, is it possible that the parameters were "
# f"overwritten after they were initialized? "
# f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ")
params_to_apply_fn_to: Iterable[Parameter] = list(
sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)],
key=lambda p: p.ds_id))
for param in params_to_apply_fn_to:
param.all_gather()
fn_to_apply(module_to_apply_fn_to)
for param in params_to_apply_fn_to:
dist.broadcast(param.data, 0, group=param.ds_process_group)
for param in params_to_apply_fn_to:
param.partition(has_been_updated=True)
wrapped_fn_to_apply.wrapped = True
return wrapped_fn_to_apply
@functools.wraps(orig_module_apply_fn)
def wrapped_apply(module: Module, fn_to_apply: Callable) -> None:
orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply))
return wrapped_apply
def hook_for_skip_init(module):
# this function is intended for handling the logic of torch.nn.utils.skip_init
# skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta'
# the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device).
def partition_after_empty_init(f):
@functools.wraps(f)
def wrapper(module, *args, **kwargs):
_module = f(module, *args, **kwargs)
# here is the post-hook for module.apply(empty_like...)
# after module.apply(empty_like...), the module has completed its empty init on real device
# since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init
self._post_init_method(_module)
return _module
return wrapper
def post_wrapper_to_empty(f):
# append some wrapper restoration after to_empty() call
@functools.wraps(f)
def wrapper(*args, **kwargs):
res = f(*args, **kwargs)
# restore _apply hook
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class_apply(subclass)
# self restore
module.to_empty = f
return res
return wrapper
def _enable_class_apply(cls):
if '_apply' in cls.__dict__:
cls._old_apply_of_skip_init_hook = cls._apply
cls._apply = partition_after_empty_init(cls._apply)
def _disable_class_apply(cls):
if hasattr(cls, '_old_apply_of_skip_init_hook'):
cls._apply = cls._old_apply_of_skip_init_hook
# add hooks for to_empty: apply_(empty_like)
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_enable_class_apply(subclass)
# add a restore hook when exiting skip_init
module.to_empty = post_wrapper_to_empty(module.to_empty)
def partition_after(f):
@functools.wraps(f)
def wrapper(module, *args, **kwargs):
# important logic: We want to run post_init only after child's __init__ is
# completed, and do nothing after __init__ of any of its parents and grandparents in
# the inheritance ancestry. This way the partitioning will need to happen only once
# when the whole object is ready to be partitioned and not before. This is because
# often the child module will need to tweak the weights - for example running a
# custom weights init function. So if a parent created the weights param, the child
# won't need to gather it in order to tweak it
print_rank_0(f'Before initializing {module.__class__.__name__}', force=False)
is_child_module = False
if not hasattr(module, "_ds_child_entered"):
# child's __init__ was called, since parents all see the same object they can now skip post_init
is_child_module = True
setattr(module, "_ds_child_entered", True)
init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta'
if init_on_meta:
self.skip_init_depth += 1
f(module, *args, **kwargs)
if init_on_meta and self.skip_init_depth == 1:
# check and handle the logic of empty_init
hook_for_skip_init(module)
if is_child_module:
# child's __init__ is done, now we can run a single post_init on the child object
delattr(module, "_ds_child_entered")
print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False)
if self.skip_init_depth == 0:
self._post_init_method(module)
print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False)
if init_on_meta:
self.skip_init_depth -= 1
return wrapper
def _enable_class(cls):
if '__init__' in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
def _init_subclass(cls, **kwargs):
if '__init__' in cls.__dict__:
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_enable_class(subclass)
# holding onto some methods so we can put them back the way they were in __exit__
torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
torch.Tensor.__old_new__ = torch.Tensor.__new__
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
if Init.override_module_apply:
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
self._add_tensor_creation_wrappers()
if self.mem_efficient_linear:
print_rank_0(
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
force=False)
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = zero3_linear_wrap
if self.quantized_initialization:
print_rank_0("nn.functional.linear has been overridden with quantized linear version.", force=False)
torch.nn.functional.linear = wrap_quantized_functional(torch.nn.functional.linear)
torch.nn.functional.embedding = wrap_quantized_functional(torch.nn.functional.embedding)
for cls in WEIGHT_QUANTIZATION_LAYERS:
cls._load_from_state_dict = wrap_load_from_state_dict(cls._load_from_state_dict)
logger.info("Enable Zero3 engine with INT4 quantization.")
self.patched = True
def unpatch_init_and_builtins(self):
if self.patched:
def _disable_class(cls):
if hasattr(cls, '_old_init'):
cls.__init__ = cls._old_init
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class(subclass)
# putting methods back the way we found them
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
if Init.override_module_apply:
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
self._remove_tensor_creation_wrappers()
self.patched = False
def _add_tensor_creation_wrappers(self):
torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype)
torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)
torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype)
torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype)
def _remove_tensor_creation_wrappers(self):
torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.tensor = _orig_torch_tensor
torch.empty = _orig_torch_empty
torch.zeros = _orig_torch_zeros
torch.ones = _orig_torch_ones
torch.full = _orig_torch_full
torch.arange = _orig_torch_arange
torch.eye = _orig_torch_eye
torch.randn = _orig_torch_randn
def shutdown_init_context():
"""
This function is used to initialize deepspeed engine inside the context of Init.
We need to remove the wrappers but keep the context.
"""
if top_level_context:
top_level_context.unpatch_init_and_builtins()
def restore_init_context():
"""
This function is used to restore the wrappers after deepspeed engine is initialized.
"""
if top_level_context:
top_level_context.patch_init_and_builtins()
class AllGatherHandle:
def __init__(self, handle, param: Parameter, quantization=None) -> None:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to be available")
self.__handle = handle
self.__param = param
self.__quantization = quantization
def wait(self, handle_dependency=True) -> None:
instrument_w_nvtx(self.__handle.wait)()
if self.__quantization:
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
self.__param.data = self.__quantization.backend.dequantize(
self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device)
self.__param.ds_status = ZeroParamStatus.AVAILABLE
class AllGatherCoalescedHandle:
data_buffer = []
def __init__(
self,
allgather_handle,
params: List[Parameter],
partitions: List[Tensor],
world_size: int,
use_secondary_tensor=False,
quantization=None,
) -> None:
self.allgather_handle = allgather_handle
self.params = params
self.partitions = partitions
self.world_size = world_size
self.use_secondary_tensor = use_secondary_tensor
self.complete = False
self.quantization = quantization
for param in self.params:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
@instrument_w_nvtx
def wait(self, handle_dependency=True) -> None:
if self.complete:
return
instrument_w_nvtx(self.allgather_handle.wait)()
if self.quantization:
instrument_w_nvtx(self.quantization.quant_handle.wait)()
flat_tensor = self.quantization.backend.dequantize(
self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device)
self.partitions: List[Parameter] = []
for i in range(self.world_size):
self.partitions.append(
flat_tensor.narrow(0, self.quantization.partition_sz * i, self.quantization.partition_sz))
# split the single tensor out into individual tensors
param_offset = 0
for param in self.params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
partitions: List[Tensor] = []
ds_tensor_numel = param.ds_tensor.ds_numel
if self.use_secondary_tensor:
ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups
for rank in range(self.world_size):
param_start = rank * ds_tensor_numel
if param_start < param.ds_numel:
part_to_copy = self.partitions[rank].narrow(0, param_offset,
min(param.ds_numel - param_start, ds_tensor_numel))
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
param.ds_status = ZeroParamStatus.AVAILABLE
if not get_accelerator().is_synchronized_device() and handle_dependency:
for part_to_copy in partitions:
part_to_copy.record_stream(get_accelerator().current_stream())
param_offset += ds_tensor_numel
self.complete = True
if not get_accelerator().is_synchronized_device() and not handle_dependency:
# if the device needs to handle dependencies and opts for explicit processing outside the function.
AllGatherCoalescedHandle.data_buffer.append(partitions)
@staticmethod
def free_buffer():
AllGatherCoalescedHandle.data_buffer = []
class MultipleAllGatherHandles:
def __init__(self, handles: List[AllGatherCoalescedHandle]):
self.handles = handles
def wait(self, handle_dependency=True) -> None:
for handle in self.handles:
handle.wait(handle_dependency)
class AllReduceCoalescedHandle:
def __init__(self, handle, params: List[Parameter]) -> None:
self.handle = handle
self.params = params
self.complete = False
for param in self.params:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
@instrument_w_nvtx
def wait(self) -> None:
if self.complete:
return
instrument_w_nvtx(self.handle.wait)()
for param in self.params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
param.ds_status = ZeroParamStatus.AVAILABLE
self.complete = True
class QuantizationInfo:
# a placeholder object to store all quant related vars used in handles
def __init__(self) -> None:
self.quantized_param = None
self.backend = None
self.quant_handle = None
self.scale_buffer = None
class CUDAQuantizer:
async_flag = True
target_group_size = 8000 # the optimal size is 4k, so we set the target to be below 8k
group_size_cache = dict()
quantizer_cuda_module = None
def __init__(self) -> None:
if CUDAQuantizer.quantizer_cuda_module is None:
CUDAQuantizer.quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
def quantize(self, param, groups=None):
if groups is None:
try:
groups = self.group_size_cache[param.numel()]
except KeyError:
groups = math.ceil(param.numel() / self.target_group_size)
while groups < param.numel():
if param.numel() % (8 * groups) == 0:
break
groups += 1
while True:
if param.numel() % (8 * groups * 2) == 0 and param.numel(
) / groups > self.target_group_size: #hard limit of 16k group_size
groups *= 2
else:
break
assert (
param.numel() % (8 * groups) == 0
), f"Qantized weight requires the number of weights be a multiple of 8. Yet {param.numel()} cannot be divided by 8*{groups}"
assert (param.numel() / groups < 16000), f"{param.numel()} / {groups} is larger than 16k"
assert param.numel(
) > groups, f"Adaptive grouping algorithm cannot find a group size for input tensor of size {param.numel()}"
self.group_size_cache[param.numel()] = groups
return self.quantizer_cuda_module.quantize(param.to(get_accelerator().device_name()), groups, 8,
self.quantizer_cuda_module.Symmetric)
def dequantize(self, quantized_param, scale):
return self.quantizer_cuda_module.dequantize(quantized_param, scale, scale.numel(), 8,
self.quantizer_cuda_module.Symmetric)
def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle:
for param in params:
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(f"expect param.ds_status == ZeroParamStatus.NOT_AVAILABLE, got{param.ds_summary()}")
param.ds_status = ZeroParamStatus.INFLIGHT
params = sorted(params, key=lambda p: p.ds_id)
if len(params) == 1:
param, = params
return NoGatherHandle(param)
return NoGatherCoalescedHandle(params)
# Replaces all parameters in module with Scattered Parameters
class Init(InsertPostInitMethodToModuleSubClasses):
param_id = 0
param_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "param_persistence_threshold")
model_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "model_persistence_threshold")
num_persisted_parameters = 0
num_persisted_elements = 0
apply_param_persistence = False
override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply")
def __init__(self,
module=None,
data_parallel_group=None,
mem_efficient_linear=True,
remote_device=None,
pin_memory=False,
config_dict_or_path=None,
config=None,
enabled=True,
dtype=None,
mpu=None,
zero_param_parallel_group=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
sequence_data_parallel_group=None,
param_swapper=None):
"""A context to enable massive model construction for training with
ZeRO-3. Models are automatically partitioned (or, sharded) across the
system and converted to half precision.
Args:
module (``torch.nn.Module``, optional): If provided, partition the model as
if it was constructed in the context.
data_parallel_group (``deepspeed.comm`` process group, optional):
The group of processes to partition among. Defaults to all processes.
Synonymous with sequence data parallel group for param partitioning
across both sequence and data parallel groups.
mem_efficient_linear (bool, optional): Replace
torch.nn.functional.linear with an implementation that allows
DeepSpeed to partition parameters. Defaults to ``True``.
remote_device (string, optional): The initial device to store model
weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
memory. The model may still be moved to GPU based on the
offload settings for training. Defaults to param offload device if a config is
defined, otherwise GPU.
pin_memory (bool, optional): Potentially increase performance by
using pinned memory for model weights. ``remote_device`` must be
``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
for swapping fp16 params to NVMe.
config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
enabled (bool, optional): If ``False``, this context has no
effect. Defaults to ``True``.
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
zero_param_parallel_group(``object``, optional): Parallel (comm) group for dual partitioning of ZeRO params.
zero_quantized_weights (bool, optional): If ``True``, turn on quantized weights in all gather weights. Default is ``False``
zero_quantized_nontrainable_weights (bool, optional): If ``True``, nontrainable weights will be stored in quantized format. Default is ``False``
param_swapper (``deepspeed.runtime.swap_tensor.partitioned_param_swapper.AsyncPartitionedParameterSwapper``, optional): [Experimental] Use existing parameter swapper. Defaults to ``None``.
This argument will be removed in the near future.
This context accelerates model initialization and enables models that
are too large to allocate in their entirety in CPU memory. It has the
following effects:
#. allocates tensors to either GPU or CPU memory or NVMe
#. converts floating point tensors to half precision
#. immediately partitions tensors among the group of data-parallel devices
#. (*optional*) replaces ``torch.nn.functional.linear`` with a more
memory-efficient implementation
These modifications allow for models that exceed the size of local CPU/GPU
memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
or GPU memory or NVMe) across all nodes. Consider initializing a model with one
trillion parameters, whose weights occupy two terabytes (TB) in half
precision. The initial CPU allocation in full precision requires 4TB of
memory *per process*, and so a system with 8 GPUs per node would need 32TB of
CPU memory due to data-parallel redundancies. Instead, by immediately
partitioning tensors we remove the redundancies. The result is that
regardless of the number of GPUs, we still only require the original 4TB. This
allows for a linear increase in model size with the aggregate system memory.
For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
parameter model with 4 nodes and 32 GPUs.
Important: If the fp16 weights of the model can't fit onto a single GPU memory
this feature must be used.
.. note::
Initializes ``deepspeed.comm`` if it has not already been done so.
See :meth:`deepspeed.init_distributed` for more information.
.. note::
Only applicable to training with ZeRO-3.
Examples
--------
#. Allocate a model and partition it among all processes:
.. code-block:: python
with deepspeed.zero.Init():
model = MyLargeModel()
#. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:
.. code-block:: python
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device="cpu",
pin_memory=True):
model = MyLargeModel()
#. Partition an already-allocated model in CPU memory:
.. code-block:: python
model = deepspeed.zero.Init(module=model)
"""
if config is not None:
config_dict_or_path = config
logger.warning(
f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.')
_ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path,
mpu) if config_dict_or_path is not None else None
if _ds_config is not None:
mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear
super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype)
if not dist.is_initialized():
init_distributed()
assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
if data_parallel_group is None:
self.ds_process_group = dist.get_world_group()
else:
self.ds_process_group = data_parallel_group
if sequence_data_parallel_group is not None:
logger.warning(
f"sequence_data_parallel_group' is deprecated and will be removed. Use 'data_parallel_group' instead.")
if data_parallel_group is not None:
raise ValueError(
"Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments."
)
self.ds_process_group = sequence_data_parallel_group
self.rank = dist.get_rank(group=self.ds_process_group)
self.dp_world_size = dist.get_world_size(group=self.ds_process_group)
self.zero_param_process_group = zero_param_parallel_group
if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None:
groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size)
self.zero_param_process_group = groups._get_zero_param_intra_parallel_group()
self.num_ranks_in_param_group = self.dp_world_size
self.rank_in_group = self.rank
self.num_param_groups = 1
if self.zero_param_process_group is not None:
self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size()
self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group)
self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup()
print_rank_0(f"hpZeRO group size: {self.num_ranks_in_param_group}", force=True)
logger.debug(
"hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} "
.format(self.rank, self.rank_in_group, groups._get_zero_param_intra_parallel_group_ranks()))
# Local device is the device where the parameters are consumed, must be default device.
# It is the device where parameters are fully instantiated using allgather
self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
get_accelerator().set_device(self.local_device)
self.quantized_weights = zero_quantized_weights