diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index fa9220f4fcd0..be778b83f8bb 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.utils.torch import required_torch_version try: from torch.compiler import is_compiling as torch_is_compiling @@ -16,7 +17,7 @@ def is_compile_supported(): - return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile") + return required_torch_version(min_version=2.1) def disable(func): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 459cffce52c8..28f91cb9b3ab 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,6 +16,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item @@ -1159,7 +1160,6 @@ def overlapping_partition_gradients_reduce_epilogue(self): def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] self.leaf_parameters = defaultdict(list) for i, param_group in enumerate(self.fp16_groups): for param in param_group: @@ -1172,15 +1172,12 @@ def create_reduce_and_remove_grad_hooks(self): #print(f"After all gather {param.device}, {param.shape}") def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) #print(f"param grad fn {param.expand_as(param).grad_fn}") if z3_leaf_parameter(param): diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py index eb22d3561035..1d32775fe64a 100644 --- a/deepspeed/utils/torch.py +++ b/deepspeed/utils/torch.py @@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None): return False return True + + +def register_grad_hook(param, hook): + if required_torch_version(min_version=2.1): + return param.register_post_accumulate_grad_hook(hook) + else: + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + return grad_acc.register_hook(hook) diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py index f25bbc1be692..eb283924f73c 100644 --- a/tests/unit/ops/transformer/inference/test_bias_add.py +++ b/tests/unit/ops/transformer/inference/test_bias_add.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_add_reference(activations, bias): return activations + bias diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index e3a3bad63961..f0a09245e890 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -10,8 +10,8 @@ from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer import DeepSpeedInferenceConfig from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version from .inference_test_utils import allclose, get_dtypes -from packaging import version as pkg_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -34,7 +34,7 @@ def run_bias_gelu_ds(activations, bias): @pytest.mark.parametrize("channels", [512, 1232, 4096]) @pytest.mark.parametrize("dtype", get_dtypes()) def test_bias_gelu(batch, sequence, channels, dtype): - if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"): + if not required_torch_version(min_version=1.12): pytest.skip("gelu implementation matches only after torch 1.12") activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name()) diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py index 2ab195ee0115..6f5173bbc827 100644 --- a/tests/unit/ops/transformer/inference/test_matmul.py +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -11,8 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - def allclose(x, y): assert x.dtype == y.dtype