Skip to content

Commit

Permalink
Gluon.probability (apache#18403)
Browse files Browse the repository at this point in the history
* package created

* mvn WIP

* normal wip, to be tested

* update

* docstring added, normal mostly done

* add test file

* Bernoulli WIP

* bernoulli wip

* bernoulli doc done

* dense variational WIP

* add kl infra

* implement normal kl method

* refactor kl

* add not implemented handling, rename kl_storage

* add  abstract method and Categorical class

* rewrite logit2prob prob2logit for multiclass support

* normal broadcast_to implemented

* categorical mostly done

* update distributions/utils.py

* add dot ahead of import

* fix normal F

* bernoulli, normal brief tests implemented

* add hybridize tests

* transformation infras done

* affine transformation, implemented tested

* add tests cases

* add sum_right_most

* fix get F bug

* compose transform implemented, tested

* fix

* add event_dim

* fetch mvn from upstremm

* clean code, implement normal cdf and tests

* constraint in bernoulli done

* fix constraint

* finish half normal

* add cached_property

* add test on cached_property

* add more features to distribution and constratins

* change constraint

* fix bernoulli

* add independent

* add independent tests

* update naming of cached_property

* revert

* add constraints

* add Cat

* add Stack for imperative mode

* add Stack for imperative mode

* add bernoulli entropy

* categorical WIP

* categorical sampling implemented

* finish categorical log_prob, sampling

* enumerate_support finished

* polish StochasticBlock, add test

* add test for stochastic sequential

* clean loss list in __call__

* fix affine, implement sigmoid, softmax

* add gumbel, relaxed bernoulli

* relaxed one-hot sampling implemented

* gamma done

* gamma, dirichlet implemented

* beta done

* gumbel softmax log-likelihood implemented

* refactor tests, implement exponential, fix compose transform

* weibull implemented, transformed distribution cdf icdf added

* pareto implemented

* uniform wip

* uniform done

* rewrite lgamma, implement chi2

* fix chi2 scale

* F distributiion done

* t implemented

* fix tiny problem

* cauchy done

* add half cauchy

* multinomial done, tests to be added

* add multinomial test

* MVN done, tests todo

* mvn polished

* fix a few precison issues

* add erf, erfinv unified api and learnable transform

* fix mvn attribute check

* MVN done

* poisson done

* hack poisson for size support

* geometric finished

* negative binomial done

* binomial done

* implement some kl

* add more kl

* refactor kl test

* add more kl

* binomial kl todo

* change constraint logical op implement

* implement gamma entropy

* finish beta dirchlet entropy

* finishi all entropy

* kl finished

* add constraint test

* domain map done

* remove bayesian dense

* fix tiny problems

* add kl uniform normal

* add kl tests

* acquire patch from upstream

* add some doc

* finish doc

* refactor kl test(WIP)

* add more kl, fix float32 underflow issue

* make sampling more stable

* handle inconsistent mode

* replace boolean idx with np.where

* fix file name

* add more doc

* add constraint check

* add half_normal/cauchy pdf cdf support check

* fix import problem

* change nosetest to pytest

* remove buggy lines

* change alias register path

* attempt to fix ci

* fix lint, change a few tests

* fix lint

* modify hybrid sequential

* fix lint

* change import order

* add test gluon probability v2

* fix hybridize flag

* change implementation of stochastic block

* fix lint

* fix comments

* fix block

* modify domain map

* add raises for improper add_loss

* add raises for improper add_loss

* add extra cases

* change collectLoss decorator to mandatory

* skip stochastic block tests

* remove test cases

* put gpu tests back

* add test_gluon_stochastic_block back

* remove export test

* put a test back

* tiny refactor

* add memory leak flag

* small changes

Co-authored-by: Zheng <[email protected]>
  • Loading branch information
xidulu and Zheng authored Jul 7, 2020
1 parent 54c0155 commit b4b8b80
Show file tree
Hide file tree
Showing 49 changed files with 10,320 additions and 11 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/gluon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@
from . import model_zoo

from . import contrib

from . import probability
26 changes: 26 additions & 0 deletions python/mxnet/gluon/probability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Probability module"""

from .block import *

from .distributions import *

from .transformation import *
22 changes: 22 additions & 0 deletions python/mxnet/gluon/probability/block/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Stochastic block."""

from .stochastic_block import *
134 changes: 134 additions & 0 deletions python/mxnet/gluon/probability/block/stochastic_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=abstract-method
"""Stochastic block class."""
__all__ = ['StochasticBlock', 'StochasticSequential']

from functools import wraps
from ...block import HybridBlock
from ...utils import _indent


class StochasticBlock(HybridBlock):
"""`StochasticBlock` extends `HybridBlock` to support accumulating loss
in the forward phase, which is extremely useful in building Bayesian Neural Network,
where the loss function is composed of a classification loss and a KL loss.
"""

def __init__(self, **kwargs):
super(StochasticBlock, self).__init__(**kwargs)
self._losses = []
self._losscache = []
# Recording whether collectLoss is invoked.
self._flag = False

def add_loss(self, loss):
self._losscache.append(loss)

@staticmethod
def collectLoss(func):
"""To accumulate loss during the forward phase, one could first decorate
hybrid_forward with `StochasticBlock.collectLoss,
and then collect the loss tensor `x` by calling self.add_loss(x).
For example, in the following forward function,
we generate samples from a Gaussian parameterized by `loc` and `scale` and
accumulate the KL-divergence between it and its prior into the block's loss storage.:
@StochasticBlock.collectLoss
def forward(self, loc, scale):
qz = mgp.Normal(loc, scale)
# prior
pz = mgp.Normal(np.zeros_like(loc), np.ones_like(scale))
self.add_loss(mgp.kl_divergence(qz, pz))
return qz.sample()
"""
@wraps(func)
def inner(self, *args, **kwargs):
# Loss from hybrid_forward
func_out = func(self, *args, **kwargs)
collected_loss = self._losscache
self._losscache = []
self._flag = True
return (func_out, collected_loss)

return inner

def __call__(self, *args, **kwargs):
# pylint: disable=arguments-differ
self._flag = False
out = super().__call__(*args, **kwargs)
if not self._flag:
raise ValueError("The forward function should be decorated by " +
"StochasticBlock.collectLoss")
self._losses = out[1]
return out[0]

@property
def losses(self):
return self._losses


class StochasticSequential(StochasticBlock):
"""Stack StochasticBlock sequentially.
"""

def __init__(self, **kwargs):
super(StochasticSequential, self).__init__(**kwargs)
self._layers = []

def add(self, *blocks):
"""Adds block on top of the stack."""
for block in blocks:
self._layers.append(block)
self.register_child(block)

@StochasticBlock.collectLoss
def forward(self, x, *args):
# pylint: disable=arguments-differ
for block in self._children.values():
x = block()(x, *args)
args = []
if isinstance(x, (tuple, list)):
args = x[1:]
x = x[0]
if args:
x = tuple([x] + list(args))
for block in self._layers:
if hasattr(block, '_losses'):
self.add_loss(block._losses)
return x

def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join([' ({key}): {block}'.format(key=key,
block=_indent(block().__repr__(), 2))
for key, block in self._children.items()])
return s.format(name=self.__class__.__name__, modstr=modstr)

def __getitem__(self, key):
layers = list(self._children.values())[key]
if isinstance(layers, list):
net = type(self)()
net.add(*(l() for l in layers))
return net
else:
return layers()

def __len__(self):
return len(self._children)
86 changes: 86 additions & 0 deletions python/mxnet/gluon/probability/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Distribution classes."""

from .distribution import *

from .exp_family import *

from .exponential import *

from .weibull import *

from .pareto import *

from .uniform import *

from .normal import *

from .laplace import *

from .cauchy import *

from .half_cauchy import *

from .poisson import *

from .geometric import *

from .negative_binomial import *

from .gamma import *

from .dirichlet import *

from .beta import *

from .chi2 import *

from .fishersnedecor import *

from .studentT import *

from .half_normal import *

from .independent import *

from .bernoulli import *

from .binomial import *

from .relaxed_bernoulli import *

from .gumbel import *

from .categorical import *

from .one_hot_categorical import *

from .relaxed_one_hot_categorical import *

from .multinomial import *

from .multivariate_normal import *

from .transformed_distribution import *

from .divergence import *

from .utils import *
Loading

0 comments on commit b4b8b80

Please sign in to comment.