-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StateVariablesMapping #3523
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #3523 +/- ##
==========================================
+ Coverage 56.21% 56.23% +0.02%
==========================================
Files 100 100
Lines 11849 11882 +33
==========================================
+ Hits 6661 6682 +21
- Misses 5188 5200 +12 ☔ View full report in Codecov by Sentry. |
raise AttributeError(f"Variable '{name}' not found.") | ||
return value | ||
|
||
def __setattr__(self, name: str, value: Variable[tp.Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend to constrain the API and allow either ["x"]
or .x
but not both (possibly in a separate PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have this dual API in a couple of place in NNX, the main idea is to make intuitive both hard-coded and programmatic access, e.g:
# hard-coded
x = state.linear.kernel
# programmatic
x = state[module_name][param_name]
@@ -72,9 +124,13 @@ def __init__( | |||
super().__setattr__('_mapping', dict(mapping)) | |||
|
|||
@property | |||
def variables(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: | |||
def raw_mapping(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another naming idea: storage
.
@@ -72,9 +124,13 @@ def __init__( | |||
super().__setattr__('_mapping', dict(mapping)) | |||
|
|||
@property | |||
def variables(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: | |||
def raw_mapping(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to return a read-only view of the underlying mapping, e.g. wrapped in types.MappingProxyType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case we a embracing mutability in case the user wants to do some surgery.
What does this PR do?
VariablesMapping
toModuleVariablesMapping
State.variables
now returns aStateVariablesMapping
with a similar behavior toModuleVariablesMapping
.State.raw_mapping
which exposes._mapping