Skip to content

Commit

Permalink
[nnx] add extract APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 15, 2024
1 parent 5c97143 commit 2c4cb91
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 77 deletions.
269 changes: 269 additions & 0 deletions flax/nnx/nnx/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import abc
import contextlib
import dataclasses
import threading
import typing as tp

import jax
from jax._src.tree_util import broadcast_prefix

from flax import struct
from flax.nnx.nnx.state import State
from flax.typing import PathParts
from flax.nnx.nnx import graph


class Missing:
pass


MISSING = Missing()
A = tp.TypeVar('A')
E = tp.TypeVar('E', bound='Extractable')
Index = int
KeyEntry = tp.TypeVar('KeyEntry', bound=tp.Hashable)
KeyPath = tuple[KeyEntry, ...]
Prefix = tp.Any
Leaf = tp.Any


class Extractable(abc.ABC):
@property
@abc.abstractmethod
def index(self) -> Index: ...


class ExtractableStates(Extractable):
@property
@abc.abstractmethod
def states(self) -> tp.Iterable[State]: ...

@property
@abc.abstractmethod
def graphdef(self) -> graph.GraphDef[tp.Any]: ...


class ExtractionIndex(struct.PyTreeNode, Extractable):
"""Index of a graph node in a Pytree structure."""

_index: Index = struct.field(pytree_node=False)

@property
def index(self) -> Index:
return self._index


@tp.overload
def extract_graph_nodes(
pytree: A,
/,
*,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> tuple[A, tuple[tp.Any, ...]]: ...


@tp.overload
def extract_graph_nodes(
pytree: A,
/,
*,
prefix: tp.Any,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]: ...


def extract_graph_nodes(
pytree: A,
/,
*,
prefix: tp.Any = MISSING,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> (
tuple[A, tuple[tp.Any, ...]]
| tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]
):
"""Extracts all graph nodes from a pytree."""
nodes = graph.RefMap[tp.Any, Index]()
node_prefixes = []
leaves = []

prefix_leaves = broadcast_prefix(
prefix,
pytree,
is_leaf=lambda x: x is None,
)
key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree)

assert len(key_leaves) == len(prefix_leaves)

for (keypath, leaf), prefix_leaf in zip(key_leaves, prefix_leaves):
if validate_fn:
validate_fn(keypath, prefix_leaf, leaf)
if graph.is_graph_node(leaf):
if leaf not in nodes:
index = nodes[leaf] = len(nodes)
node_prefixes.append(prefix_leaf)
else:
index = nodes[leaf]
# check consistent aliasing
if prefix_leaf != node_prefixes[index]:
path_str = jax.tree_util.keystr(keypath)
raise ValueError(
f'Inconsistent aliasing detected. Node {type(leaf)} at path {path_str} '
f'has different prefixes: {prefix_leaf} and {node_prefixes[index]}.'
)
leaves.append(ExtractionIndex(index))
else:
leaves.append(leaf)

pytree_out = jax.tree.unflatten(treedef, leaves)

if prefix is MISSING:
return pytree_out, tuple(nodes)
else:
return pytree_out, tuple(nodes), tuple(node_prefixes)


def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A:
"""Inserts graph nodes into a pytree."""

def _maybe_insert(x):
if isinstance(x, Extractable):
return nodes[x.index]
return x

return jax.tree_util.tree_map(
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, Extractable)
)


def extract_indexes(
pytree,
/,
types: tuple[type[E], ...] | type[E] = Extractable, # type: ignore[assignment]
) -> tuple[E, ...]:
"""Extracts all indexes from a pytree."""
indexes: list[E] = []
for x in jax.tree.leaves(
pytree, is_leaf=lambda x: isinstance(x, Extractable)
):
if isinstance(x, Extractable):
if not isinstance(x, types):
raise ValueError(f'Expected Extractable of type {types}, got {type(x)}')
indexes.append(x) # type: ignore[arg-type]
return tuple(indexes)


def replace_indexes(
pytree: A,
replace_fn: tp.Callable[[Extractable], tp.Any],
/,
clear: bool = False,
) -> A:
def _replace_map_fn(x):
if isinstance(x, Extractable):
return replace_fn(x)
elif clear:
return None
return x

return jax.tree_util.tree_map(
_replace_map_fn, pytree, is_leaf=lambda x: isinstance(x, Extractable)
)


def merge_extractable_states(
extractable_states: tp.Sequence[ExtractableStates], /
):
if len(extractable_states) == 0:
raise ValueError('Expected at least one ExtractableStates object')

graphdef = extractable_states[0].graphdef
flat_state: list[tuple[PathParts, tp.Any]] = []

for extractable_state in extractable_states:
flat_state.extend(
((extractable_state.index, *path), value)
for state in extractable_state.states
for path, value in state.flat_state().items()
)

state = State.from_flat_path(flat_state)
return graphdef, state


def check_consistent_aliasing(
nodes: tuple[tp.Any, ...], prefixes: tuple[tp.Any, ...]
):
node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]()

# collect all paths and prefixes for each node
for node, prefix in zip(nodes, prefixes):
for path, value in graph.iter_graph(node):
if graph.is_graph_node(value):
if value in node_prefixes:
paths_prefixes = node_prefixes[value]
paths_prefixes.append((path, prefix))
else:
node_prefixes[value] = [(path, prefix)]

# check for inconsistent aliasing
node_msgs = []
for node, paths_prefixes in node_prefixes.items():
unique_prefixes = {prefix for _, prefix in paths_prefixes}
if len(unique_prefixes) > 1:
path_prefix_repr = '\n'.join(
f' {"/".join(map(str,path)) if path else "<root>"}: {prefix}'
for path, prefix in paths_prefixes
)
nodes_msg = f'Node: {type(node)}\n{path_prefix_repr}'
node_msgs.append(nodes_msg)

if node_msgs:
raise ValueError(
'Inconsistent aliasing detected. The following nodes have different prefixes:\n'
+ '\n'.join(node_msgs)
)

# -----------------------------
# broadcast
# -----------------------------


@dataclasses.dataclass
class BroadcastContext(threading.local):
broadcast_state_stacks: dict[str, list[tp.Any]] = dataclasses.field(
default_factory=dict
)


BROADCAST_CONTEXT = BroadcastContext()


@contextlib.contextmanager
def broadcast_state(tag: str, state: tp.Any):
if tag in BROADCAST_CONTEXT.broadcast_state_stacks:
stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag]
else:
stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag] = []
stack.append(state)
try:
yield
finally:
stack.pop()
if not stack:
del BROADCAST_CONTEXT.broadcast_state_stacks[tag]


def get_broadcast_state(tag: str) -> tp.Any:
if tag not in BROADCAST_CONTEXT.broadcast_state_stacks:
raise ValueError(f'No broadcast state found for {tag!r}')

stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag]

if not stack:
raise RuntimeError(
f'Empty broadcast state stack for {tag!r}, this is a bug'
)

return stack[-1]
44 changes: 0 additions & 44 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,50 +1595,6 @@ class Static(tp.Generic[A]):

jax.tree_util.register_static(Static)

# ---------------------------------------------------------
# insert/extract_graph_nodes API
# ---------------------------------------------------------


@dataclasses.dataclass(frozen=True)
class GraphNodeIndex:
"""Index of a graph node in a Pytree structure."""

index: Index


jax.tree_util.register_static(GraphNodeIndex)


def extract_graph_nodes(pytree: A, /) -> tuple[A, tuple[tp.Any, ...]]:
"""Extracts all graph nodes from a pytree."""
nodes = RefMap[tp.Any, Index]()

def _maybe_extract(x):
if is_graph_node(x):
if x not in nodes:
index = nodes[x] = len(nodes)
else:
index = nodes[x]
return GraphNodeIndex(index)
return x

return jax.tree_util.tree_map(_maybe_extract, pytree), tuple(nodes)


def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A:
"""Inserts graph nodes into a pytree."""

def _maybe_insert(x):
if isinstance(x, GraphNodeIndex):
return nodes[x.index]
return x

return jax.tree_util.tree_map(
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, GraphNodeIndex)
)


# ---------------------------------------------------------
# Pytree
# ---------------------------------------------------------
Expand Down
13 changes: 12 additions & 1 deletion flax/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,13 @@ def flat_state(self) -> FlatState[V]:
return traversals.flatten_mapping(self._mapping)

@classmethod
def from_flat_path(cls, flat_state: tp.Mapping[PathParts, V], /) -> State:
def from_flat_path(
cls,
flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]],
/,
) -> State:
if not isinstance(flat_state, tp.Mapping):
flat_state = dict(flat_state)
nested_state = traversals.unflatten_mapping(flat_state)
return cls(nested_state)

Expand All @@ -176,7 +182,12 @@ def split(
*filters: filterlib.Filter,
) -> tuple[State[K, V], ...]: ...

@tp.overload
def split(
self, /, *filters: filterlib.Filter
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: ...

def split( # type: ignore[misc]
self, first: filterlib.Filter, /, *filters: filterlib.Filter
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
"""Split a ``State`` into one or more ``State``'s. The
Expand Down
15 changes: 7 additions & 8 deletions flax/nnx/nnx/transforms/looping.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx.nnx import filterlib, graph, rnglib, spmd
from flax.nnx.nnx import extract, filterlib, graph, rnglib, spmd
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.proxy_caller import DelayedAccessor
from flax.nnx.nnx.state import State
Expand Down Expand Up @@ -254,7 +254,7 @@ def scan_fn(
input_graph_nodes = ctx.merge(
graphdef, *scan_states, carry_state, split_rng_state, broadcast_rng_state
)
(args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes)
(args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes)

out = f(*args, **kwargs)

Expand All @@ -271,10 +271,9 @@ def scan_fn(
carry_arg_out = out
scan_args_out = None

(
(carry_arg_out, scan_args_out),
output_graph_nodes,
) = graph.extract_graph_nodes((carry_arg_out, scan_args_out))
((carry_arg_out, scan_args_out), output_graph_nodes) = (
extract.extract_graph_nodes((carry_arg_out, scan_args_out))
)

# split module state
(
Expand Down Expand Up @@ -353,7 +352,7 @@ def scan(
@graph.update_context('scan')
def scan_apply_wrapper(*args, **kwargs):
# extract nodes
(args, kwargs), input_graph_nodes = graph.extract_graph_nodes(
(args, kwargs), input_graph_nodes = extract.extract_graph_nodes(
(args, kwargs)
)
input_rng_streams = rnglib.backup_keys(input_graph_nodes)
Expand Down Expand Up @@ -465,7 +464,7 @@ def scan_apply_wrapper(*args, **kwargs):
broadcast_rng_state_out,
)

carry_arg_out, scan_args_out = graph.insert_graph_nodes(
carry_arg_out, scan_args_out = extract.insert_graph_nodes(
(carry_arg_out, scan_args_out), output_graph_nodes
)

Expand Down
Loading

0 comments on commit 2c4cb91

Please sign in to comment.