Skip to content

Commit

Permalink
Merge pull request #32 from FirdausChoudhury/training_param_bugs
Browse files Browse the repository at this point in the history
Implemented is_trained() and fixed bug for duplicate training on Parametric Models
  • Loading branch information
BirkhoffG authored Nov 30, 2023
2 parents fde7e17 + 1985357 commit 700e591
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 39 deletions.
5 changes: 3 additions & 2 deletions nbs/00_base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@
" @property\n",
" def is_trained(self) -> bool:\n",
" \"\"\"Return whether the module is trained or not.\"\"\"\n",
" raise NotImplementedError\n",
" self._is_trained = getattr(self, '_is_trained', False)\n",
" return self._is_trained\n",
" \n",
" def train(self, data, **kwargs):\n",
" \"\"\"Train the module.\"\"\"\n",
Expand All @@ -191,5 +192,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
106 changes: 71 additions & 35 deletions nbs/03_explain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
Expand All @@ -25,22 +34,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using JAX backend.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand Down Expand Up @@ -295,7 +289,8 @@
" cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss)\n",
" train_config = train_config or {}\n",
" if isinstance(cf_module, ParametricCFModule):\n",
" cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n",
" if not cf_module.is_trained:\n",
" cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n",
" cf_module.before_generate_cf()\n",
" return cf_module\n",
"\n",
Expand Down Expand Up @@ -360,23 +355,10 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.7335 - loss: 0.5497 \n",
"Epoch 2/3\n",
"\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 988us/step - accuracy: 0.8045 - loss: 0.4213\n",
"Epoch 3/3\n",
"\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 989us/step - accuracy: 0.8149 - loss: 0.4002\n"
]
}
],
"outputs": [],
"source": [
"dm = load_data(\"adult\")\n",
"ml_model = MLModule().train(dm, epochs=3)"
"ml_model = load_ml_module(\"adult\")"
]
},
{
Expand All @@ -387,7 +369,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7946c12d4e3647eb88907625d0e2678d",
"model_id": "6632bec477ef4a689db15bc0028f2cde",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -397,6 +379,14 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n",
" warnings.warn(\n"
]
}
],
"source": [
Expand All @@ -405,6 +395,52 @@
" dm, ml_model.pred_fn,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n",
" warnings.warn(\n",
"Epoch 0: 100%|██████████| 191/191 [00:08<00:00, 22.21batch/s, train/train_loss_1=0.06329722, train/train_loss_2=0.07011371, train/train_loss_3=0.101814255] \n",
"/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n",
" warnings.warn(\n"
]
}
],
"source": [
"cfnet = CounterNet()\n",
"cfnet.train(dm, epochs=1)\n",
"# Test cases for checking if ParametricCFModule is trained twice.\n",
"# If it is trained twice, cfs will be different.\n",
"cfs = jax.vmap(cfnet.generate_cf)(dm.xs)\n",
"assert cfnet.is_trained == True\n",
"exp = generate_cf_explanations(cfnet, dm)\n",
"assert np.allclose(einops.rearrange(exp.cfs, 'N 1 K -> N K'), cfs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# hide\n",
"# dm = load_data(\"dummy\")\n",
"# ml_model = load_ml_module(\"dummy\")\n",
"\n",
"# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n",
"# m = cf_module()\n",
"# assert m.is_trained == False\n",
"# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)\n",
"# assert m.is_trained == True\n",
"# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)"
]
}
],
"metadata": {
Expand All @@ -415,5 +451,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
3 changes: 2 additions & 1 deletion relax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class TrainableMixedin:
@property
def is_trained(self) -> bool:
"""Return whether the module is trained or not."""
raise NotImplementedError
self._is_trained = getattr(self, '_is_trained', False)
return self._is_trained

def train(self, data, **kwargs):
"""Train the module."""
Expand Down
3 changes: 2 additions & 1 deletion relax/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def prepare_cf_module(
cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss)
train_config = train_config or {}
if isinstance(cf_module, ParametricCFModule):
cf_module.train(data_module, pred_fn=pred_fn, **train_config)
if not cf_module.is_trained:
cf_module.train(data_module, pred_fn=pred_fn, **train_config)
cf_module.before_generate_cf()
return cf_module

Expand Down

0 comments on commit 700e591

Please sign in to comment.