Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge nnx.errors to flax.errors #4186

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def __reduce__(self):
return (FlaxError, (str(self),))


#################################################
# NNX errors #
#################################################


class TraceContextError(FlaxError):
pass


#################################################
# lazy_init.py errors #
#################################################
Expand Down
17 changes: 0 additions & 17 deletions flax/nnx/errors.py

This file was deleted.

2 changes: 1 addition & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import numpy as np

from flax.nnx import (
errors,
reprlib,
tracers,
)
from flax.nnx import graph
from flax.nnx.variables import Variable, VariableState
from flax.typing import Key
from flax import errors

G = tp.TypeVar('G', bound='Object')

Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import jax

from flax import nnx
from flax import errors
from flax.nnx import reprlib, tracers
from flax.typing import Missing
import jax.tree_util as jtu
Expand Down Expand Up @@ -259,7 +259,7 @@ def __setattr__(self, name: str, value: Any) -> None:

def _setattr(self, name: str, value: tp.Any):
if not self._trace_state.is_valid():
raise nnx.errors.TraceContextError(
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)

Expand Down
4 changes: 2 additions & 2 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, TypeVar

from absl.testing import absltest
from flax import nnx
from flax import nnx, errors
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -39,7 +39,7 @@ def test_trace_level(self):
@jax.jit
def f():
with self.assertRaisesRegex(
nnx.errors.TraceContextError,
errors.TraceContextError,
"Cannot mutate 'Dict' from different trace level",
):
m.a = 2
Expand Down
5 changes: 3 additions & 2 deletions tests/nnx/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl.testing import absltest

from flax import nnx
from flax import errors


class TestRngs(absltest.TestCase):
Expand Down Expand Up @@ -58,7 +59,7 @@ def test_rng_trace_level_constraints(self):
@jax.jit
def f():
with self.assertRaisesRegex(
nnx.errors.TraceContextError,
errors.TraceContextError,
'Cannot call RngStream from a different trace level',
):
rngs.params()
Expand All @@ -76,7 +77,7 @@ def h():

self.assertIsInstance(rngs1, nnx.Rngs)
with self.assertRaisesRegex(
nnx.errors.TraceContextError,
errors.TraceContextError,
'Cannot call RngStream from a different trace level',
):
rngs1.params()
Expand Down
Loading