Skip to content

Commit

Permalink
Move all NNX content up a level to be equal with Linen, to make pytho…
Browse files Browse the repository at this point in the history
…n packaging more consistent.

This will simplify imports for all NNX users - no more `flax.nnx.nnx` paths.

Also renamed `state.py` to `statelib.py` so that `nnx.state()` function keep working.

PiperOrigin-RevId: 671520869
  • Loading branch information
IvyZX authored and Flax Authors committed Sep 6, 2024
1 parent e848a99 commit 6157d13
Show file tree
Hide file tree
Showing 113 changed files with 234 additions and 264 deletions.
10 changes: 5 additions & 5 deletions docs/api_reference/flax.nnx/bridge.rst
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
bridge
------------------------

.. automodule:: flax.nnx.nnx.bridge
.. currentmodule:: flax.nnx.nnx.bridge
.. automodule:: flax.nnx.bridge
.. currentmodule:: flax.nnx.bridge

.. flax_module::
:module: flax.nnx.nnx.bridge
:module: flax.nnx.bridge
:class: ToNNX

.. flax_module::
:module: flax.nnx.nnx.bridge
:module: flax.nnx.bridge
:class: ToLinen

.. autofunction:: to_linen

.. flax_module::
:module: flax.nnx.nnx.bridge
:module: flax.nnx.bridge
:class: NNXMeta
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
270 changes: 131 additions & 139 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,147 +18,139 @@
from flax.linen.pooling import pool as pool
from flax.typing import Initializer as Initializer

from .nnx.bridge import wrappers as wrappers
from .nnx.bridge.variables import (
from .bridge import wrappers as wrappers
from .bridge.variables import (
register_variable_name_type_pair as register_variable_name_type_pair,
)
from .nnx import graph as graph
from .nnx import errors as errors
from .nnx import helpers as helpers
from .nnx import bridge as bridge
from .nnx import traversals as traversals
from .nnx import filterlib as filterlib
from .nnx import transforms as transforms
from .nnx import extract as extract
from .nnx.filterlib import WithTag as WithTag
from .nnx.filterlib import PathContains as PathContains
from .nnx.filterlib import OfType as OfType
from .nnx.filterlib import Any as Any
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.filterlib import Everything as Everything
from .nnx.filterlib import Nothing as Nothing
from .nnx.graph import GraphDef as GraphDef
from .nnx.graph import GraphState as GraphState
from .nnx.graph import PureState as PureState
from .nnx.object import Object as Object
from .nnx.helpers import Dict as Dict
from .nnx.helpers import List as List
from .nnx.helpers import Sequential as Sequential
from .nnx.helpers import TrainState as TrainState
from .nnx.module import M as M
from .nnx.module import Module as Module
from .nnx.graph import merge as merge
from .nnx.graph import UpdateContext as UpdateContext
from .nnx.graph import update_context as update_context
from .nnx.graph import current_update_context as current_update_context
from .nnx.graph import split as split
from .nnx.graph import update as update
from .nnx.graph import clone as clone
from .nnx.graph import pop as pop
from .nnx.graph import state as state
from .nnx.graph import graphdef as graphdef
from .nnx.graph import iter_graph as iter_graph
from .nnx.graph import call as call
from .nnx.graph import SplitContext as SplitContext
from .nnx.graph import split_context as split_context
from .nnx.graph import MergeContext as MergeContext
from .nnx.graph import merge_context as merge_context
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
from .nnx.nn.activations import gelu as gelu
from .nnx.nn.activations import glu as glu
from .nnx.nn.activations import hard_sigmoid as hard_sigmoid
from .nnx.nn.activations import hard_silu as hard_silu
from .nnx.nn.activations import hard_swish as hard_swish
from .nnx.nn.activations import hard_tanh as hard_tanh
from .nnx.nn.activations import leaky_relu as leaky_relu
from .nnx.nn.activations import log_sigmoid as log_sigmoid
from .nnx.nn.activations import log_softmax as log_softmax
from .nnx.nn.activations import logsumexp as logsumexp
from .nnx.nn.activations import one_hot as one_hot
from .nnx.nn.activations import relu as relu
from .nnx.nn.activations import relu6 as relu6
from .nnx.nn.activations import selu as selu
from .nnx.nn.activations import sigmoid as sigmoid
from .nnx.nn.activations import silu as silu
from .nnx.nn.activations import soft_sign as soft_sign
from .nnx.nn.activations import softmax as softmax
from .nnx.nn.activations import softplus as softplus
from .nnx.nn.activations import standardize as standardize
from .nnx.nn.activations import swish as swish
from .nnx.nn.activations import tanh as tanh
from .nnx.nn.attention import MultiHeadAttention as MultiHeadAttention
from .nnx.nn.attention import combine_masks as combine_masks
from .nnx.nn.attention import dot_product_attention as dot_product_attention
from .nnx.nn.attention import make_attention_mask as make_attention_mask
from .nnx.nn.attention import make_causal_mask as make_causal_mask
from .nnx.nn.linear import Conv as Conv
from .nnx.nn.linear import ConvTranspose as ConvTranspose
from .nnx.nn.linear import Embed as Embed
from .nnx.nn.linear import Linear as Linear
from .nnx.nn.linear import LinearGeneral as LinearGeneral
from .nnx.nn.linear import Einsum as Einsum
from .nnx.nn.lora import LoRA as LoRA
from .nnx.nn.lora import LoRALinear as LoRALinear
from .nnx.nn.lora import LoRAParam as LoRAParam
from .nnx.nn.normalization import BatchNorm as BatchNorm
from .nnx.nn.normalization import LayerNorm as LayerNorm
from .nnx.nn.normalization import RMSNorm as RMSNorm
from .nnx.nn.normalization import GroupNorm as GroupNorm
from .nnx.nn.stochastic import Dropout as Dropout
from .nnx.rnglib import Rngs as Rngs
from .nnx.rnglib import RngStream as RngStream
from .nnx.rnglib import RngState as RngState
from .nnx.rnglib import RngKey as RngKey
from .nnx.rnglib import RngCount as RngCount
from .nnx.rnglib import ForkStates as ForkStates
from .nnx.rnglib import fork as fork
from .nnx.rnglib import reseed as reseed
from .nnx.rnglib import split_rngs as split_rngs
from .nnx.rnglib import restore_rngs as restore_rngs
from .nnx.spmd import PARTITION_NAME as PARTITION_NAME
from .nnx.spmd import get_partition_spec as get_partition_spec
from .nnx.spmd import get_named_sharding as get_named_sharding
from .nnx.spmd import with_partitioning as with_partitioning
from .nnx.spmd import with_sharding_constraint as with_sharding_constraint
from .nnx.state import State as State
from .nnx.training import metrics as metrics
from .nnx.variables import (
from .filterlib import WithTag as WithTag
from .filterlib import PathContains as PathContains
from .filterlib import OfType as OfType
from .filterlib import Any as Any
from .filterlib import All as All
from .filterlib import Not as Not
from .filterlib import Everything as Everything
from .filterlib import Nothing as Nothing
from .graph import GraphDef as GraphDef
from .graph import GraphState as GraphState
from .graph import PureState as PureState
from .object import Object as Object
from .helpers import Dict as Dict
from .helpers import List as List
from .helpers import Sequential as Sequential
from .helpers import TrainState as TrainState
from .module import M as M
from .module import Module as Module
from .graph import merge as merge
from .graph import UpdateContext as UpdateContext
from .graph import update_context as update_context
from .graph import current_update_context as current_update_context
from .graph import split as split
from .graph import update as update
from .graph import clone as clone
from .graph import pop as pop
from .graph import state as state
from .graph import graphdef as graphdef
from .graph import iter_graph as iter_graph
from .graph import call as call
from .graph import SplitContext as SplitContext
from .graph import split_context as split_context
from .graph import MergeContext as MergeContext
from .graph import merge_context as merge_context
from .nn import initializers as initializers
from .nn.activations import celu as celu
from .nn.activations import elu as elu
from .nn.activations import gelu as gelu
from .nn.activations import glu as glu
from .nn.activations import hard_sigmoid as hard_sigmoid
from .nn.activations import hard_silu as hard_silu
from .nn.activations import hard_swish as hard_swish
from .nn.activations import hard_tanh as hard_tanh
from .nn.activations import leaky_relu as leaky_relu
from .nn.activations import log_sigmoid as log_sigmoid
from .nn.activations import log_softmax as log_softmax
from .nn.activations import logsumexp as logsumexp
from .nn.activations import one_hot as one_hot
from .nn.activations import relu as relu
from .nn.activations import relu6 as relu6
from .nn.activations import selu as selu
from .nn.activations import sigmoid as sigmoid
from .nn.activations import silu as silu
from .nn.activations import soft_sign as soft_sign
from .nn.activations import softmax as softmax
from .nn.activations import softplus as softplus
from .nn.activations import standardize as standardize
from .nn.activations import swish as swish
from .nn.activations import tanh as tanh
from .nn.attention import MultiHeadAttention as MultiHeadAttention
from .nn.attention import combine_masks as combine_masks
from .nn.attention import dot_product_attention as dot_product_attention
from .nn.attention import make_attention_mask as make_attention_mask
from .nn.attention import make_causal_mask as make_causal_mask
from .nn.linear import Conv as Conv
from .nn.linear import ConvTranspose as ConvTranspose
from .nn.linear import Embed as Embed
from .nn.linear import Linear as Linear
from .nn.linear import LinearGeneral as LinearGeneral
from .nn.linear import Einsum as Einsum
from .nn.lora import LoRA as LoRA
from .nn.lora import LoRALinear as LoRALinear
from .nn.lora import LoRAParam as LoRAParam
from .nn.normalization import BatchNorm as BatchNorm
from .nn.normalization import LayerNorm as LayerNorm
from .nn.normalization import RMSNorm as RMSNorm
from .nn.normalization import GroupNorm as GroupNorm
from .nn.stochastic import Dropout as Dropout
from .rnglib import Rngs as Rngs
from .rnglib import RngStream as RngStream
from .rnglib import RngState as RngState
from .rnglib import RngKey as RngKey
from .rnglib import RngCount as RngCount
from .rnglib import ForkStates as ForkStates
from .rnglib import fork as fork
from .rnglib import reseed as reseed
from .rnglib import split_rngs as split_rngs
from .rnglib import restore_rngs as restore_rngs
from .spmd import PARTITION_NAME as PARTITION_NAME
from .spmd import get_partition_spec as get_partition_spec
from .spmd import get_named_sharding as get_named_sharding
from .spmd import with_partitioning as with_partitioning
from .spmd import with_sharding_constraint as with_sharding_constraint
from .statelib import State as State
from .training import metrics as metrics
from .variables import (
Param as Param,
)
# this needs to be imported before optimizer to prevent circular import
from .nnx.training import optimizer as optimizer
from .nnx.training.metrics import Metric as Metric
from .nnx.training.metrics import MultiMetric as MultiMetric
from .nnx.training.optimizer import Optimizer as Optimizer
from .nnx.transforms.deprecated import Jit as Jit
from .nnx.transforms.deprecated import Remat as Remat
from .nnx.transforms.deprecated import Scan as Scan
from .nnx.transforms.deprecated import Vmap as Vmap
from .nnx.transforms.deprecated import Pmap as Pmap
from .nnx.transforms.autodiff import DiffState as DiffState
from .nnx.transforms.autodiff import grad as grad
from .nnx.transforms.autodiff import value_and_grad as value_and_grad
from .nnx.transforms.autodiff import custom_vjp as custom_vjp
from .nnx.transforms.autodiff import remat as remat
from .nnx.transforms.compilation import jit as jit
from .nnx.transforms.compilation import StateSharding as StateSharding
from .nnx.transforms.iteration import Carry as Carry
from .nnx.transforms.iteration import scan as scan
from .nnx.transforms.iteration import vmap as vmap
from .nnx.transforms.iteration import pmap as pmap
from .nnx.transforms.transforms import eval_shape as eval_shape
from .nnx.transforms.transforms import cond as cond
from .nnx.transforms.iteration import StateAxes as StateAxes
from .nnx.variables import A as A
from .nnx.variables import BatchStat as BatchStat
from .nnx.variables import Cache as Cache
from .nnx.variables import Intermediate as Intermediate
from .nnx.variables import Variable as Variable
from .nnx.variables import VariableState as VariableState
from .nnx.variables import VariableMetadata as VariableMetadata
from .nnx.variables import with_metadata as with_metadata
from .nnx.visualization import display as display
from .nnx.extract import to_tree, from_tree, TreeNode
from .training import optimizer as optimizer
from .training.metrics import Metric as Metric
from .training.metrics import MultiMetric as MultiMetric
from .training.optimizer import Optimizer as Optimizer
from .transforms.deprecated import Jit as Jit
from .transforms.deprecated import Remat as Remat
from .transforms.deprecated import Scan as Scan
from .transforms.deprecated import Vmap as Vmap
from .transforms.deprecated import Pmap as Pmap
from .transforms.autodiff import DiffState as DiffState
from .transforms.autodiff import grad as grad
from .transforms.autodiff import value_and_grad as value_and_grad
from .transforms.autodiff import custom_vjp as custom_vjp
from .transforms.autodiff import remat as remat
from .transforms.compilation import jit as jit
from .transforms.compilation import StateSharding as StateSharding
from .transforms.iteration import Carry as Carry
from .transforms.iteration import scan as scan
from .transforms.iteration import vmap as vmap
from .transforms.iteration import pmap as pmap
from .transforms.transforms import eval_shape as eval_shape
from .transforms.transforms import cond as cond
from .transforms.iteration import StateAxes as StateAxes
from .variables import A as A
from .variables import BatchStat as BatchStat
from .variables import Cache as Cache
from .variables import Intermediate as Intermediate
from .variables import Variable as Variable
from .variables import VariableState as VariableState
from .variables import VariableMetadata as VariableMetadata
from .variables import with_metadata as with_metadata
from .visualization import display as display
from .extract import to_tree, from_tree, TreeNode
File renamed without changes.
8 changes: 4 additions & 4 deletions flax/nnx/nnx/bridge/module.py → flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
import typing as tp
import typing_extensions as tpe

from flax.nnx.nnx import graph, rnglib
import flax.nnx.nnx.module as nnx_module
from flax.nnx.nnx.proxy_caller import (
from flax.nnx import graph, rnglib
import flax.nnx.module as nnx_module
from flax.nnx.proxy_caller import (
CallableProxy,
DelayedAccessor,
)
from flax.nnx.nnx.object import Object
from flax.nnx.object import Object

M = tp.TypeVar('M', bound='Module')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax
from flax import struct
from flax.core import meta
from flax.nnx.nnx import variables as variableslib
from flax.nnx import variables as variableslib
import typing as tp


Expand Down
16 changes: 8 additions & 8 deletions flax/nnx/nnx/bridge/wrappers.py → flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from flax import nnx
from flax import linen
from flax.core import meta
from flax.nnx.nnx import graph
from flax.nnx.nnx.bridge import variables as bv
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.rnglib import Rngs
from flax.nnx.nnx.state import State
from flax.nnx.nnx.object import Object
from flax.nnx import graph
from flax.nnx.bridge import variables as bv
from flax.nnx.module import GraphDef, Module
from flax.nnx.rnglib import Rngs
from flax.nnx.statelib import State
from flax.nnx.object import Object
import jax
from jax import tree_util as jtu

Expand Down Expand Up @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
module = fn
assert callable(fn)
else:
if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module):
if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)):
raise ValueError(f'{fn = } needs to be a method of an NNX Module.')
module = fn.__self__
_set_initializing(module, True)
Expand Down Expand Up @@ -207,7 +207,7 @@ class ToLinen(linen.Module):
>>> variables.keys()
dict_keys(['nnx', 'params'])
>>> type(variables['nnx']['graphdef'])
<class 'flax.nnx.nnx.graph.NodeDef'>
<class 'flax.nnx.graph.NodeDef'>
Args:
nnx_class: The NNX Module class (not instance!).
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions flax/nnx/nnx/extract.py → flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
# from jax._src.tree_util import broadcast_prefix

from flax import struct
from flax.nnx.nnx.object import Object
from flax.nnx.object import Object
from flax.typing import MISSING, PathParts
from flax.nnx.nnx import graph
from flax.nnx import graph


A = tp.TypeVar('A')
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions flax/nnx/nnx/graph.py → flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import typing_extensions as tpe

from flax.core.frozen_dict import FrozenDict
from flax.nnx.nnx import filterlib, reprlib
from flax.nnx.nnx.proxy_caller import (
from flax.nnx import filterlib, reprlib
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
DelayedAccessor,
)
from flax.nnx.nnx.state import FlatState, State
from flax.nnx.nnx.variables import Variable, VariableState
from flax.nnx.statelib import FlatState, State
from flax.nnx.variables import Variable, VariableState
from flax.typing import Key, PathParts

A = tp.TypeVar('A')
Expand Down
Loading

0 comments on commit 6157d13

Please sign in to comment.