From f6b9005f36c0aefad85475a48f73b4b1b5951c05 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 17 Mar 2023 10:17:37 -0500 Subject: [PATCH] Revert "Add PytreeRef + ref_field (#3)" This reverts commit d0a4248443020ccf2f46a552346d47bc3302777f. --- simple_pytree/__init__.py | 9 -- simple_pytree/ids.py | 38 ------ simple_pytree/ref.py | 246 -------------------------------------- simple_pytree/tracers.py | 26 ---- tests/test_pytree.py | 2 +- tests/test_ref.py | 138 --------------------- 6 files changed, 1 insertion(+), 458 deletions(-) delete mode 100644 simple_pytree/ids.py delete mode 100644 simple_pytree/ref.py delete mode 100644 simple_pytree/tracers.py delete mode 100644 tests/test_ref.py diff --git a/simple_pytree/__init__.py b/simple_pytree/__init__.py index 56c8dd2..e3ca45e 100644 --- a/simple_pytree/__init__.py +++ b/simple_pytree/__init__.py @@ -1,12 +1,3 @@ __version__ = "0.1.5" from .pytree import Pytree, field, static_field -from .ref import ( - PytreeRef, - Ref, - RefField, - clone_references, - cross_barrier, - incremented_ref, - ref_field, -) diff --git a/simple_pytree/ids.py b/simple_pytree/ids.py deleted file mode 100644 index 30f59d5..0000000 --- a/simple_pytree/ids.py +++ /dev/null @@ -1,38 +0,0 @@ -# Taken from flax/ids.py 🏴‍☠️ - -import threading - - -class UUIDManager: - def __init__(self): - self._lock = threading.Lock() - self._id = 0 - - def __call__(self): - with self._lock: - self._id += 1 - return Id(self._id) - - -uuid = UUIDManager() - - -class Id: - def __init__(self, rawid): - self.id = rawid - - def __eq__(self, other): - return isinstance(other, Id) and other.id == self.id - - def __hash__(self): - return hash(self.id) - - def __repr__(self): - return f"Id({self.id})" - - def __deepcopy__(self, memo): - del memo - return uuid() - - def __copy__(self): - return uuid() diff --git a/simple_pytree/ref.py b/simple_pytree/ref.py deleted file mode 100644 index 6426607..0000000 --- a/simple_pytree/ref.py +++ /dev/null @@ -1,246 +0,0 @@ -import contextlib -import dataclasses -import functools -import threading -import typing as tp - -import jax - -from simple_pytree import ids, tracers -from simple_pytree.pytree import field - -A = tp.TypeVar("A") -F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any]) - - -@dataclasses.dataclass(frozen=True) -class _RefContext: - level: int - - -@dataclasses.dataclass -class _Context(threading.local): - ref_context_stack: tp.List[_RefContext] = dataclasses.field( - default_factory=lambda: [_RefContext(0)] - ) - is_crossing_barrier: bool = False - # NOTE: `barrier_cache` is not used for now but left as an optimization - # opportunity for the future. `unflatten_pytree_ref` already has the logic - # to use the cache, using `barrier_cache` would activate it. - barrier_cache: tp.Optional[tp.Dict[ids.Id, "Ref[tp.Any]"]] = None - - @property - def current_ref_context(self) -> _RefContext: - return self.ref_context_stack[-1] - - -_CONTEXT = _Context() - - -@contextlib.contextmanager -def incremented_ref(): - _CONTEXT.ref_context_stack.append( - _RefContext(_CONTEXT.current_ref_context.level + 1) - ) - try: - yield - finally: - _CONTEXT.ref_context_stack.pop() - - -def clone_references(pytree: tp.Any) -> tp.Any: - cache: tp.Dict[ids.Id, Ref[tp.Any]] = {} - - def clone_ref(pytree: tp.Any): - if isinstance(pytree, PytreeRef): - if pytree.id not in cache: - cache[pytree.id] = Ref(pytree.value, id=pytree.id) - return PytreeRef(cache[pytree.id]) - return pytree - - return jax.tree_map(clone_ref, pytree, is_leaf=lambda x: isinstance(x, PytreeRef)) - - -@contextlib.contextmanager -def barrier_cache(): - _CONTEXT.barrier_cache = {} - try: - yield - finally: - _CONTEXT.barrier_cache = None - - -@contextlib.contextmanager -def crossing_barrier(): - _CONTEXT.is_crossing_barrier = True - try: - yield - finally: - _CONTEXT.is_crossing_barrier = False - - -def _update_ref_context(pytree: tp.Any) -> tp.Any: - for pytree in jax.tree_util.tree_leaves( - pytree, is_leaf=lambda x: isinstance(x, PytreeRef) - ): - if isinstance(pytree, PytreeRef): - pytree.ref._trace_level = tracers.current_trace_level() - pytree.ref._ref_context = _CONTEXT.current_ref_context - - -def cross_barrier( - decorator, *decorator_args, **decorator_kwargs -) -> tp.Callable[[F], F]: - @functools.wraps(decorator) - def decorator_wrapper(f): - @functools.wraps(f) - def inner_wrapper(*args, **kwargs): - _CONTEXT.is_crossing_barrier = False - # _CONTEXT.barrier_cache = None # Note: barrier_cache is not used for now - with incremented_ref(): - _update_ref_context((args, kwargs)) - out = f(*args, **kwargs) - _CONTEXT.is_crossing_barrier = True - # _CONTEXT.barrier_cache = {} # Note: barrier_cache is not used for now - return out - - decorated = decorator(inner_wrapper, *decorator_args, **decorator_kwargs) - - @functools.wraps(f) - def outer_wrapper(*args, **kwargs): - args, kwargs = clone_references((args, kwargs)) - with crossing_barrier(): - out = decorated(*args, **kwargs) - out = clone_references(out) - return out - - return outer_wrapper - - return decorator_wrapper - - -class Ref(tp.Generic[A]): - def __init__(self, value: A, id: tp.Optional[ids.Id] = None): - self._value = value - self._ref_context = _CONTEXT.current_ref_context - self._id = ids.uuid() if id is None else id - self._trace_level = tracers.current_trace_level() - - @property - def id(self) -> ids.Id: - return self._id - - @property - def value(self) -> A: - return self._value - - @value.setter - def value(self, value: A): - if ( - self._ref_context is not _CONTEXT.current_ref_context - and not _CONTEXT.is_crossing_barrier - ): - raise ValueError("Cannot mutate ref from different context") - if ( - self._trace_level != tracers.current_trace_level() - and not _CONTEXT.is_crossing_barrier - ): - raise ValueError("Cannot mutate ref from different trace level") - self._value = value - - -class PytreeRef(tp.Generic[A]): - def __init__(self, ref_or_value: tp.Union[Ref[A], A]): - if isinstance(ref_or_value, Ref): - self._ref = ref_or_value - else: - self._ref = Ref(ref_or_value) - - @property - def ref(self) -> Ref[A]: - return self._ref - - @property - def id(self) -> ids.Id: - return self.ref.id - - @property - def value(self) -> A: - return self.ref.value - - @value.setter - def value(self, value: A): - self.ref.value = value - - -def flatten_pytree_ref(pytree: PytreeRef[A]) -> tp.Tuple[tp.Tuple[A], Ref[A]]: - return (pytree.value,), pytree.ref - - -def unflatten_pytree_ref(ref: Ref[A], children: tp.Tuple[A]) -> PytreeRef[A]: - value = children[0] - if _CONTEXT.barrier_cache is not None: - if ref.id not in _CONTEXT.barrier_cache: - _CONTEXT.barrier_cache[ref.id] = Ref(value, id=ref.id) - - ref = _CONTEXT.barrier_cache[ref.id] - else: - ref.value = value - return PytreeRef(ref) - - -jax.tree_util.register_pytree_node(PytreeRef, flatten_pytree_ref, unflatten_pytree_ref) - - -@dataclasses.dataclass -class RefField(tp.Generic[A]): - default: tp.Any = dataclasses.MISSING - name: str = "" - - def __set_name__(self, owner, name): - self.name = name - - def __get__(self, obj, objtype=None): - if obj is None: - return self - if not hasattr(obj, f"_ref_{self.name}"): - if self.default is not dataclasses.MISSING: - obj.__dict__[f"_ref_{self.name}"] = PytreeRef(self.default) - else: - raise AttributeError(f"Attribute {self.name} is not set") - return getattr(obj, f"_ref_{self.name}").value - - def __set__(self, obj, value: tp.Union[A, Ref[A], PytreeRef[A], "RefField[A]"]): - if isinstance(value, RefField): - return - - if hasattr(obj, f"_ref_{self.name}"): - if isinstance(value, (Ref, PytreeRef, RefField)): - raise AttributeError(f"Cannot change reference of {self.name}") - getattr(obj, f"_ref_{self.name}").value = value - elif isinstance(value, PytreeRef): - setattr(obj, f"_ref_{self.name}", value) - else: - setattr(obj, f"_ref_{self.name}", PytreeRef(value)) - - -def ref_field( - default: tp.Any = dataclasses.MISSING, - *, - default_factory: tp.Any = dataclasses.MISSING, - init: bool = True, - repr: bool = True, - hash: tp.Optional[bool] = None, - compare: bool = True, - metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, -): - return field( - default=RefField(default=default), - pytree_node=True, - default_factory=default_factory, - init=init, - repr=repr, - hash=hash, - compare=compare, - metadata=metadata, - ) diff --git a/simple_pytree/tracers.py b/simple_pytree/tracers.py deleted file mode 100644 index 18341b6..0000000 --- a/simple_pytree/tracers.py +++ /dev/null @@ -1,26 +0,0 @@ -# Taken from flax/core/tracer.py 🏴‍☠️ - -import jax - - -def current_trace(): - """Returns the innermost Jax tracer.""" - return jax.core.find_top_trace(()) - - -def trace_level(main): - """Returns the level of the trace of -infinity if it is None.""" - if main: - return main.level - return float("-inf") - - -def current_trace_level(): - """Returns the level of the current trace.""" - return trace_level(current_trace()) - - -def check_trace_level(base_level, error): - level = current_trace_level() - if level != base_level: - raise error diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 545c1eb..74b3cfd 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -134,7 +134,7 @@ def __init__(self, y) -> None: self.x = 2 self.y = y - pytree: Foo = Foo(y=3) + pytree = Foo(y=3) leaves = jax.tree_util.tree_leaves(pytree) assert leaves == [3] diff --git a/tests/test_ref.py b/tests/test_ref.py deleted file mode 100644 index 9c9872a..0000000 --- a/tests/test_ref.py +++ /dev/null @@ -1,138 +0,0 @@ -import dataclasses - -import jax -import pytest - -from simple_pytree import ( - Pytree, - PytreeRef, - clone_references, - cross_barrier, - incremented_ref, - ref_field, -) - - -class TestPytreeRef: - def test_ref(self): - p1 = PytreeRef(1) - assert p1.value == 1 - - p2 = jax.tree_map(lambda x: x + 1, p1) - - assert p1.value == 2 - assert p2.value == 2 - assert p1 is not p2 - assert p1.ref is p2.ref - - p1.value = 3 - - assert p1.value == 3 - assert p2.value == 3 - - def test_ref_context(self): - p1 = PytreeRef(1) - p2 = jax.tree_map(lambda x: x, p1) # copy - assert p1.value == 1 - assert p2.value == 1 - p1.value = 2 # OK - assert p2.value == 2 - - with incremented_ref(): - with pytest.raises( - ValueError, match="Cannot mutate ref from different context" - ): - p1.value = 3 - - p1, p2 = clone_references((p1, p2)) - assert p1.value == 2 - p2.value = 3 # OK - assert p1.value == 3 - - with pytest.raises( - ValueError, match="Cannot mutate ref from different context" - ): - p1.value = 4 - - p1, p2 = clone_references((p1, p2)) - assert p1.value == 3 - p1.value = 4 # OK - assert p2.value == 4 - - def test_ref_trace_level(self): - p1: PytreeRef[int] = PytreeRef(1) - - @jax.jit - def f(): - with pytest.raises( - ValueError, match="Cannot mutate ref from different trace level" - ): - p1.value = 2 - return 1 - - f() - - @cross_barrier(jax.jit) - def g(p2: PytreeRef[int]): - p2.value = 2 - assert p1.ref is not p2.ref - return p2 - - p2 = g(p1) - p2_ref = p2.ref - - assert p1.value == 1 - assert p2.value == 2 - - p2.value = 3 - assert p1.value == 1 - assert p2.value == 3 - - p3 = g(p1) - p3_ref = p3.ref - - assert p3_ref is not p2_ref - assert p3.value == 2 - - def test_barrier(self): - p1: PytreeRef[int] = PytreeRef(1) - - @cross_barrier(jax.jit) - def g(p2: PytreeRef[int]): - p2.value = 2 - assert p1.ref is not p2.ref - return p2 - - p2 = g(p1) - assert p1.ref is not p2.ref - assert p1.value == 1 - assert p2.value == 2 - - # test passing a reference to a jitted function without cross_barrier - @jax.jit - def f(p1): - return None - - with pytest.raises( - ValueError, match="Cannot mutate ref from different trace level" - ): - f(p1) - - assert isinstance(p1.value, int) - assert p1.value == 1 - - -class TestRefField: - def test_ref_field(self): - @dataclasses.dataclass - class Foo(Pytree, mutable=True): - a: int = ref_field() - - foo1 = Foo() - foo1.a = 2 - assert foo1.a == 2 - - foo2 = jax.tree_map(lambda x: x + 1, foo1) - - assert foo1.a == 3 - assert foo2.a == 3