Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New unifying structure for coupling architectures #139

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

fdraxler
Copy link
Collaborator

No description provided.

@fdraxler fdraxler marked this pull request as draft November 11, 2022 15:50

class Positive(Parameter):
def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor:
return torch.exp(unconstrained)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have something like SoftPositive that saturates at a certain value.


class Increasing(Parameter):
def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor:
return unconstrained[:, 0] + torch.cumsum(torch.exp(unconstrained[:, 1:]), dim=1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


def parameterize(**parameters):
def wrap(cls):
def construct(*args, split=EvenSplit(), subnet_constructor, **kwargs):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wraps(cls) for future docstring

self.split = split
self.transform = transform
# TODO: 2 subnets? or just singular coupling?
self.subnet = subnet
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am in favor of just one, but maybe we should check the most recent literature.


def get_parameters(self, condition: torch.Tensor) -> Dict[str, torch.Tensor]:
parameters = self.subnet(condition)
parameters = torch.split(parameters, self.parameter_counts, dim=1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we handle shapes? Always split at first non-batch dimension? Could also be a parameter to the Coupling.




class Coupling(Transform):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about the final API we want have. What goes here?

inn = SequenceINN()
inn.append(...)

playground.py Outdated
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = self.split.forward(x)
parameters = self.get_parameters(x2)
z1 = self.transform.forward(x1, **parameters)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log det J!

parameters = self.get_parameters(x2)
z1 = self.transform.forward(x1, **parameters)
parameters = self.get_parameters(z1)
z2 = self.transform.forward(x2, **parameters)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make consistent with one/double coupling choice above.

@mjack3
Copy link

mjack3 commented Nov 21, 2023

Is there someone working on this repository? He looks very stopped

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants