-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP,POC] Faster functional modules #983
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the approach -- it's awesome that it speeds up small neural nets a lot. I'm curious about your thoughts on parameter tying
for module_name, m in model.named_modules(): | ||
for param_name, p in list(m.named_parameters(recurse=False)): | ||
delattr(m, param_name) | ||
setattr(m, param_name, None) | ||
yield (module_name, m, param_name, p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we previously used create_names_map was for parameter tying. If someone creates a module that looks like:
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.linear.bias = self.bias
then fmodel, params = make_functional(Foo())
returns 2 Tensors (self.linear.weight and self.bias) instead of 3 Tensors. When the user calls fmodel([w, b], x)
, then b gets loaded to self.bias and self.linear.bias and w gets loaded to self.linear.weight.
Under the new strategy, it seems like params
would have 3 tensors: [self.bias, self.linear.weight, self.linear.bias].
In general I'm not really sure what the interaction between parameter tying and make_functional should be. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. If we want to keep things as they are we could link params to a list of modules and a list of names (instead of a single module and a single name). That will come with a slight overhead though...
It's the kind of design choice where you will always make someone unhappy (there will be someone out there that wants multiple copies of the same param), but it's probably not the majority of users.
old_states = _swap_state( | ||
self.param_modules + self.buffer_modules, | ||
self.param_names + self.buffer_names, | ||
list(params) + list(buffers) | ||
) | ||
old_params = old_states[:len(self.param_modules)] | ||
old_buffers = old_states[len(self.param_modules):] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure I understand why this is faster: is it because we no longer need to traverse through the module to find the submodules; we've already made the submodules directly available to swap their parameters out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly!
Instead of going through a tree of param names, we just flatten it and go through a single list of modules, one-level names and values.
param_module_names, param_modules, param_names, params = zip(*param_container) | ||
else: | ||
param_module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, we guaranteed that params is returned in the same order as what gets returned by original_model.parameters(). After this change, is that still true?
(Side note) To be honest, we've been thinking of changing the API so that params isn't returned as a flat list; instead we probably want to return some sort of dictionary or object so that one can easily figure out which params
corresponds to which parameters on the original module. This is something that a couple of users have asked us for. If we returned a dictionary then it doesn't matter that params isn't the same as what gets returned by original_module.parameters()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh man this would be so great! I could definitely use that feature.
To be honest, I was thinking about using TensorDict from torchrl to pass params to functorch stateless modules. We could nest the dicts (eg d["module"]["param"] to d["module.param"]), expand the params, change device or whatever, in batch and with little or no effort since all those ops are built-in tensordict methods. I think there's a good synergy that we could get from TensorDict functorch. At the moment, TensorDict isn't torchscriptable though, I don't know how much trouble it is for you.
@nairbv @shagunsodhani
Proposes a new method to load weights in
FunctionalModule
andFunctionalModuleWithBuffers
.A map
module
<->param_name
<->param_value
is created and used to set attributes.Test:
The following test runs twice as fast on CPU than current implementation:
Other metrics:
On torchrl's DDPG, the new in a full forward-backard pass, the old implementation of
_swap_state
takes approx. 20% of the runtime with small neural nets (2 layers MLP with 256 cells) on CPU. The new implementation takes approx. 6% of runtime.