diff --git a/tests/python/test_transformer_engine.py b/tests/python/test_transformer_engine.py index 19e5ca90fa1..239ab0a4df2 100644 --- a/tests/python/test_transformer_engine.py +++ b/tests/python/test_transformer_engine.py @@ -32,8 +32,11 @@ class ComputeType(Enum): @pytest.mark.mpi @pytest.mark.parametrize( "compute_type", - [ComputeType.FORWARD, ComputeType.BACKWARD], - ids=["forward", "backward"], + # TODO(#3119): add the backward test back. + # [ComputeType.FORWARD, ComputeType.BACKWARD], + # ids=["forward", "backward"], + [ComputeType.FORWARD], + ids=["forward"], ) def test_transformer_layer(mpi_test, benchmark, compute_type): # Hyperparameters for GPT-3