Skip to content

Commit

Permalink
Implement __reduce__ with cache_hash
Browse files Browse the repository at this point in the history
This fixes GH issue #613 and #494. It turns out that the hash
cache-clearing implementation for non-slots classes was flawed and never
quite worked properly. This switches away from using __setstate__ and
instead adds a custom __reduce__ that removes the cached hash value from
the default serialized output.

This commit also refactors some of the tests a bit, to try and more
cleanly organize the tests related to this issue.
  • Loading branch information
pganssle committed Jan 13, 2020
1 parent 9b5e988 commit c5b1dfd
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 100 deletions.
46 changes: 18 additions & 28 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,17 @@ def _frozen_delattrs(self, name):
raise FrozenInstanceError()


def _cache_hash_reduce(self):
obj_reduce = object.__reduce__(self)
if len(obj_reduce) > 2:
state = obj_reduce[2]

if isinstance(state, dict) and _hash_cache_field in state:
state[_hash_cache_field] = None

return obj_reduce


class _ClassBuilder(object):
"""
Iteratively build *one* class.
Expand Down Expand Up @@ -483,6 +494,13 @@ def __init__(
self._cls_dict["__setattr__"] = _frozen_setattrs
self._cls_dict["__delattr__"] = _frozen_delattrs

if (
cache_hash
and cls.__reduce__ is object.__reduce__
and cls.__reduce_ex__ is object.__reduce_ex__
):
self._cls_dict["__reduce__"] = _cache_hash_reduce

def __repr__(self):
return "<_ClassBuilder(cls={cls})>".format(cls=self._cls.__name__)

Expand Down Expand Up @@ -523,34 +541,6 @@ def _patch_original_class(self):
for name, value in self._cls_dict.items():
setattr(cls, name, value)

# Attach __setstate__. This is necessary to clear the hash code
# cache on deserialization. See issue
# https://github.com/python-attrs/attrs/issues/482 .
# Note that this code only handles setstate for dict classes.
# For slotted classes, see similar code in _create_slots_class .
if self._cache_hash:
existing_set_state_method = getattr(cls, "__setstate__", None)
if existing_set_state_method:
raise NotImplementedError(
"Currently you cannot use hash caching if "
"you specify your own __setstate__ method."
"See https://github.com/python-attrs/attrs/issues/494 ."
)

# Clears the cached hash state on serialization; for frozen
# classes we need to bypass the class's setattr method.
if self._frozen:

def cache_hash_set_state(chss_self, _):
object.__setattr__(chss_self, _hash_cache_field, None)

else:

def cache_hash_set_state(chss_self, _):
setattr(chss_self, _hash_cache_field, None)

cls.__setstate__ = cache_hash_set_state

return cls

def _create_slots_class(self):
Expand Down
135 changes: 68 additions & 67 deletions tests/test_dunders.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,23 @@ def test_str_no_repr(self):
) == e.value.args[0]


# these are for use in TestAddHash.test_cache_hash_serialization
# they need to be out here so they can be un-pickled
@attr.attrs(hash=True, cache_hash=False)
class HashCacheSerializationTestUncached(object):
foo_value = attr.ib(default=20)


@attr.attrs(hash=True, cache_hash=True)
class HashCacheSerializationTestCached(object):
foo_value = attr.ib(default=20)


@attr.attrs(slots=True, hash=True, cache_hash=True)
class HashCacheSerializationTestCachedSlots(object):
foo_value = attr.ib(default=20)


class TestAddHash(object):
"""
Tests for `_add_hash`.
Expand Down Expand Up @@ -492,85 +509,69 @@ def __hash__(self):
assert 2 == uncached_instance.hash_counter.times_hash_called
assert 1 == cached_instance.hash_counter.times_hash_called

def test_cache_hash_serialization(self):
@pytest.mark.parametrize("cache_hash", [True, False])
@pytest.mark.parametrize("frozen", [True, False])
@pytest.mark.parametrize("slots", [True, False])
def test_copy_hash_cleared(self, cache_hash, frozen, slots):
"""
Tests that the hash cache is cleared on deserialization to fix
https://github.com/python-attrs/attrs/issues/482 .
Test that the default hash is recalculated after a copy operation.
"""

# First, check that our fix didn't break serialization without
# hash caching.
# We don't care about the result of this; we just want to make sure we
# can do it without exceptions.
hash(pickle.loads(pickle.dumps(HashCacheSerializationTestUncached)))

def assert_hash_code_not_cached_across_serialization(original):
# Now check our fix for #482 for when hash caching is enabled.
original_hash = hash(original)
round_tripped = pickle.loads(pickle.dumps(original))
# What we want to guard against is having a stale hash code
# when a field's hash code differs in a new interpreter after
# deserialization. This is tricky to test because we are,
# of course, still running in the same interpreter. So
# after deserialization we reach in and change the value of
# a field to simulate the field changing its hash code. We then
# check that the object's hash code changes, indicating that we
# don't have a stale hash code.
# This could fail in two ways: (1) pickle.loads could get the hash
# code of the deserialized value (triggering it to cache) before
# we alter the field value. This doesn't happen in our tested
# Python versions. (2) "foo" and "something different" could
# have a hash collision on this interpreter run. But this is
# extremely improbable and would just result in one buggy test run.
round_tripped.foo_string = "something different"
assert original_hash != hash(round_tripped)

# Slotted and dict classes implement __setstate__ differently,
# so we need to test both cases.
assert_hash_code_not_cached_across_serialization(
HashCacheSerializationTestCached()
)
assert_hash_code_not_cached_across_serialization(
HashCacheSerializationTestCachedSlots()
)
kwargs = dict(frozen=frozen, slots=slots, cache_hash=cache_hash,)

def test_caching_and_custom_setstate(self):
"""
The combination of a custom __setstate__ and cache_hash=True is caught
with a helpful message.
# Ensure that we can mutate the copied object if it's frozen
# and that the hash can be calculated if not.
if frozen:
_setattr = object.__setattr__
else:
kwargs["hash"] = True
_setattr = setattr

This is needed because we handle clearing the cache after
deserialization with a custom __setstate__. It is possible to make both
work, but it requires some thought about how to go about it, so it has
not yet been implemented.
"""
with pytest.raises(
NotImplementedError,
match="Currently you cannot use hash caching if you "
"specify your own __setstate__ method.",
):
@attr.s(**kwargs)
class C(object):
x = attr.ib()

@attr.attrs(hash=True, cache_hash=True)
class NoCacheHashAndCustomSetState(object):
def __setstate__(self, state):
pass
a = C(1)
hash(a) # Ensure that any hash cache would be calculated before copy
b = copy.deepcopy(a)

_setattr(b, "x", 100)

# these are for use in TestAddHash.test_cache_hash_serialization
# they need to be out here so they can be un-pickled
@attr.attrs(hash=True, cache_hash=False)
class HashCacheSerializationTestUncached(object):
foo_string = attr.ib(default="foo")
assert hash(a) != hash(b)

@pytest.mark.parametrize(
"klass",
[
HashCacheSerializationTestUncached,
HashCacheSerializationTestCached,
HashCacheSerializationTestCachedSlots,
],
)
def test_cache_hash_serialization_hash_cleared(self, klass):
"""
Tests that the hash cache is cleared on deserialization to fix
https://github.com/python-attrs/attrs/issues/482 .
@attr.attrs(hash=True, cache_hash=True)
class HashCacheSerializationTestCached(object):
foo_string = attr.ib(default="foo")
This test is intended to guard against a stale hash code surviving
across serialization (which may cause problems when the hash value
is different in different interpreters).
"""

obj = klass()
original_hash = hash(obj)
obj_rt = self._roundtrip_pickle(obj)

@attr.attrs(slots=True, hash=True, cache_hash=True)
class HashCacheSerializationTestCachedSlots(object):
foo_string = attr.ib(default="foo")
# Modify an attribute of the object used in the hash calculation; this
# assumes that pickle doesn't call `hash` before this point, and that
# there is no hash collision between the integers 20 and 40.
obj_rt.foo_value = 40

assert original_hash == hash(obj)
assert original_hash != hash(obj_rt)

def _roundtrip_pickle(self, obj):
pickle_str = pickle.dumps(obj)
return pickle.loads(pickle_str)


class TestAddInit(object):
Expand Down
60 changes: 55 additions & 5 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,16 +1466,66 @@ class C2(C):

assert [C2] == C.__subclasses__()

def test_cache_hash_with_frozen_serializes(self):
def _get_copy_kwargs(include_slots=True):
"""
Frozen classes with cache_hash should be serializable.
Get copies
"""
options = ["frozen", "hash", "cache_hash"]

@attr.s(cache_hash=True, frozen=True)
if include_slots:
options.extend(["slots", "weakref_slot"])

out_kwargs = []
for args in itertools.product([True, False], repeat=len(options)):
kwargs = dict(zip(options, args))

kwargs["hash"] = kwargs["hash"] or None

if kwargs["cache_hash"] and not (
kwargs["frozen"] or kwargs["hash"]
):
continue

out_kwargs.append(kwargs)

return out_kwargs

@pytest.mark.parametrize("kwargs", _get_copy_kwargs())
def test_copy(self, kwargs):
"""
Ensure that an attrs class can be copied successfully.
"""

@attr.s(eq=True, **kwargs)
class C(object):
pass
x = attr.ib()

a = C(1)
b = copy.deepcopy(a)

assert a == b

@pytest.mark.parametrize("kwargs", _get_copy_kwargs(include_slots=False))
def test_copy_custom_setstate(self, kwargs):
"""
Ensure that non-slots classes respect a custom __setstate__.
"""

@attr.s(eq=True, **kwargs)
class C(object):
x = attr.ib()

def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
state["x"] *= 5
self.__dict__.update(state)

expected = C(25)
actual = copy.copy(C(5))

copy.deepcopy(C())
assert actual == expected


class TestMakeOrder:
Expand Down

0 comments on commit c5b1dfd

Please sign in to comment.