Skip to content

Commit

Permalink
Optimize DiracDelta logprob for exact equality
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 29, 2024
1 parent d693156 commit c7840ae
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from pytensor.graph.op import HasInnerGraph
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.basic import AllocEmpty, get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -244,7 +244,7 @@ class DiracDelta(MeasurableOp, Op):

__props__ = ("rtol", "atol")

def __init__(self, rtol=1e-5, atol=1e-8):
def __init__(self, rtol, atol):
self.rtol = rtol
self.atol = atol

Expand All @@ -267,15 +267,25 @@ def infer_shape(self, fgraph, node, input_shapes):
return input_shapes


dirac_delta = DiracDelta()
def dirac_delta(x, rtol=1e-5, atol=1e-8):
return DiracDelta(rtol, atol)(x)


@_logprob.register(DiracDelta)
def diracdelta_logprob(op, values, *inputs, **kwargs):
(values,) = values
(const_value,) = inputs
values, const_value = pt.broadcast_arrays(values, const_value)
return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf)
def diracdelta_logprob(op, values, const_value, **kwargs):
[value] = values

if const_value.owner and isinstance(const_value.owner.op, AllocEmpty):
# Any value is considered valid for an AllocEmpty array
return pt.zeros_like(value)

if op.rtol == 0 and op.atol == 0:
# Strict equality, cheaper logp
match = pt.eq(value, const_value)
else:
# Loose equality, more expensive logp
match = pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol)
return pt.switch(match, 0.0, -np.inf)


def find_negated_var(var):
Expand Down

0 comments on commit c7840ae

Please sign in to comment.