Skip to content

Commit

Permalink
Parametrization Functionality (pytorch#33344)
Browse files Browse the repository at this point in the history
Summary:
Provides the implementation for feature request issue pytorch#28937.

Adds the `Parametrization` functionality and implements `Pruning` on top of it.
It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions.

As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):
- More thorough test
- Documentation

Resolves  pytorch#28937

albanD

Pull Request resolved: pytorch#33344

Reviewed By: zhangguanheng66

Differential Revision: D26816708

Pulled By: albanD

fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
  • Loading branch information
lezcano authored and Sacha Refshauge committed Mar 31, 2021
1 parent 9dc6c53 commit 8bef3ff
Show file tree
Hide file tree
Showing 3 changed files with 723 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,21 @@ From the ``torch.nn.utils`` module
parameters_to_vector
vector_to_parameters

.. autosummary::
:toctree: generated
:nosignatures:

parametrize.register_parametrization
parametrize.remove_parametrizations
parametrize.cached
parametrize.is_parametrized

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

parametrize.ParametrizationList
prune.BasePruningMethod

.. autosummary::
Expand Down
349 changes: 349 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.nn.init as init
import torch.nn.utils.rnn as rnn_utils
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
import torch.nn.utils.parametrize as parametrize
import torch.nn.utils.prune as prune
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.nn import Parameter
Expand Down Expand Up @@ -1939,6 +1940,354 @@ def test_vector_to_parameters(self):
sample = next(model.parameters())[0, 0, 0]
self.assertTrue(torch.equal(sample.data, vec.data[:5]))

# torch/nn/utils/parametrize
def test_register_and_remove_parametrization(self):
r"""Test that it is possible to add a few parametrizations
on a parameter or a buffer and that removing them restores the initial state
It also tests that backpropagating through them works as expected
"""
# Define a couple matrix parametrizations
class Skew(nn.Module):
def forward(self, X):
X = X.tril(-1)
return X - X.T

class Orthogonal(nn.Module):
def forward(self, X):
# Cayley map
# If X is skew-symmetric it returns an orthogonal matrix
Id = torch.eye(X.size(0), device=X.device)
return torch.solve(Id - X, Id + X).solution

# Define a couple vector parametrizations
class FirstZero(nn.Module):
def forward(self, x):
return torch.cat([x.new_zeros(1), x[1:]])

class LastZero(nn.Module):
def forward(self, x):
return torch.cat([x[:-1], x.new_zeros(1)])

model = nn.Linear(8, 8)
initial_weight_id = id(model.weight)
initial_bias_id = id(model.bias)
initial_model = deepcopy(model)

# Test one parametrization
parametrize.register_parametrization(model, "weight", Skew())
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
# Result should be skew-symmetric
A = model.weight
self.assertTrue(torch.allclose(A, -A.T))
# Remove and check consistency
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertEqual(model.__class__, nn.Linear)

# Test two parametrizations at the same time and removing them
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())
# Result should be orthogonal
X = model.weight
Id = torch.eye(X.size(0), device=X.device)
self.assertTrue(torch.allclose(X.T @ X, Id))
# Structure tests
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertIn("weight", model.parametrizations)
self.assertNotIn("weight", model._parameters)
# Remove
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)

# Add everything
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())
parametrize.register_parametrization(model, "bias", FirstZero())
parametrize.register_parametrization(model, "bias", LastZero())

# Basic tests
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertTrue(parametrize.is_parametrized(model, "bias"))
self.assertEqual(model.bias[0].item(), 0.)
self.assertEqual(model.bias[-1].item(), 0.)
self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happpened
# Should not throw
(model.weight.T @ model.bias).sum().backward()
with torch.no_grad():
for p in model.parameters():
p.add_(- p.grad, alpha=0.01)

# Remove first parametrization.
# Check that the model is still parametrized and so is the second parameter
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized
self.assertFalse(parametrize.is_parametrized(model, "weight")) # Parametrization removed
self.assertTrue(parametrize.is_parametrized(model, "bias")) # Still parametrized
self.assertEqual(model.bias[0].item(), 0.) # Still parametrized
self.assertEqual(model.bias[-1].item(), 0.) # Still parametrized
self.assertNotEqual(model.weight, initial_model.weight) # Has been updated
self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id
self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened
# Should not throw
(model.weight.T @ model.bias).sum().backward()
with torch.no_grad():
for p in model.parameters():
p.add_(- p.grad, alpha=0.01)

# Remove the second parametrization.
# Check that the module is not parametrized
parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
self.assertFalse(parametrize.is_parametrized(model)) # Still parametrized
self.assertNotEqual(model.bias, initial_model.bias) # Has been updated
self.assertNotEqual(model.bias[0].item(), 0.) # Still parametrized
self.assertNotEqual(model.bias[-1].item(), 0.) # Still parametrized
self.assertEqual(id(model.bias), initial_bias_id)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
self.assertEqual(len(list(model.parameters())), 2)
# Should not throw
(model.weight.T @ model.bias).sum().backward()
with torch.no_grad():
for p in model.parameters():
p.add_(- p.grad, alpha=0.01)

def test_register_and_remove_buffer_parametrization(self):
r"""Test that it is possible to add and remove parametrizations on buffers"""
# Define a couple vector parametrizations
class FirstZero(nn.Module):
def forward(self, x):
return torch.cat([x.new_zeros(1), x[1:]])

class LastZero(nn.Module):
def forward(self, x):
return torch.cat([x[:-1], x.new_zeros(1)])

model = nn.Linear(8, 8)

# Instantiate parametrizations on buffers. It should work as expected
delattr(model, "bias")
model.register_buffer("bias", torch.ones(8))
parametrize.register_parametrization(model, "bias", FirstZero())
parametrize.register_parametrization(model, "bias", LastZero())
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "bias"))
self.assertEqual(model.bias[0].item(), 0.)
self.assertEqual(model.bias[-1].item(), 0.)
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
self.assertEqual(len(list(model.parameters())), 1)

# Remove parametrizations on buffers. It should work as expected
parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
self.assertFalse(parametrize.is_parametrized(model))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertEqual(model.bias[0].item(), 0.)
self.assertEqual(model.bias[-1].item(), 0.)
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
self.assertEqual(len(list(model.parameters())), 1)

def test_serialization_parametrization(self):
r"""Test that it is possible to serialize a parametrized model via state_dict"""
# A stateful parametrization
class Orthogonal(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("id", torch.eye(n))
self.register_buffer("B", torch.empty(n, n))
init.orthogonal_(self.B)

def forward(self, X):
A = X.triu(1)
A = A - A.T
return self.B @ torch.solve(self.id - A, self.id + A).solution

def get_model():
model = torch.nn.Sequential(
torch.nn.Linear(5, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 1),
)

parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
return model

model = get_model()

prev_weight = model[0].weight
prev_B = model[0].parametrizations.weight[0].B

new_model = get_model()
with TemporaryFileName() as fname:
torch.save(model.state_dict(), fname)
new_model.load_state_dict(torch.load(fname))

# Integrity tests
self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
self.assertEqual(prev_weight, new_model[0].weight)
self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)

# Trying to save the whole parametrized model raises
with self.assertRaisesRegex(RuntimeError, "state_dict"):
with TemporaryFileName() as fname:
torch.save(model, fname)

def test_initialization_parametrization(self):
r"""Test that it is possible to initialize a parametrization when it
implements a `right_inverse` method
"""
class Skew(nn.Module):
def forward(self, X):
A = X.triu(1)
return A - A.T

def is_skew(self, A):
return torch.allclose(A, -A.T, atol=1e-6)

def right_inverse(self, X):
if not self.is_skew(X):
raise ValueError("The matrix is not skew-symmetric.")
return X.triu(1)

# Implements a Cayley map where right_inverse is not quite the inverse of forward
class Orthogonal(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("B", torch.eye(n))

def forward(self, A):
Id = torch.eye(X.size(0))
return self.B @ torch.solve(Id - A, Id + A).solution

def is_orthogonal(self, X):
Id = torch.eye(X.size(0))
return torch.allclose(X.T @ X, Id, atol=1e-4)

def right_inverse(self, X):
if not self.is_orthogonal(X):
raise ValueError("The input is not orthogonal.")
# cayley(0) == Id, so B @ cayley(0) == B
self.B = X
return torch.zeros_like(X)

N = 5
model = nn.Linear(N, N)
# Register the skew-symmetric onstraint. The result is now skew-symmetric
parametrize.register_parametrization(model, "weight", Skew())
X = torch.rand(N, N)
# X is not skew-symmetric, so it throws an error
with self.assertRaises(ValueError):
model.weight = X
# Make X skew-symmetric
X = X - X.T
model.weight = X
self.assertEqual(model.parametrizations.weight.original, X.triu(1))
self.assertEqual(model.weight, X)

# Having several parametrizations registered should work in the same way
parametrize.register_parametrization(model, "weight", Orthogonal(N))
# Register now the Cayley map. The result is now orthogonal
X = torch.rand(N, N)
# X is not orthogonal, so it throws an error
with self.assertRaises(ValueError):
model.weight = X
init.orthogonal_(X)
model.weight = X
self.assertEqual(model.weight, X)
self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))

def test_errors_parametrization(self):
# A parametrization shall not change the size of the parameter
class ChangeSize(nn.Module):
def forward(self, x):
return x[:-1]

# A simple parametrization that does not implement a right_inverse
class Double(nn.Module):
def forward(self, x):
return 2 * x

module = nn.Linear(3, 4)
# This should not throw when registering
parametrize.register_parametrization(module, "weight", ChangeSize())
# It throws in the forward
with self.assertRaisesRegex(RuntimeError, "may not change the size"):
module(torch.rand(2))
# Undo
parametrize.remove_parametrizations(module, "weight", leave_parametrized=False)
self.assertFalse(parametrize.is_parametrized(module))

# Removing a parametrization from an unparametrized tensor throws
with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
parametrize.remove_parametrizations(module, "bias")
# Nothing odd happens
self.assertFalse(parametrize.is_parametrized(module))

# Register a parametrization on a non-existing parameter breaks
with self.assertRaisesRegex(ValueError, "does not have a parameter"):
parametrize.register_parametrization(module, "foo", ChangeSize())
self.assertFalse(parametrize.is_parametrized(module))

# Try to assign to a parametrization that does not implement `right_inverse`
parametrize.register_parametrization(module, "weight", Double())
with self.assertRaisesRegex(RuntimeError, "right_inverse"):
module.weight = torch.rand(4, 3)
# Undo
parametrize.remove_parametrizations(module, "weight", leave_parametrized=False)
self.assertFalse(parametrize.is_parametrized(module))

def test_caching_parametrization(self):
r"""Test the caching system of a parametrization"""
# Define a couple matrix parametrizations
class Skew(nn.Module):
def forward(self, X):
X = X.tril(-1)
return X - X.T

class Orthogonal(nn.Module):
def forward(self, X):
Id = torch.eye(X.size(0), device=X.device)
return torch.solve(Id - X, Id + X).solution

model = nn.Linear(5, 5)
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())

# Test that the caching system works
with parametrize.cached():
X = model.weight
Y = model.weight
self.assertEqual(id(X), id(Y))

def test_dtype_parametrization(self):
r"""Test a case that is not allowed when removing a parametrization"""
class ChangeType(nn.Module):
def forward(self, X):
return X.double()

module = nn.Linear(4, 4).float()
input_ = torch.rand(4).double()
# It is allowed to register a parametrization that changes the dtype
parametrize.register_parametrization(module, "weight", ChangeType())
module(input_)
# We can remove it leaving the original tensor
parametrize.remove_parametrizations(module, "weight", leave_parametrized=False)
# But leaving it parametrized breaks
parametrize.register_parametrization(module, "weight", ChangeType())
with self.assertRaisesRegex(ValueError, "changes the dtype"):
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)

# torch/nn/utils/prune.py
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
def test_validate_pruning_amount_init(self):
Expand Down
Loading

0 comments on commit 8bef3ff

Please sign in to comment.