-
Notifications
You must be signed in to change notification settings - Fork 21
/
parametrization.py
125 lines (103 loc) · 4.81 KB
/
parametrization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
def get_parameters(model):
parametrized_params = []
def get_parametrized_params(mod):
nonlocal parametrized_params
if isinstance(mod, Parametrization):
parametrized_params.append(mod.A)
def not_in(elem, l):
return all(elem is not x for x in l)
model.apply(get_parametrized_params)
unconstrained_params = (param for param in model.parameters() if not_in(param, parametrized_params))
return unconstrained_params, parametrized_params
class Parametrization(nn.Module):
"""
Implements the parametrization of a manifold in terms of a Euclidean space
It gives the parametrized matrix through the attribute `B`
To use it, subclass it and implement the method `retraction` and the method `forward` (and optionally `project`). See the documentation in these methods for details
You can find an example in the file `orthogonal.py` where we implement the Orthogonal class to optimize over the Stiefel manifold using an arbitrary retraction
"""
def __init__(self, A, base, mode):
"""
mode: "static" or a tuple such that:
mode[0] == "dynamic"
mode[1]: int, K, the number of steps after which we should change the basis of the dyn triv
mode[2]: int, M, the number of changes of basis after which we should project back onto the manifold the basis. This is particularly helpful for small values of K.
"""
super(Parametrization, self).__init__()
assert mode == "static" or (isinstance(mode, tuple) and len(mode) == 3 and mode[0] == "dynamic")
self.A = nn.Parameter(A)
self.register_buffer("_B", None)
self.register_buffer('base', base)
# This is necessary, as it will be generated again the first time that self.B is called
# We still need to register the buffer though
if mode == "static":
self.mode = mode
else:
self.mode = mode[0]
self.K = mode[1]
self.M = mode[2]
self.k = 0
self.m = 0
# This implements the parametrization trick in a rather slick way.
# We put a hook on A, such that, whenever its gradients are computed, we
# get rid of self._B so that it has to be recomputed the next time that
# self.B is accessed
def hook(grad):
nonlocal self
self._B = None
self.A.register_hook(hook)
def rebase(self):
with torch.no_grad():
self.base.data.copy_(self._B.data)
self.A.data.zero_()
@property
def B(self):
not_B = self._B is None
if not_B or (not self._B.grad_fn and torch.is_grad_enabled()):
self._B = self.retraction(self.A, self.base)
# Just to be safe
self._B.requires_grad_()
# Now self._B it's not a leaf tensor, so we convert it into a leaf
self._B.retain_grad()
# Increment the counters for the dyntriv algorithm if we have generated B
if self.mode == "dynamic" and not_B:
if self.k == 0:
self.rebase()
# Project the base back to the manifold every M changes of base
# Increment the counter before as we don't project the first time
self.m = (self.m + 1) % self.M
# It's optional to implement this method
if self.m == 0 and hasattr(self, "project"):
with torch.no_grad():
self.base = self.project(self.base)
# Change the basis after K optimization steps
# Increment the counter afterwards as we change the basis in the first iteration
if self.K != "infty":
self.k = (self.k + 1) % self.K
else:
# Make sure that we just update the base once
if self.k == 0:
self.k = 1
return self._B
def retraction(self, A, base):
"""
It computes r_{base}(A).
Notice that A will not always be in the tangent space of our manifold
For this reason, we first have to use A to parametrize the tangent space,
and then compute the retraction
When dealing with Lie groups, raw_A is always projected into the Lie algebra, as an optimization (cf. Section E in the paper)
"""
raise NotImplementedError
def project(self, base):
"""
This method is OPTIONAL
It returns the projected base back into the manifold
"""
raise NotImplementedError
def forward(self, input):
"""
It uses the attribute self.B to implement the layer itself (e.g. Linear, CNN, ...)
"""
raise NotImplementedError