Skip to content

Commit

Permalink
Notebook fixes (#333)
Browse files Browse the repository at this point in the history
* Fix predict_batch import errors

* Import os to avoid NameError

* Update ipython kernel name to default

* Fix import path to scale_by_instance

* Clarify the installation of wilds library before executing the first cell

* Update import path to avoid ModuleNotFoundError

* Change flag DOWNLOAD=True to ensure the notebook can run, update attributes

* Fix typo, update attribues

* Remove unused predict_batch import

* Update predict_batch calls
  • Loading branch information
jklaise authored Sep 15, 2021
1 parent 55beb90 commit c664bbe
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 128 deletions.
40 changes: 16 additions & 24 deletions examples/ad_ae_cifar10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,7 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR:fbprophet:Importing plotly failed. Interactive plots will not work.\n"
]
}
],
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand All @@ -72,7 +64,7 @@
"\n",
"from alibi_detect.ad import AdversarialAE\n",
"from alibi_detect.utils.fetching import fetch_detector, fetch_tf_model\n",
"from alibi_detect.utils.prediction import predict_batch\n",
"from alibi_detect.utils.tensorflow.prediction import predict_batch\n",
"from alibi_detect.utils.saving import save_detector, load_detector\n",
"from alibi_detect.datasets import fetch_attack, fetch_cifar10c, corruption_types_cifar10c"
]
Expand Down Expand Up @@ -293,7 +285,7 @@
}
],
"source": [
"y_pred = predict_batch(clf, X_test, batch_size=32, return_class=True)\n",
"y_pred = predict_batch(X_test, clf, batch_size=32).argmax(axis=1)\n",
"acc_y_pred = accuracy(y_test, y_pred)\n",
"print('Accuracy: {:.4f}'.format(acc_y_pred))"
]
Expand Down Expand Up @@ -353,8 +345,8 @@
"metadata": {},
"outputs": [],
"source": [
"y_pred_cw = predict_batch(clf, X_test_cw, batch_size=32, return_class=True)\n",
"y_pred_slide = predict_batch(clf, X_test_slide, batch_size=32, return_class=True)"
"y_pred_cw = predict_batch(X_test_cw, clf, batch_size=32).argmax(axis=1)\n",
"y_pred_slide = predict_batch(X_test_slide, clf, batch_size=32).argmax(axis=1)"
]
},
{
Expand Down Expand Up @@ -554,8 +546,8 @@
"metadata": {},
"outputs": [],
"source": [
"X_recon_cw = predict_batch(ad.ae, X_test_cw, batch_size=32)\n",
"X_recon_slide = predict_batch(ad.ae, X_test_slide, batch_size=32)"
"X_recon_cw = predict_batch(X_test_cw, ad.ae, batch_size=32)\n",
"X_recon_slide = predict_batch(X_test_slide, ad.ae, batch_size=32)"
]
},
{
Expand All @@ -564,8 +556,8 @@
"metadata": {},
"outputs": [],
"source": [
"y_recon_cw = predict_batch(clf, X_recon_cw, batch_size=32, return_class=True)\n",
"y_recon_slide = predict_batch(clf, X_recon_slide, batch_size=32, return_class=True)"
"y_recon_cw = predict_batch(X_recon_cw, clf, batch_size=32).argmax(axis=1)\n",
"y_recon_slide = predict_batch(X_recon_slide, clf, batch_size=32).argmax(axis=1)"
]
},
{
Expand Down Expand Up @@ -840,7 +832,7 @@
}
],
"source": [
"y_pred_mix = predict_batch(clf, X_mix, batch_size=32, return_class=True)\n",
"y_pred_mix = predict_batch(X_mix, clf, batch_size=32).argmax(axis=1)\n",
"acc_y_pred_mix = accuracy(y_mix, y_pred_mix)\n",
"print('Accuracy {:.4f}'.format(acc_y_pred_mix))"
]
Expand Down Expand Up @@ -989,12 +981,12 @@
],
"source": [
"# reconstructed adversarial instances\n",
"X_recon_cw_t = predict_batch(ad_t.ae, X_test_cw, batch_size=32)\n",
"X_recon_slide_t = predict_batch(ad_t.ae, X_test_slide, batch_size=32)\n",
"X_recon_cw_t = predict_batch(X_test_cw, ad_t.ae, batch_size=32)\n",
"X_recon_slide_t = predict_batch(X_test_slide, ad_t.ae, batch_size=32)\n",
"\n",
"# make predictions on reconstructed instances and compute accuracy\n",
"y_recon_cw_t = predict_batch(clf, X_recon_cw_t, batch_size=32, return_class=True)\n",
"y_recon_slide_t = predict_batch(clf, X_recon_slide_t, batch_size=32, return_class=True)\n",
"y_recon_cw_t = predict_batch(X_recon_cw_t, clf, batch_size=32).argmax(axis=1)\n",
"y_recon_slide_t = predict_batch(X_recon_slide_t, clf, batch_size=32).argmax(axis=1)\n",
"acc_y_recon_cw_t = accuracy(y_test, y_recon_cw_t)\n",
"acc_y_recon_slide_t = accuracy(y_test, y_recon_slide_t)\n",
"print('Accuracy after C&W attack {:.4f} -- reconstruction {:.4f}'.format(acc_y_pred_cw, acc_y_recon_cw_t))\n",
Expand Down Expand Up @@ -1284,7 +1276,7 @@
" X_corr, mean_test, std_test = scale_by_instance(X_corr)\n",
" \n",
" print('Make predictions on corrupted dataset...')\n",
" y_pred_corr = predict_batch(clf, X_corr, batch_size=32, return_class=True)\n",
" y_pred_corr = predict_batch(X_corr, clf, batch_size=32).argmax(axis=1)\n",
" \n",
" print('Compute adversarial scores on corrupted dataset...')\n",
" score_corr = ad_t.score(X_corr, batch_size=32)\n",
Expand Down Expand Up @@ -1401,7 +1393,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down
20 changes: 6 additions & 14 deletions examples/cd_distillation_cifar10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Importing plotly failed. Interactive plots will not work.\n"
]
}
],
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand All @@ -52,7 +44,7 @@
"\n",
"from alibi_detect.models.tensorflow.resnet import scale_by_instance\n",
"from alibi_detect.utils.fetching import fetch_tf_model, fetch_detector\n",
"from alibi_detect.utils.prediction import predict_batch\n",
"from alibi_detect.utils.tensorflow.prediction import predict_batch\n",
"from alibi_detect.utils.saving import save_detector\n",
"from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c"
]
Expand Down Expand Up @@ -500,7 +492,7 @@
" 4: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},\n",
" 5: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},\n",
"}\n",
"y_pred = predict_batch(clf, X_test, batch_size=256, return_class=True)\n",
"y_pred = predict_batch(X_test, clf, batch_size=256).argmax(axis=1)\n",
"score_x = ad.score(X_test, batch_size=256)\n",
"\n",
"for s in severities:\n",
Expand All @@ -511,7 +503,7 @@
" X_corr = scale_by_instance(X_corr)\n",
" \n",
" print('Make predictions on corrupted dataset...')\n",
" y_pred_corr = predict_batch(clf, X_corr, batch_size=1000, return_class=True)\n",
" y_pred_corr = predict_batch(X_corr, clf, batch_size=1000).argmax(axis=1)\n",
" \n",
" print('Compute adversarial scores on corrupted dataset...')\n",
" score_corr = ad.score(X_corr, batch_size=256)\n",
Expand Down Expand Up @@ -780,7 +772,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down
17 changes: 13 additions & 4 deletions examples/cd_ks_cifar10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"import tensorflow as tf\n",
"\n",
"from alibi_detect.cd import KSDrift\n",
"from alibi_detect.models.resnet import scale_by_instance\n",
"from alibi_detect.models.tensorflow.resnet import scale_by_instance\n",
"from alibi_detect.utils.fetching import fetch_tf_model, fetch_detector\n",
"from alibi_detect.utils.saving import save_detector, load_detector\n",
"from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c"
Expand Down Expand Up @@ -1327,14 +1327,23 @@
"hash": "ffba93b5284319fb7a107c8eacae647f441487dcc7e0323a4c0d3feb66ea8c5e"
},
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('.venv': venv)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": ""
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Loading

0 comments on commit c664bbe

Please sign in to comment.