From 57b495d179745ee92764c8fd1e5ee84587f5a02d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 10 May 2021 08:36:14 -0700 Subject: [PATCH] [UnitTests] Added tvm_known_failing_targets option for the unittests. Intended to mark tests that fail for a particular target, and are intended to be fixed in the future. Typically, these would result either from implementing a new test, or from an in-progress implementation of a new target. --- python/tvm/testing.py | 109 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index c4bdf3c4b3cb4..59f08d62beb7b 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -708,17 +708,35 @@ def _target_to_requirement(target): return [] -def _pytest_target_params(targets, excluded_targets=None): +def _pytest_target_params(targets, excluded_targets=None, known_failing_targets=None): # Include unrunnable targets here. They get skipped by the # pytest.mark.skipif in _target_to_requirement(), showing up as # skipped tests instead of being hidden entirely. - if targets == None: + if targets is None: if excluded_targets is None: excluded_targets = set() - targets = [t["target"] for t in _get_targets() if t["target_kind"] not in excluded_targets] - - return [pytest.param(target, marks=_target_to_requirement(target)) for target in targets] + if known_failing_targets is None: + known_failing_targets = set() + + target_marks = [] + for t in _get_targets(): + if t["target_kind"] not in excluded_targets: + extra_marks = [ + pytest.mark.skipif( + t["target_kind"] in known_failing_targets, + reason='Known failing test for target "{}"'.format(t["target_kind"]), + ) + ] + target_marks.append((t["target"], extra_marks)) + + else: + target_marks = [(target, []) for target in targets] + + return [ + pytest.param(target, marks=_target_to_requirement(target) + extra_marks) + for target, extra_marks in target_marks + ] def _auto_parametrize_target(metafunc): @@ -730,6 +748,19 @@ def _auto_parametrize_target(metafunc): file. """ + + def get_param_list(plugins, param_name): + output = set() + for plugin in plugins: + if hasattr(plugin, param_name): + param = getattr(plugin, param_name, []) + # Can be defined either as a string, or a list of strings. + if isinstance(param, str): + output.add(param) + else: + output |= set(param) + return output + if "target" in metafunc.fixturenames: mark = metafunc.definition.get_closest_marker("parametrize") if not mark or "target" not in mark.args[0]: @@ -740,16 +771,12 @@ def _auto_parametrize_target(metafunc): metafunc.function, metafunc.module, ] - excluded_targets = set() - for plugin in plugins: - if hasattr(plugin, "tvm_excluded_targets"): - # Can be defined either as a string, or a list of strings. - if isinstance(plugin.tvm_excluded_targets, str): - excluded_targets.add(plugin.tvm_excluded_targets) - else: - excluded_targets |= set(plugin.tvm_excluded_targets) + excluded_targets = get_param_list(plugins, "tvm_excluded_targets") + known_failing_targets = get_param_list(plugins, "tvm_known_failing_targets") - metafunc.parametrize("target", _pytest_target_params(None, excluded_targets)) + metafunc.parametrize( + "target", _pytest_target_params(None, excluded_targets, known_failing_targets) + ) def parametrize_targets(*args): @@ -815,7 +842,7 @@ def exclude_targets(*args): f : function Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:, where `xxxxxxxxx` is any name. - targets : list[str], optional + targets : list[str] Set of targets to exclude. Example @@ -839,6 +866,58 @@ def wraps(func): return wraps +def known_failing_targets(*args): + """Skip a test that is known to fail on a particular target. + + Use this decorator when you want your test to be run over a + variety of targets and devices (including cpu and gpu devices), + but know that it fails for some targets. For example, a newly + implemented runtime may not support all features being tested, and + should be excluded. + + Alternatively, this can be specified in the conftest.py or the + file containing the test by setting the global variable + "tvm_known_failing_targets". + + This is distinct from :py:func:`exclude_targets`, as these known + failing tests are still included in the final report as being + skipped, and show up in detailed views. Where + :py:func:`exclude_targets` is intended to mark targets that are + inherently incompatible with the test being run, + :py:func:`known_failing_targets` is intended to mark known failure + modes that are either exposed by implementing new tests, or lack + of feature implementation in a newly-implemented runtime, and will + be resolved in the future. + + Parameters + ---------- + f : function + Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:, + where `xxxxxxxxx` is any name. + targets : list[str] + Set of targets to skip. + + Example + ------- + >>> @tvm.testing.known_failing_targets("cuda") + >>> def test_mytest(target, dev): + >>> ... # do something + + Or + + >>> @tvm.testing.known_failing_targets("llvm", "cuda") + >>> def test_mytest(target, dev): + >>> ... # do something + + """ + + def wraps(func): + func.tvm_known_failing_targets = args + return func + + return wraps + + def identity_after(x, sleep): """Testing function to return identity after sleep