-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TransformerEngine] Support backward(retain_graph=True)
#701
Comments
One of the reason for the issue could that we have our own version of Context which may probably behaves differently when backward is called with |
I think this may be the famous clearing of collections... |
PyTorch requires all tensor data to be saved with |
Even if we fix this, TransformerEngine itself has problem with running the backward multiple times with import torch
from transformer_engine.pytorch import Linear as TELinear, fp8_autocast
m = TELinear(16, 16)
x = torch.randn(16, 16, device='cuda')
with fp8_autocast(True):
o = m(x).sum()
o.backward(retain_graph=True)
# this fails with
# AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[0]
# looks like TELinear.backward mutates the context object such that it is not reusable.
o.backward() |
Let's do the work anyway so that we hit the same error as in pure TE usage. |
Currently, we send the data saved backward from forward using a mock context object using a dictionary lightning-thunder/thunder/executors/transformer_engineex.py Lines 125 to 140 in ab514fc
A dictionary instead of a Context object directly was used to let clear_mutable_collections clear the dict and subsequently release the memory held by tensors. It was also considered too much work to faithfully model the properties of saved tensors (number of saved tensors their type and shape). However, to ensure correct behavior from PyTorch we need to pass tensor objects through PyTorch's save_for_backward(*tensors) function. That means we need to modify TransformerEngine's meta to return TensorProxies explicitly and not hidden behind a generic 'CollectionProxy'.
|
…ckward]. One of the ranks failed with the following error. This test was added as a performance baseline, and doesn't verify any nvFuser functionality. So I am disabling it to clean CI and will investigate it further. ``` __________________________________________________________________________________________________________________________________________________________________________________________________________________________________ test_transformer_layer[backward] ___________________________________________________________________________________________________________________________________________________________________________________________________________________________________ mpi_test = <mpi_fixtures.MpiTest object at 0x7f594feccdf0>, benchmark = <pytest_benchmark.fixture.BenchmarkFixture object at 0x7f594feccbb0>, compute_type = <ComputeType.BACKWARD: 2> @pytest.mark.mpi @pytest.mark.parametrize( "compute_type", [ComputeType.FORWARD, ComputeType.BACKWARD], ids=["forward", "backward"], ) def test_transformer_layer(mpi_test, benchmark, compute_type): # Hyperparameters for GPT-3 hidden_size = 12288 num_heads = 96 ffn_hidden_size = hidden_size * 4 batch_size = 1 sequence_length = 2048 dtype = torch.bfloat16 size = mpi_test.size rank = mpi_test.rank torch.cuda.set_device(rank) os.environ["MASTER_ADDR"] = "localhost" # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. os.environ["MASTER_PORT"] = "29500" dist.init_process_group( backend="nccl", init_method="env://", world_size=size, rank=rank, ) tp_group = dist.new_group() transformer_layer = te.TransformerLayer( hidden_size, ffn_hidden_size, num_heads, set_parallel_mode=True, tp_group=tp_group, ) transformer_layer.to(dtype).to("cuda") x = torch.randn( batch_size, sequence_length, hidden_size, dtype=dtype, device="cuda" ) match compute_type: case ComputeType.FORWARD: def benchmark_fn(profile): if profile: torch.cuda.cudart().cudaProfilerStart() y = transformer_layer(x) torch.cuda.synchronize() if profile: torch.cuda.cudart().cudaProfilerStop() return y # Warmup. y = benchmark_fn(False) assert y.size() == torch.Size([batch_size, sequence_length, hidden_size]) benchmark.pedantic(benchmark_fn, args=(True,), rounds=5) case ComputeType.BACKWARD: # Due to # Lightning-AI/lightning-thunder#701, a # limitation in TransformerEngine, we can't repeatedly call # torch.autograd.backward to benchmark just backprop. As a # workaround, the code below runs forward before each backprop but # only measure the backprop time. def setup_fn(profile): y = transformer_layer(x) dy = torch.rand_like(y) torch.cuda.synchronize() # Unlike for forward, I can't pass `profile` directly to # `benchmark_fn` because `benchmark.pedantic` is not allowed to # take both `setup` and `args`. Therefore, we pass `profile` to # `setup_fn`, which in turn passes iit through to # `benchmark_fn`. return (y, dy, profile), {} def benchmark_fn(y, dy, profile): if profile: torch.cuda.cudart().cudaProfilerStart() torch.autograd.backward(y, dy) torch.cuda.synchronize() if profile: torch.cuda.cudart().cudaProfilerStop() # Warmup. > args, kwargs = setup_fn(False) tests/python/test_transformer_engine.py:123: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tests/python/test_transformer_engine.py:102: in setup_fn y = transformer_layer(x) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl return self._call_impl(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747: in _call_impl return forward_call(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py:677: in forward self_attention_outputs = self.self_attention( /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl return self._call_impl(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747: in _call_impl return forward_call(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py:8535: in forward projection_output = self.proj( /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl return self._call_impl(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747: in _call_impl return forward_call(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:629: in _fn return fn(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/linear.py:1017: in forward out = linear_fn(*args) /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:575: in apply return super().apply(*args, **kwargs) # type: ignore[misc] /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/linear.py:360: in forward out, _ = allreduce(out, tp_group) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/distributed.py:886: in allreduce handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op) /usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py:83: in wrapper return func(*args, **kwargs) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor = tensor([[ 2.3438, -1.4609, -1.5234, ..., 2.1250, -3.9375, -2.3906], [ 2.4844, 0.3125, 2.2812, ..., -1.906...], [-1.4297, -8.3125, -1.8516, ..., -7.2812, 1.1094, 2.5625]], device='cuda:3', dtype=torch.bfloat16), op = <RedOpType.SUM: 0>, group = <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f5afd67e430>, async_op = False @_exception_logger def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): """ Reduces the tensor data across all machines in a way that all get the final result. After the call ``tensor`` is going to be bitwise identical in all processes. Complex tensors are supported. Args: tensor (Tensor): Input and output of the collective. The function operates in-place. op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Examples: >>> # xdoctest: +SKIP("no rank") >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> device = torch.device(f'cuda:{rank}') >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) >>> tensor tensor([4, 6], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1 >>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) >>> tensor tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 """ _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): _warn_not_in_group("all_reduce") return if tensor.is_complex(): if not supports_complex(op): raise ValueError(f"all_reduce does not support {op} on complex tensors") tensor = torch.view_as_real(tensor) opts = AllreduceOptions() opts.reduceOp = op if group is None: group = _get_default_group() if group in _world.pg_coalesce_state.keys(): # We are in coalescing context, do not issue single operation, just append a collective representation coll = _CollOp(all_reduce, tensor, None, op, None) _world.pg_coalesce_state[group].append(coll) if async_op: return _IllegalWork() else: return None > work = group.allreduce([tensor], opts) E torch.distributed.DistBackendError: NCCL error in: /opt/pytorch/pytorch/torch/csrc/distributed/c10d/NCCLUtils.hpp:317, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.22.3 E ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. E Last error: E socketStartConnect: Connect to 10.120.104.58<41447> failed : Software caused connection abort /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:2603: DistBackendError ```
#3119) …ckward]. One of the ranks failed with the following error. This test was added as a performance baseline, and doesn't verify any nvFuser functionality. So I am disabling it to clean CI and will investigate it further. ``` __________________________________________________________________________________________________________________________________________________________________________________________________________________________________ test_transformer_layer[backward] ___________________________________________________________________________________________________________________________________________________________________________________________________________________________________ mpi_test = <mpi_fixtures.MpiTest object at 0x7f594feccdf0>, benchmark = <pytest_benchmark.fixture.BenchmarkFixture object at 0x7f594feccbb0>, compute_type = <ComputeType.BACKWARD: 2> @pytest.mark.mpi @pytest.mark.parametrize( "compute_type", [ComputeType.FORWARD, ComputeType.BACKWARD], ids=["forward", "backward"], ) def test_transformer_layer(mpi_test, benchmark, compute_type): # Hyperparameters for GPT-3 hidden_size = 12288 num_heads = 96 ffn_hidden_size = hidden_size * 4 batch_size = 1 sequence_length = 2048 dtype = torch.bfloat16 size = mpi_test.size rank = mpi_test.rank torch.cuda.set_device(rank) os.environ["MASTER_ADDR"] = "localhost" # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. os.environ["MASTER_PORT"] = "29500" dist.init_process_group( backend="nccl", init_method="env://", world_size=size, rank=rank, ) tp_group = dist.new_group() transformer_layer = te.TransformerLayer( hidden_size, ffn_hidden_size, num_heads, set_parallel_mode=True, tp_group=tp_group, ) transformer_layer.to(dtype).to("cuda") x = torch.randn( batch_size, sequence_length, hidden_size, dtype=dtype, device="cuda" ) match compute_type: case ComputeType.FORWARD: def benchmark_fn(profile): if profile: torch.cuda.cudart().cudaProfilerStart() y = transformer_layer(x) torch.cuda.synchronize() if profile: torch.cuda.cudart().cudaProfilerStop() return y # Warmup. y = benchmark_fn(False) assert y.size() == torch.Size([batch_size, sequence_length, hidden_size]) benchmark.pedantic(benchmark_fn, args=(True,), rounds=5) case ComputeType.BACKWARD: # Due to # Lightning-AI/lightning-thunder#701, a # limitation in TransformerEngine, we can't repeatedly call # torch.autograd.backward to benchmark just backprop. As a # workaround, the code below runs forward before each backprop but # only measure the backprop time. def setup_fn(profile): y = transformer_layer(x) dy = torch.rand_like(y) torch.cuda.synchronize() # Unlike for forward, I can't pass `profile` directly to # `benchmark_fn` because `benchmark.pedantic` is not allowed to # take both `setup` and `args`. Therefore, we pass `profile` to # `setup_fn`, which in turn passes iit through to # `benchmark_fn`. return (y, dy, profile), {} def benchmark_fn(y, dy, profile): if profile: torch.cuda.cudart().cudaProfilerStart() torch.autograd.backward(y, dy) torch.cuda.synchronize() if profile: torch.cuda.cudart().cudaProfilerStop() # Warmup. > args, kwargs = setup_fn(False) tests/python/test_transformer_engine.py:123: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tests/python/test_transformer_engine.py:102: in setup_fn y = transformer_layer(x) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl return self._call_impl(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747: in _call_impl return forward_call(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py:677: in forward self_attention_outputs = self.self_attention( /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl return self._call_impl(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747: in _call_impl return forward_call(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py:8535: in forward projection_output = self.proj( /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl return self._call_impl(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1747: in _call_impl return forward_call(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:629: in _fn return fn(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/linear.py:1017: in forward out = linear_fn(*args) /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:575: in apply return super().apply(*args, **kwargs) # type: ignore[misc] /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/linear.py:360: in forward out, _ = allreduce(out, tp_group) /usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/distributed.py:886: in allreduce handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op) /usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py:83: in wrapper return func(*args, **kwargs) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor = tensor([[ 2.3438, -1.4609, -1.5234, ..., 2.1250, -3.9375, -2.3906], [ 2.4844, 0.3125, 2.2812, ..., -1.906...], [-1.4297, -8.3125, -1.8516, ..., -7.2812, 1.1094, 2.5625]], device='cuda:3', dtype=torch.bfloat16), op = <RedOpType.SUM: 0>, group = <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f5afd67e430>, async_op = False @_exception_logger def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): """ Reduces the tensor data across all machines in a way that all get the final result. After the call ``tensor`` is going to be bitwise identical in all processes. Complex tensors are supported. Args: tensor (Tensor): Input and output of the collective. The function operates in-place. op (optional): One of the values from ``torch.distributed.ReduceOp`` enum. Specifies an operation used for element-wise reductions. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op Returns: Async work handle, if async_op is set to True. None, if not async_op or if not part of the group Examples: >>> # xdoctest: +SKIP("no rank") >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> device = torch.device(f'cuda:{rank}') >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) >>> tensor tensor([4, 6], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1 >>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) >>> tensor tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 """ _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): _warn_not_in_group("all_reduce") return if tensor.is_complex(): if not supports_complex(op): raise ValueError(f"all_reduce does not support {op} on complex tensors") tensor = torch.view_as_real(tensor) opts = AllreduceOptions() opts.reduceOp = op if group is None: group = _get_default_group() if group in _world.pg_coalesce_state.keys(): # We are in coalescing context, do not issue single operation, just append a collective representation coll = _CollOp(all_reduce, tensor, None, op, None) _world.pg_coalesce_state[group].append(coll) if async_op: return _IllegalWork() else: return None > work = group.allreduce([tensor], opts) E torch.distributed.DistBackendError: NCCL error in: /opt/pytorch/pytorch/torch/csrc/distributed/c10d/NCCLUtils.hpp:317, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.22.3 E ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. E Last error: E socketStartConnect: Connect to 10.120.104.58<41447> failed : Software caused connection abort /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:2603: DistBackendError ```
cc: @IvanYashchuk
The text was updated successfully, but these errors were encountered: