diff --git a/mcbackend/__init__.py b/mcbackend/__init__.py index 5b9800b..7ffdf23 100644 --- a/mcbackend/__init__.py +++ b/mcbackend/__init__.py @@ -12,7 +12,7 @@ except ModuleNotFoundError: pass -__version__ = "0.5.1" +__version__ = "0.5.2" __all__ = [ "NumPyBackend", "Backend", diff --git a/mcbackend/core.py b/mcbackend/core.py index 30593fa..85970b8 100644 --- a/mcbackend/core.py +++ b/mcbackend/core.py @@ -23,6 +23,9 @@ _log = logging.getLogger(__file__) +__all__ = ("is_rigid", "chain_id", "Chain", "Run", "Backend") + + def is_rigid(nshape: Optional[Shape]): """Determines wheather the shape is constant. @@ -133,6 +136,20 @@ def sample_stats(self) -> Dict[str, Variable]: return {var.name: var for var in self.rmeta.sample_stats} +def get_tune_mask(chain: Chain, slc: slice = slice(None)) -> numpy.ndarray: + """Load the tuning mask from either a ``"tune"``, or a ``"*__tune"`` stat. + + Raises + ------ + KeyError + When no matching stat is found. + """ + for sname in chain.sample_stats: + if sname.endswith("__tune") or sname == "tune": + return chain.get_stats(sname, slc).astype(bool) + raise KeyError("No tune stat found.") + + class Run: """A handle on one MCMC run.""" @@ -231,14 +248,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) -> slc = slice(0, min_clen) # Obtain a mask by which draws can be split into warmup/posterior - if "tune" in chain.sample_stats: - tune = chain.get_stats("tune", slc).astype(bool) - else: + try: + # Use the same slice to avoid shape issues in case the chain is still active + tune = get_tune_mask(chain, slc) + except KeyError: if c == 0: _log.warning( "No 'tune' stat found. Assuming all iterations are posterior draws." ) - tune = numpy.full((chain_lengths[chain.cid],), False) + tune = numpy.full((slc.stop,), False) # Split all variables draws into warmup/posterior for var in variables: diff --git a/mcbackend/test_utils.py b/mcbackend/test_utils.py index c8062d7..3c3bf0b 100644 --- a/mcbackend/test_utils.py +++ b/mcbackend/test_utils.py @@ -277,11 +277,12 @@ def test__get_chains(self): assert len(chain) == 1 pass - def test__to_inferencedata(self): + @pytest.mark.parametrize("tstatname", ["tune", "sampler__tune", "nottune"]) + def test__to_inferencedata(self, tstatname, caplog): rmeta = make_runmeta( flexibility=False, sample_stats=[ - Variable("tune", "bool"), + Variable(tstatname, "bool"), Variable("sampler_0__logp", "float32"), Variable("warning", "str"), ], @@ -294,15 +295,22 @@ def test__to_inferencedata(self): draws = [make_draw(rmeta.variables) for _ in range(n)] stats = [make_draw(rmeta.sample_stats) for _ in range(n)] for i, (d, s) in enumerate(zip(draws, stats)): - s["tune"] = i < 4 + s[tstatname] = i < 4 chain.append(d, s) idata = run.to_inferencedata() assert isinstance(idata, arviz.InferenceData) assert idata.warmup_posterior.dims["chain"] == 1 - assert idata.warmup_posterior.dims["draw"] == 4 assert idata.posterior.dims["chain"] == 1 - assert idata.posterior.dims["draw"] == 6 + if tstatname == "nottune": + # Splitting into warmup/posterior requires a tune stat! + assert any("No 'tune' stat" in r.message for r in caplog.records) + assert idata.warmup_posterior.dims["draw"] == 0 + assert idata.posterior.dims["draw"] == 10 + else: + assert idata.warmup_posterior.dims["draw"] == 4 + assert idata.posterior.dims["draw"] == 6 + for var in rmeta.variables: assert var.name in set(idata.posterior.keys()) for svar in rmeta.sample_stats: