Skip to content

Commit

Permalink
Share nnx node registry between threads
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed May 3, 2024
1 parent 859435e commit ba935a0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
20 changes: 10 additions & 10 deletions flax/experimental/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:

@dataclasses.dataclass
class GraphUtilsContext(threading.local):
node_types: dict[
type, 'NodeImpl[tp.Any, tp.Any, tp.Any]'
] = dataclasses.field(default_factory=dict)
seen_modules_repr: set[int] | None = None


Expand Down Expand Up @@ -170,6 +167,9 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
]


_node_impl_for_type: dict[type, 'NodeImpl[tp.Any, tp.Any, tp.Any]'] = {}


def register_graph_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
Expand All @@ -178,7 +178,7 @@ def register_graph_node_type(
create_empty: tp.Callable[[AuxData], Node],
clear: tp.Callable[[Node, AuxData], None],
):
CONTEXT.node_types[type] = GraphNodeImpl(
_node_impl_for_type[type] = GraphNodeImpl(
type=type,
flatten=flatten,
set_key=set_key,
Expand All @@ -189,17 +189,17 @@ def register_graph_node_type(


def is_node(x: tp.Any) -> bool:
if type(x) in CONTEXT.node_types:
if type(x) in _node_impl_for_type:
return True
return is_pytree_node(x)


def is_graph_node(x: tp.Any) -> bool:
return type(x) in CONTEXT.node_types
return type(x) in _node_impl_for_type


def is_node_type(x: type[tp.Any]) -> bool:
return x in CONTEXT.node_types or x is PytreeType
return x in _node_impl_for_type or x is PytreeType


def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
Expand All @@ -208,19 +208,19 @@ def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:

node_type = type(x)

if node_type not in CONTEXT.node_types:
if node_type not in _node_impl_for_type:
if is_pytree_node(x):
return PYTREE_NODE_IMPL
else:
raise ValueError(f'Unknown node type: {x}')

return CONTEXT.node_types[node_type]
return _node_impl_for_type[node_type]


def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
if x is PytreeType:
return PYTREE_NODE_IMPL
return CONTEXT.node_types[x]
return _node_impl_for_type[x]


class _HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
Expand Down
20 changes: 20 additions & 0 deletions flax/experimental/nnx/tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from functools import partial
from threading import Thread
import jax
import pytest

Expand Down Expand Up @@ -399,3 +400,22 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
m2, _ = nnx.graph.unflatten(graphdef, state, idxmap=idx_in_ref_out)
assert m2 is m
assert m2.ref is m2


class SimpleModule(nnx.Module):
pass


class SimplePyTreeModule(nnx.Module, experimental_pytree=True):
pass


@pytest.mark.parametrize(['x'], [(SimpleModule(),), (SimplePyTreeModule(),)])
def test_threading(x: nnx.Module):
class MyThread(Thread):
def run(self) -> None:
nnx.graph.split(x)

thread = MyThread()
thread.start()
thread.join()

0 comments on commit ba935a0

Please sign in to comment.