Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] improved graph update mechanism #3759

Merged
merged 1 commit into from
Mar 19, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 14, 2024

What does this PR do?

Adds a mechanism that allows performing arbitrary mutations on a cloned graph and recreating them on the original graph, this is useful to faithfully update a graph across jax boundaries.

Changes

  • adds MutableNodeImpl.clear API to enable graph_unflatten to update existing references
  • graph_flatten now returns a ref_to_index: Mapping[Any, int] that maps the node references to their flatten index.
  • graph_unflatten now returns a index_to_ref: dict[int, Any] maps index flatten references to their node reference.
  • graph_unflatten now accepts a ref_cache: dict[int, Any] that maps flatten indexes to existing node references to use instead of creating new ones. clear is used to empty the existing node's state so they can be treated as new nodes during the unflattening process.

@codecov-commenter
Copy link

codecov-commenter commented Mar 14, 2024

Codecov Report

Attention: Patch coverage is 89.85507% with 21 lines in your changes are missing coverage. Please review.

Project coverage is 59.49%. Comparing base (83d118a) to head (10a6b2d).
Report is 7 commits behind head on main.

Files Patch % Lines
flax/experimental/nnx/nnx/graph_utils.py 84.95% 17 Missing ⚠️
flax/experimental/nnx/nnx/variables.py 33.33% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3759      +/-   ##
==========================================
+ Coverage   58.75%   59.49%   +0.74%     
==========================================
  Files         101      101              
  Lines       12413    12621     +208     
==========================================
+ Hits         7293     7509     +216     
+ Misses       5120     5112       -8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@cgarciae cgarciae force-pushed the nnx-improve-graph-update branch 8 times, most recently from 673efa1 to 5e5ca14 Compare March 15, 2024 12:43
@cgarciae cgarciae changed the title [nnx] improve graph_update [nnx] improved graph update mechanism Mar 15, 2024
@cgarciae cgarciae force-pushed the nnx-improve-graph-update branch 2 times, most recently from c636565 to 10a6b2d Compare March 15, 2024 17:31
flax/experimental/nnx/nnx/graph_utils.py Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Show resolved Hide resolved

m = Foo()

static: nnx.graph_utils.GraphDef[Foo]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd expect the type checker to be able to infer this type with no hints.

If you decide to remove this declaration, please remove other similar ones in this file as well.

Copy link
Collaborator Author

@cgarciae cgarciae Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is here to force the static checker to treat static as a GraphDef everywhere, else its reset to Any when its redefined:

static, idx_out_idx_in = static_out.value

@@ -24,7 +25,7 @@ def test_flatten(self):
a = {'a': 1, 'b': nnx.Param(2)}
g = [a, 3, a, nnx.Param(4)]

state, static = nnx.graph_utils.graph_flatten(g)
state, static = nnx.graph_utils.graph_flatten(g)[:2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think it might be more clear to do

state, static, _ = ...

instead, since you already do unpacking.

flax/experimental/nnx/nnx/module.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
@copybara-service copybara-service bot merged commit b626099 into main Mar 19, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-improve-graph-update branch March 19, 2024 17:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants