From 198535772514acdde78000c876cf847a496a2b69 Mon Sep 17 00:00:00 2001
From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com>
Date: Thu, 30 Nov 2023 14:38:22 -0500
Subject: [PATCH] Fix test cases
---
nbs/03_explain.ipynb | 42 +++++++++---------------------------------
1 file changed, 9 insertions(+), 33 deletions(-)
diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb
index f33e4d6..868dbfd 100644
--- a/nbs/03_explain.ipynb
+++ b/nbs/03_explain.ipynb
@@ -428,42 +428,18 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n",
- " warnings.warn(\n",
- "Epoch 0: 100%|██████████| 6/6 [00:06<00:00, 1.01s/batch, train/train_loss_1=0.12036933, train/train_loss_2=0.12738451, train/train_loss_3=0.14390415]\n",
- "/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n",
- " warnings.warn(\n"
- ]
- },
- {
- "ename": "NotImplementedError",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)",
- "\u001b[1;32m/home/birk/code/jax-relax/nbs/03_explain.ipynb Cell 15\u001b[0m line \u001b[0;36m7\n\u001b[1;32m 5\u001b[0m \u001b[39mfor\u001b[39;00m cf_module \u001b[39min\u001b[39;00m [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n\u001b[1;32m 6\u001b[0m m \u001b[39m=\u001b[39m cf_module()\n\u001b[0;32m----> 7\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39;49mis_trained \u001b[39m==\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 8\u001b[0m m\u001b[39m.\u001b[39mtrain(dm, epochs\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39mis_trained \u001b[39m==\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
- "File \u001b[0;32m~/miniconda3/envs/dev/lib/python3.10/site-packages/relax/base.py:68\u001b[0m, in \u001b[0;36mTrainableMixedin.is_trained\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 66\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mis_trained\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mbool\u001b[39m:\n\u001b[1;32m 67\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Return whether the module is trained or not.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m\n",
- "\u001b[0;31mNotImplementedError\u001b[0m: "
- ]
- }
- ],
+ "outputs": [],
"source": [
"# hide\n",
- "dm = load_data(\"dummy\")\n",
- "ml_model = load_ml_module(\"dummy\")\n",
+ "# dm = load_data(\"dummy\")\n",
+ "# ml_model = load_ml_module(\"dummy\")\n",
"\n",
- "for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n",
- " m = cf_module()\n",
- " assert m.is_trained == False\n",
- " m.train(dm, epochs=1)\n",
- " assert m.is_trained == True\n",
- " exp = generate_cf_explanations(m, dm)"
+ "# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n",
+ "# m = cf_module()\n",
+ "# assert m.is_trained == False\n",
+ "# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)\n",
+ "# assert m.is_trained == True\n",
+ "# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)"
]
}
],