From 79200d2bec04c8b843ad8862344ea0167e7e9cf9 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 8 Oct 2024 12:57:01 +0200 Subject: [PATCH] Do not use initval in test model PRs https://github.com/pymc-devs/pymc/pull/7508 and https://github.com/pymc-devs/pymc/pull/7492 introduced incompatible changes but were not tested simultaneously. Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals. --- tests/models.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/models.py b/tests/models.py index b66c1dc67d..fd45fb8bdb 100644 --- a/tests/models.py +++ b/tests/models.py @@ -18,7 +18,6 @@ import pytensor import pytensor.tensor as pt -from pytensor import config from pytensor.compile.ops import as_op import pymc as pm @@ -30,9 +29,9 @@ def simple_model(): mu = -2.1 tau = 1.3 with Model() as model: - Normal("x", mu, tau=tau, size=2, initval=np.array([0.1, 0.1]).astype(config.floatX)) + x = Normal("x", mu, tau=tau, size=2) - return model.initial_point(), model, (mu, tau**-0.5) + return {"x": np.array([0.1, 0.1], dtype=x.type.dtype)}, model, (mu, tau**-0.5) def another_simple_model(): @@ -46,11 +45,11 @@ def simple_categorical(): p = np.array([0.1, 0.2, 0.3, 0.4]) v = np.array([0.0, 1.0, 2.0, 3.0]) with Model() as model: - Categorical("x", p, size=3, initval=[1, 2, 3]) + x = Categorical("x", p, size=3) mu = np.dot(p, v) var = np.dot(p, (v - mu) ** 2) - return model.initial_point(), model, (mu, var) + return {"x": np.array([1, 2, 3], dtype=x.type.dtype)}, model, (mu, var) def multidimensional_model(): @@ -98,15 +97,14 @@ def mv_simple(): p = np.array([[2.0, 0, 0], [0.05, 0.1, 0], [1.0, -0.05, 5.5]]) tau = np.dot(p, p.T) with pm.Model() as model: - pm.MvNormal( + x = pm.MvNormal( "x", pt.constant(mu), tau=pt.constant(tau), - initval=np.array([0.1, 1.0, 0.8]), ) H = tau C = np.linalg.inv(H) - return model.initial_point(), model, (mu, C) + return {"x": np.array([0.1, 1.0, 0.8], dtype=x.type.dtype)}, model, (mu, C) def mv_simple_coarse():