-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New unifying structure for coupling architectures #139
base: master
Are you sure you want to change the base?
Conversation
|
||
class Positive(Parameter): | ||
def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: | ||
return torch.exp(unconstrained) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should have something like SoftPositive
that saturates at a certain value.
|
||
class Increasing(Parameter): | ||
def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: | ||
return unconstrained[:, 0] + torch.cumsum(torch.exp(unconstrained[:, 1:]), dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
|
||
def parameterize(**parameters): | ||
def wrap(cls): | ||
def construct(*args, split=EvenSplit(), subnet_constructor, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wraps(cls)
for future docstring
self.split = split | ||
self.transform = transform | ||
# TODO: 2 subnets? or just singular coupling? | ||
self.subnet = subnet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am in favor of just one, but maybe we should check the most recent literature.
|
||
def get_parameters(self, condition: torch.Tensor) -> Dict[str, torch.Tensor]: | ||
parameters = self.subnet(condition) | ||
parameters = torch.split(parameters, self.parameter_counts, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should we handle shapes? Always split at first non-batch dimension? Could also be a parameter to the Coupling.
|
||
|
||
|
||
class Coupling(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about the final API we want have. What goes here?
inn = SequenceINN()
inn.append(...)
playground.py
Outdated
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x1, x2 = self.split.forward(x) | ||
parameters = self.get_parameters(x2) | ||
z1 = self.transform.forward(x1, **parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log det J!
parameters = self.get_parameters(x2) | ||
z1 = self.transform.forward(x1, **parameters) | ||
parameters = self.get_parameters(z1) | ||
z2 = self.transform.forward(x2, **parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make consistent with one/double coupling choice above.
Is there someone working on this repository? He looks very stopped |
No description provided.