From 68947c829a0f6c3af7c8183b1ca8ae18e690001b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 22 Jun 2021 14:26:38 -0700 Subject: [PATCH] [UnitTests] Added meta-tests for tvm.testing functionality --- python/tvm/testing.py | 48 ++++-- .../unittest/test_tvm_testing_features.py | 149 ++++++++++++++++++ 2 files changed, 187 insertions(+), 10 deletions(-) create mode 100644 tests/python/unittest/test_tvm_testing_features.py diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 896f3b3c521a4..8178b0a14b292 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -548,6 +548,26 @@ def requires_cuda(*args): return _compose(args, _requires_cuda) +def requires_nvptx(*args): + """Mark a test as requiring the NVPTX compilation on the CUDA runtime + + This also marks the test as requiring a cuda gpu, and requiring + LLVM support. + + Parameters + ---------- + f : function + Function to mark + + """ + _requires_nvptx = [ + pytest.mark.skipif(not device_enabled("nvptx"), reason="NVPTX support not enabled"), + *requires_llvm(), + *requires_gpu(), + ] + return _compose(args, _requires_nvptx) + + def requires_cudagraph(*args): """Mark a test as requiring the CUDA Graph Feature @@ -718,7 +738,7 @@ def _target_to_requirement(target): if target.startswith("vulkan"): return requires_vulkan() if target.startswith("nvptx"): - return [*requires_llvm(), *requires_gpu()] + return requires_nvptx() if target.startswith("metal"): return requires_metal() if target.startswith("opencl"): @@ -753,6 +773,7 @@ def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): reason='Known failing test for target "{}"'.format(t["target_kind"]) ) ) + target_marks.append((t["target"], extra_marks)) else: @@ -971,8 +992,10 @@ def parameter(*values, ids=None): """ + # Optional cls parameter in case a parameter is defined inside a + # class scope. @pytest.fixture(params=values, ids=ids) - def as_fixture(request): + def as_fixture(*_cls, request): return request.param return as_fixture @@ -1030,7 +1053,9 @@ def parameters(*value_sets): outputs = [] for param_values in zip(*value_sets): - def fixture_func(request): + # Optional cls parameter in case a parameter is defined inside a + # class scope. + def fixture_func(*_cls, request): return request.param fixture_func.parametrize_group = parametrize_group @@ -1137,7 +1162,11 @@ def wraps(func): def _fixture_cache(func): cache = {} - num_uses = 0 + + # Can't use += on a bound method's property. Therefore, this is a + # list rather than a variable so that it can be accessed from the + # pytest_collection_modifyitems(). + num_uses_remaining = [0] # Using functools.lru_cache would require the function arguments # to be hashable, which wouldn't allow caching fixtures that @@ -1186,13 +1215,12 @@ def wrapper(*args, **kwargs): finally: # Clear the cache once all tests that use a particular fixture # have completed. - nonlocal num_uses - num_uses += 1 - if num_uses == wrapper.num_tests_use_this: + num_uses_remaining[0] -= 1 + if not num_uses_remaining[0]: cache.clear() # Set in the pytest_collection_modifyitems() - wrapper.num_tests_use_this = 0 + wrapper.num_uses_remaining = num_uses_remaining return wrapper @@ -1210,8 +1238,8 @@ def _count_num_fixture_uses(items): for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): # Only increment the active fixturedef, in a name has been overridden. fixturedef = fixturedefs[-1] - if hasattr(fixturedef.func, "num_tests_use_this"): - fixturedef.func.num_tests_use_this += 1 + if hasattr(fixturedef.func, "num_uses_remaining"): + fixturedef.func.num_uses_remaining[0] += 1 def _remove_global_fixture_definitions(items): diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py new file mode 100644 index 0000000000000..1a7595aac5c78 --- /dev/null +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys + +import pytest + +import tvm.testing + +# This file tests features in tvm.testing, such as verifying that +# cached fixtures are run an appropriate number of times. As a +# result, the order of the tests is important. Use of --last-failed +# or --failed-first while debugging this file is not advised. + + +class TestTargetAutoParametrization: + targets_used = [] + devices_used = [] + enabled_targets = [target for target, dev in tvm.testing.enabled_targets()] + enabled_devices = [dev for target, dev in tvm.testing.enabled_targets()] + + def test_target_parametrization(self, target): + assert target in self.enabled_targets + self.targets_used.append(target) + + def test_device_parametrization(self, dev): + assert dev in self.enabled_devices + self.devices_used.append(dev) + + def test_all_targets_used(self): + assert self.targets_used == self.enabled_targets + assert self.devices_used == self.enabled_devices + + targets_with_explicit_list = [] + + @tvm.testing.parametrize_targets("llvm") + def test_explicit_list(self, target): + assert target == "llvm" + self.targets_with_explicit_list.append(target) + + def test_no_repeats_in_explicit_list(self): + assert self.targets_with_explicit_list == ["llvm"] + + targets_with_exclusion = [] + + @tvm.testing.exclude_targets("llvm") + def test_exclude_target(self, target): + assert "llvm" not in target + self.targets_with_exclusion.append(target) + + def test_all_nonexcluded_targets_ran(self): + assert self.targets_with_exclusion == [ + target for target in self.enabled_targets if not target.startswith("llvm") + ] + + run_targets_with_known_failure = [] + + @tvm.testing.known_failing_targets("llvm") + def test_known_failing_target(self, target): + # This test runs for all targets, but intentionally fails for + # llvm. The behavior is working correctly if this test shows + # up as an expected failure, xfail. + self.run_targets_with_known_failure.append(target) + assert "llvm" not in target + + def test_all_targets_ran(self): + assert self.run_targets_with_known_failure == self.enabled_targets + + +class TestJointParameter: + param1_vals = [1, 2, 3] + param2_vals = ["a", "b", "c"] + + independent_usages = 0 + param1 = tvm.testing.parameter(*param1_vals) + param2 = tvm.testing.parameter(*param2_vals) + + joint_usages = 0 + joint_param_vals = list(zip(param1_vals, param2_vals)) + joint_param1, joint_param2 = tvm.testing.parameters(*joint_param_vals) + + def test_using_independent(self, param1, param2): + type(self).independent_usages += 1 + + def test_independent(self): + assert self.independent_usages == len(self.param1_vals) * len(self.param2_vals) + + def test_using_joint(self, joint_param1, joint_param2): + type(self).joint_usages += 1 + assert (joint_param1, joint_param2) in self.joint_param_vals + + def test_joint(self): + assert self.joint_usages == len(self.joint_param_vals) + + +class TestFixtureCaching: + param1_vals = [1, 2, 3] + param2_vals = ["a", "b", "c"] + + param1 = tvm.testing.parameter(*param1_vals) + param2 = tvm.testing.parameter(*param2_vals) + + uncached_calls = 0 + cached_calls = 0 + + @tvm.testing.fixture + def uncached_fixture(self, param1): + type(self).uncached_calls += 1 + return 2 * param1 + + def test_use_uncached(self, param1, param2, uncached_fixture): + assert 2 * param1 == uncached_fixture + + def test_uncached_count(self): + assert self.uncached_calls == len(self.param1_vals) * len(self.param2_vals) + + @tvm.testing.fixture(cache_return_value=True) + def cached_fixture(self, param1): + type(self).cached_calls += 1 + return 3 * param1 + + def test_use_cached(self, param1, param2, cached_fixture): + assert 3 * param1 == cached_fixture + + def test_cached_count(self): + cache_disabled = bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))) + if cache_disabled: + assert self.cached_calls == len(self.param1_vals) * len(self.param2_vals) + else: + assert self.cached_calls == len(self.param1_vals) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv))