Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented is_trained() and fixed bug for duplicate training on Parametric Models #32

Merged
merged 9 commits into from
Nov 30, 2023
5 changes: 3 additions & 2 deletions nbs/00_base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@
" @property\n",
" def is_trained(self) -> bool:\n",
" \"\"\"Return whether the module is trained or not.\"\"\"\n",
" raise NotImplementedError\n",
" self._is_trained = getattr(self, '_is_trained', False)\n",
" return self._is_trained\n",
" \n",
" def train(self, data, **kwargs):\n",
" \"\"\"Train the module.\"\"\"\n",
Expand All @@ -191,5 +192,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
106 changes: 71 additions & 35 deletions nbs/03_explain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
Expand All @@ -25,22 +34,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using JAX backend.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand Down Expand Up @@ -295,7 +289,8 @@
" cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss)\n",
" train_config = train_config or {}\n",
" if isinstance(cf_module, ParametricCFModule):\n",
" cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n",
" if not cf_module.is_trained:\n",
" cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n",
" cf_module.before_generate_cf()\n",
" return cf_module\n",
"\n",
Expand Down Expand Up @@ -360,23 +355,10 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.7335 - loss: 0.5497 \n",
"Epoch 2/3\n",
"\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 988us/step - accuracy: 0.8045 - loss: 0.4213\n",
"Epoch 3/3\n",
"\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 989us/step - accuracy: 0.8149 - loss: 0.4002\n"
]
}
],
"outputs": [],
"source": [
"dm = load_data(\"adult\")\n",
"ml_model = MLModule().train(dm, epochs=3)"
"ml_model = load_ml_module(\"adult\")"
]
},
{
Expand All @@ -387,7 +369,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7946c12d4e3647eb88907625d0e2678d",
"model_id": "6632bec477ef4a689db15bc0028f2cde",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -397,6 +379,14 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n",
" warnings.warn(\n"
]
}
],
"source": [
Expand All @@ -405,6 +395,52 @@
" dm, ml_model.pred_fn,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n",
" warnings.warn(\n",
"Epoch 0: 100%|██████████| 191/191 [00:08<00:00, 22.21batch/s, train/train_loss_1=0.06329722, train/train_loss_2=0.07011371, train/train_loss_3=0.101814255] \n",
"/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n",
" warnings.warn(\n"
]
}
],
"source": [
"cfnet = CounterNet()\n",
"cfnet.train(dm, epochs=1)\n",
"# Test cases for checking if ParametricCFModule is trained twice.\n",
"# If it is trained twice, cfs will be different.\n",
"cfs = jax.vmap(cfnet.generate_cf)(dm.xs)\n",
"assert cfnet.is_trained == True\n",
"exp = generate_cf_explanations(cfnet, dm)\n",
"assert np.allclose(einops.rearrange(exp.cfs, 'N 1 K -> N K'), cfs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# hide\n",
"# dm = load_data(\"dummy\")\n",
"# ml_model = load_ml_module(\"dummy\")\n",
"\n",
"# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n",
"# m = cf_module()\n",
"# assert m.is_trained == False\n",
"# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)\n",
"# assert m.is_trained == True\n",
"# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)"
]
}
],
"metadata": {
Expand All @@ -415,5 +451,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
3 changes: 2 additions & 1 deletion relax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class TrainableMixedin:
@property
def is_trained(self) -> bool:
"""Return whether the module is trained or not."""
raise NotImplementedError
self._is_trained = getattr(self, '_is_trained', False)
return self._is_trained

def train(self, data, **kwargs):
"""Train the module."""
Expand Down
3 changes: 2 additions & 1 deletion relax/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def prepare_cf_module(
cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss)
train_config = train_config or {}
if isinstance(cf_module, ParametricCFModule):
cf_module.train(data_module, pred_fn=pred_fn, **train_config)
if not cf_module.is_trained:
cf_module.train(data_module, pred_fn=pred_fn, **train_config)
cf_module.before_generate_cf()
return cf_module

Expand Down