diff --git a/nbs/00_base.ipynb b/nbs/00_base.ipynb index a153a9c..d1d698d 100644 --- a/nbs/00_base.ipynb +++ b/nbs/00_base.ipynb @@ -173,7 +173,8 @@ " @property\n", " def is_trained(self) -> bool:\n", " \"\"\"Return whether the module is trained or not.\"\"\"\n", - " raise NotImplementedError\n", + " self._is_trained = getattr(self, '_is_trained', False)\n", + " return self._is_trained\n", " \n", " def train(self, data, **kwargs):\n", " \"\"\"Train the module.\"\"\"\n", @@ -191,5 +192,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index 0ec8d23..868dbfd 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -13,7 +13,16 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "#| hide\n", "%load_ext autoreload\n", @@ -25,22 +34,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using JAX backend.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], + "outputs": [], "source": [ "#| export\n", "from __future__ import annotations\n", @@ -295,7 +289,8 @@ " cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss)\n", " train_config = train_config or {}\n", " if isinstance(cf_module, ParametricCFModule):\n", - " cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n", + " if not cf_module.is_trained:\n", + " cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n", " cf_module.before_generate_cf()\n", " return cf_module\n", "\n", @@ -360,23 +355,10 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.7335 - loss: 0.5497 \n", - "Epoch 2/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 988us/step - accuracy: 0.8045 - loss: 0.4213\n", - "Epoch 3/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 989us/step - accuracy: 0.8149 - loss: 0.4002\n" - ] - } - ], + "outputs": [], "source": [ "dm = load_data(\"adult\")\n", - "ml_model = MLModule().train(dm, epochs=3)" + "ml_model = load_ml_module(\"adult\")" ] }, { @@ -387,7 +369,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7946c12d4e3647eb88907625d0e2678d", + "model_id": "6632bec477ef4a689db15bc0028f2cde", "version_major": 2, "version_minor": 0 }, @@ -397,6 +379,14 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", + " warnings.warn(\n" + ] } ], "source": [ @@ -405,6 +395,52 @@ " dm, ml_model.pred_fn,\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", + " warnings.warn(\n", + "Epoch 0: 100%|██████████| 191/191 [00:08<00:00, 22.21batch/s, train/train_loss_1=0.06329722, train/train_loss_2=0.07011371, train/train_loss_3=0.101814255] \n", + "/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "cfnet = CounterNet()\n", + "cfnet.train(dm, epochs=1)\n", + "# Test cases for checking if ParametricCFModule is trained twice.\n", + "# If it is trained twice, cfs will be different.\n", + "cfs = jax.vmap(cfnet.generate_cf)(dm.xs)\n", + "assert cfnet.is_trained == True\n", + "exp = generate_cf_explanations(cfnet, dm)\n", + "assert np.allclose(einops.rearrange(exp.cfs, 'N 1 K -> N K'), cfs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# hide\n", + "# dm = load_data(\"dummy\")\n", + "# ml_model = load_ml_module(\"dummy\")\n", + "\n", + "# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n", + "# m = cf_module()\n", + "# assert m.is_trained == False\n", + "# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)\n", + "# assert m.is_trained == True\n", + "# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)" + ] } ], "metadata": { @@ -415,5 +451,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/relax/base.py b/relax/base.py index 5c3fcbf..a81df6f 100644 --- a/relax/base.py +++ b/relax/base.py @@ -65,7 +65,8 @@ class TrainableMixedin: @property def is_trained(self) -> bool: """Return whether the module is trained or not.""" - raise NotImplementedError + self._is_trained = getattr(self, '_is_trained', False) + return self._is_trained def train(self, data, **kwargs): """Train the module.""" diff --git a/relax/explain.py b/relax/explain.py index bf44c8cc..1773729 100644 --- a/relax/explain.py +++ b/relax/explain.py @@ -191,7 +191,8 @@ def prepare_cf_module( cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss) train_config = train_config or {} if isinstance(cf_module, ParametricCFModule): - cf_module.train(data_module, pred_fn=pred_fn, **train_config) + if not cf_module.is_trained: + cf_module.train(data_module, pred_fn=pred_fn, **train_config) cf_module.before_generate_cf() return cf_module