diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 67ff11fcd2..d57d5b170f 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -393,14 +393,19 @@ class AR(SymbolicDistribution): """ - def __new__(cls, *args, steps=None, **kwargs): + def __new__(cls, name, rho, *args, steps=None, constant=False, ar_order=None, **kwargs): + rhos = at.atleast_1d(at.as_tensor_variable(floatX(rho))) + ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order) steps = get_steps( steps=steps, shape=None, # Shape will be checked in `cls.dist` dims=kwargs.get("dims", None), observed=kwargs.get("observed", None), + step_shape_offset=ar_order, + ) + return super().__new__( + cls, name, rhos, *args, steps=steps, constant=constant, ar_order=ar_order, **kwargs ) - return super().__new__(cls, *args, steps=steps, **kwargs) @classmethod def dist( @@ -426,34 +431,12 @@ def dist( ) init_dist = kwargs["init"] - steps = get_steps(steps=steps, shape=kwargs.get("shape", None)) + ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order) + steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=ar_order) if steps is None: raise ValueError("Must specify steps or shape parameter") steps = at.as_tensor_variable(intX(steps), ndim=0) - if ar_order is None: - # If ar_order is not specified we do constant folding on the shape of rhos - # to retrieve it. For example, this will detect that - # Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before. - shape_fg = FunctionGraph( - outputs=[rhos.shape[-1]], - features=[ShapeFeature()], - clone=True, - ) - (folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs - folded_shape = getattr(folded_shape, "data", None) - if folded_shape is None: - raise ValueError( - "Could not infer ar_order from last dimension of rho. Pass it " - "explictily or make sure rho have a static shape" - ) - ar_order = int(folded_shape) - int(constant) - if ar_order < 1: - raise ValueError( - "Inferred ar_order is smaller than 1. Increase the last dimension " - "of rho or remove constant_term" - ) - if init_dist is not None: if not isinstance(init_dist, TensorVariable) or not isinstance( init_dist.owner.op, RandomVariable @@ -477,6 +460,41 @@ def dist( return super().dist([rhos, sigma, init_dist, steps, ar_order, constant], **kwargs) + @classmethod + def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: bool) -> int: + """Compute ar_order given inputs + + If ar_order is not specified we do constant folding on the shape of rhos + to retrieve it. For example, this will detect that + Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before. + + Raises + ------ + ValueError + If inferred ar_order cannot be inferred from rhos or if it is less than 1 + """ + if ar_order is None: + shape_fg = FunctionGraph( + outputs=[rhos.shape[-1]], + features=[ShapeFeature()], + clone=True, + ) + (folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs + folded_shape = getattr(folded_shape, "data", None) + if folded_shape is None: + raise ValueError( + "Could not infer ar_order from last dimension of rho. Pass it " + "explictily or make sure rho have a static shape" + ) + ar_order = int(folded_shape) - int(constant) + if ar_order < 1: + raise ValueError( + "Inferred ar_order is smaller than 1. Increase the last dimension " + "of rho or remove constant_term" + ) + + return ar_order + @classmethod def num_rngs(cls, *args, **kwargs): return 2 @@ -540,7 +558,7 @@ def step(*args): fn=step, outputs_info=[{"initial": init_.T, "taps": range(-ar_order, 0)}], non_sequences=[rhos_bcast_.T[::-1], sigma_.T, noise_rng], - n_steps=at.max((0, steps_ - ar_order)), + n_steps=steps_, strict=True, ) (noise_next_rng,) = tuple(innov_updates_.values()) diff --git a/pymc/tests/test_distributions_timeseries.py b/pymc/tests/test_distributions_timeseries.py index d6df46d61f..eadafc2d63 100644 --- a/pymc/tests/test_distributions_timeseries.py +++ b/pymc/tests/test_distributions_timeseries.py @@ -363,7 +363,7 @@ def test_batched_sigma(self): beta_tp.set_value(np.zeros((ar_order,))) # Should always be close to zero sigma_tp = np.full(batch_size, [0.01, 0.1, 1, 10, 100]) y_eval = t0["y"].eval({t0["sigma"]: sigma_tp}) - assert y_eval.shape == (*batch_size, steps) + assert y_eval.shape == (*batch_size, steps + ar_order) assert np.allclose(y_eval.std(axis=(0, 2)), [0.01, 0.1, 1, 10, 100], rtol=0.1) def test_batched_init_dist(self): @@ -389,7 +389,7 @@ def test_batched_init_dist(self): init_dist = t0["y"].owner.inputs[2] init_dist_tp = np.full((batch_size, ar_order), (np.arange(batch_size) * 100)[:, None]) y_eval = t0["y"].eval({init_dist: init_dist_tp}) - assert y_eval.shape == (batch_size, steps) + assert y_eval.shape == (batch_size, steps + ar_order) assert np.allclose( y_eval[:, -10:].mean(-1), np.arange(batch_size) * 100, rtol=0.1, atol=0.5 ) @@ -429,7 +429,7 @@ def test_multivariate_init_dist(self): def test_moment(self, size, expected): with Model() as model: init_dist = Constant.dist([[1.0, 2.0], [3.0, 4.0]]) - AR("x", rho=[0, 0], init_dist=init_dist, steps=7, size=size) + AR("x", rho=[0, 0], init_dist=init_dist, steps=5, size=size) assert_moment_is_expected(model, expected, check_finite_logp=False)