Skip to content

Commit

Permalink
Add ImputationWarning class.
Browse files Browse the repository at this point in the history
The idea is that a programmer be able to ignore imputation warnings if they know that data is being imputed. It's easier to do this with a distinct class than with just UserWarning.
  • Loading branch information
rpgoldman committed Oct 29, 2019
1 parent 15eb75e commit ee4bd61
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
17 changes: 15 additions & 2 deletions pymc3/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
__all__ = ['SamplingError', 'IncorrectArgumentsError', 'TraceDirectoryError']
__all__ = [
"SamplingError",
"IncorrectArgumentsError",
"TraceDirectoryError",
"ImputationWarning",
]


class SamplingError(RuntimeError):
Expand All @@ -8,6 +13,14 @@ class SamplingError(RuntimeError):
class IncorrectArgumentsError(ValueError):
pass


class TraceDirectoryError(ValueError):
'''Error from trying to load a trace from an incorrectly-structured directory,'''
"""Error from trying to load a trace from an incorrectly-structured directory,"""

pass


class ImputationWarning(UserWarning):
"""Warning that there are missing values that will be imputed."""

pass
3 changes: 2 additions & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
from .blocking import DictToArrayBijection, ArrayOrdering
from .util import get_transformed_name
from .exceptions import ImputationWarning

__all__ = [
'Model', 'Factor', 'compilef', 'fn', 'fastfn', 'modelcontext',
Expand Down Expand Up @@ -1341,7 +1342,7 @@ def as_tensor(data, name, model, distribution):
impute_message = ('Data in {name} contains missing values and'
' will be automatically imputed from the'
' sampling distribution.'.format(name=name))
warnings.warn(impute_message, UserWarning)
warnings.warn(impute_message, ImputationWarning)
from .distributions import NoDistribution
testval = np.broadcast_to(distribution.default(), data.shape)[data.mask]
fakedist = NoDistribution.dist(shape=data.mask.sum(), dtype=dtype,
Expand Down

0 comments on commit ee4bd61

Please sign in to comment.