From b33e1c2ac43f7bd23f39b0678fbe85cdf8b2d22d Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Fri, 26 Jul 2024 20:40:13 +0200 Subject: [PATCH] Fixing failed tests --- sbi/inference/posteriors/base_posterior.py | 15 +++ tutorials/20_nspe.ipynb | 127 +++++++++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 31bdc257a..a0c88a534 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -264,6 +264,21 @@ def potential( theta.to(self._device), track_gradients=track_gradients ) + def _x_else_default_x(self, x: Optional[Array]) -> Tensor: + if x is not None: + # New x, reset posterior sampler. + self._posterior_sampler = None + return process_x( + x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x + ) + elif self.default_x is None: + raise ValueError( + "Context `x` needed when a default has not been set." + "If you'd like to have a default, use the `.set_default_x()` method." + ) + else: + return self.default_x + def _calculate_map( self, num_iter: int = 1_000, diff --git a/tutorials/20_nspe.ipynb b/tutorials/20_nspe.ipynb index 77c2e494e..0feb8c388 100644 --- a/tutorials/20_nspe.ipynb +++ b/tutorials/20_nspe.ipynb @@ -1490,6 +1490,133 @@ ")\n", "assert posterior is not None, \"Using 'infer' with keyword arguments failed\"" ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "import pytest\n", + "from pyro.infer.mcmc import MCMC\n", + "from torch import Tensor, eye, zeros\n", + "from torch.distributions import MultivariateNormal\n", + "\n", + "from sbi.inference import (\n", + " SNL,\n", + " MCMCPosterior,\n", + " likelihood_estimator_based_potential,\n", + " simulate_for_sbi,\n", + ")\n", + "from sbi.samplers.mcmc import PyMCSampler, SliceSamplerSerial, SliceSamplerVectorized\n", + "from sbi.simulators.linear_gaussian import diagonal_linear_gaussian\n", + "from sbi.utils.user_input_checks import process_prior, process_simulator\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "65cf0dca4e1741d99999557f6653aeb8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00 x_o = [[1]]\n > x_o = [[1, 2, 3]]\n\n are interpreted as single observations with a leading batch dimension of\n one. However\n\n > x_o = [ [1], [2] ]\n > x_o = [ [1,2,3], [4,5,6] ]\n\n are interpreted as a batch of two scalar or vector observations, which\n is not supported yet. The following is interpreted as a matrix-shaped\n observation, e.g. a monochromatic image:\n\n > x_o = [ [[1,2,3], [4,5,6]] ]\n\n Finally, for convenience,\n\n > x_o = [1]\n > x_o = [1, 2, 3]\n\n will be interpreted as a single scalar or single vector observation\n respectively, without the user needing to wrap or unsqueeze them.\n ", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[25], line 31\u001b[0m\n\u001b[1;32m 26\u001b[0m posterior \u001b[38;5;241m=\u001b[39m MCMCPosterior(\n\u001b[1;32m 27\u001b[0m potential_fn, theta_transform\u001b[38;5;241m=\u001b[39mtransform, method\u001b[38;5;241m=\u001b[39msampling_method, proposal\u001b[38;5;241m=\u001b[39mprior\n\u001b[1;32m 28\u001b[0m )\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m posterior\u001b[38;5;241m.\u001b[39mposterior_sampler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 31\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[43mposterior\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 32\u001b[0m \u001b[43m \u001b[49m\u001b[43msample_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_chains\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 33\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mx_o\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 34\u001b[0m \u001b[43m \u001b[49m\u001b[43mmcmc_parameters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minit_strategy\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mprior\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmcmc_params_fast\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# assert isinstance(samples, Tensor)\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;66;03m# assert samples.shape == (num_samples, num_chains, num_dim)\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;66;03m# else: # sampling_method == \"slice_np_vectorized\"\u001b[39;00m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# assert type(posterior.posterior_sampler) is SliceSamplerVectorized\u001b[39;00m\n", + "File \u001b[0;32m~/sbi/sbi/inference/posteriors/mcmc_posterior.py:248\u001b[0m, in \u001b[0;36mMCMCPosterior.sample\u001b[0;34m(self, sample_shape, x, method, thin, warmup_steps, num_chains, init_strategy, init_strategy_parameters, init_strategy_num_candidates, mcmc_parameters, mcmc_method, sample_with, num_workers, mp_context, show_progress_bars)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msample\u001b[39m(\n\u001b[1;32m 211\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 212\u001b[0m sample_shape: Shape \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mSize(),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 226\u001b[0m show_progress_bars: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 227\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m 228\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Return samples from posterior distribution $p(\\theta|x)$ with MCMC.\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \n\u001b[1;32m 230\u001b[0m \u001b[38;5;124;03m Check the `__init__()` method for a description of all arguments as well as\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[38;5;124;03m Samples from posterior.\u001b[39;00m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 248\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpotential_fn\u001b[38;5;241m.\u001b[39mset_x(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_x_else_default_x\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# Replace arguments that were not passed with their default.\u001b[39;00m\n\u001b[1;32m 251\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmethod \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m method\n", + "File \u001b[0;32m~/sbi/sbi/inference/posteriors/base_posterior.py:120\u001b[0m, in \u001b[0;36mNeuralPosterior._x_else_default_x\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 118\u001b[0m \u001b[38;5;66;03m# New x, reset posterior sampler.\u001b[39;00m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_posterior_sampler \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 120\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprocess_x\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_event_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefault_x \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 123\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mContext `x` needed when a default has not been set.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIf you\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124md like to have a default, use the `.set_default_x()` method.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 125\u001b[0m )\n", + "File \u001b[0;32m~/sbi/sbi/utils/user_input_checks.py:589\u001b[0m, in \u001b[0;36mprocess_x\u001b[0;34m(x, x_event_shape, allow_iid_x)\u001b[0m\n\u001b[1;32m 587\u001b[0m input_x_shape \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 588\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m allow_iid_x:\n\u001b[0;32m--> 589\u001b[0m \u001b[43mcheck_for_possibly_batched_x_shape\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_x_shape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 590\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 591\u001b[0m warn_on_iid_x(num_trials\u001b[38;5;241m=\u001b[39minput_x_shape[\u001b[38;5;241m0\u001b[39m])\n", + "File \u001b[0;32m~/sbi/sbi/utils/user_input_checks.py:278\u001b[0m, in \u001b[0;36mcheck_for_possibly_batched_x_shape\u001b[0;34m(x_shape)\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[38;5;66;03m# Reject multidimensional data with batch_shape > 1.\u001b[39;00m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x_ndim \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m inferred_batch_shape \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 278\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 279\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"The `x` passed to condition the posterior for evaluation or sampling\u001b[39;00m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;124;03m has an inferred batch shape larger than one. This is not supported in\u001b[39;00m\n\u001b[1;32m 281\u001b[0m \u001b[38;5;124;03m some sbi methods for reasons depending on the scenario:\u001b[39;00m\n\u001b[1;32m 282\u001b[0m \n\u001b[1;32m 283\u001b[0m \u001b[38;5;124;03m - in case you want to evaluate or sample conditioned on several iid\u001b[39;00m\n\u001b[1;32m 284\u001b[0m \u001b[38;5;124;03m xs e.g., (p(theta | [x1, x2, x3])), this is fully supported only\u001b[39;00m\n\u001b[1;32m 285\u001b[0m \u001b[38;5;124;03m for likelihood based SNLE and SNRE. For SNPE it is supported only\u001b[39;00m\n\u001b[1;32m 286\u001b[0m \u001b[38;5;124;03m for a fixed number of trials and using an appropriate embedding\u001b[39;00m\n\u001b[1;32m 287\u001b[0m \u001b[38;5;124;03m net, i.e., by treating the trials as additional data dimension. In\u001b[39;00m\n\u001b[1;32m 288\u001b[0m \u001b[38;5;124;03m that case, make sure to pass xo with a leading batch dimensionen.\u001b[39;00m\n\u001b[1;32m 289\u001b[0m \n\u001b[1;32m 290\u001b[0m \u001b[38;5;124;03m - in case you trained with a single round to do amortized inference\u001b[39;00m\n\u001b[1;32m 291\u001b[0m \u001b[38;5;124;03m and now you want to evaluate or sample a given theta conditioned on\u001b[39;00m\n\u001b[1;32m 292\u001b[0m \u001b[38;5;124;03m several xs, one after the other, e.g, p(theta | x1), p(theta | x2),\u001b[39;00m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;124;03m p(theta| x3): this broadcasting across xs is not supported in sbi.\u001b[39;00m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;124;03m Instead, what you can do it to call posterior.log_prob(theta, xi)\u001b[39;00m\n\u001b[1;32m 295\u001b[0m \u001b[38;5;124;03m multiple times with different xi.\u001b[39;00m\n\u001b[1;32m 296\u001b[0m \n\u001b[1;32m 297\u001b[0m \u001b[38;5;124;03m - finally, if your observation is multidimensional, e.g., an image,\u001b[39;00m\n\u001b[1;32m 298\u001b[0m \u001b[38;5;124;03m make sure to pass it with a leading batch dimension, e.g., with\u001b[39;00m\n\u001b[1;32m 299\u001b[0m \u001b[38;5;124;03m shape (1, xdim1, xdim2). Beware that the current implementation\u001b[39;00m\n\u001b[1;32m 300\u001b[0m \u001b[38;5;124;03m of sbi might not provide stable support for this and result in\u001b[39;00m\n\u001b[1;32m 301\u001b[0m \u001b[38;5;124;03m shape mismatches.\u001b[39;00m\n\u001b[1;32m 302\u001b[0m \n\u001b[1;32m 303\u001b[0m \u001b[38;5;124;03m NOTE: below we use list notation to reduce clutter, but `x` should be of\u001b[39;00m\n\u001b[1;32m 304\u001b[0m \u001b[38;5;124;03m type torch.Tensor or ndarray.\u001b[39;00m\n\u001b[1;32m 305\u001b[0m \n\u001b[1;32m 306\u001b[0m \u001b[38;5;124;03m For example:\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03m > x_o = [[1]]\u001b[39;00m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;124;03m > x_o = [[1, 2, 3]]\u001b[39;00m\n\u001b[1;32m 310\u001b[0m \n\u001b[1;32m 311\u001b[0m \u001b[38;5;124;03m are interpreted as single observations with a leading batch dimension of\u001b[39;00m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;124;03m one. However\u001b[39;00m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m > x_o = [ [1], [2] ]\u001b[39;00m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;124;03m > x_o = [ [1,2,3], [4,5,6] ]\u001b[39;00m\n\u001b[1;32m 316\u001b[0m \n\u001b[1;32m 317\u001b[0m \u001b[38;5;124;03m are interpreted as a batch of two scalar or vector observations, which\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;124;03m is not supported yet. The following is interpreted as a matrix-shaped\u001b[39;00m\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124;03m observation, e.g. a monochromatic image:\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \n\u001b[1;32m 321\u001b[0m \u001b[38;5;124;03m > x_o = [ [[1,2,3], [4,5,6]] ]\u001b[39;00m\n\u001b[1;32m 322\u001b[0m \n\u001b[1;32m 323\u001b[0m \u001b[38;5;124;03m Finally, for convenience,\u001b[39;00m\n\u001b[1;32m 324\u001b[0m \n\u001b[1;32m 325\u001b[0m \u001b[38;5;124;03m > x_o = [1]\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;124;03m > x_o = [1, 2, 3]\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \n\u001b[1;32m 328\u001b[0m \u001b[38;5;124;03m will be interpreted as a single scalar or single vector observation\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;124;03m respectively, without the user needing to wrap or unsqueeze them.\u001b[39;00m\n\u001b[1;32m 330\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 331\u001b[0m )\n\u001b[1;32m 332\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 333\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n", + "\u001b[0;31mValueError\u001b[0m: The `x` passed to condition the posterior for evaluation or sampling\n has an inferred batch shape larger than one. This is not supported in\n some sbi methods for reasons depending on the scenario:\n\n - in case you want to evaluate or sample conditioned on several iid\n xs e.g., (p(theta | [x1, x2, x3])), this is fully supported only\n for likelihood based SNLE and SNRE. For SNPE it is supported only\n for a fixed number of trials and using an appropriate embedding\n net, i.e., by treating the trials as additional data dimension. In\n that case, make sure to pass xo with a leading batch dimensionen.\n\n - in case you trained with a single round to do amortized inference\n and now you want to evaluate or sample a given theta conditioned on\n several xs, one after the other, e.g, p(theta | x1), p(theta | x2),\n p(theta| x3): this broadcasting across xs is not supported in sbi.\n Instead, what you can do it to call posterior.log_prob(theta, xi)\n multiple times with different xi.\n\n - finally, if your observation is multidimensional, e.g., an image,\n make sure to pass it with a leading batch dimension, e.g., with\n shape (1, xdim1, xdim2). Beware that the current implementation\n of sbi might not provide stable support for this and result in\n shape mismatches.\n\n NOTE: below we use list notation to reduce clutter, but `x` should be of\n type torch.Tensor or ndarray.\n\n For example:\n\n > x_o = [[1]]\n > x_o = [[1, 2, 3]]\n\n are interpreted as single observations with a leading batch dimension of\n one. However\n\n > x_o = [ [1], [2] ]\n > x_o = [ [1,2,3], [4,5,6] ]\n\n are interpreted as a batch of two scalar or vector observations, which\n is not supported yet. The following is interpreted as a matrix-shaped\n observation, e.g. a monochromatic image:\n\n > x_o = [ [[1,2,3], [4,5,6]] ]\n\n Finally, for convenience,\n\n > x_o = [1]\n > x_o = [1, 2, 3]\n\n will be interpreted as a single scalar or single vector observation\n respectively, without the user needing to wrap or unsqueeze them.\n " + ] + } + ], + "source": [ + "\n", + "sampling_method: str = \"slice_np_vectorized\"\n", + "num_chains: int = 4,\n", + "mcmc_params_fast: dict = {}\n", + "num_dim: int = 2\n", + "num_samples: int = 42\n", + "num_trials: int = 2\n", + "num_simulations: int = 10\n", + "\n", + "x_o = zeros((num_trials, num_dim))\n", + "mcmc_params_fast[\"num_chains\"] = num_chains\n", + "\n", + "prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))\n", + "simulator = diagonal_linear_gaussian\n", + "\n", + "inference = SNL(prior, show_progress_bars=False)\n", + "\n", + "prior, _, prior_returns_numpy = process_prior(prior)\n", + "simulator = process_simulator(simulator, prior, prior_returns_numpy)\n", + "theta, x = simulate_for_sbi(\n", + " simulator, prior, num_simulations, simulation_batch_size=10\n", + ")\n", + "estimator = inference.append_simulations(theta, x).train(max_num_epochs=5)\n", + "potential_fn, transform = likelihood_estimator_based_potential(\n", + " estimator, prior, x_o\n", + ")\n", + "posterior = MCMCPosterior(\n", + " potential_fn, theta_transform=transform, method=sampling_method, proposal=prior\n", + ")\n", + "\n", + "assert posterior.posterior_sampler is None\n", + "samples = posterior.sample(\n", + " sample_shape=(num_samples, num_chains),\n", + " x=x_o,\n", + " mcmc_parameters={\"init_strategy\": \"prior\", **mcmc_params_fast},\n", + ")\n", + "# assert isinstance(samples, Tensor)\n", + "# assert samples.shape == (num_samples, num_chains, num_dim)\n", + "\n", + "# if \"pyro\" in sampling_method:\n", + "# assert type(posterior.posterior_sampler) is MCMC\n", + "# elif \"pymc\" in sampling_method:\n", + "# assert type(posterior.posterior_sampler) is PyMCSampler\n", + "# elif sampling_method == \"slice_np\":\n", + "# assert type(posterior.posterior_sampler) is SliceSamplerSerial\n", + "# else: # sampling_method == \"slice_np_vectorized\"\n", + "# assert type(posterior.posterior_sampler) is SliceSamplerVectorized\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 2])" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "posterior.default_x.shape" + ] } ], "metadata": {