diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4afb93a7fe..1a1fc45e6a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -39,7 +39,8 @@ import numpy as np import pytensor.gradient as tg -from arviz import InferenceData +from arviz import InferenceData, dict_to_dataset +from arviz.data.base import make_attrs from fastprogress.fastprogress import progress_bar from pytensor.graph.basic import Variable from typing_extensions import Protocol, TypeAlias @@ -47,6 +48,11 @@ import pymc as pm from pymc.backends import RunType, TraceOrBackend, init_traces +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError @@ -293,8 +299,8 @@ def _sample_external_nuts( "`idata_kwargs` are currently ignored by the nutpie sampler", UserWarning, ) - compiled_model = nutpie.compile_pymc_model(model) + t_start = time.time() idata = nutpie.sample( compiled_model, draws=draws, @@ -305,6 +311,37 @@ def _sample_external_nuts( progress_bar=progressbar, **nuts_sampler_kwargs, ) + t_sample = time.time() - t_start + # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed + # gather observed and constant data as nutpie.sample() has no access to the PyMC model + coords, dims = coords_and_dims_for_inferencedata(model) + constant_data = dict_to_dataset( + find_constants(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + observed_data = dict_to_dataset( + find_observations(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + attrs = make_attrs( + { + "sampling_time": t_sample, + }, + library=nutpie, + ) + for k, v in attrs.items(): + idata.posterior.attrs[k] = v + idata.add_groups( + {"constant_data": constant_data, "observed_data": observed_data}, + coords=coords, + dims=dims, + ) return idata elif sampler == "numpyro": diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index fc3e6e1551..d7370a560d 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -16,7 +16,7 @@ import numpy.testing as npt import pytest -from pymc import Model, Normal, sample +from pymc import ConstantData, Model, Normal, sample @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) @@ -25,7 +25,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): pytest.importorskip(nuts_sampler) with Model(): - Normal("x") + x = Normal("x", 100, 5) + y = ConstantData("y", [1, 2, 3, 4]) + ConstantData("z", [100, 190, 310, 405]) + + Normal("L", mu=x, sigma=0.1, observed=y) kwargs = dict( nuts_sampler=nuts_sampler, @@ -55,7 +59,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): ) ) assert warns == expected - + assert "y" in idata1.constant_data + assert "z" in idata1.constant_data + assert "L" in idata1.observed_data assert idata1.posterior.chain.size == 2 assert idata1.posterior.draw.size == 500 np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)