diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index f41c63c978..3b3fc4737c 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -48,7 +48,7 @@ from pytensor.graph.op import HasInnerGraph from pytensor.raise_op import CheckAndRaise from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import get_underlying_scalar_constant_value +from pytensor.tensor.basic import AllocEmpty, get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.random.op import RandomVariable @@ -244,7 +244,7 @@ class DiracDelta(MeasurableOp, Op): __props__ = ("rtol", "atol") - def __init__(self, rtol=1e-5, atol=1e-8): + def __init__(self, rtol, atol): self.rtol = rtol self.atol = atol @@ -267,15 +267,25 @@ def infer_shape(self, fgraph, node, input_shapes): return input_shapes -dirac_delta = DiracDelta() +def dirac_delta(x, rtol=1e-5, atol=1e-8): + return DiracDelta(rtol, atol)(x) @_logprob.register(DiracDelta) -def diracdelta_logprob(op, values, *inputs, **kwargs): - (values,) = values - (const_value,) = inputs - values, const_value = pt.broadcast_arrays(values, const_value) - return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf) +def diracdelta_logprob(op, values, const_value, **kwargs): + [value] = values + + if const_value.owner and isinstance(const_value.owner.op, AllocEmpty): + # Any value is considered valid for an AllocEmpty array + return pt.zeros_like(value) + + if op.rtol == 0 and op.atol == 0: + # Strict equality, cheaper logp + match = pt.eq(value, const_value) + else: + # Loose equality, more expensive logp + match = pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol) + return pt.switch(match, 0.0, -np.inf) def find_negated_var(var):