Skip to content

Commit

Permalink
[UnitTests] Added tvm_known_failing_targets option for the unittests.
Browse files Browse the repository at this point in the history
Intended to mark tests that fail for a particular target, and are
intended to be fixed in the future.  Typically, these would result
either from implementing a new test, or from an in-progress
implementation of a new target.
  • Loading branch information
Lunderberg committed May 10, 2021
1 parent 2692113 commit 57b495d
Showing 1 changed file with 94 additions and 15 deletions.
109 changes: 94 additions & 15 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,17 +708,35 @@ def _target_to_requirement(target):
return []


def _pytest_target_params(targets, excluded_targets=None):
def _pytest_target_params(targets, excluded_targets=None, known_failing_targets=None):
# Include unrunnable targets here. They get skipped by the
# pytest.mark.skipif in _target_to_requirement(), showing up as
# skipped tests instead of being hidden entirely.
if targets == None:
if targets is None:
if excluded_targets is None:
excluded_targets = set()

targets = [t["target"] for t in _get_targets() if t["target_kind"] not in excluded_targets]

return [pytest.param(target, marks=_target_to_requirement(target)) for target in targets]
if known_failing_targets is None:
known_failing_targets = set()

target_marks = []
for t in _get_targets():
if t["target_kind"] not in excluded_targets:
extra_marks = [
pytest.mark.skipif(
t["target_kind"] in known_failing_targets,
reason='Known failing test for target "{}"'.format(t["target_kind"]),
)
]
target_marks.append((t["target"], extra_marks))

else:
target_marks = [(target, []) for target in targets]

return [
pytest.param(target, marks=_target_to_requirement(target) + extra_marks)
for target, extra_marks in target_marks
]


def _auto_parametrize_target(metafunc):
Expand All @@ -730,6 +748,19 @@ def _auto_parametrize_target(metafunc):
file.
"""

def get_param_list(plugins, param_name):
output = set()
for plugin in plugins:
if hasattr(plugin, param_name):
param = getattr(plugin, param_name, [])
# Can be defined either as a string, or a list of strings.
if isinstance(param, str):
output.add(param)
else:
output |= set(param)
return output

if "target" in metafunc.fixturenames:
mark = metafunc.definition.get_closest_marker("parametrize")
if not mark or "target" not in mark.args[0]:
Expand All @@ -740,16 +771,12 @@ def _auto_parametrize_target(metafunc):
metafunc.function,
metafunc.module,
]
excluded_targets = set()
for plugin in plugins:
if hasattr(plugin, "tvm_excluded_targets"):
# Can be defined either as a string, or a list of strings.
if isinstance(plugin.tvm_excluded_targets, str):
excluded_targets.add(plugin.tvm_excluded_targets)
else:
excluded_targets |= set(plugin.tvm_excluded_targets)
excluded_targets = get_param_list(plugins, "tvm_excluded_targets")
known_failing_targets = get_param_list(plugins, "tvm_known_failing_targets")

metafunc.parametrize("target", _pytest_target_params(None, excluded_targets))
metafunc.parametrize(
"target", _pytest_target_params(None, excluded_targets, known_failing_targets)
)


def parametrize_targets(*args):
Expand Down Expand Up @@ -815,7 +842,7 @@ def exclude_targets(*args):
f : function
Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:,
where `xxxxxxxxx` is any name.
targets : list[str], optional
targets : list[str]
Set of targets to exclude.
Example
Expand All @@ -839,6 +866,58 @@ def wraps(func):
return wraps


def known_failing_targets(*args):
"""Skip a test that is known to fail on a particular target.
Use this decorator when you want your test to be run over a
variety of targets and devices (including cpu and gpu devices),
but know that it fails for some targets. For example, a newly
implemented runtime may not support all features being tested, and
should be excluded.
Alternatively, this can be specified in the conftest.py or the
file containing the test by setting the global variable
"tvm_known_failing_targets".
This is distinct from :py:func:`exclude_targets`, as these known
failing tests are still included in the final report as being
skipped, and show up in detailed views. Where
:py:func:`exclude_targets` is intended to mark targets that are
inherently incompatible with the test being run,
:py:func:`known_failing_targets` is intended to mark known failure
modes that are either exposed by implementing new tests, or lack
of feature implementation in a newly-implemented runtime, and will
be resolved in the future.
Parameters
----------
f : function
Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:,
where `xxxxxxxxx` is any name.
targets : list[str]
Set of targets to skip.
Example
-------
>>> @tvm.testing.known_failing_targets("cuda")
>>> def test_mytest(target, dev):
>>> ... # do something
Or
>>> @tvm.testing.known_failing_targets("llvm", "cuda")
>>> def test_mytest(target, dev):
>>> ... # do something
"""

def wraps(func):
func.tvm_known_failing_targets = args
return func

return wraps


def identity_after(x, sleep):
"""Testing function to return identity after sleep
Expand Down

0 comments on commit 57b495d

Please sign in to comment.