Skip to content

Commit

Permalink
Fix bug when GrowingSphere generates cfs without constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Apr 27, 2024
1 parent d1c3d72 commit 9928e3b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 28 deletions.
33 changes: 6 additions & 27 deletions nbs/methods/05_sphere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using JAX backend.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand Down Expand Up @@ -424,7 +409,7 @@
" num_categories=num_categories,\n",
" cat_perturb_fn=perturb_fn\n",
" )\n",
" self.apply_constraints = default_apply_constraints_fn\n",
" # self.apply_constraints = default_apply_constraints_fn\n",
" else:\n",
" self.perturb_fn = default_perturb_function\n",
" \n",
Expand Down Expand Up @@ -477,17 +462,10 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb643ef8358941f9b44a6485d4c8c6d8",
"model_id": "9b878d10d889461a8ee545d83be209e1",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -518,7 +496,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6bc22bdc99d44ef9940b7b3df1887b37",
"model_id": "56b5fff2be5048f1be1169e3f9100854",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -554,7 +532,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6cff5dd3fe34fb0ab66def5f5368e85",
"model_id": "6daed510cacd455eb98ffeedff9e739a",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -579,6 +557,7 @@
"cfs = jax.jit(jax.vmap(partial_gen))(xs_test, rng_key=jrand.split(jrand.PRNGKey(0), len(xs_test)))\n",
"\n",
"assert cfs.shape == (x_shape[0], x_shape[1])\n",
"assert cfs.min() >= 0 and cfs.max() <= 1\n",
"\n",
"print(\"Validity: \", keras.metrics.binary_accuracy(\n",
" (1 - model.pred_fn(xs_test)).round(),\n",
Expand Down
2 changes: 1 addition & 1 deletion relax/methods/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def before_generate_cf(self, *args, **kwargs):
num_categories=num_categories,
cat_perturb_fn=perturb_fn
)
self.apply_constraints = default_apply_constraints_fn
# self.apply_constraints = default_apply_constraints_fn
else:
self.perturb_fn = default_perturb_function

Expand Down

0 comments on commit 9928e3b

Please sign in to comment.