diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 6eeb9ab03f602..10a02350a821b 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -56,6 +56,8 @@ def test_something(): """ import collections import copy +import copyreg +import ctypes import functools import logging import os @@ -1229,6 +1231,60 @@ def wraps(func): return wraps(func) +class _DeepCopyAllowedClasses(dict): + def __init__(self, allowed_class_list): + self.allowed_class_list = allowed_class_list + super().__init__() + + def get(self, key, *args, **kwargs): + """Overrides behavior of copy.deepcopy to avoid implicit copy. + + By default, copy.deepcopy uses a dict of id->object to track + all objects that it has seen, which is passed as the second + argument to all recursive calls. This class is intended to be + passed in instead, and inspects the type of all objects being + copied. + + Where copy.deepcopy does a best-effort attempt at copying an + object, for unit tests we would rather have all objects either + be copied correctly, or to throw an error. Classes that + define an explicit method to perform a copy are allowed, as + are any explicitly listed classes. Classes that would fall + back to using object.__reduce__, and are not explicitly listed + as safe, will throw an exception. + + """ + obj = ctypes.cast(key, ctypes.py_object).value + cls = type(obj) + if ( + cls in copy._deepcopy_dispatch + or issubclass(cls, type) + or getattr(obj, "__deepcopy__", None) + or copyreg.dispatch_table.get(cls) + or cls.__reduce__ is not object.__reduce__ + or cls.__reduce_ex__ is not object.__reduce_ex__ + or cls in self.allowed_class_list + ): + return super().get(key, *args, **kwargs) + + rfc_url = ( + "https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md" + ) + raise TypeError( + ( + f"Cannot copy fixture of type {cls.__name__}. TVM fixture caching " + "is limited to objects that explicitly provide the ability " + "to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__)," + "and forbids the use of the default `object.__reduce__` and " + "`object.__reduce_ex__`. For third-party classes that are " + "safe to use with copy.deepcopy, please add the class to " + "the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n" + "\n" + f"For discussion on this restriction, please see {rfc_url}." + ) + ) + + def _fixture_cache(func): cache = {} @@ -1268,18 +1324,14 @@ def wrapper(*args, **kwargs): except KeyError: cached_value = cache[cache_key] = func(*args, **kwargs) - try: - yield copy.deepcopy(cached_value) - except TypeError as e: - rfc_url = ( - "https://github.com/apache/tvm-rfcs/blob/main/rfcs/" - "0007-parametrized-unit-tests.md#unresolved-questions" - ) - message = ( - "TVM caching of fixtures can only be used on serializable data types, not {}.\n" - "Please see {} for details/discussion." - ).format(type(cached_value), rfc_url) - raise TypeError(message) from e + yield copy.deepcopy( + cached_value, + # allowed_class_list should be a list of classes that + # are safe to copy using copy.deepcopy, but do not + # implement __deepcopy__, __reduce__, or + # __reduce_ex__. + _DeepCopyAllowedClasses(allowed_class_list=[]), + ) finally: # Clear the cache once all tests that use a particular fixture diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index 4b699096e96ae..df3ccaca5cc64 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -210,5 +210,44 @@ def test_pytest_mark_covariant(self, request, target, other_param): self.check_marks(request, target) +@pytest.mark.skipif( + bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))), + reason="Cannot test cache behavior while caching is disabled", +) +class TestCacheableTypes: + class EmptyClass: + pass + + @tvm.testing.fixture(cache_return_value=True) + def uncacheable_fixture(self): + return self.EmptyClass() + + @pytest.mark.xfail(reason="Requests cached fixture of uncacheable type", strict=True) + def test_uses_uncacheable(self, uncacheable_fixture): + pass + + class ImplementsReduce: + def __reduce__(self): + return super().__reduce__() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_reduce(self): + return self.ImplementsReduce() + + def test_uses_reduce(self, fixture_with_reduce): + pass + + class ImplementsDeepcopy: + def __deepcopy__(self, memo): + return type(self)() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_deepcopy(self): + return self.ImplementsDeepcopy() + + def test_uses_deepcopy(self, fixture_with_deepcopy): + pass + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))