Skip to content

Commit

Permalink
Hack fix event shape computation for SequenceINN in PushForwardDistri…
Browse files Browse the repository at this point in the history
…bution
  • Loading branch information
fdraxler committed Jan 20, 2023
1 parent 01f9699 commit 76a7d0a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
6 changes: 4 additions & 2 deletions FrEIA/distributions/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

from FrEIA.modules import InvertibleModule
from FrEIA.modules.inverse import Inverse
from FrEIA.utils import force_to
from FrEIA.utils import force_to, output_dims_compatible


class PushForwardDistribution(Distribution):
arg_constraints = {}

def __init__(self, base_distribution: Distribution,
transform: InvertibleModule):
super().__init__(torch.Size(), transform.output_dims(transform.dims_in)[0])
# Hack as SequenceINN and GraphINN do not work with input/output shape API
event_shape = output_dims_compatible(transform)
super().__init__(torch.Size(), event_shape)
self.base_distribution = base_distribution
self.transform = transform

Expand Down
13 changes: 2 additions & 11 deletions FrEIA/modules/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import Tensor

from FrEIA.modules import InvertibleModule
from FrEIA.utils import output_dims_compatible


class Inverse(InvertibleModule):
Expand All @@ -11,17 +12,7 @@ class Inverse(InvertibleModule):
"""
def __init__(self, module: InvertibleModule):
# Hack as SequenceINN and GraphINN do not work with input/output shape API
no_output_dims = (
hasattr(module, "force_tuple_output")
and not module.force_tuple_output
)
if not no_output_dims:
input_dims = module.output_dims(module.dims_in)
else:
try:
input_dims = module.output_dims(None)
except TypeError:
raise NotImplementedError(f"Can't determine output dimensions for {module.__class__}.")
input_dims = output_dims_compatible(module)
super().__init__(input_dims, module.dims_c)
self.module = module

Expand Down
18 changes: 18 additions & 0 deletions FrEIA/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
from typing import Callable, Any


def output_dims_compatible(invertible_module):
"""
Hack to get output dimensions from any module as
SequenceINN and GraphINN do not work with input/output shape API.
"""
no_output_dims = (
hasattr(invertible_module, "force_tuple_output")
and not invertible_module.force_tuple_output
)
if not no_output_dims:
return invertible_module.output_dims(invertible_module.dims_in)
else:
try:
return invertible_module.output_dims(None)
except TypeError:
raise NotImplementedError(f"Can't determine output dimensions for {invertible_module.__class__}.")


def f_except(f: Callable, x: torch.Tensor, *dim, **kwargs):
""" Apply f on all dimensions except those specified in dim """
result = x
Expand Down
25 changes: 20 additions & 5 deletions tests/test_pushed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from FrEIA.distributions import StandardNormalDistribution, PullBackDistribution
from FrEIA.distributions import StandardNormalDistribution, PullBackDistribution, PushForwardDistribution
from FrEIA.framework import SequenceINN
from FrEIA.modules import AllInOneBlock

Expand All @@ -15,22 +15,37 @@ def subnet(dim_in, dim_out):


class PushedDistributionTest(unittest.TestCase):
def create_distribution(self):
def create_pull_back_distribution(self):
inn = SequenceINN(2)
inn.append(AllInOneBlock((inn.shapes[-1],), subnet_constructor=subnet))
latent = StandardNormalDistribution(2)
distribution = PullBackDistribution(latent, inn)
return distribution

def create_push_forward_distribution(self):
inn = SequenceINN(2)
inn.append(AllInOneBlock((inn.shapes[-1],), subnet_constructor=subnet))
latent = StandardNormalDistribution(2)
distribution = PushForwardDistribution(latent, inn)
return distribution

def test_log_prob(self):
self.create_distribution().log_prob(torch.randn(16, 2))
self.create_pull_back_distribution().log_prob(torch.randn(16, 2))
self.create_push_forward_distribution().log_prob(torch.randn(16, 2))

def test_log_prob_shape_mismatch(self):
with self.assertRaises(RuntimeError):
self.create_distribution().log_prob(torch.randn(16, 3))
self.create_pull_back_distribution().log_prob(torch.randn(16, 3))
with self.assertRaises(RuntimeError):
self.create_push_forward_distribution().log_prob(torch.randn(16, 3))

def test_sample(self):
batch_size = 16
sample = self.create_distribution().sample((batch_size,))

sample = self.create_pull_back_distribution().sample((batch_size,))
self.assertFalse(sample.requires_grad)
self.assertEqual(sample.shape, (batch_size, 2))

sample = self.create_push_forward_distribution().sample((batch_size,))
self.assertFalse(sample.requires_grad)
self.assertEqual(sample.shape, (batch_size, 2))

0 comments on commit 76a7d0a

Please sign in to comment.