-
Notifications
You must be signed in to change notification settings - Fork 22.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[proposal] [discussion] Refactor pruning/weight_norm using new Reparametrization functionality + actually deprecate old impl of SpectralNorm #7313
Comments
Could you illustrate your proposal with some code snippets? It will be easier to understand |
For first point, spectral norm has compute_weight method, but it’s defined in an internal structure. I think a global method like constrain_weight_with_spectral_norm would be more discoverable. For second point, to come up with a good abstraction, one needs to know why the current module object patching was chosen vs defining a new wrapping module. I don’t know the answer. @t-vi may have some ideas, he wondered about a better abstraction also in #7298 (comment) |
We're trying to figure out what this would look like -- code snippets would definitely help |
So what it looks like to me as a user is that I would want to have "computed parameters".
I would propose to
This does not solve how to call the hook when someone else overwrites the parameter, but hey. I'm not set on any of the specifics. I can whip up a prototype if that helps. So the main functional difference to how spectral normalization is currently implemented is that it replaces the forward_pre hook with an update hook. The advantage to the user is that she can see this as a modification of parameters rather than the modules having the parameters. |
So to put some code here (not remotely tested enough): Note how the spectral norm class now looks like a regular module - to the point where you would not need to have the utility spectral_norm and remove_spectral_norm functions (the latter could just be replaced While it does add some trickery in |
An alternative to trying to chase parameter changes might be to use the pre_forward_hook within module. It would seem to depend on the application whether you want something giving a new result at every invocation even when the inputs didn't change or whether that is wasteful. This could be mitigated by having the update only be done in training mode or some such. One could also simplify the logic between the module and the CalculatedParam by keeping a calculation counter and check for None or the calculation counter to indicate out of date instead of having a hook that sets the parameter to None. |
Incidentally, we have also been working on abstracting reparameterization https://github.com/NVIDIA/apex/tree/master/apex/reparameterization. Current implementation puts boilerplate code to reparameterization.py, with weight norm defining just particular weight recomputation functions. The recomputation of weight is controlled by |
@ngimel Nice, that looks like a clean variant of how weight_norm is currently done in master. One thing I am wondering (and I think that is the main difference to my proposed implementation strategy): Is this something that works at module or at parameter level. I see it more naturally at the parameter level, while the other solutions seem to look more at the module level. One thing that isn't easily achieved with my demo implementation is having multiple parameters that need to be calculated jointly. This could be achieved by returning a dict and have getattr check for that. In terms of usability ("less boilerplate code") I think that allowing the parameter computation to be in a nn.Module subclass is a more natural solution than wrappers. |
@t-vi I've asked the author of that implementation to come here to give his perspective. CalculatedParameters in nn.Module that you have look nice, but they do require modifications to the core and optimizers, whereas apex approach is more self-contained. As you note, joint recalculation of multiple parameters does not come naturally, neither to apex nor to CalculatedParameters. Also, I think all current approaches break parameter sharing, and this is something that has to be figured out going forward. |
Actually, I could move the caching to the (Edit: I did now, and also provided a tentative interface for multiple parameters by returning a dict from forward and the caching matching the suffix - not 100% clean, but it might be a start.) |
Do we want parameter calculation to be able to provide gradients by itself? |
If you want to manipulate gradients, you probably want to do that using Function or a backward hook. |
Chiming in on this discussion. I'm the author of the reparameterization @ngimel shared earlier. My/hooking implementation:
Cons:
Optimizer handles reparameterization:
Cons:
Calculated/Lazy Parameter:
Cons:
Module wrapper:
Cons:
The ideal sort of setup I envision is either one module wrapper class that handles everything for the user, or a combination of CalculatedParameters and hooks/module wrapper. (Preferably the latter) I think though in general PyTorch needs some sort of module wrapper nn.Module subclass that absorbs most if not all the functionality/attributes of the wrapped module instance. This would alleviate a lot of code considerations that have to be done when using things like (Distributed)DataParallel which forces us to add |
@raulpuric Thank you for commenting. If you have a moment to look at it, what do you think are the key things missing from the calculated param prototype ( https://github.com/t-vi/pytorch/tree/calculated_param ) relative to your vision? |
Imagine we have a parameter precalc with for an op (pseudo-code):
For the question of wrapper vs patching, I think wrapper is still cleaner (even if the user has to type |
For the calculated parameter, you just have a module that uses I'm not sure I quite understand the wrapper vs. patching argument. I see CalculatedParameter to do neither (you just assign the parameter). If you wanted to stack things, you would assign the one CalculatedParameter's parameters with those of another (needs a look at propagating the calling In a different area: I looked a bit into what to do about calling the Parameter.updated hook and think it can be pushed to TensorBase in C and have it called automatically in inplace operations. That only works when using |
@t-vi For the module wrapper vs patching, I think proposed CalculatedParameter stands somewhere in the middle, and I agree it does stacking more clear than hooks. My argument was to address @raulpuric's Cons section for the module wrapper. |
@apaszke @soumith Do you have opinions regarding questions in #7313 (comment)? (especially about ease of use of params precalc together with functional interface). |
So to give an update from my side: I implemented inplace hooks in Variables and am updating CalculatedParameter to use that. This would enable the idiom
(in SGD) to correctly invalidate the CalculatedParameter. From my POV this removes the last major obstacle to using CalculatedParameter to cleanly implement cached calculated parameters. |
So I pushed an implementation using inplace hooks on tensor base to For multiple parameters, I the main alternative to passing the parameter to the CalculatedParam call would be to specify the name when assigning. So it would be
but you could equally do
if your forward returns those in a dict. Regarding the functional interface, I think it is a bit of a red herring, but without a Module subclass instance and its getattr method, one would have to do
i.e. add the "()" to make it a call. Given that it gives you an error message rather than silently fail, I would think it is OK. As discussed above, it would need deprecation of using p.data.do_() instead of with nograd(): p.do_(). In theory, one could try to move the hooks to the storage level, but why would we - it adds complications for not really that much gain. I look forward to your comments. |
@t-vi kudos for this prototype supporting functional interface and multiple parameter pre-calc! |
Thanks. So in your opinion, would this satisfy the requirements for such the desired abstraction? |
@apaszke Any chance that we get some sort of input from you or one of the other devs on this? |
@t-vi what does usage look like with a regular Linear module? sth like Also what's your opinion on replacing
|
My reason for suggesting get_output is the "explicit is better than implicit" Python-Zen. |
given the effort in #8806, maybe moving the hook to the storage, is a good idea, too. |
@vadimkantorov @lezcano this proposal looks a lot like Pyro's abstraction for transformed parameters, which mostly relies on machinery in Summary of torch Transforms for constrained parametersOur pattern is to use a Transform object that implements t(t.inv(t(x))) == t(x)
t.inv(t(t.inv(y))) == t.inv(y) Note that PyTorch's The transform machinery also plays well with Constraint machinery for constrained parameters (e.g. learnable correlation matrices). The usual syntax uses transform_to(-)(-) as in y = transform_to(constraints.lower_cholesky)(torch.randn(4,4)) We've put a lot of work into tracking shapes for static shape computations in transforms and constraints, since shapes are needed by On top of PyTorch's |
Thank you for the very nice summary @fritzo ! So, the When it comes to If it did, then it would be direct either to adapt its API directly in their definition or on a function so that it could be used as a parametrisation for Having this, it looks like it could be reasonable to have a hierarchy of the form |
Hehe, a recent request for function/module inverses: #52553 :) |
@vadimkantorov |
Interestingly enough, it looks like parametrisations would allow to solve issues 1-3 pointed out in the distributions issue, wouldn't they? I think that that would be quite a good synergy really, and a direct application of constrained optimisation via parametrisations, which is what #33344 brings to the table. I think it'd be worth looking into that in the future. |
@lezcano Yes, I can see how point 1 of the distribution issue might be solved by parametrisations. Could you elaborate on your proposed strategies for addressing issue 2 (cache flushing on parameter mutation) and issue 3 (optional constructor args) in the distributions issue? Issue 2 on caching seems especially complex and warrants detailed discussion. |
Potential implementation clash with attribute caching or whatever's going in #50431 |
@fritzo Issue 2: the current parametrisation PR offers an opt-in caching system. We opted to go for explicit rather than implicit when it comes to activating it. At the beginning, I implemented an automatic one of sorts, that would work in most situations, but, of course, making it work in all situations is remarkably tricky with the current API. As such, you can activate this caching system via a context manager as: with torch.cached(): # it won't be in the torch namespace, but still
outputs = model(inputs) or you could go with a finer grain and simply wrap the loop of your RNN, if your RNN happens to be parametrised: with torch.cached():
for x in xs:
out_rnn = self.rnn_cell(x, out_rnn) Issue 3 (how to deal with the optional constructor args). One solution would be to go as described in section "High Level API" on the summary in #7313 (comment). This would be exactly the same idea as having functions that do the dispatching of several classes, as you described in the issue that you linked. Another idea that expresses better the intent would be to implement the relation between these constraints directly. When there is a diffeomorphic relation between two spaces, as it is in the case of logits vs probabilities via the exponential, we can express these isomorphisms explicitly. In these cases, it may be reasonable to have two parametrisations parametrising the same unconstrained argument. As such, an object class Exp(nn.Module):
def forward(self, x):
return x.exp()
def right_inverse(self, x):
return x.log()
class Bernoulli(nn.Module):
def __init__(self, prob=None, logits=None):
if (prob is None) == (logit is None):
raise ValueError("Provide one and only one!")
# We omit the checks that prob / logits are in the correct domains
t = prob if prob is not None else logits
# Assume that we have decided to implement prob in terms of logit
if isinstance(t, nn.Parameter):
self.register_parameter("logits", t)
else:
self.register_buffer("logits", t)
# Parametrise logits
register_parametrization(self, "logits", LogitParametrization())
# We now put a parametrisation on top of the first parametrisation
# parametrized_name is not currently supported, but it is direct to implement
register_parametrization(self, "logits", Exp(), parametrized_name="prob")
# Initialise using the parameter that we have
with torch.no_grad():
if prob is not None:
self.prob = prob
else:
self.logits = logits
distr = Bernoulli(prob=...)
distr2 = Bernoulli(prob2=...)
# From here on, working with `distr.prob` or `distr.logits` is equivalent
# If we wrap the forward pass in a `with cached()`, both the results are cached the first time that they are needed, and reused throughout all the forward pass
# Assigning to `distr.prob` or `distr.logits` also works as expected If there are distribution classes that admit two very different implementations, it would still be reasonable to have this solution together with the first one via the factory pattern. This last case also happens more generally in the setting of constraints as there might be some constraints that do not have a parametrisation that's better than the others. For example, there are plenty of ways to perform optimisation with orthogonal constraints (consider any differentiable map that maps an unconstrained matrix to an orthogonal matrix, e.g. the matrix exponential, Cayley map, Householder reflections...). In this case, there is no "one size fits all" parametrisation, so we will implement several of these and we will dispatch them with an |
@vadimkantorov the solution we're probably going to go with (#52576) won't affect this, so no need to worry about this :P |
Summary: Provides the implementation for feature request issue #28937. Adds the `Parametrization` functionality and implements `Pruning` on top of it. It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example. It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion #7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions. As described in #7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in #28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...) TODO (when implementation is validated): - More thorough test - Documentation Resolves #28937 albanD Pull Request resolved: #33344 Reviewed By: zhangguanheng66 Differential Revision: D26816708 Pulled By: albanD fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
The only detail is that the name is "parametrizations". We went for this one over "reparametrizations" as it is a less verbose version of an already very long word. |
Summary: Provides the implementation for feature request issue pytorch#28937. Adds the `Parametrization` functionality and implements `Pruning` on top of it. It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example. It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions. As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...) TODO (when implementation is validated): - More thorough test - Documentation Resolves pytorch#28937 albanD Pull Request resolved: pytorch#33344 Reviewed By: zhangguanheng66 Differential Revision: D26816708 Pulled By: albanD fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
Summary: Provides the implementation for feature request issue pytorch#28937. Adds the `Parametrization` functionality and implements `Pruning` on top of it. It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example. It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions. As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...) TODO (when implementation is validated): - More thorough test - Documentation Resolves pytorch#28937 albanD Pull Request resolved: pytorch#33344 Reviewed By: zhangguanheng66 Differential Revision: D26816708 Pulled By: albanD fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
Related: #55368 |
#103001 for weight_norm |
Currently the weight_norm and spectral_norm are patching a passed module + implement special functions for adding/removing these from a module.
Some ideas for refactoring to make it less tricky:
torch.matmul
andF.conv2d
cc @jerryzh168 @jianyuh @dzhulgakov
The text was updated successfully, but these errors were encountered: