Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Truncation of CustomDist #6947

Merged
merged 3 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from pymc.printing import str_for_dist
from pymc.pytensorf import (
collect_default_updates,
collect_default_updates_inner_fgraph,
constant_fold,
convert_observed_data,
floatX,
Expand Down Expand Up @@ -298,16 +299,17 @@ def __init__(
raise ValueError("ndim_supp or gufunc_signature must be provided")

kwargs.setdefault("inline", True)
kwargs.setdefault("strict", True)
super().__init__(*args, **kwargs)

def update(self, node: Node):
def update(self, node: Node) -> dict[Variable, Variable]:
"""Symbolic update expression for input random state variables

Returns a dictionary with the symbolic expressions required for correct updating
of random state input variables repeated function evaluations. This is used by
`pytensorf.compile_pymc`.
"""
return {}
return collect_default_updates_inner_fgraph(node)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

def batch_ndim(self, node: Node) -> int:
"""Number of dimensions of the distribution's batch shape."""
Expand Down Expand Up @@ -701,24 +703,10 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
symbolic random methods.
"""

default_output = -1
default_output = 0

_print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}")

def update(self, node: Node):
op = node.op
inner_updates = collect_default_updates(
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
)

# Map inner updates to outer inputs/outputs
updates = {}
for rng, update in inner_updates.items():
inp_idx = op.inner_inputs.index(rng)
out_idx = op.inner_outputs.index(update)
updates[node.inputs[inp_idx]] = node.outputs[out_idx]
return updates


@_support_point.register(CustomSymbolicDistRV)
def dist_support_point(op, rv, *args):
Expand Down Expand Up @@ -818,14 +806,17 @@ def rv_op(
if logp is not None:

@_logprob.register(rv_type)
def custom_dist_logp(op, values, size, *params, **kwargs):
return logp(values[0], *params[: len(dist_params)])
def custom_dist_logp(op, values, size, *inputs, **kwargs):
[value] = values
rv_params = inputs[: len(dist_params)]
return logp(value, *rv_params)

if logcdf is not None:

@_logcdf.register(rv_type)
def custom_dist_logcdf(op, value, size, *params, **kwargs):
return logcdf(value, *params[: len(dist_params)])
def custom_dist_logcdf(op, value, size, *inputs, **kwargs):
rv_params = inputs[: len(dist_params)]
return logcdf(value, *rv_params)

if support_point is not None:

Expand Down Expand Up @@ -858,22 +849,29 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param, *dummy_dist_params]
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
rngs = updates_dict.keys()
rngs_updates = updates_dict.values()
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
inputs=[*dummy_params, *rngs],
outputs=[dummy_rv, *rngs_updates],
signature=signature,
)
new_rv = new_rv_op(new_size, *dist_params)
new_rv = new_rv_op(new_size, *dist_params, *rngs)

return new_rv

# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
# We retrieve them here
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
rngs = updates_dict.keys()
rngs_updates = updates_dict.values()
rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
inputs=[*dummy_params, *rngs],
outputs=[dummy_rv, *rngs_updates],
signature=signature,
)
return rv_op(size, *dist_params)
return rv_op(size, *dist_params, *rngs)

@staticmethod
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str:
Expand Down
20 changes: 8 additions & 12 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ def __init__(self, *args, ar_order, constant_term, **kwargs):

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


Expand Down Expand Up @@ -658,13 +657,13 @@ def step(*args):
ar_ = pt.concatenate([init_, innov_.T], axis=-1)

ar_op = AutoRegressiveRV(
inputs=[rhos_, sigma_, init_, steps_],
inputs=[rhos_, sigma_, init_, steps_, noise_rng],
outputs=[noise_next_rng, ar_],
ar_order=ar_order,
constant_term=constant_term,
)

ar = ar_op(rhos, sigma, init_dist, steps)
ar = ar_op(rhos, sigma, init_dist, steps, noise_rng)
return ar


Expand Down Expand Up @@ -731,7 +730,6 @@ class GARCH11RV(SymbolicRandomVariable):

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


Expand Down Expand Up @@ -797,7 +795,6 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
# In this case the size of the init_dist depends on the parameters shape
batch_size = pt.broadcast_shape(omega, alpha_1, beta_1, initial_vol)
init_dist = change_dist_size(init_dist, batch_size)
# initial_vol = initial_vol * pt.ones(batch_size)

# Create OpFromGraph representing random draws from GARCH11 process
# Variables with underscore suffix are dummy inputs into the OpFromGraph
Expand All @@ -819,7 +816,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):

(y_t, _), innov_updates_ = pytensor.scan(
fn=step,
outputs_info=[init_, initial_vol_ * pt.ones(batch_size)],
outputs_info=[init_, pt.broadcast_to(initial_vol_.astype("floatX"), init_.shape)],
non_sequences=[omega_, alpha_1_, beta_1_, noise_rng],
n_steps=steps_,
strict=True,
Expand All @@ -831,11 +828,11 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
)

garch11_op = GARCH11RV(
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_],
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_, noise_rng],
outputs=[noise_next_rng, garch11_],
)

garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps)
garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng)
return garch11


Expand Down Expand Up @@ -891,14 +888,13 @@ class EulerMaruyamaRV(SymbolicRandomVariable):
ndim_supp = 1
_print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}")

def __init__(self, *args, dt, sde_fn, **kwargs):
def __init__(self, *args, dt: float, sde_fn: Callable, **kwargs):
self.dt = dt
self.sde_fn = sde_fn
super().__init__(*args, **kwargs)

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


Expand Down Expand Up @@ -1010,14 +1006,14 @@ def step(*prev_args):
)

eulermaruyama_op = EulerMaruyamaRV(
inputs=[init_, steps_, *sde_pars_],
inputs=[init_, steps_, *sde_pars_, noise_rng],
outputs=[noise_next_rng, sde_out_],
dt=dt,
sde_fn=sde_fn,
signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)",
)

eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars)
eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars, noise_rng)
return eulermaruyama


Expand Down
Loading
Loading