diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f02afdc623..764c44a2d2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -108,6 +108,7 @@ jobs: - | tests/backends/test_mcbackend.py + tests/backends/test_zarr.py tests/distributions/test_truncated.py tests/logprob/test_abstract.py tests/logprob/test_basic.py @@ -284,6 +285,7 @@ jobs: - | tests/backends/test_arviz.py + tests/backends/test_zarr.py tests/variational/test_updates.py fail-fast: false runs-on: ${{ matrix.os }} diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 85e6694a95..da98eb54cb 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -19,6 +19,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - jax diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 97d25dd5b8..a3fa60660c 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -10,6 +10,7 @@ dependencies: - cachetools>=4.2.1 - cloudpickle - h5py>=2.7 +- zarr>=2.5.0,<3 # Jaxlib version must not be greater than jax version! - blackjax>=1.2.2 - jax>=0.4.28 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 58cde0d327..d95621a1bc 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -23,6 +23,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 6d785e2cac..b9e672d587 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -20,6 +20,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for dev, testing and docs build - ipython>=7.16 - myst-nb<=1.0.0 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index fd17c31711..f98272a7f9 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -23,6 +23,7 @@ dependencies: - scipy>=1.4.1 - typing-extensions>=3.7.4 - threadpoolctl>=3.1.0 +- zarr>=2.5.0,<3 # Extra dependencies for testing - ipython>=7.16 - pre-commit>=2.8.0 diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 986a34f4ba..897a73f190 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -72,6 +72,7 @@ from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray +from pymc.backends.zarr import ZarrTrace from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep @@ -118,15 +119,27 @@ def _init_trace( def init_traces( *, - backend: TraceOrBackend | None, + backend: TraceOrBackend | ZarrTrace | None, chains: int, expected_length: int, step: BlockedStep | CompoundStep, - initial_point: Mapping[str, np.ndarray], + initial_point: dict[str, np.ndarray], model: Model, trace_vars: list[TensorVariable] | None = None, + tune: int = 0, ) -> tuple[RunType | None, Sequence[IBaseTrace]]: """Initialize a trace recorder for each chain.""" + if isinstance(backend, ZarrTrace): + backend.init_trace( + chains=chains, + draws=expected_length - tune, + tune=tune, + step=step, + model=model, + vars=trace_vars, + test_point=initial_point, + ) + return None, backend.straces if HAS_MCB and isinstance(backend, Backend): return init_chain_adapters( backend=backend, diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py new file mode 100644 index 0000000000..e46629bb54 --- /dev/null +++ b/pymc/backends/zarr.py @@ -0,0 +1,541 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any + +import arviz as az +import numcodecs +import numpy as np +import xarray as xr +import zarr + +from arviz.data.base import make_attrs +from arviz.data.inference_data import WARMUP_TAG +from numcodecs.abc import Codec +from pytensor.tensor.variable import TensorVariable +from zarr.storage import BaseStore, default_compressor +from zarr.sync import Synchronizer + +import pymc + +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) +from pymc.backends.base import BaseTrace +from pymc.blocking import StatDtype, StatShape +from pymc.model.core import Model, modelcontext +from pymc.step_methods.compound import ( + BlockedStep, + CompoundStep, + StatsBijection, + get_stats_dtypes_shapes_from_steps, +) +from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name + + +class ZarrChain(BaseTrace): + def __init__( + self, + store: BaseStore | MutableMapping, + stats_bijection: StatsBijection, + synchronizer: Synchronizer | None = None, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: dict[str, np.ndarray] | None = None, + ): + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) + self.unconstrained_variables = { + var.name for var in self.vars if is_transformed_name(var.name) + } + self.draw_idx = 0 + self._posterior = zarr.open_group( + store, synchronizer=synchronizer, path="posterior", mode="a" + ) + if self.unconstrained_variables: + self._unconstrained_posterior = zarr.open_group( + store, synchronizer=synchronizer, path="unconstrained_posterior", mode="a" + ) + self._sample_stats = zarr.open_group( + store, synchronizer=synchronizer, path="sample_stats", mode="a" + ) + self._sampling_state = zarr.open_group( + store, synchronizer=synchronizer, path="_sampling_state", mode="a" + ) + self.stats_bijection = stats_bijection + + def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] + self.chain = chain + + def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + chain = self.chain + draw_idx = self.draw_idx + unconstrained_variables = self.unconstrained_variables + for var_name, var_value in zip(self.varnames, self.fn(draw)): + if var_name in unconstrained_variables: + self._unconstrained_posterior[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + else: + self._posterior[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + for var_name, var_value in self.stats_bijection.map(stats).items(): + self._sample_stats[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + self.draw_idx += 1 + + def record_sampling_state(self, step): + self._sampling_state.sampling_state.set_coordinate_selection( + self.chain, np.array([step.sampling_state], dtype="object") + ) + self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx) + + +FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None +DEFAULT_FILL_VALUES: dict[Any, FILL_VALUE_TYPE] = { + np.floating: np.nan, + np.integer: 0, + np.bool_: False, + np.str_: "", + np.datetime64: np.datetime64(0, "Y"), + np.timedelta64: np.timedelta64(0, "Y"), +} + + +def get_initial_fill_value_and_codec( + dtype: Any, +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: + _dtype = np.dtype(dtype) + fill_value: FILL_VALUE_TYPE = None + codec = None + try: + fill_value = DEFAULT_FILL_VALUES[_dtype] + except KeyError: + for key in DEFAULT_FILL_VALUES: + if np.issubdtype(_dtype, key): + fill_value = DEFAULT_FILL_VALUES[key] + break + else: + codec = numcodecs.Pickle() + return fill_value, _dtype, codec + + +class ZarrTrace: + def __init__( + self, + store: BaseStore | MutableMapping | None = None, + synchronizer: Synchronizer | None = None, + compressor: Codec | None | _UnsetType = UNSET, + draws_per_chunk: int = 1, + include_transformed: bool = False, + ): + self.synchronizer = synchronizer + if compressor is UNSET: + compressor = default_compressor + self.compressor = compressor + self.root = zarr.group( + store=store, + overwrite=True, + synchronizer=synchronizer, + ) + + self.draws_per_chunk = int(draws_per_chunk) + assert self.draws_per_chunk >= 1 + + self.include_transformed = include_transformed + + self._is_base_setup = False + + def groups(self) -> list[str]: + return [str(group_name) for group_name, _ in self.root.groups()] + + @property + def posterior(self) -> zarr.Group: + return self.root.posterior + + @property + def unconstrained_posterior(self) -> zarr.Group: + return self.root.unconstrained_posterior + + @property + def sample_stats(self) -> zarr.Group: + return self.root.sample_stats + + @property + def constant_data(self) -> zarr.Group: + return self.root.constant_data + + @property + def observed_data(self) -> zarr.Group: + return self.root.observed_data + + @property + def _sampling_state(self) -> zarr.Group: + return self.root._sampling_state + + def init_trace( + self, + chains: int, + draws: int, + tune: int, + step: BlockedStep | CompoundStep, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: dict[str, np.ndarray] | None = None, + ): + if self._is_base_setup: + raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover + model = modelcontext(model) + self.model = model + self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model) + if vars is None: + vars = model.unobserved_value_vars + + unnamed_vars = {var for var in vars if var.name is None} + assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}" + self.varnames = get_default_varnames( + [var.name for var in vars], include_transformed=self.include_transformed + ) + self.vars = [var for var in vars if var.name in self.varnames] + + self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") + + # Get variable shapes. Most backends will need this + # information. + if test_point is None: + test_point = model.initial_point() + var_values = list(zip(self.varnames, self.fn(test_point))) + self.var_dtype_shapes = { + var: (value.dtype, value.shape) + for var, value in var_values + if not is_transformed_name(var) + } + self.unc_var_dtype_shapes = { + var: (value.dtype, value.shape) for var, value in var_values if is_transformed_name(var) + } + + self.create_group( + name="constant_data", + data_dict=find_constants(self.model), + ) + + self.create_group( + name="observed_data", + data_dict=find_observations(self.model), + ) + + # Create the posterior that includes warmup draws + self.init_group_with_empty( + group=self.root.create_group(name="posterior", overwrite=True), + var_dtype_and_shape=self.var_dtype_shapes, + chains=chains, + draws=tune + draws, + ) + + # Create the unconstrained posterior group that includes warmup draws + if self.include_transformed and self.unc_var_dtype_shapes: + self.init_group_with_empty( + group=self.root.create_group(name="unconstrained_posterior", overwrite=True), + var_dtype_and_shape=self.unc_var_dtype_shapes, + chains=chains, + draws=tune + draws, + ) + + # Create the sample stats that include warmup draws + stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( + [step] if isinstance(step, BlockedStep) else step.methods + ) + self.init_group_with_empty( + group=self.root.create_group(name="sample_stats", overwrite=True), + var_dtype_and_shape=stats_dtypes_shapes, + chains=chains, + draws=tune + draws, + ) + + self.init_sampling_state_group(tune=tune, chains=chains) + + self.straces = [ + ZarrChain( + store=self.root.store, + synchronizer=self.synchronizer, + model=self.model, + vars=self.vars, + test_point=test_point, + stats_bijection=StatsBijection(step.stats_dtypes), + ) + for _ in range(chains) + ] + for chain, strace in enumerate(self.straces): + strace.setup(draws=draws, chain=chain, sampler_vars=None) + + def split_warmup_groups(self): + if "warmup_posterior" not in self.groups(): + self.split_warmup("posterior", error_if_already_split=False) + self.split_warmup("sample_stats", error_if_already_split=False) + try: + self.split_warmup("unconstrained_posterior", error_if_already_split=False) + except KeyError: + pass + + @property + def tuning_steps(self): + try: + return int(self._sampling_state.tuning_steps.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no tuning step information available" + ) + + @property + def sampling_time(self): + try: + return float(self._sampling_state.sampling_time.get_basic_selection()) + except AttributeError: # pragma: no cover + raise ValueError( + "ZarrTrace has not been initialized and there is no sampling time information available" + ) + + def init_sampling_state_group(self, tune: int, chains: int): + state = self.root.create_group(name="_sampling_state", overwrite=True) + sampling_state = state.empty( + name="sampling_state", + overwrite=True, + shape=(chains,), + chunks=(1,), + dtype="object", + object_codec=numcodecs.Pickle(), + compressor=self.compressor, + ) + sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + draw_idx = state.array( + name="draw_idx", + overwrite=True, + data=np.zeros(chains, dtype="int"), + chunks=(1,), + dtype="int", + fill_value=-1, + compressor=self.compressor, + ) + draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.array( + name="tuning_steps", + data=tune, + overwrite=True, + dtype="int", + fill_value=0, + compressor=self.compressor, + ) + state.array( + name="sampling_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + state.array( + name="sampling_start_time", + data=0.0, + dtype="float", + fill_value=0.0, + compressor=self.compressor, + ) + + chain = state.array( + name="chain", + data=np.arange(chains), + compressor=self.compressor, + ) + + chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + state.empty( + name="global_warnings", + dtype="object", + object_codec=numcodecs.Pickle(), + shape=(0,), + ) + + def init_group_with_empty( + self, + group: zarr.Group, + var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], + chains: int, + draws: int, + ) -> zarr.Group: + group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} + for name, (_dtype, shape) in var_dtype_and_shape.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) + shape = shape or () + array = group.full( + name=name, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + shape=(chains, draws, *shape), + chunks=(1, self.draws_per_chunk, *shape), + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i, shape_i in enumerate(shape): + dim = f"{name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(shape_i, dtype="int") + dims = ("chain", "draw", *dims) + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: + group: zarr.Group | None = None + if data_dict: + group_coords = {} + group = self.root.create_group(name=name, overwrite=True) + for var_name, var_value in data_dict.items(): + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(var_value.dtype) + array = group.array( + name=var_name, + data=var_value, + fill_value=fill_value, + dtype=dtype, + object_codec=object_codec, + compressor=self.compressor, + ) + try: + dims = self.vars_to_dims[var_name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i in range(var_value.ndim): + dim = f"{var_name}_dim_{i}" + dims.append(dim) + group_coords[dim] = np.arange(var_value.shape[i], dtype="int") + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array( + name=dim, + data=coord, + fill_value=None, + compressor=self.compressor, + ) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def split_warmup(self, group_name, error_if_already_split=True): + if error_if_already_split and f"{WARMUP_TAG}{group_name}" in { + group_name for group_name, _ in self.root.groups() + }: + raise RuntimeError(f"Warmup data for {group_name} has already been split") + posterior_group = self.root[group_name] + tune = self.tuning_steps + warmup_group = self.root.create_group(f"{WARMUP_TAG}{group_name}", overwrite=True) + if tune == 0: + try: + self.root.pop(f"{WARMUP_TAG}{group_name}") + except KeyError: + pass + return + for name, array in posterior_group.arrays(): + array_attrs = array.attrs.asdict() + if name == "draw": + warmup_array = warmup_group.array( + name="draw", + data=np.arange(tune), + dtype="int", + compressor=self.compressor, + ) + posterior_array = posterior_group.array( + name=name, + data=np.arange(len(array) - tune), + dtype="int", + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + else: + dims = array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[:2] == ["chain", "draw"]: + must_overwrite_posterior = True + warmup_idx = (slice(None), slice(None, tune, None)) + posterior_idx = (slice(None), slice(tune, None, None)) + else: + must_overwrite_posterior = False + warmup_idx = slice(None) + fill_value, dtype, object_codec = get_initial_fill_value_and_codec(array.dtype) + warmup_array = warmup_group.array( + name=name, + data=array[warmup_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + compressor=self.compressor, + ) + if must_overwrite_posterior: + posterior_array = posterior_group.array( + name=name, + data=array[posterior_idx], + chunks=array.chunks, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + overwrite=True, + compressor=self.compressor, + ) + posterior_array.attrs.update(array_attrs) + warmup_array.attrs.update(array_attrs) + + def to_inferencedata(self, save_warmup=False) -> az.InferenceData: + self.split_warmup_groups() + # Xarray complains if we try to open a zarr hierarchy that doesn't have consolidated metadata + consolidated_root = zarr.consolidate_metadata(self.root.store) + # The ConsolidatedMetadataStore looks like an empty store from xarray's point of view + # we need to actually grab the underlying store so that xarray doesn't produce completely + # empty arrays + store = consolidated_root.store.store + groups = {} + try: + global_attrs = { + "tuning_steps": self.tuning_steps, + "sampling_time": self.sampling_time, + } + except AttributeError: + global_attrs = {} # pragma: no cover + for name, _ in self.root.groups(): + if name.startswith("_") or (not save_warmup and name.startswith(WARMUP_TAG)): + continue + data = xr.open_zarr(store, group=name, mask_and_scale=False) + attrs = {**data.attrs, **global_attrs} + data.attrs = make_attrs(attrs=attrs, library=pymc) + groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data + return az.InferenceData(**groups) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4ee79607b7..a1f29148e7 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -50,6 +50,7 @@ find_observations, ) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains +from pymc.backends.zarr import ZarrTrace from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain @@ -475,7 +476,7 @@ def sample( blas_cores: int | None | Literal["auto"] = "auto", model: Model | None = None, **kwargs, -) -> InferenceData | MultiTrace: +) -> InferenceData | MultiTrace | ZarrTrace: r"""Draw samples from the posterior using the given step methods. Multiple step methods are supported via compound step methods. @@ -702,7 +703,7 @@ def joined_blas_limiter(): rngs = get_random_generator(random_seed).spawn(chains) random_seed_list = [rng.integers(2**30) for rng in rngs] - if not discard_tuned_samples and not return_inferencedata: + if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace): warnings.warn( "Tuning samples will be included in the returned `MultiTrace` object, which can lead to" " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n" @@ -808,6 +809,7 @@ def joined_blas_limiter(): trace_vars=trace_vars, initial_point=ip, model=model, + tune=tune, ) sample_args = { @@ -890,7 +892,7 @@ def joined_blas_limiter(): # into a function to make it easier to test and refactor. return _sample_return( run=run, - traces=traces, + traces=trace if isinstance(trace, ZarrTrace) else traces, tune=tune, t_sampling=t_sampling, discard_tuned_samples=discard_tuned_samples, @@ -905,7 +907,7 @@ def joined_blas_limiter(): def _sample_return( *, run: RunType | None, - traces: Sequence[IBaseTrace], + traces: Sequence[IBaseTrace] | ZarrTrace, tune: int, t_sampling: float, discard_tuned_samples: bool, @@ -914,18 +916,69 @@ def _sample_return( keep_warning_stat: bool, idata_kwargs: dict[str, Any], model: Model, -) -> InferenceData | MultiTrace: +) -> InferenceData | MultiTrace | ZarrTrace: """Pick/slice chains, run diagnostics and convert to the desired return type. Final step of `pm.sampler`. """ + if isinstance(traces, ZarrTrace): + # Split warmup from posterior samples + traces.split_warmup_groups() + + # Set sampling time + traces._sampling_state.sampling_time.set_basic_selection((), t_sampling) + + # Compute number of actual draws per chain + total_draws_per_chain = traces._sampling_state.draw_idx[:] + n_chains = len(traces.straces) + desired_tune = traces.tuning_steps + desired_draw = len(traces.posterior.draw) + tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune) + draws_per_chain = total_draws_per_chain - tuning_steps_per_chain + + total_n_tune = tuning_steps_per_chain.sum() + total_draws = draws_per_chain.sum() + + _log.info( + f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations ' + f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) " + f"took {t_sampling:.0f} seconds." + ) + + if compute_convergence_checks or return_inferencedata: + idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples) + log_likelihood = idata_kwargs.pop("log_likelihood", False) + if log_likelihood: + from pymc.stats.log_density import compute_log_likelihood + + idata = compute_log_likelihood( + idata, + var_names=None if log_likelihood is True else log_likelihood, + extend_inferencedata=True, + model=model, + sample_dims=["chain", "draw"], + progressbar=False, + ) + if compute_convergence_checks: + warns = run_convergence_checks(idata, model) + for warn in warns: + traces._sampling_state.global_warnings.append(np.array([warn])) + log_warnings(warns) + + if return_inferencedata: + # By default we drop the "warning" stat which contains `SamplerWarning` + # objects that can not be stored with `.to_netcdf()`. + if not keep_warning_stat: + return drop_warning_stat(idata) + return idata + return traces + # Pick and slice chains to keep the maximum number of samples if discard_tuned_samples: traces, length = _choose_chains(traces, tune) else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened # ideally via the "tune" statistic, but not all samplers record it! if "tune" in mtrace.stat_names: @@ -954,7 +1007,6 @@ def _sample_return( f"took {t_sampling:.0f} seconds." ) - idata = None if compute_convergence_checks or return_inferencedata: ikwargs: dict[str, Any] = {"model": model, "save_warmup": not discard_tuned_samples} ikwargs.update(idata_kwargs) diff --git a/pymc/util.py b/pymc/util.py index 8ec8aa84de..881aedeef9 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import re import warnings from collections.abc import Sequence @@ -275,7 +276,12 @@ def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData: nidata = arviz.InferenceData(attrs=idata.attrs) for gname, group in idata.items(): if "sample_stat" in gname: - group = group.drop_vars(names=["warning", "warning_dim_0"], errors="ignore") + warning_vars = [ + name + for name in group.data_vars + if name == "warning" or re.match(r"sampler_\d+__warning", str(name)) + ] + group = group.drop_vars(names=[*warning_vars, "warning_dim_0"], errors="ignore") nidata.add_groups({gname: group}, coords=group.coords, dims=group.dims) return nidata diff --git a/requirements-dev.txt b/requirements-dev.txt index 082eab73ce..890b87e19b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -32,3 +32,4 @@ threadpoolctl>=3.1.0 types-cachetools typing-extensions>=3.7.4 watermark +zarr>=2.5.0,<3 diff --git a/requirements.txt b/requirements.txt index b59ca29127..0a2c224706 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0 typing-extensions>=3.7.4 +zarr>=2.5.0,<3 diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py new file mode 100644 index 0000000000..eb95947587 --- /dev/null +++ b/tests/backends/test_zarr.py @@ -0,0 +1,448 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools + +from dataclasses import asdict + +import numpy as np +import pytest +import zarr + +from arviz import InferenceData + +import pymc as pm + +from pymc.backends.zarr import ZarrTrace +from pymc.stats.convergence import SamplerWarning +from pymc.step_methods import NUTS, CompoundStep, Metropolis +from pymc.step_methods.state import equal_dataclass_values +from tests.helpers import equal_sampling_states + + +@pytest.fixture(scope="module") +def model(): + time_int = np.array([np.timedelta64(np.timedelta64(i, "h"), "ns") for i in range(25)]) + coords = { + "dim_int": range(3), + "dim_str": ["A", "B"], + "dim_time": np.datetime64("2024-10-16") + time_int, + "dim_interval": time_int, + } + rng = np.random.default_rng(42) + with pm.Model(coords=coords) as model: + data1 = pm.Data("data1", np.ones(3, dtype="bool"), dims=["dim_int"]) + data2 = pm.Data("data2", np.ones(3, dtype="bool")) + time = pm.Data("time", time_int / np.timedelta64(1, "h"), dims="dim_time") + + a = pm.Normal("a", shape=(len(coords["dim_int"]), len(coords["dim_str"]))) + b = pm.Normal("b", dims=["dim_int", "dim_str"]) + c = pm.Deterministic("c", a + b, dims=["dim_int", "dim_str"]) + + d = pm.LogNormal("d", dims="dim_time") + e = pm.Deterministic("e", (time + d)[:, None] + c[0], dims=["dim_interval", "dim_str"]) + + obs = pm.Normal( + "obs", + mu=e, + observed=rng.normal(size=(len(coords["dim_time"]), len(coords["dim_str"]))), + dims=["dim_time", "dim_str"], + ) + + return model + + +@pytest.fixture(params=[True, False]) +def include_transformed(request): + return request.param + + +@pytest.fixture(params=["single_step", "compound_step"]) +def model_step(request, model): + rng = np.random.default_rng(42) + with model: + if request.param == "single_step": + step = NUTS(rng=rng) + else: + rngs = rng.spawn(2) + step = CompoundStep( + [ + Metropolis(vars=model["a"], rng=rngs[0]), + NUTS(vars=[rv for rv in model.value_vars if rv.name != "a"], rng=rngs[1]), + ] + ) + return step + + +def test_record(model, model_step, include_transformed): + store = zarr.MemoryStore() + trace = ZarrTrace(store=store, include_transformed=include_transformed) + draws = 5 + tune = 5 + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + # Assert that init was successful + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "constant_data", + "observed_data", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Record samples from the ZarrChain + manually_collected_warmup_draws = [] + manually_collected_warmup_stats = [] + manually_collected_draws = [] + manually_collected_stats = [] + point = model.initial_point() + for draw in range(tune + draws): + tuning = draw < tune + if not tuning: + model_step.stop_tuning() + point, stats = model_step.step(point) + if tuning: + manually_collected_warmup_draws.append(point) + manually_collected_warmup_stats.append(stats) + else: + manually_collected_draws.append(point) + manually_collected_stats.append(stats) + trace.straces[0].record(point, stats) + trace.straces[0].record_sampling_state(model_step) + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + + # Assert split warmup + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + expected_groups = { + "_sampling_state", + "sample_stats", + "posterior", + "warmup_sample_stats", + "warmup_posterior", + "constant_data", + "observed_data", + } + if include_transformed: + trace.split_warmup("unconstrained_posterior") + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert {group_name for group_name, _ in trace.root.groups()} == expected_groups + # trace.consolidate() + + # Assert observed data is correct + assert set(dict(trace.observed_data.arrays())) == {"obs", "dim_time", "dim_str"} + assert list(trace.observed_data.obs.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time", "dim_str"] + np.testing.assert_array_equal(trace.observed_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.observed_data.dim_str[:], model.coords["dim_str"]) + + # Assert constant data is correct + assert set(dict(trace.constant_data.arrays())) == { + "data1", + "data2", + "data2_dim_0", + "time", + "dim_time", + "dim_int", + } + assert list(trace.constant_data.data1.attrs["_ARRAY_DIMENSIONS"]) == ["dim_int"] + assert list(trace.constant_data.data2.attrs["_ARRAY_DIMENSIONS"]) == ["data2_dim_0"] + assert list(trace.constant_data.time.attrs["_ARRAY_DIMENSIONS"]) == ["dim_time"] + np.testing.assert_array_equal(trace.constant_data.dim_time[:], model.coords["dim_time"]) + np.testing.assert_array_equal(trace.constant_data.dim_int[:], model.coords["dim_int"]) + + # Assert unconstrained posterior has correct shapes + assert {rv.name for rv in model.free_RVs + model.deterministics} <= set( + dict(trace.posterior.arrays()) + ) + if include_transformed: + assert {"d_log__", "chain", "draw", "d_log___dim_0"} == set( + dict(trace.unconstrained_posterior.arrays()) + ) + assert list(trace.unconstrained_posterior.d_log__.attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + "d_log___dim_0", + ] + np.testing.assert_array_equal(trace.unconstrained_posterior.chain, np.arange(1)) + np.testing.assert_array_equal(trace.unconstrained_posterior.draw, np.arange(draws)) + np.testing.assert_array_equal( + trace.unconstrained_posterior.d_log___dim_0, np.arange(len(model.coords["dim_time"])) + ) + + # Assert posterior has correct shape + posterior_dims = set() + for rv_name in [rv.name for rv in model.free_RVs + model.deterministics]: + if rv_name == "a": + expected_dims = ["a_dim_0", "a_dim_1"] + else: + expected_dims = model.named_vars_to_dims[rv_name] + posterior_dims |= set(expected_dims) + assert list(trace.posterior[rv_name].attrs["_ARRAY_DIMENSIONS"]) == [ + "chain", + "draw", + *expected_dims, + ] + for posterior_dim in posterior_dims: + try: + model_coord = model.coords[posterior_dim] + except KeyError: + model_coord = { + "a_dim_0": np.arange(len(model.coords["dim_int"])), + "a_dim_1": np.arange(len(model.coords["dim_str"])), + "chain": np.arange(1), + "draw": np.arange(draws), + }[posterior_dim] + np.testing.assert_array_equal(trace.posterior[posterior_dim][:], model_coord) + + # Assert sample stats have correct shape + stats_bijection = trace.straces[0].stats_bijection + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var in trace.posterior.arrays(): + assert np.array_equal(trace.posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{stat_val} != {value}") + + # Assert manually collected warmup samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_warmup_draws, manually_collected_warmup_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["warmup_unconstrained_posterior"] + else: + posterior = trace.root["warmup_posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["warmup_sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{stat_val} != {value}") + + # Assert manually collected posterior samples match + for draw_idx, (draw, stat) in enumerate( + zip(manually_collected_draws, manually_collected_stats) + ): + stat = stats_bijection.map(stat) + for var, value in draw.items(): + if var == "d_log__": + if not include_transformed: + continue + posterior = trace.root["unconstrained_posterior"] + else: + posterior = trace.root["posterior"] + if var in posterior.arrays(): + assert np.array_equal(posterior[var][0, draw_idx], value) + for var, value in stat.items(): + sample_stats = trace.root["sample_stats"] + stat_val = sample_stats[var][0, draw_idx] + if not isinstance(stat_val, SamplerWarning): + unequal_stats = stat_val != value + else: + unequal_stats = not equal_dataclass_values(asdict(stat_val), asdict(value)) + if unequal_stats and not (np.isnan(stat_val) and np.isnan(value)): + raise AssertionError(f"{stat_val} != {value}") + + # Assert sampling_state is correct + assert list(trace._sampling_state.draw_idx[:]) == [draws + tune] + assert equal_sampling_states( + trace._sampling_state.sampling_state[0], + model_step.sampling_state, + ) + + # Assert to inference data returns the expected groups + idata = trace.to_inferencedata(save_warmup=True) + expected_groups = { + "posterior", + "constant_data", + "observed_data", + "sample_stats", + "warmup_posterior", + "warmup_sample_stats", + } + if include_transformed: + expected_groups.add("unconstrained_posterior") + expected_groups.add("warmup_unconstrained_posterior") + assert set(idata.groups()) == expected_groups + for group in idata.groups(): + for name, value in itertools.chain( + idata[group].data_vars.items(), idata[group].coords.items() + ): + try: + array = getattr(trace, group)[name][:] + except AttributeError: + array = trace.root[group][name][:] + if "sample_stats" in group and "warning" in name: + continue + np.testing.assert_array_equal(array, value) + + +@pytest.mark.parametrize("tune", [0, 5, 10]) +def test_split_warmup(tune, model, model_step, include_transformed): + store = zarr.MemoryStore() + trace = ZarrTrace(store=store, include_transformed=include_transformed) + draws = 10 - tune + trace.init_trace(chains=1, draws=draws, tune=tune, model=model, step=model_step) + + trace.split_warmup("posterior") + trace.split_warmup("sample_stats") + assert len(trace.root.posterior.draw) == draws + assert len(trace.root.sample_stats.draw) == draws + if tune == 0: + with pytest.raises(KeyError): + trace.root["warmup_posterior"] + else: + assert len(trace.root["warmup_posterior"].draw) == tune + assert len(trace.root["warmup_sample_stats"].draw) == tune + + with pytest.raises(RuntimeError): + trace.split_warmup("posterior") + + for var_name, posterior_array in trace.posterior.arrays(): + dims = posterior_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert posterior_array.shape[1] == draws + assert trace.root["warmup_posterior"][var_name].shape[1] == tune + for var_name, sample_stats_array in trace.sample_stats.arrays(): + dims = sample_stats_array.attrs["_ARRAY_DIMENSIONS"] + if len(dims) >= 2 and dims[1] == "draw": + assert sample_stats_array.shape[1] == draws + assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune + + +@pytest.fixture(scope="function", params=[True, False]) +def discard_tuned_samples(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def return_inferencedata(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def keep_warning_stat(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def parallel(request): + return request.param + + +@pytest.fixture(scope="function", params=[True, False]) +def log_likelihood(request): + return request.param + + +def test_sample( + model, + model_step, + include_transformed, + discard_tuned_samples, + return_inferencedata, + keep_warning_stat, + parallel, + log_likelihood, +): + if not return_inferencedata and not log_likelihood: + pytest.skip( + reason="log_likelihood is only computed if an inference data object is returned" + ) + store = zarr.MemoryStore() + trace = ZarrTrace(store=store, include_transformed=include_transformed) + tune = 2 + draws = 3 + if parallel: + chains = 2 + cores = 2 + else: + chains = 1 + cores = 1 + with model: + out_trace = pm.sample( + draws=draws, + tune=tune, + chains=chains, + cores=cores, + trace=trace, + step=model_step, + discard_tuned_samples=discard_tuned_samples, + return_inferencedata=return_inferencedata, + keep_warning_stat=keep_warning_stat, + idata_kwargs={"log_likelihood": log_likelihood}, + ) + + if not return_inferencedata: + assert isinstance(out_trace, ZarrTrace) + assert out_trace.root.store is trace.root.store + else: + assert isinstance(out_trace, InferenceData) + + expected_groups = {"posterior", "constant_data", "observed_data", "sample_stats"} + if include_transformed: + expected_groups |= {"unconstrained_posterior"} + if not return_inferencedata or not discard_tuned_samples: + expected_groups |= {"warmup_posterior", "warmup_sample_stats"} + if include_transformed: + expected_groups |= {"warmup_unconstrained_posterior"} + if not return_inferencedata: + expected_groups |= {"_sampling_state"} + elif log_likelihood: + expected_groups |= {"log_likelihood"} + assert set(out_trace.groups()) == expected_groups + + if return_inferencedata: + warning_stat = ( + "sampler_1__warning" if isinstance(model_step, CompoundStep) else "sampler_0__warning" + ) + if keep_warning_stat: + assert warning_stat in out_trace.sample_stats + else: + assert warning_stat not in out_trace.sample_stats + + # Assert that all variables have non empty samples (not NaNs) + if return_inferencedata: + assert all( + (not np.any(np.isnan(v))) and v.shape[:2] == (chains, draws) + for v in out_trace.posterior.data_vars.values() + ) + else: + dimensions = {*model.coords, "a_dim_0", "a_dim_1", "chain", "draw"} + assert all( + (not np.any(np.isnan(v[:]))) and v.shape[:2] == (chains, draws) + for name, v in out_trace.posterior.arrays() + if name not in dimensions + )