diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index 302571edf8..048258f85e 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -23,6 +23,7 @@ from pymc3.distributions import HalfCauchy, Normal, transforms from pymc3 import Potential, Deterministic from pymc3.model import ValueGradFunction +from .helpers import select_by_precision class NewModel(pm.Model): @@ -192,17 +193,33 @@ def test_matrix_multiplication(): tune=0, compute_convergence_checks=False, progressbar=False) + decimal = select_by_precision(7, 5) for point in posterior.points(): - npt.assert_almost_equal(point['matrix'] @ point['transformed'], - point['rv_rv']) - npt.assert_almost_equal(np.ones((2, 2)) @ point['transformed'], - point['np_rv']) - npt.assert_almost_equal(point['matrix'] @ np.ones(2), - point['rv_np']) - npt.assert_almost_equal(point['matrix'] @ point['rv_rv'], - point['rv_det']) - npt.assert_almost_equal(point['rv_rv'] @ point['transformed'], - point['det_rv']) + npt.assert_almost_equal( + point['matrix'] @ point['transformed'], + point['rv_rv'], + decimal=decimal, + ) + npt.assert_almost_equal( + np.ones((2, 2)) @ point['transformed'], + point['np_rv'], + decimal=decimal, + ) + npt.assert_almost_equal( + point['matrix'] @ np.ones(2), + point['rv_np'], + decimal=decimal, + ) + npt.assert_almost_equal( + point['matrix'] @ point['rv_rv'], + point['rv_det'], + decimal=decimal, + ) + npt.assert_almost_equal( + point['rv_rv'] @ point['transformed'], + point['det_rv'], + decimal=decimal, + ) def test_duplicate_vars():