diff --git a/nbs/methods/05_sphere.ipynb b/nbs/methods/05_sphere.ipynb index 0d094a0..176a0e1 100644 --- a/nbs/methods/05_sphere.ipynb +++ b/nbs/methods/05_sphere.ipynb @@ -37,22 +37,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using JAX backend.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], + "outputs": [], "source": [ "#| export\n", "from __future__ import annotations\n", @@ -424,7 +409,7 @@ " num_categories=num_categories,\n", " cat_perturb_fn=perturb_fn\n", " )\n", - " self.apply_constraints = default_apply_constraints_fn\n", + " # self.apply_constraints = default_apply_constraints_fn\n", " else:\n", " self.perturb_fn = default_perturb_function\n", " \n", @@ -477,17 +462,10 @@ "execution_count": null, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n" - ] - }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cb643ef8358941f9b44a6485d4c8c6d8", + "model_id": "9b878d10d889461a8ee545d83be209e1", "version_major": 2, "version_minor": 0 }, @@ -518,7 +496,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6bc22bdc99d44ef9940b7b3df1887b37", + "model_id": "56b5fff2be5048f1be1169e3f9100854", "version_major": 2, "version_minor": 0 }, @@ -554,7 +532,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e6cff5dd3fe34fb0ab66def5f5368e85", + "model_id": "6daed510cacd455eb98ffeedff9e739a", "version_major": 2, "version_minor": 0 }, @@ -579,6 +557,7 @@ "cfs = jax.jit(jax.vmap(partial_gen))(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), len(xs_test)))\n", "\n", "assert cfs.shape == (x_shape[0], x_shape[1])\n", + "assert cfs.min() >= 0 and cfs.max() <= 1\n", "\n", "print(\"Validity: \", keras.metrics.binary_accuracy(\n", " (1 - model.pred_fn(xs_test)).round(),\n", diff --git a/relax/methods/sphere.py b/relax/methods/sphere.py index 77bb513..55767ae 100644 --- a/relax/methods/sphere.py +++ b/relax/methods/sphere.py @@ -287,7 +287,7 @@ def before_generate_cf(self, *args, **kwargs): num_categories=num_categories, cat_perturb_fn=perturb_fn ) - self.apply_constraints = default_apply_constraints_fn + # self.apply_constraints = default_apply_constraints_fn else: self.perturb_fn = default_perturb_function