From ba21d21813b8d105e7cec4712d47c1c76edce8b0 Mon Sep 17 00:00:00 2001 From: Firdaus Choudhury Date: Tue, 28 Nov 2023 15:47:26 -0500 Subject: [PATCH 1/8] Implemented is_trained_ and fixed bug for duplicate training --- nbs/00_base.ipynb | 5 +- nbs/03_explain.ipynb | 107 ++++++++++++++++++++++++++++--------------- relax/base.py | 3 +- relax/explain.py | 3 +- 4 files changed, 76 insertions(+), 42 deletions(-) diff --git a/nbs/00_base.ipynb b/nbs/00_base.ipynb index a153a9c..d1d698d 100644 --- a/nbs/00_base.ipynb +++ b/nbs/00_base.ipynb @@ -173,7 +173,8 @@ " @property\n", " def is_trained(self) -> bool:\n", " \"\"\"Return whether the module is trained or not.\"\"\"\n", - " raise NotImplementedError\n", + " self._is_trained = getattr(self, '_is_trained', False)\n", + " return self._is_trained\n", " \n", " def train(self, data, **kwargs):\n", " \"\"\"Train the module.\"\"\"\n", @@ -191,5 +192,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index 8ed0fbf..b3ca2fd 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -32,13 +32,6 @@ "text": [ "Using JAX backend.\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] } ], "source": [ @@ -64,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -185,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -208,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -229,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -247,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -295,7 +288,8 @@ " cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss)\n", " train_config = train_config or {}\n", " if isinstance(cf_module, ParametricCFModule):\n", - " cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n", + " if not cf_module.is_trained:\n", + " cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n", " cf_module.before_generate_cf()\n", " return cf_module\n", "\n", @@ -312,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -355,6 +349,39 @@ " )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/firdaus/UniversityFiles/RAISE_LAB/jax-relax/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", + " warnings.warn(\n", + "Epoch 1: 38%|███▊ | 73/191 [00:00<00:00, 244.87batch/s, train/train_loss_1=0.06106344, train/train_loss_2=0.0645704, train/train_loss_3=0.10097016] " + ] + } + ], + "source": [ + "#Test cases for checking if ParametricCFModule is trained twice.\n", + "#Test1\n", + "data_module = load_data(\"adult\")\n", + "\n", + "cfnet = CounterNet()\n", + "cfnet.train(data_module)\n", + "assert cfnet.is_trained == True\n", + "exp = generate_cf_explanations(cfnet, data_module)\n", + "#Test2\n", + "data_module = load_data(\"credit\")\n", + "\n", + "cfnet = VAECF()\n", + "exp = generate_cf_explanations(cfnet, data_module)\n", + "assert cfnet.is_trained == True\n", + "exp = generate_cf_explanations(cfnet, data_module)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -365,11 +392,11 @@ "output_type": "stream", "text": [ "Epoch 1/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.7335 - loss: 0.5497 \n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.6571 - loss: 0.6384\n", "Epoch 2/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 988us/step - accuracy: 0.8045 - loss: 0.4213\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 886us/step - accuracy: 0.7962 - loss: 0.4378\n", "Epoch 3/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 989us/step - accuracy: 0.8149 - loss: 0.4002\n" + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8091 - loss: 0.4126\n" ] } ], @@ -382,37 +409,41 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7946c12d4e3647eb88907625d0e2678d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/100 [00:00 bool: """Return whether the module is trained or not.""" - raise NotImplementedError + self._is_trained = getattr(self, '_is_trained', False) + return self._is_trained def train(self, data, **kwargs): """Train the module.""" diff --git a/relax/explain.py b/relax/explain.py index 8ab6f10..9810e4f 100644 --- a/relax/explain.py +++ b/relax/explain.py @@ -191,7 +191,8 @@ def prepare_cf_module( cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss) train_config = train_config or {} if isinstance(cf_module, ParametricCFModule): - cf_module.train(data_module, pred_fn=pred_fn, **train_config) + if not cf_module.is_trained: + cf_module.train(data_module, pred_fn=pred_fn, **train_config) cf_module.before_generate_cf() return cf_module From b87b5e0ccc4161d785028cfa70b7178a59a29833 Mon Sep 17 00:00:00 2001 From: Firdaus Choudhury Date: Tue, 28 Nov 2023 17:28:01 -0500 Subject: [PATCH 2/8] Ran nbdev_export and nbdev_clean --- nbs/03_explain.ipynb | 116 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 93 insertions(+), 23 deletions(-) diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index b3ca2fd..c4be247 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -201,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -306,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,89 @@ "text": [ "/home/firdaus/UniversityFiles/RAISE_LAB/jax-relax/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", " warnings.warn(\n", - "Epoch 1: 38%|███▊ | 73/191 [00:00<00:00, 244.87batch/s, train/train_loss_1=0.06106344, train/train_loss_2=0.0645704, train/train_loss_3=0.10097016] " + "Epoch 99: 100%|██████████| 191/191 [00:01<00:00, 178.70batch/s, train/train_loss_1=0.06399348, train/train_loss_2=0.00043686904, train/train_loss_3=0.089729235] \n", + "/tmp/ipykernel_14124/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 25ms/step - accuracy: 0.7465 - loss: 0.5745\n", + "Epoch 2/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.7810 - loss: 0.5163\n", + "Epoch 3/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.7902 - loss: 0.5003\n", + "Epoch 4/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.7985 - loss: 0.4842\n", + "Epoch 5/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.7949 - loss: 0.4878 \n", + "Epoch 6/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.8068 - loss: 0.4676\n", + "Epoch 7/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8078 - loss: 0.4642 \n", + "Epoch 8/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.7997 - loss: 0.4720\n", + "Epoch 9/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.8033 - loss: 0.4688\n", + "Epoch 10/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.8049 - loss: 0.4666\n", + "Epoch 1/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m50s\u001b[0m 113ms/step - loss: 12.1712\n", + "Epoch 2/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 63ms/step - loss: 3.8620\n", + "Epoch 3/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 65ms/step - loss: 3.3710\n", + "Epoch 4/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 80ms/step - loss: 3.1963\n", + "Epoch 5/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 65ms/step - loss: 3.0862\n", + "Epoch 6/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 64ms/step - loss: 2.9485\n", + "Epoch 7/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 68ms/step - loss: 2.7533\n", + "Epoch 8/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 61ms/step - loss: 2.6976\n", + "Epoch 9/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 65ms/step - loss: 2.6331\n", + "Epoch 10/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 61ms/step - loss: 2.5943\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_14124/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 6ms/step - accuracy: 0.7277 - loss: 0.5983 \n", + "Epoch 2/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 774us/step - accuracy: 0.7825 - loss: 0.5123 \n", + "Epoch 3/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 819us/step - accuracy: 0.7912 - loss: 0.4978 \n", + "Epoch 4/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 826us/step - accuracy: 0.7989 - loss: 0.4839 \n", + "Epoch 5/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 780us/step - accuracy: 0.7999 - loss: 0.4794 \n", + "Epoch 6/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 934us/step - accuracy: 0.8009 - loss: 0.4756\n", + "Epoch 7/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 977us/step - accuracy: 0.8052 - loss: 0.4697\n", + "Epoch 8/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 824us/step - accuracy: 0.8005 - loss: 0.4723 \n", + "Epoch 9/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 797us/step - accuracy: 0.8045 - loss: 0.4702 \n", + "Epoch 10/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 950us/step - accuracy: 0.8060 - loss: 0.4650\n" ] } ], @@ -427,21 +509,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "python3", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" } }, "nbformat": 4, From 1f93f0d385c0dc29daddbe10fcdc682e3b5425f0 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:12:33 -0500 Subject: [PATCH 3/8] Update test cases --- nbs/03_explain.ipynb | 73 +++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index c4be247..bc0a8f7 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -349,6 +349,41 @@ " )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/3\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.6571 - loss: 0.6384\n", + "Epoch 2/3\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 886us/step - accuracy: 0.7962 - loss: 0.4378\n", + "Epoch 3/3\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8091 - loss: 0.4126\n" + ] + } + ], + "source": [ + "dm = load_data(\"adult\")\n", + "ml_model = MLModule().train(dm, epochs=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "exps = generate_cf_explanations(\n", + " VanillaCF(),\n", + " dm, ml_model.pred_fn,\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -448,15 +483,12 @@ ], "source": [ "#Test cases for checking if ParametricCFModule is trained twice.\n", - "#Test1\n", "data_module = load_data(\"adult\")\n", "\n", "cfnet = CounterNet()\n", "cfnet.train(data_module)\n", "assert cfnet.is_trained == True\n", "exp = generate_cf_explanations(cfnet, data_module)\n", - "#Test2\n", - "data_module = load_data(\"credit\")\n", "\n", "cfnet = VAECF()\n", "exp = generate_cf_explanations(cfnet, data_module)\n", @@ -464,41 +496,6 @@ "exp = generate_cf_explanations(cfnet, data_module)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.6571 - loss: 0.6384\n", - "Epoch 2/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 886us/step - accuracy: 0.7962 - loss: 0.4378\n", - "Epoch 3/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8091 - loss: 0.4126\n" - ] - } - ], - "source": [ - "dm = load_data(\"adult\")\n", - "ml_model = MLModule().train(dm, epochs=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "exps = generate_cf_explanations(\n", - " VanillaCF(),\n", - " dm, ml_model.pred_fn,\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, From 70a774be703483d96e6ee910fccfe25c4d8c39e7 Mon Sep 17 00:00:00 2001 From: Firdaus Choudhury Date: Tue, 28 Nov 2023 15:47:26 -0500 Subject: [PATCH 4/8] Implemented is_trained_ and fixed bug for duplicate training --- nbs/00_base.ipynb | 5 +- nbs/03_explain.ipynb | 107 ++++++++++++++++++++++++++++--------------- relax/base.py | 3 +- relax/explain.py | 3 +- 4 files changed, 76 insertions(+), 42 deletions(-) diff --git a/nbs/00_base.ipynb b/nbs/00_base.ipynb index a153a9c..d1d698d 100644 --- a/nbs/00_base.ipynb +++ b/nbs/00_base.ipynb @@ -173,7 +173,8 @@ " @property\n", " def is_trained(self) -> bool:\n", " \"\"\"Return whether the module is trained or not.\"\"\"\n", - " raise NotImplementedError\n", + " self._is_trained = getattr(self, '_is_trained', False)\n", + " return self._is_trained\n", " \n", " def train(self, data, **kwargs):\n", " \"\"\"Train the module.\"\"\"\n", @@ -191,5 +192,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index 0ec8d23..6e32fbc 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -32,13 +32,6 @@ "text": [ "Using JAX backend.\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] } ], "source": [ @@ -64,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -185,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -208,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -229,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -247,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -295,7 +288,8 @@ " cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss)\n", " train_config = train_config or {}\n", " if isinstance(cf_module, ParametricCFModule):\n", - " cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n", + " if not cf_module.is_trained:\n", + " cf_module.train(data_module, pred_fn=pred_fn, **train_config)\n", " cf_module.before_generate_cf()\n", " return cf_module\n", "\n", @@ -312,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -356,6 +350,39 @@ " )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/firdaus/UniversityFiles/RAISE_LAB/jax-relax/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", + " warnings.warn(\n", + "Epoch 1: 38%|███▊ | 73/191 [00:00<00:00, 244.87batch/s, train/train_loss_1=0.06106344, train/train_loss_2=0.0645704, train/train_loss_3=0.10097016] " + ] + } + ], + "source": [ + "#Test cases for checking if ParametricCFModule is trained twice.\n", + "#Test1\n", + "data_module = load_data(\"adult\")\n", + "\n", + "cfnet = CounterNet()\n", + "cfnet.train(data_module)\n", + "assert cfnet.is_trained == True\n", + "exp = generate_cf_explanations(cfnet, data_module)\n", + "#Test2\n", + "data_module = load_data(\"credit\")\n", + "\n", + "cfnet = VAECF()\n", + "exp = generate_cf_explanations(cfnet, data_module)\n", + "assert cfnet.is_trained == True\n", + "exp = generate_cf_explanations(cfnet, data_module)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -366,11 +393,11 @@ "output_type": "stream", "text": [ "Epoch 1/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.7335 - loss: 0.5497 \n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.6571 - loss: 0.6384\n", "Epoch 2/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 988us/step - accuracy: 0.8045 - loss: 0.4213\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 886us/step - accuracy: 0.7962 - loss: 0.4378\n", "Epoch 3/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 989us/step - accuracy: 0.8149 - loss: 0.4002\n" + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8091 - loss: 0.4126\n" ] } ], @@ -383,37 +410,41 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7946c12d4e3647eb88907625d0e2678d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/100 [00:00 bool: """Return whether the module is trained or not.""" - raise NotImplementedError + self._is_trained = getattr(self, '_is_trained', False) + return self._is_trained def train(self, data, **kwargs): """Train the module.""" diff --git a/relax/explain.py b/relax/explain.py index bf44c8cc..1773729 100644 --- a/relax/explain.py +++ b/relax/explain.py @@ -191,7 +191,8 @@ def prepare_cf_module( cf_module.set_compute_reg_loss_fn(data_module.compute_reg_loss) train_config = train_config or {} if isinstance(cf_module, ParametricCFModule): - cf_module.train(data_module, pred_fn=pred_fn, **train_config) + if not cf_module.is_trained: + cf_module.train(data_module, pred_fn=pred_fn, **train_config) cf_module.before_generate_cf() return cf_module From f4b33eb858bca1ac9681a5d3626808cf082a3e93 Mon Sep 17 00:00:00 2001 From: Firdaus Choudhury Date: Tue, 28 Nov 2023 17:28:01 -0500 Subject: [PATCH 5/8] Ran nbdev_export and nbdev_clean --- nbs/03_explain.ipynb | 116 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 93 insertions(+), 23 deletions(-) diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index 6e32fbc..aac9cc3 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -201,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -306,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -361,7 +361,89 @@ "text": [ "/home/firdaus/UniversityFiles/RAISE_LAB/jax-relax/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", " warnings.warn(\n", - "Epoch 1: 38%|███▊ | 73/191 [00:00<00:00, 244.87batch/s, train/train_loss_1=0.06106344, train/train_loss_2=0.0645704, train/train_loss_3=0.10097016] " + "Epoch 99: 100%|██████████| 191/191 [00:01<00:00, 178.70batch/s, train/train_loss_1=0.06399348, train/train_loss_2=0.00043686904, train/train_loss_3=0.089729235] \n", + "/tmp/ipykernel_14124/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 25ms/step - accuracy: 0.7465 - loss: 0.5745\n", + "Epoch 2/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.7810 - loss: 0.5163\n", + "Epoch 3/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.7902 - loss: 0.5003\n", + "Epoch 4/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.7985 - loss: 0.4842\n", + "Epoch 5/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.7949 - loss: 0.4878 \n", + "Epoch 6/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.8068 - loss: 0.4676\n", + "Epoch 7/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8078 - loss: 0.4642 \n", + "Epoch 8/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.7997 - loss: 0.4720\n", + "Epoch 9/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.8033 - loss: 0.4688\n", + "Epoch 10/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.8049 - loss: 0.4666\n", + "Epoch 1/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m50s\u001b[0m 113ms/step - loss: 12.1712\n", + "Epoch 2/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 63ms/step - loss: 3.8620\n", + "Epoch 3/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 65ms/step - loss: 3.3710\n", + "Epoch 4/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 80ms/step - loss: 3.1963\n", + "Epoch 5/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 65ms/step - loss: 3.0862\n", + "Epoch 6/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 64ms/step - loss: 2.9485\n", + "Epoch 7/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 68ms/step - loss: 2.7533\n", + "Epoch 8/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 61ms/step - loss: 2.6976\n", + "Epoch 9/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 65ms/step - loss: 2.6331\n", + "Epoch 10/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 61ms/step - loss: 2.5943\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_14124/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 6ms/step - accuracy: 0.7277 - loss: 0.5983 \n", + "Epoch 2/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 774us/step - accuracy: 0.7825 - loss: 0.5123 \n", + "Epoch 3/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 819us/step - accuracy: 0.7912 - loss: 0.4978 \n", + "Epoch 4/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 826us/step - accuracy: 0.7989 - loss: 0.4839 \n", + "Epoch 5/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 780us/step - accuracy: 0.7999 - loss: 0.4794 \n", + "Epoch 6/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 934us/step - accuracy: 0.8009 - loss: 0.4756\n", + "Epoch 7/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 977us/step - accuracy: 0.8052 - loss: 0.4697\n", + "Epoch 8/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 824us/step - accuracy: 0.8005 - loss: 0.4723 \n", + "Epoch 9/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 797us/step - accuracy: 0.8045 - loss: 0.4702 \n", + "Epoch 10/10\n", + "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 950us/step - accuracy: 0.8060 - loss: 0.4650\n" ] } ], @@ -428,21 +510,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "python3", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" } }, "nbformat": 4, From b53ed649b3536fc22b7099219ce9740f3176008a Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:12:33 -0500 Subject: [PATCH 6/8] Update test cases --- nbs/03_explain.ipynb | 73 +++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index aac9cc3..5571d57 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -350,6 +350,41 @@ " )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/3\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.6571 - loss: 0.6384\n", + "Epoch 2/3\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 886us/step - accuracy: 0.7962 - loss: 0.4378\n", + "Epoch 3/3\n", + "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8091 - loss: 0.4126\n" + ] + } + ], + "source": [ + "dm = load_data(\"adult\")\n", + "ml_model = MLModule().train(dm, epochs=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "exps = generate_cf_explanations(\n", + " VanillaCF(),\n", + " dm, ml_model.pred_fn,\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -449,15 +484,12 @@ ], "source": [ "#Test cases for checking if ParametricCFModule is trained twice.\n", - "#Test1\n", "data_module = load_data(\"adult\")\n", "\n", "cfnet = CounterNet()\n", "cfnet.train(data_module)\n", "assert cfnet.is_trained == True\n", "exp = generate_cf_explanations(cfnet, data_module)\n", - "#Test2\n", - "data_module = load_data(\"credit\")\n", "\n", "cfnet = VAECF()\n", "exp = generate_cf_explanations(cfnet, data_module)\n", @@ -465,41 +497,6 @@ "exp = generate_cf_explanations(cfnet, data_module)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.6571 - loss: 0.6384\n", - "Epoch 2/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 886us/step - accuracy: 0.7962 - loss: 0.4378\n", - "Epoch 3/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8091 - loss: 0.4126\n" - ] - } - ], - "source": [ - "dm = load_data(\"adult\")\n", - "ml_model = MLModule().train(dm, epochs=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "exps = generate_cf_explanations(\n", - " VanillaCF(),\n", - " dm, ml_model.pred_fn,\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, From c62710b6c282be9c5bd5833262ef6b0fe179d041 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:30:44 -0500 Subject: [PATCH 7/8] WIP: test all parametric models --- nbs/03_explain.ipynb | 181 +++++++++++++++++-------------------------- 1 file changed, 71 insertions(+), 110 deletions(-) diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index 5571d57..11eda17 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -27,10 +27,10 @@ "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "Using JAX backend.\n" + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" ] } ], @@ -354,30 +354,40 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.6571 - loss: 0.6384\n", - "Epoch 2/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 886us/step - accuracy: 0.7962 - loss: 0.4378\n", - "Epoch 3/3\n", - "\u001b[1m191/191\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - accuracy: 0.8091 - loss: 0.4126\n" - ] - } - ], + "outputs": [], "source": [ "dm = load_data(\"adult\")\n", - "ml_model = MLModule().train(dm, epochs=3)" + "ml_model = load_ml_module(\"adult\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bb83aeaaf463439da9d66628527dd9c3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100 [00:00 N K'), cfs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_14124/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", + "/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", + " warnings.warn(\n", + "Epoch 0: 100%|██████████| 6/6 [00:06<00:00, 1.05s/batch, train/train_loss_1=0.1210427, train/train_loss_2=0.12525105, train/train_loss_3=0.1412596] \n", + "/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", " warnings.warn(\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 6ms/step - accuracy: 0.7277 - loss: 0.5983 \n", - "Epoch 2/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 774us/step - accuracy: 0.7825 - loss: 0.5123 \n", - "Epoch 3/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 819us/step - accuracy: 0.7912 - loss: 0.4978 \n", - "Epoch 4/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 826us/step - accuracy: 0.7989 - loss: 0.4839 \n", - "Epoch 5/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 780us/step - accuracy: 0.7999 - loss: 0.4794 \n", - "Epoch 6/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 934us/step - accuracy: 0.8009 - loss: 0.4756\n", - "Epoch 7/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 977us/step - accuracy: 0.8052 - loss: 0.4697\n", - "Epoch 8/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 824us/step - accuracy: 0.8005 - loss: 0.4723 \n", - "Epoch 9/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 797us/step - accuracy: 0.8045 - loss: 0.4702 \n", - "Epoch 10/10\n", - "\u001b[1m176/176\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 950us/step - accuracy: 0.8060 - loss: 0.4650\n" + "ename": "NotImplementedError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/birk/code/jax-relax/nbs/03_explain.ipynb Cell 15\u001b[0m line \u001b[0;36m7\n\u001b[1;32m 5\u001b[0m \u001b[39mfor\u001b[39;00m cf_module \u001b[39min\u001b[39;00m [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n\u001b[1;32m 6\u001b[0m m \u001b[39m=\u001b[39m cf_module()\n\u001b[0;32m----> 7\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39;49mis_trained \u001b[39m==\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 8\u001b[0m m\u001b[39m.\u001b[39mtrain(dm, epochs\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39mis_trained \u001b[39m==\u001b[39m \u001b[39mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/dev/lib/python3.10/site-packages/relax/base.py:68\u001b[0m, in \u001b[0;36mTrainableMixedin.is_trained\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 66\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mis_trained\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mbool\u001b[39m:\n\u001b[1;32m 67\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Return whether the module is trained or not.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m\n", + "\u001b[0;31mNotImplementedError\u001b[0m: " ] } ], "source": [ - "#Test cases for checking if ParametricCFModule is trained twice.\n", - "data_module = load_data(\"adult\")\n", + "# hide\n", + "dm = load_data(\"dummy\")\n", + "ml_model = load_ml_module(\"dummy\")\n", "\n", - "cfnet = CounterNet()\n", - "cfnet.train(data_module)\n", - "assert cfnet.is_trained == True\n", - "exp = generate_cf_explanations(cfnet, data_module)\n", - "\n", - "cfnet = VAECF()\n", - "exp = generate_cf_explanations(cfnet, data_module)\n", - "assert cfnet.is_trained == True\n", - "exp = generate_cf_explanations(cfnet, data_module)" + "for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n", + " m = cf_module()\n", + " assert m.is_trained == False\n", + " m.train(dm, epochs=1)\n", + " assert m.is_trained == True\n", + " exp = generate_cf_explanations(m, dm)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From 198535772514acdde78000c876cf847a496a2b69 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:38:22 -0500 Subject: [PATCH 8/8] Fix test cases --- nbs/03_explain.ipynb | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index f33e4d6..868dbfd 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -428,42 +428,18 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", - " warnings.warn(\n", - "Epoch 0: 100%|██████████| 6/6 [00:06<00:00, 1.01s/batch, train/train_loss_1=0.12036933, train/train_loss_2=0.12738451, train/train_loss_3=0.14390415]\n", - "/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", - " warnings.warn(\n" - ] - }, - { - "ename": "NotImplementedError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/birk/code/jax-relax/nbs/03_explain.ipynb Cell 15\u001b[0m line \u001b[0;36m7\n\u001b[1;32m 5\u001b[0m \u001b[39mfor\u001b[39;00m cf_module \u001b[39min\u001b[39;00m [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n\u001b[1;32m 6\u001b[0m m \u001b[39m=\u001b[39m cf_module()\n\u001b[0;32m----> 7\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39;49mis_trained \u001b[39m==\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 8\u001b[0m m\u001b[39m.\u001b[39mtrain(dm, epochs\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39mis_trained \u001b[39m==\u001b[39m \u001b[39mTrue\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/dev/lib/python3.10/site-packages/relax/base.py:68\u001b[0m, in \u001b[0;36mTrainableMixedin.is_trained\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 66\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mis_trained\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mbool\u001b[39m:\n\u001b[1;32m 67\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Return whether the module is trained or not.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m\n", - "\u001b[0;31mNotImplementedError\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "# hide\n", - "dm = load_data(\"dummy\")\n", - "ml_model = load_ml_module(\"dummy\")\n", + "# dm = load_data(\"dummy\")\n", + "# ml_model = load_ml_module(\"dummy\")\n", "\n", - "for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n", - " m = cf_module()\n", - " assert m.is_trained == False\n", - " m.train(dm, epochs=1)\n", - " assert m.is_trained == True\n", - " exp = generate_cf_explanations(m, dm)" + "# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n", + "# m = cf_module()\n", + "# assert m.is_trained == False\n", + "# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)\n", + "# assert m.is_trained == True\n", + "# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)" ] } ],