Skip to content

Commit

Permalink
Make AR steps extend shape beyond initial_dist
Browse files Browse the repository at this point in the history
This is consistent with the meaning of steps in the GaussianRandomWalk and translates directly to the number of scan steps taken
  • Loading branch information
ricardoV94 committed May 5, 2022
1 parent 614bb06 commit 504da82
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 30 deletions.
72 changes: 45 additions & 27 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,19 @@ class AR(SymbolicDistribution):
"""

def __new__(cls, *args, steps=None, **kwargs):
def __new__(cls, name, rho, *args, steps=None, constant=False, ar_order=None, **kwargs):
rhos = at.atleast_1d(at.as_tensor_variable(floatX(rho)))
ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order)
steps = get_steps(
steps=steps,
shape=None, # Shape will be checked in `cls.dist`
dims=kwargs.get("dims", None),
observed=kwargs.get("observed", None),
step_shape_offset=ar_order,
)
return super().__new__(
cls, name, rhos, *args, steps=steps, constant=constant, ar_order=ar_order, **kwargs
)
return super().__new__(cls, *args, steps=steps, **kwargs)

@classmethod
def dist(
Expand All @@ -426,34 +431,12 @@ def dist(
)
init_dist = kwargs["init"]

steps = get_steps(steps=steps, shape=kwargs.get("shape", None))
ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order)
steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=ar_order)
if steps is None:
raise ValueError("Must specify steps or shape parameter")
steps = at.as_tensor_variable(intX(steps), ndim=0)

if ar_order is None:
# If ar_order is not specified we do constant folding on the shape of rhos
# to retrieve it. For example, this will detect that
# Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
shape_fg = FunctionGraph(
outputs=[rhos.shape[-1]],
features=[ShapeFeature()],
clone=True,
)
(folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
folded_shape = getattr(folded_shape, "data", None)
if folded_shape is None:
raise ValueError(
"Could not infer ar_order from last dimension of rho. Pass it "
"explictily or make sure rho have a static shape"
)
ar_order = int(folded_shape) - int(constant)
if ar_order < 1:
raise ValueError(
"Inferred ar_order is smaller than 1. Increase the last dimension "
"of rho or remove constant_term"
)

if init_dist is not None:
if not isinstance(init_dist, TensorVariable) or not isinstance(
init_dist.owner.op, RandomVariable
Expand All @@ -477,6 +460,41 @@ def dist(

return super().dist([rhos, sigma, init_dist, steps, ar_order, constant], **kwargs)

@classmethod
def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: bool) -> int:
"""Compute ar_order given inputs
If ar_order is not specified we do constant folding on the shape of rhos
to retrieve it. For example, this will detect that
Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
Raises
------
ValueError
If inferred ar_order cannot be inferred from rhos or if it is less than 1
"""
if ar_order is None:
shape_fg = FunctionGraph(
outputs=[rhos.shape[-1]],
features=[ShapeFeature()],
clone=True,
)
(folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
folded_shape = getattr(folded_shape, "data", None)
if folded_shape is None:
raise ValueError(
"Could not infer ar_order from last dimension of rho. Pass it "
"explictily or make sure rho have a static shape"
)
ar_order = int(folded_shape) - int(constant)
if ar_order < 1:
raise ValueError(
"Inferred ar_order is smaller than 1. Increase the last dimension "
"of rho or remove constant_term"
)

return ar_order

@classmethod
def num_rngs(cls, *args, **kwargs):
return 2
Expand Down Expand Up @@ -540,7 +558,7 @@ def step(*args):
fn=step,
outputs_info=[{"initial": init_.T, "taps": range(-ar_order, 0)}],
non_sequences=[rhos_bcast_.T[::-1], sigma_.T, noise_rng],
n_steps=at.max((0, steps_ - ar_order)),
n_steps=steps_,
strict=True,
)
(noise_next_rng,) = tuple(innov_updates_.values())
Expand Down
6 changes: 3 additions & 3 deletions pymc/tests/test_distributions_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_batched_sigma(self):
beta_tp.set_value(np.zeros((ar_order,))) # Should always be close to zero
sigma_tp = np.full(batch_size, [0.01, 0.1, 1, 10, 100])
y_eval = t0["y"].eval({t0["sigma"]: sigma_tp})
assert y_eval.shape == (*batch_size, steps)
assert y_eval.shape == (*batch_size, steps + ar_order)
assert np.allclose(y_eval.std(axis=(0, 2)), [0.01, 0.1, 1, 10, 100], rtol=0.1)

def test_batched_init_dist(self):
Expand All @@ -389,7 +389,7 @@ def test_batched_init_dist(self):
init_dist = t0["y"].owner.inputs[2]
init_dist_tp = np.full((batch_size, ar_order), (np.arange(batch_size) * 100)[:, None])
y_eval = t0["y"].eval({init_dist: init_dist_tp})
assert y_eval.shape == (batch_size, steps)
assert y_eval.shape == (batch_size, steps + ar_order)
assert np.allclose(
y_eval[:, -10:].mean(-1), np.arange(batch_size) * 100, rtol=0.1, atol=0.5
)
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_multivariate_init_dist(self):
def test_moment(self, size, expected):
with Model() as model:
init_dist = Constant.dist([[1.0, 2.0], [3.0, 4.0]])
AR("x", rho=[0, 0], init_dist=init_dist, steps=7, size=size)
AR("x", rho=[0, 0], init_dist=init_dist, steps=5, size=size)
assert_moment_is_expected(model, expected, check_finite_logp=False)


Expand Down

0 comments on commit 504da82

Please sign in to comment.