From 198535772514acdde78000c876cf847a496a2b69 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:38:22 -0500 Subject: [PATCH] Fix test cases --- nbs/03_explain.ipynb | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/nbs/03_explain.ipynb b/nbs/03_explain.ipynb index f33e4d6..868dbfd 100644 --- a/nbs/03_explain.ipynb +++ b/nbs/03_explain.ipynb @@ -428,42 +428,18 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/birk/miniconda3/envs/dev/lib/python3.10/site-packages/relax/legacy/ckpt_manager.py:47: UserWarning: `monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored.\n", - " warnings.warn(\n", - "Epoch 0: 100%|██████████| 6/6 [00:06<00:00, 1.01s/batch, train/train_loss_1=0.12036933, train/train_loss_2=0.12738451, train/train_loss_3=0.14390415]\n", - "/tmp/ipykernel_5475/4129963786.py:17: DeprecationWarning: Argument `data` is deprecated. Use `data_module` instead.\n", - " warnings.warn(\n" - ] - }, - { - "ename": "NotImplementedError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/birk/code/jax-relax/nbs/03_explain.ipynb Cell 15\u001b[0m line \u001b[0;36m7\n\u001b[1;32m 5\u001b[0m \u001b[39mfor\u001b[39;00m cf_module \u001b[39min\u001b[39;00m [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n\u001b[1;32m 6\u001b[0m m \u001b[39m=\u001b[39m cf_module()\n\u001b[0;32m----> 7\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39;49mis_trained \u001b[39m==\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 8\u001b[0m m\u001b[39m.\u001b[39mtrain(dm, epochs\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[39massert\u001b[39;00m m\u001b[39m.\u001b[39mis_trained \u001b[39m==\u001b[39m \u001b[39mTrue\u001b[39;00m\n", - "File \u001b[0;32m~/miniconda3/envs/dev/lib/python3.10/site-packages/relax/base.py:68\u001b[0m, in \u001b[0;36mTrainableMixedin.is_trained\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 66\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mis_trained\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mbool\u001b[39m:\n\u001b[1;32m 67\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Return whether the module is trained or not.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 68\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m\n", - "\u001b[0;31mNotImplementedError\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "# hide\n", - "dm = load_data(\"dummy\")\n", - "ml_model = load_ml_module(\"dummy\")\n", + "# dm = load_data(\"dummy\")\n", + "# ml_model = load_ml_module(\"dummy\")\n", "\n", - "for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n", - " m = cf_module()\n", - " assert m.is_trained == False\n", - " m.train(dm, epochs=1)\n", - " assert m.is_trained == True\n", - " exp = generate_cf_explanations(m, dm)" + "# for cf_module in [CounterNet, CCHVAE, VAECF, L2C, ProtoCF, CLUE]:\n", + "# m = cf_module()\n", + "# assert m.is_trained == False\n", + "# m.train(dm, pred_fn=ml_model.pred_fn, epochs=1)\n", + "# assert m.is_trained == True\n", + "# exp = generate_cf_explanations(m, dm, pred_fn=ml_model.pred_fn)" ] } ],