Skip to content

Commit

Permalink
[UnitTests] Require cached fixtures to be copy-able, with opt-in. (ap…
Browse files Browse the repository at this point in the history
…ache#8451)

* [UnitTests] Require cached fixtures to be copy-able, with opt-in.

Previously, any class that doesn't raise a TypeError in copy.deepcopy
could be used as a return value in a @tvm.testing.fixture.  This has
the possibility of incorrectly copying classes inherit the default
object.__reduce__ implementation.  Therefore, only classes that
explicitly implement copy functionality (e.g. __deepcopy__ or
__getstate__/__setstate__), or that are explicitly listed in
tvm.testing._fixture_cache are allowed to be cached.

* [UnitTests] Added TestCachedFixtureIsCopy

Verifies that tvm.testing.fixture caching returns copy of object, not
the original object.

* [UnitTests] Correct parametrization of cudnn target.

Previous checks for enabled runtimes were based only on the target
kind.  CuDNN is the same target kind as "cuda", and therefore needs
special handling.

* Change test on uncacheable to check for explicit TypeError
  • Loading branch information
Lunderberg authored and ylc committed Sep 29, 2021
1 parent 0a11c04 commit 9486faa
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 14 deletions.
86 changes: 72 additions & 14 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def test_something():
"""
import collections
import copy
import copyreg
import ctypes
import functools
import logging
import os
Expand Down Expand Up @@ -386,8 +388,14 @@ def _get_targets(target_str=None):
targets = []
for target in target_names:
target_kind = target.split()[0]
is_enabled = tvm.runtime.enabled(target_kind)
is_runnable = is_enabled and tvm.device(target_kind).exist

if target_kind == "cuda" and "cudnn" in tvm.target.Target(target).attrs.get("libs", []):
is_enabled = tvm.support.libinfo()["USE_CUDNN"].lower() in ["on", "true", "1"]
is_runnable = is_enabled and cudnn.exists()
else:
is_enabled = tvm.runtime.enabled(target_kind)
is_runnable = is_enabled and tvm.device(target_kind).exist

targets.append(
{
"target": target,
Expand Down Expand Up @@ -1251,6 +1259,60 @@ def wraps(func):
return wraps(func)


class _DeepCopyAllowedClasses(dict):
def __init__(self, allowed_class_list):
self.allowed_class_list = allowed_class_list
super().__init__()

def get(self, key, *args, **kwargs):
"""Overrides behavior of copy.deepcopy to avoid implicit copy.
By default, copy.deepcopy uses a dict of id->object to track
all objects that it has seen, which is passed as the second
argument to all recursive calls. This class is intended to be
passed in instead, and inspects the type of all objects being
copied.
Where copy.deepcopy does a best-effort attempt at copying an
object, for unit tests we would rather have all objects either
be copied correctly, or to throw an error. Classes that
define an explicit method to perform a copy are allowed, as
are any explicitly listed classes. Classes that would fall
back to using object.__reduce__, and are not explicitly listed
as safe, will throw an exception.
"""
obj = ctypes.cast(key, ctypes.py_object).value
cls = type(obj)
if (
cls in copy._deepcopy_dispatch
or issubclass(cls, type)
or getattr(obj, "__deepcopy__", None)
or copyreg.dispatch_table.get(cls)
or cls.__reduce__ is not object.__reduce__
or cls.__reduce_ex__ is not object.__reduce_ex__
or cls in self.allowed_class_list
):
return super().get(key, *args, **kwargs)

rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md"
)
raise TypeError(
(
f"Cannot copy fixture of type {cls.__name__}. TVM fixture caching "
"is limited to objects that explicitly provide the ability "
"to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__),"
"and forbids the use of the default `object.__reduce__` and "
"`object.__reduce_ex__`. For third-party classes that are "
"safe to use with copy.deepcopy, please add the class to "
"the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n"
"\n"
f"For discussion on this restriction, please see {rfc_url}."
)
)


def _fixture_cache(func):
cache = {}

Expand Down Expand Up @@ -1290,18 +1352,14 @@ def wrapper(*args, **kwargs):
except KeyError:
cached_value = cache[cache_key] = func(*args, **kwargs)

try:
yield copy.deepcopy(cached_value)
except TypeError as e:
rfc_url = (
"https://github.com/apache/tvm-rfcs/blob/main/rfcs/"
"0007-parametrized-unit-tests.md#unresolved-questions"
)
message = (
"TVM caching of fixtures can only be used on serializable data types, not {}.\n"
"Please see {} for details/discussion."
).format(type(cached_value), rfc_url)
raise TypeError(message) from e
yield copy.deepcopy(
cached_value,
# allowed_class_list should be a list of classes that
# are safe to copy using copy.deepcopy, but do not
# implement __deepcopy__, __reduce__, or
# __reduce_ex__.
_DeepCopyAllowedClasses(allowed_class_list=[]),
)

finally:
# Clear the cache once all tests that use a particular fixture
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_tvm_testing_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ def test_cached_count(self):
assert self.cached_calls == len(self.param1_vals)


class TestCachedFixtureIsCopy:
param = tvm.testing.parameter(1, 2, 3, 4)

@tvm.testing.fixture(cache_return_value=True)
def cached_mutable_fixture(self):
return {"val": 0}

def test_modifies_fixture(self, param, cached_mutable_fixture):
assert cached_mutable_fixture["val"] == 0

# The tests should receive a copy of the fixture value. If
# the test receives the original and not a copy, then this
# will cause the next parametrization to fail.
cached_mutable_fixture["val"] = param


class TestBrokenFixture:
# Tests that use a fixture that throws an exception fail, and are
# marked as setup failures. The tests themselves are never run.
Expand Down Expand Up @@ -210,5 +226,44 @@ def test_pytest_mark_covariant(self, request, target, other_param):
self.check_marks(request, target)


@pytest.mark.skipif(
bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))),
reason="Cannot test cache behavior while caching is disabled",
)
class TestCacheableTypes:
class EmptyClass:
pass

@tvm.testing.fixture(cache_return_value=True)
def uncacheable_fixture(self):
return self.EmptyClass()

def test_uses_uncacheable(self, request):
with pytest.raises(TypeError):
request.getfixturevalue("uncacheable_fixture")

class ImplementsReduce:
def __reduce__(self):
return super().__reduce__()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_reduce(self):
return self.ImplementsReduce()

def test_uses_reduce(self, fixture_with_reduce):
pass

class ImplementsDeepcopy:
def __deepcopy__(self, memo):
return type(self)()

@tvm.testing.fixture(cache_return_value=True)
def fixture_with_deepcopy(self):
return self.ImplementsDeepcopy()

def test_uses_deepcopy(self, fixture_with_deepcopy):
pass


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 9486faa

Please sign in to comment.