Skip to content

Commit

Permalink
[UnitTests] Added meta-tests for tvm.testing functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Jun 23, 2021
1 parent 0da04a8 commit 68947c8
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 10 deletions.
48 changes: 38 additions & 10 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,26 @@ def requires_cuda(*args):
return _compose(args, _requires_cuda)


def requires_nvptx(*args):
"""Mark a test as requiring the NVPTX compilation on the CUDA runtime
This also marks the test as requiring a cuda gpu, and requiring
LLVM support.
Parameters
----------
f : function
Function to mark
"""
_requires_nvptx = [
pytest.mark.skipif(not device_enabled("nvptx"), reason="NVPTX support not enabled"),
*requires_llvm(),
*requires_gpu(),
]
return _compose(args, _requires_nvptx)


def requires_cudagraph(*args):
"""Mark a test as requiring the CUDA Graph Feature
Expand Down Expand Up @@ -718,7 +738,7 @@ def _target_to_requirement(target):
if target.startswith("vulkan"):
return requires_vulkan()
if target.startswith("nvptx"):
return [*requires_llvm(), *requires_gpu()]
return requires_nvptx()
if target.startswith("metal"):
return requires_metal()
if target.startswith("opencl"):
Expand Down Expand Up @@ -753,6 +773,7 @@ def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None):
reason='Known failing test for target "{}"'.format(t["target_kind"])
)
)

target_marks.append((t["target"], extra_marks))

else:
Expand Down Expand Up @@ -971,8 +992,10 @@ def parameter(*values, ids=None):
"""

# Optional cls parameter in case a parameter is defined inside a
# class scope.
@pytest.fixture(params=values, ids=ids)
def as_fixture(request):
def as_fixture(*_cls, request):
return request.param

return as_fixture
Expand Down Expand Up @@ -1030,7 +1053,9 @@ def parameters(*value_sets):
outputs = []
for param_values in zip(*value_sets):

def fixture_func(request):
# Optional cls parameter in case a parameter is defined inside a
# class scope.
def fixture_func(*_cls, request):
return request.param

fixture_func.parametrize_group = parametrize_group
Expand Down Expand Up @@ -1137,7 +1162,11 @@ def wraps(func):

def _fixture_cache(func):
cache = {}
num_uses = 0

# Can't use += on a bound method's property. Therefore, this is a
# list rather than a variable so that it can be accessed from the
# pytest_collection_modifyitems().
num_uses_remaining = [0]

# Using functools.lru_cache would require the function arguments
# to be hashable, which wouldn't allow caching fixtures that
Expand Down Expand Up @@ -1186,13 +1215,12 @@ def wrapper(*args, **kwargs):
finally:
# Clear the cache once all tests that use a particular fixture
# have completed.
nonlocal num_uses
num_uses += 1
if num_uses == wrapper.num_tests_use_this:
num_uses_remaining[0] -= 1
if not num_uses_remaining[0]:
cache.clear()

# Set in the pytest_collection_modifyitems()
wrapper.num_tests_use_this = 0
wrapper.num_uses_remaining = num_uses_remaining

return wrapper

Expand All @@ -1210,8 +1238,8 @@ def _count_num_fixture_uses(items):
for fixturedefs in item._fixtureinfo.name2fixturedefs.values():
# Only increment the active fixturedef, in a name has been overridden.
fixturedef = fixturedefs[-1]
if hasattr(fixturedef.func, "num_tests_use_this"):
fixturedef.func.num_tests_use_this += 1
if hasattr(fixturedef.func, "num_uses_remaining"):
fixturedef.func.num_uses_remaining[0] += 1


def _remove_global_fixture_definitions(items):
Expand Down
149 changes: 149 additions & 0 deletions tests/python/unittest/test_tvm_testing_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import sys

import pytest

import tvm.testing

# This file tests features in tvm.testing, such as verifying that
# cached fixtures are run an appropriate number of times. As a
# result, the order of the tests is important. Use of --last-failed
# or --failed-first while debugging this file is not advised.


class TestTargetAutoParametrization:
targets_used = []
devices_used = []
enabled_targets = [target for target, dev in tvm.testing.enabled_targets()]
enabled_devices = [dev for target, dev in tvm.testing.enabled_targets()]

def test_target_parametrization(self, target):
assert target in self.enabled_targets
self.targets_used.append(target)

def test_device_parametrization(self, dev):
assert dev in self.enabled_devices
self.devices_used.append(dev)

def test_all_targets_used(self):
assert self.targets_used == self.enabled_targets
assert self.devices_used == self.enabled_devices

targets_with_explicit_list = []

@tvm.testing.parametrize_targets("llvm")
def test_explicit_list(self, target):
assert target == "llvm"
self.targets_with_explicit_list.append(target)

def test_no_repeats_in_explicit_list(self):
assert self.targets_with_explicit_list == ["llvm"]

targets_with_exclusion = []

@tvm.testing.exclude_targets("llvm")
def test_exclude_target(self, target):
assert "llvm" not in target
self.targets_with_exclusion.append(target)

def test_all_nonexcluded_targets_ran(self):
assert self.targets_with_exclusion == [
target for target in self.enabled_targets if not target.startswith("llvm")
]

run_targets_with_known_failure = []

@tvm.testing.known_failing_targets("llvm")
def test_known_failing_target(self, target):
# This test runs for all targets, but intentionally fails for
# llvm. The behavior is working correctly if this test shows
# up as an expected failure, xfail.
self.run_targets_with_known_failure.append(target)
assert "llvm" not in target

def test_all_targets_ran(self):
assert self.run_targets_with_known_failure == self.enabled_targets


class TestJointParameter:
param1_vals = [1, 2, 3]
param2_vals = ["a", "b", "c"]

independent_usages = 0
param1 = tvm.testing.parameter(*param1_vals)
param2 = tvm.testing.parameter(*param2_vals)

joint_usages = 0
joint_param_vals = list(zip(param1_vals, param2_vals))
joint_param1, joint_param2 = tvm.testing.parameters(*joint_param_vals)

def test_using_independent(self, param1, param2):
type(self).independent_usages += 1

def test_independent(self):
assert self.independent_usages == len(self.param1_vals) * len(self.param2_vals)

def test_using_joint(self, joint_param1, joint_param2):
type(self).joint_usages += 1
assert (joint_param1, joint_param2) in self.joint_param_vals

def test_joint(self):
assert self.joint_usages == len(self.joint_param_vals)


class TestFixtureCaching:
param1_vals = [1, 2, 3]
param2_vals = ["a", "b", "c"]

param1 = tvm.testing.parameter(*param1_vals)
param2 = tvm.testing.parameter(*param2_vals)

uncached_calls = 0
cached_calls = 0

@tvm.testing.fixture
def uncached_fixture(self, param1):
type(self).uncached_calls += 1
return 2 * param1

def test_use_uncached(self, param1, param2, uncached_fixture):
assert 2 * param1 == uncached_fixture

def test_uncached_count(self):
assert self.uncached_calls == len(self.param1_vals) * len(self.param2_vals)

@tvm.testing.fixture(cache_return_value=True)
def cached_fixture(self, param1):
type(self).cached_calls += 1
return 3 * param1

def test_use_cached(self, param1, param2, cached_fixture):
assert 3 * param1 == cached_fixture

def test_cached_count(self):
cache_disabled = bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0")))
if cache_disabled:
assert self.cached_calls == len(self.param1_vals) * len(self.param2_vals)
else:
assert self.cached_calls == len(self.param1_vals)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 68947c8

Please sign in to comment.