Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add GraphNode base class #3790

Merged
merged 1 commit into from
Mar 27, 2024
Merged

[nnx] add GraphNode base class #3790

merged 1 commit into from
Mar 27, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 27, 2024

What does this PR do?

  • Adds GraphNode as a base class for Module and other types that need a graph flatten/unflatten implementation according to Module's current definition.
  • Adds nnx.split and nnx.update to work with any graph node.
  • Redefines nnx.merge to match the spread signature returned by .split
  • Orders the leaves when flattening Module.

@cgarciae cgarciae force-pushed the nnx-graph-node-base branch 4 times, most recently from 5d9c822 to e217b27 Compare March 27, 2024 11:47
@codecov-commenter
Copy link

codecov-commenter commented Mar 27, 2024

Codecov Report

Attention: Patch coverage is 82.06522% with 33 lines in your changes are missing coverage. Please review.

Project coverage is 59.62%. Comparing base (0ab0365) to head (6948138).

Files Patch % Lines
flax/experimental/nnx/nnx/graph_utils.py 80.89% 30 Missing ⚠️
flax/experimental/nnx/nnx/helpers.py 0.00% 1 Missing ⚠️
flax/experimental/nnx/nnx/module.py 90.00% 1 Missing ⚠️
flax/experimental/nnx/tests/test_spmd.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3790      +/-   ##
==========================================
+ Coverage   59.56%   59.62%   +0.06%     
==========================================
  Files         101      101              
  Lines       12624    12655      +31     
==========================================
+ Hits         7519     7546      +27     
- Misses       5105     5109       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
state = deepcopy(state)
return graphdef.merge(state)

def __hash__(self) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also override __eq__?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more tented to remove this implementation of __hash__ and just have the base one.

flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
@@ -860,6 +887,101 @@ def _graph_update_static(
node_impl.set_key(node, name, value_updates)


@tp.overload
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally would probably leave the overloading out. It is verbose and only marginally useful.

I understand that you are trying to preserve tuple cardinality, but I'm not sure how common such a mistake is to justify all this boilerplate.

flax/experimental/nnx/nnx/graph_utils.py Outdated Show resolved Hide resolved
@cgarciae cgarciae force-pushed the nnx-graph-node-base branch 4 times, most recently from 34de63f to 0db3362 Compare March 27, 2024 17:33
@copybara-service copybara-service bot merged commit 6f1f1ef into main Mar 27, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-graph-node-base branch March 27, 2024 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants