-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] improved graph update mechanism #3759
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3759 +/- ##
==========================================
+ Coverage 58.75% 59.49% +0.74%
==========================================
Files 101 101
Lines 12413 12621 +208
==========================================
+ Hits 7293 7509 +216
+ Misses 5120 5112 -8 ☔ View full report in Codecov by Sentry. |
673efa1
to
5e5ca14
Compare
c636565
to
10a6b2d
Compare
|
||
m = Foo() | ||
|
||
static: nnx.graph_utils.GraphDef[Foo] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd expect the type checker to be able to infer this type with no hints.
If you decide to remove this declaration, please remove other similar ones in this file as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is here to force the static checker to treat static
as a GraphDef
everywhere, else its reset to Any
when its redefined:
static, idx_out_idx_in = static_out.value
@@ -24,7 +25,7 @@ def test_flatten(self): | |||
a = {'a': 1, 'b': nnx.Param(2)} | |||
g = [a, 3, a, nnx.Param(4)] | |||
|
|||
state, static = nnx.graph_utils.graph_flatten(g) | |||
state, static = nnx.graph_utils.graph_flatten(g)[:2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think it might be more clear to do
state, static, _ = ...
instead, since you already do unpacking.
10a6b2d
to
7694ab8
Compare
7694ab8
to
c8cf3f5
Compare
What does this PR do?
Adds a mechanism that allows performing arbitrary mutations on a cloned graph and recreating them on the original graph, this is useful to faithfully update a graph across jax boundaries.
Changes
MutableNodeImpl.clear
API to enablegraph_unflatten
to update existing referencesgraph_flatten
now returns aref_to_index: Mapping[Any, int]
that maps the node references to their flatten index.graph_unflatten
now returns aindex_to_ref: dict[int, Any]
maps index flatten references to their node reference.graph_unflatten
now accepts aref_cache: dict[int, Any]
that maps flatten indexes to existing node references to use instead of creating new ones.clear
is used to empty the existing node's state so they can be treated as new nodes during the unflattening process.