Skip to content

Commit

Permalink
Resolved issues with transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayam753 committed Jul 8, 2020
1 parent 1e037de commit 2aa6347
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 25 deletions.
7 changes: 7 additions & 0 deletions pymc4/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
UnitContinuousDistribution,
BoundedContinuousDistribution,
)
from pymc4.distributions import transforms
from .half_student_t import HalfStudentT as TFPHalfStudentT


Expand Down Expand Up @@ -1080,6 +1081,12 @@ class Pareto(BoundedContinuousDistribution):
Scale parameter (scale > 0).
"""

def _init_transform(self, transform):
if transform is None:
return transforms.LowerBound(self.lower_limit())
else:
return transform

def __init__(self, name, concentration, scale, **kwargs):
super().__init__(name, concentration=concentration, scale=scale, **kwargs)

Expand Down
21 changes: 14 additions & 7 deletions pymc4/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
PositiveDiscreteDistribution,
BoundedDiscreteDistribution,
)
from pymc4.distributions import transforms

__all__ = [
"Bernoulli",
Expand Down Expand Up @@ -67,10 +68,10 @@ def _init_distribution(conditions, **kwargs):
return tfd.Bernoulli(probs=probs, **kwargs)

def lower_limit(self):
return 0
return 0.0

def upper_limit(self):
return 1
return 1.0


class Binomial(BoundedDiscreteDistribution):
Expand Down Expand Up @@ -124,7 +125,7 @@ def _init_distribution(conditions, **kwargs):
return tfd.Binomial(total_count=total_count, probs=probs, **kwargs)

def lower_limit(self):
return 0
return 0.0

def upper_limit(self):
return self._distribution.total_count
Expand Down Expand Up @@ -200,7 +201,7 @@ def _init_distribution(conditions, **kwargs):
)

def lower_limit(self):
return 0
return 0.0

def upper_limit(self):
return self._distribution.total_count
Expand Down Expand Up @@ -306,7 +307,7 @@ def _init_distribution(conditions, **kwargs):
return tfd.FiniteDiscrete(outcomes, probs=probs, **kwargs)

def lower_limit(self):
return 0
return 0.0

def upper_limit(self):
return len(self._distribution.probs)
Expand Down Expand Up @@ -353,6 +354,12 @@ class Geometric(BoundedDiscreteDistribution):
# Another example for a wrong type used on the tensorflow side
_test_value = 2.0 # type: ignore

def _init_transform(self, transform):
if transform is None:
return transforms.LowerBound(self.lower_limit())
else:
return transform

def __init__(self, name, probs, **kwargs):
super().__init__(name, probs=probs, **kwargs)

Expand All @@ -362,7 +369,7 @@ def _init_distribution(conditions, **kwargs):
return tfd.Geometric(probs=probs, **kwargs)

def lower_limit(self):
return 1
return 1.0

def upper_limit(self):
return float("inf")
Expand Down Expand Up @@ -760,7 +767,7 @@ def _init_distribution(conditions, **kwargs):
return tfd.OrderedLogistic(cutpoints=cutpoints, loc=loc, **kwargs)

def lower_limit(self):
return 0
return 0.0

def upper_limit(self):
return len(self._distribution.cutpoints)
12 changes: 12 additions & 0 deletions pymc4/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ class DiscreteDistribution(Distribution):


class BoundedDistribution(Distribution):
def _init_transform(self, transform):
if transform is None:
return transforms.Interval(self.lower_limit(), self.upper_limit())
else:
return transform

@abc.abstractmethod
def lower_limit(self):
raise NotImplementedError
Expand Down Expand Up @@ -331,6 +337,12 @@ def upper_limit(self):
class PositiveDiscreteDistribution(BoundedDiscreteDistribution):
_test_value = 1

def _init_transform(self, transform):
if transform is None:
return transforms.Log()
else:
return transform

def lower_limit(self):
return 0

Expand Down
61 changes: 44 additions & 17 deletions pymc4/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tensorflow_probability import bijectors as tfb

__all__ = ["Log", "Sigmoid"]
__all__ = ["Log", "Sigmoid", "LowerBound", "UpperBound", "Interval"]


class JacobianPreference(enum.Enum):
Expand Down Expand Up @@ -113,13 +113,13 @@ def inverse_log_det_jacobian(self, z):
return self._transform.forward_log_det_jacobian(z)


class Log(Transform):
name = "log"
class BackwardTransform(Transform):
"""Base class for Transforms with Jacobian Preference as Backward"""

JacobianPreference = JacobianPreference.Backward

def __init__(self):
# NOTE: We actually need the inverse to match PyMC3, do we?
self._transform = tfb.Exp()
def __init__(self, transform):
self._transform = transform

def forward(self, x):
return self._transform.inverse(x)
Expand All @@ -134,21 +134,48 @@ def inverse_log_det_jacobian(self, z):
return self._transform.forward_log_det_jacobian(z, self._transform.forward_min_event_ndims)


class Sigmoid(Transform):
class Log(BackwardTransform):
name = "log"

def __init__(self):
# NOTE: We actually need the inverse to match PyMC3, do we?
transform = tfb.Exp()
super().__init__(transform)


class Sigmoid(BackwardTransform):
name = "sigmoid"
JacobianPreference = JacobianPreference.Backward

def __init__(self):
self._transform = tfb.Sigmoid()
transform = tfb.Sigmoid()
super().__init__(transform)

def forward(self, x):
return self._transform.inverse(x)

def inverse(self, z):
return self._transform.forward(z)
class LowerBound(BackwardTransform):
""""Transformation to interval [lower_limit, inf]"""

def forward_log_det_jacobian(self, x):
return self._transform.inverse_log_det_jacobian(x, self._transform.inverse_min_event_ndims)
name = "lowerbound"

def inverse_log_det_jacobian(self, z):
return self._transform.forward_log_det_jacobian(z, self._transform.forward_min_event_ndims)
def __init__(self, lower_limit):
transform = tfb.Chain([tfb.Shift(lower_limit), tfb.Exp()])
super().__init__(transform)


class UpperBound(BackwardTransform):
""""Transformation to interval [-inf, upper_limit]"""

name = "upperbound"

def __init__(self, upper_limit):
transform = tfb.Chain([tfb.Shift(upper_limit), tfb.Scale(-1), tfb.Exp()])
super().__init__(transform)


class Interval(BackwardTransform):
""""Transformation to interval [lower_limit, upper_limit]"""

name = "interval"

def __init__(self, lower_limit, upper_limit):
transform = tfb.Sigmoid(low=lower_limit, high=upper_limit)
super().__init__(transform)
2 changes: 1 addition & 1 deletion pymc4/flow/transformed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def make_transformed_model(dist, transform, state):

# Important:
# I have no idea yet, how to make that beautiful.
# Here we indicate the distribution is already autotransformed nto to get in the infinite loop
# Here we indicate the distribution is already autotransformed not to get in the infinite loop
dist.model_info["autotransformed"] = True

# 2. here decide on logdet computation, this might be effective
Expand Down

0 comments on commit 2aa6347

Please sign in to comment.