From 76a7d0ac82b07851582e4b0731b2f325a31b4c1d Mon Sep 17 00:00:00 2001 From: Felix Draxle Date: Fri, 20 Jan 2023 09:18:28 +0100 Subject: [PATCH] Hack fix event shape computation for SequenceINN in PushForwardDistribution --- FrEIA/distributions/transformed.py | 6 ++++-- FrEIA/modules/inverse.py | 13 ++----------- FrEIA/utils.py | 18 ++++++++++++++++++ tests/test_pushed.py | 25 ++++++++++++++++++++----- 4 files changed, 44 insertions(+), 18 deletions(-) diff --git a/FrEIA/distributions/transformed.py b/FrEIA/distributions/transformed.py index 33ed0e2..7722ff5 100644 --- a/FrEIA/distributions/transformed.py +++ b/FrEIA/distributions/transformed.py @@ -5,7 +5,7 @@ from FrEIA.modules import InvertibleModule from FrEIA.modules.inverse import Inverse -from FrEIA.utils import force_to +from FrEIA.utils import force_to, output_dims_compatible class PushForwardDistribution(Distribution): @@ -13,7 +13,9 @@ class PushForwardDistribution(Distribution): def __init__(self, base_distribution: Distribution, transform: InvertibleModule): - super().__init__(torch.Size(), transform.output_dims(transform.dims_in)[0]) + # Hack as SequenceINN and GraphINN do not work with input/output shape API + event_shape = output_dims_compatible(transform) + super().__init__(torch.Size(), event_shape) self.base_distribution = base_distribution self.transform = transform diff --git a/FrEIA/modules/inverse.py b/FrEIA/modules/inverse.py index 15c92f4..74a53d1 100644 --- a/FrEIA/modules/inverse.py +++ b/FrEIA/modules/inverse.py @@ -3,6 +3,7 @@ from torch import Tensor from FrEIA.modules import InvertibleModule +from FrEIA.utils import output_dims_compatible class Inverse(InvertibleModule): @@ -11,17 +12,7 @@ class Inverse(InvertibleModule): """ def __init__(self, module: InvertibleModule): # Hack as SequenceINN and GraphINN do not work with input/output shape API - no_output_dims = ( - hasattr(module, "force_tuple_output") - and not module.force_tuple_output - ) - if not no_output_dims: - input_dims = module.output_dims(module.dims_in) - else: - try: - input_dims = module.output_dims(None) - except TypeError: - raise NotImplementedError(f"Can't determine output dimensions for {module.__class__}.") + input_dims = output_dims_compatible(module) super().__init__(input_dims, module.dims_c) self.module = module diff --git a/FrEIA/utils.py b/FrEIA/utils.py index fa71477..f57b8ae 100644 --- a/FrEIA/utils.py +++ b/FrEIA/utils.py @@ -3,6 +3,24 @@ from typing import Callable, Any +def output_dims_compatible(invertible_module): + """ + Hack to get output dimensions from any module as + SequenceINN and GraphINN do not work with input/output shape API. + """ + no_output_dims = ( + hasattr(invertible_module, "force_tuple_output") + and not invertible_module.force_tuple_output + ) + if not no_output_dims: + return invertible_module.output_dims(invertible_module.dims_in) + else: + try: + return invertible_module.output_dims(None) + except TypeError: + raise NotImplementedError(f"Can't determine output dimensions for {invertible_module.__class__}.") + + def f_except(f: Callable, x: torch.Tensor, *dim, **kwargs): """ Apply f on all dimensions except those specified in dim """ result = x diff --git a/tests/test_pushed.py b/tests/test_pushed.py index f877a52..941bb32 100644 --- a/tests/test_pushed.py +++ b/tests/test_pushed.py @@ -2,7 +2,7 @@ import torch -from FrEIA.distributions import StandardNormalDistribution, PullBackDistribution +from FrEIA.distributions import StandardNormalDistribution, PullBackDistribution, PushForwardDistribution from FrEIA.framework import SequenceINN from FrEIA.modules import AllInOneBlock @@ -15,22 +15,37 @@ def subnet(dim_in, dim_out): class PushedDistributionTest(unittest.TestCase): - def create_distribution(self): + def create_pull_back_distribution(self): inn = SequenceINN(2) inn.append(AllInOneBlock((inn.shapes[-1],), subnet_constructor=subnet)) latent = StandardNormalDistribution(2) distribution = PullBackDistribution(latent, inn) return distribution + def create_push_forward_distribution(self): + inn = SequenceINN(2) + inn.append(AllInOneBlock((inn.shapes[-1],), subnet_constructor=subnet)) + latent = StandardNormalDistribution(2) + distribution = PushForwardDistribution(latent, inn) + return distribution + def test_log_prob(self): - self.create_distribution().log_prob(torch.randn(16, 2)) + self.create_pull_back_distribution().log_prob(torch.randn(16, 2)) + self.create_push_forward_distribution().log_prob(torch.randn(16, 2)) def test_log_prob_shape_mismatch(self): with self.assertRaises(RuntimeError): - self.create_distribution().log_prob(torch.randn(16, 3)) + self.create_pull_back_distribution().log_prob(torch.randn(16, 3)) + with self.assertRaises(RuntimeError): + self.create_push_forward_distribution().log_prob(torch.randn(16, 3)) def test_sample(self): batch_size = 16 - sample = self.create_distribution().sample((batch_size,)) + + sample = self.create_pull_back_distribution().sample((batch_size,)) + self.assertFalse(sample.requires_grad) + self.assertEqual(sample.shape, (batch_size, 2)) + + sample = self.create_push_forward_distribution().sample((batch_size,)) self.assertFalse(sample.requires_grad) self.assertEqual(sample.shape, (batch_size, 2))