diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 71ab0770d64e..04a235b64fdf 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -56,6 +56,8 @@ def test_something(): """ import collections import copy +import copyreg +import ctypes import functools import logging import os @@ -386,8 +388,14 @@ def _get_targets(target_str=None): targets = [] for target in target_names: target_kind = target.split()[0] - is_enabled = tvm.runtime.enabled(target_kind) - is_runnable = is_enabled and tvm.device(target_kind).exist + + if target_kind == "cuda" and "cudnn" in tvm.target.Target(target).attrs.get("libs", []): + is_enabled = tvm.support.libinfo()["USE_CUDNN"].lower() in ["on", "true", "1"] + is_runnable = is_enabled and cudnn.exists() + else: + is_enabled = tvm.runtime.enabled(target_kind) + is_runnable = is_enabled and tvm.device(target_kind).exist + targets.append( { "target": target, @@ -1251,6 +1259,60 @@ def wraps(func): return wraps(func) +class _DeepCopyAllowedClasses(dict): + def __init__(self, allowed_class_list): + self.allowed_class_list = allowed_class_list + super().__init__() + + def get(self, key, *args, **kwargs): + """Overrides behavior of copy.deepcopy to avoid implicit copy. + + By default, copy.deepcopy uses a dict of id->object to track + all objects that it has seen, which is passed as the second + argument to all recursive calls. This class is intended to be + passed in instead, and inspects the type of all objects being + copied. + + Where copy.deepcopy does a best-effort attempt at copying an + object, for unit tests we would rather have all objects either + be copied correctly, or to throw an error. Classes that + define an explicit method to perform a copy are allowed, as + are any explicitly listed classes. Classes that would fall + back to using object.__reduce__, and are not explicitly listed + as safe, will throw an exception. + + """ + obj = ctypes.cast(key, ctypes.py_object).value + cls = type(obj) + if ( + cls in copy._deepcopy_dispatch + or issubclass(cls, type) + or getattr(obj, "__deepcopy__", None) + or copyreg.dispatch_table.get(cls) + or cls.__reduce__ is not object.__reduce__ + or cls.__reduce_ex__ is not object.__reduce_ex__ + or cls in self.allowed_class_list + ): + return super().get(key, *args, **kwargs) + + rfc_url = ( + "https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md" + ) + raise TypeError( + ( + f"Cannot copy fixture of type {cls.__name__}. TVM fixture caching " + "is limited to objects that explicitly provide the ability " + "to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__)," + "and forbids the use of the default `object.__reduce__` and " + "`object.__reduce_ex__`. For third-party classes that are " + "safe to use with copy.deepcopy, please add the class to " + "the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n" + "\n" + f"For discussion on this restriction, please see {rfc_url}." + ) + ) + + def _fixture_cache(func): cache = {} @@ -1290,18 +1352,14 @@ def wrapper(*args, **kwargs): except KeyError: cached_value = cache[cache_key] = func(*args, **kwargs) - try: - yield copy.deepcopy(cached_value) - except TypeError as e: - rfc_url = ( - "https://github.com/apache/tvm-rfcs/blob/main/rfcs/" - "0007-parametrized-unit-tests.md#unresolved-questions" - ) - message = ( - "TVM caching of fixtures can only be used on serializable data types, not {}.\n" - "Please see {} for details/discussion." - ).format(type(cached_value), rfc_url) - raise TypeError(message) from e + yield copy.deepcopy( + cached_value, + # allowed_class_list should be a list of classes that + # are safe to copy using copy.deepcopy, but do not + # implement __deepcopy__, __reduce__, or + # __reduce_ex__. + _DeepCopyAllowedClasses(allowed_class_list=[]), + ) finally: # Clear the cache once all tests that use a particular fixture diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index 4b699096e96a..8885f55bbf4b 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -148,6 +148,22 @@ def test_cached_count(self): assert self.cached_calls == len(self.param1_vals) +class TestCachedFixtureIsCopy: + param = tvm.testing.parameter(1, 2, 3, 4) + + @tvm.testing.fixture(cache_return_value=True) + def cached_mutable_fixture(self): + return {"val": 0} + + def test_modifies_fixture(self, param, cached_mutable_fixture): + assert cached_mutable_fixture["val"] == 0 + + # The tests should receive a copy of the fixture value. If + # the test receives the original and not a copy, then this + # will cause the next parametrization to fail. + cached_mutable_fixture["val"] = param + + class TestBrokenFixture: # Tests that use a fixture that throws an exception fail, and are # marked as setup failures. The tests themselves are never run. @@ -210,5 +226,44 @@ def test_pytest_mark_covariant(self, request, target, other_param): self.check_marks(request, target) +@pytest.mark.skipif( + bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))), + reason="Cannot test cache behavior while caching is disabled", +) +class TestCacheableTypes: + class EmptyClass: + pass + + @tvm.testing.fixture(cache_return_value=True) + def uncacheable_fixture(self): + return self.EmptyClass() + + def test_uses_uncacheable(self, request): + with pytest.raises(TypeError): + request.getfixturevalue("uncacheable_fixture") + + class ImplementsReduce: + def __reduce__(self): + return super().__reduce__() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_reduce(self): + return self.ImplementsReduce() + + def test_uses_reduce(self, fixture_with_reduce): + pass + + class ImplementsDeepcopy: + def __deepcopy__(self, memo): + return type(self)() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_deepcopy(self): + return self.ImplementsDeepcopy() + + def test_uses_deepcopy(self, fixture_with_deepcopy): + pass + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))