diff --git a/tensorflow_probability/python/distributions/inflated.py b/tensorflow_probability/python/distributions/inflated.py index c10d8d9163..68a5e04e8c 100644 --- a/tensorflow_probability/python/distributions/inflated.py +++ b/tensorflow_probability/python/distributions/inflated.py @@ -224,12 +224,16 @@ def inflated_factory(default_name, distribution_class, inflated_loc, def my_init(self, inflated_loc_logits=None, inflated_loc_probs=None, name=default_name, **kwargs): + parameters = dict(locals()) if 'distribution' in kwargs: dist = kwargs['distribution'] else: dist = distribution_class(**{**kwargs, **more_kwargs}) Inflated.__init__(self, dist, inflated_loc_logits, inflated_loc_probs, inflated_loc, name=name) + # pylint: disable=protected-access + self._parameters = {**parameters, **more_kwargs} + # pylint: enable=protected-access def my_parameter_properties(unused_cls, dtype, num_classes=None): return dict(