Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_compile_static_cache #30991

Merged
merged 4 commits into from
Jun 3, 2024
Merged

Fix test_compile_static_cache #30991

merged 4 commits into from
Jun 3, 2024

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented May 23, 2024

What does this PR do?

Fix or skip test_compile_static_cache on Nvida T4 GPU. Those tests are never run on T4 until we switch to torch 2.3.

  • For LLama: just update the output values for T4.
  • For Mistral: torch.compile failed. Let's skip it for T4 and focus on A10 CI (important model testing CI)
  • (Gemma don't need any update)

cc @ArthurZucker @zhenglongjiepheonix

Comment on lines +669 to +671
if self.cuda_compute_capability_major_version == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by fail do you mean errors out?

Copy link
Collaborator Author

@ydshieh ydshieh May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes (I think)

Short error log (in /usr/local/lib/python3.8/dist-packages/torch/_inductor/utils.py:170: in do_bench)

E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True
Full error log
tests/models/mistral/test_modeling_mistral.py:711:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py:115: in decorate_context
    return func(*args, **kwargs)
src/transformers/generation/utils.py:1785: in generate
    result = self._sample(
src/transformers/generation/utils.py:2424: in _sample
    outputs = self(
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py:451: in _fn
    return fn(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py:921: in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py:400: in _convert_frame_assert
    return _compile(
/usr/lib/python3.8/contextlib.py:75: in inner
    return func(*args, **kwds)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py:676: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py:262: in time_wrapper
    r = func(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py:535: in compile_inner
    out_code = transform_code_object(code, transform)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py:1036: in transform_code_object
    transformations(instructions, code_options)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py:165: in _fn
    return fn(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py:500: in transform
    tracer.run()
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py:2149: in run
    super().run()
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py:810: in run
    and self.step()
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py:773: in step
    getattr(self, inst.opname)(inst)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py:2268: in RETURN_VALUE
    self.output.compile_subgraph(
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py:991: in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
/usr/lib/python3.8/contextlib.py:75: in inner
    return func(*args, **kwds)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py:1168: in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py:262: in time_wrapper
    r = func(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py:1241: in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py:1222: in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/repro/after_dynamo.py:117: in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/repro/after_dynamo.py:117: in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
/usr/local/lib/python3.8/dist-packages/torch/__init__.py:1729: in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
/usr/lib/python3.8/contextlib.py:75: in inner
    return func(*args, **kwds)
/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py:1102: in compile_fx
    return compile_fx(
/usr/lib/python3.8/contextlib.py:75: in inner
    return func(*args, **kwds)
/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py:1330: in compile_fx
    return aot_autograd(
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/backends/common.py:58: in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py:903: in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py:262: in time_wrapper
    r = func(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py:628: in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:443: in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:648: in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:119: in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py:262: in time_wrapper
    r = func(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py:1199: in fw_compiler_base
    _recursive_joint_graph_passes(model)
/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py:207: in _recursive_joint_graph_passes
    joint_graph_passes(gm)
/usr/local/lib/python3.8/dist-packages/torch/_inductor/fx_passes/joint_graph.py:292: in joint_graph_passes
    count += patterns.apply(graph.graph)  # type: ignore[arg-type]
/usr/local/lib/python3.8/dist-packages/torch/_inductor/pattern_matcher.py:1267: in apply
    if is_match(m) and entry.extra_check(m):
/usr/local/lib/python3.8/dist-packages/torch/_inductor/pattern_matcher.py:1055: in check_fn
    if specific_pattern_match and extra_check(specific_pattern_match):
/usr/local/lib/python3.8/dist-packages/torch/_inductor/fx_passes/pad_mm.py:420: in should_pad_mm
    return should_pad_common(mat1, mat2) and should_pad_bench(
/usr/local/lib/python3.8/dist-packages/torch/_inductor/fx_passes/pad_mm.py:352: in should_pad_bench
    ori_time = do_bench(
/usr/local/lib/python3.8/dist-packages/torch/_inductor/utils.py:170: in do_bench
    return triton_do_bench(*args, **kwargs)[0]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

fn = <function should_pad_bench.<locals>.<lambda> at 0x7f6c8786c280>, warmup = 5, rep = 100, grad_to_none = None, quantiles = (0.5, 0.2, 0.8), fast_flush = True, return_mode = 'mean'

    def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
        assert return_mode in ["min", "max", "mean", "median"]
        import torch
        """
        Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
        the 20-th and 80-th performance percentile.

        :param fn: Function to benchmark
        :type fn: Callable
        :param warmup: Warmup time (in ms)
        :type warmup: int
        :param rep: Repetition time (in ms)
        :type rep: int
        :param grad_to_none: Reset the gradient of the provided tensor to None
        :type grad_to_none: torch.tensor, optional
        :param quantiles: Performance percentile to return in addition to the median.
        :type quantiles: list[float]
        :param fast_flush: Use faster kernel to flush L2 between measurements
        :type fast_flush: bool
        """

        fn()
        torch.cuda.synchronize()

        # We maintain a buffer of 256 MB that we clear
        # before each kernel call to make sure that the L2
        # doesn't contain any input data before the run
E           
E           
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it's not faiing but erroring out!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what do the dynamo log tell you ?

Copy link
Contributor

@zhenglongjiepheonix zhenglongjiepheonix May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try investigating it on T4, probably memory constraint related issues

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what do the dynamo log tell you ?

The above Full error log (you can expand it) is all I got, but in short

        # We maintain a buffer of 256 MB that we clear
        # before each kernel call to make sure that the L2
        # doesn't contain any input data before the run

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that looks like it did some re-compilation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will open a github issue for this for @zhenglongjiepheonix to investigate on it to come up with a conclusion and/or a fix, and let's merge for now? (It works for A10 IIRC).

As discussed before, we would like to start moving the failing tests from CI reports to GitHub issue pages, so we track them in another (and we could say better) way, so our CI reports remains in a (reasonable) good state that would enable us to focus on the new failures.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker

Opened an issue here #31015

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @fxmarty for visibility!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 28, 2024

Hi @ArthurZucker any comment?

@ydshieh ydshieh changed the title Fix ci 0002 Fix test_compile_static_cache May 29, 2024
Comment on lines +669 to +671
if self.cuda_compute_capability_major_version == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @fxmarty for visibility!

@ydshieh ydshieh merged commit df848ac into main Jun 3, 2024
21 checks passed
@ydshieh ydshieh deleted the fix_ci_0002 branch June 3, 2024 13:16
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants