From c664bbed62a9f3daf1027331708f47835ad71f89 Mon Sep 17 00:00:00 2001 From: Janis Klaise Date: Wed, 15 Sep 2021 16:41:57 +0100 Subject: [PATCH] Notebook fixes (#333) * Fix predict_batch import errors * Import os to avoid NameError * Update ipython kernel name to default * Fix import path to scale_by_instance * Clarify the installation of wilds library before executing the first cell * Update import path to avoid ModuleNotFoundError * Change flag DOWNLOAD=True to ensure the notebook can run, update attributes * Fix typo, update attribues * Remove unused predict_batch import * Update predict_batch calls --- examples/ad_ae_cifar10.ipynb | 40 +++++++---------- examples/cd_distillation_cifar10.ipynb | 20 +++------ examples/cd_ks_cifar10.ipynb | 17 +++++-- examples/cd_mol.ipynb | 52 ++-------------------- examples/cd_online_camelyon.ipynb | 12 +++-- examples/cd_spot_the_diff_mnist_wine.ipynb | 6 +-- examples/cd_text_amazon.ipynb | 6 +-- examples/od_aegmm_kddcup.ipynb | 4 +- examples/od_llr_genome.ipynb | 3 +- examples/od_llr_mnist.ipynb | 20 ++++----- examples/od_seq2seq_synth.ipynb | 6 +-- examples/od_vae_cifar10.ipynb | 3 +- examples/od_vae_kddcup.ipynb | 13 ++---- 13 files changed, 74 insertions(+), 128 deletions(-) diff --git a/examples/ad_ae_cifar10.ipynb b/examples/ad_ae_cifar10.ipynb index 9c23d810d..22f014662 100644 --- a/examples/ad_ae_cifar10.ipynb +++ b/examples/ad_ae_cifar10.ipynb @@ -51,15 +51,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ERROR:fbprophet:Importing plotly failed. Interactive plots will not work.\n" - ] - } - ], + "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -72,7 +64,7 @@ "\n", "from alibi_detect.ad import AdversarialAE\n", "from alibi_detect.utils.fetching import fetch_detector, fetch_tf_model\n", - "from alibi_detect.utils.prediction import predict_batch\n", + "from alibi_detect.utils.tensorflow.prediction import predict_batch\n", "from alibi_detect.utils.saving import save_detector, load_detector\n", "from alibi_detect.datasets import fetch_attack, fetch_cifar10c, corruption_types_cifar10c" ] @@ -293,7 +285,7 @@ } ], "source": [ - "y_pred = predict_batch(clf, X_test, batch_size=32, return_class=True)\n", + "y_pred = predict_batch(X_test, clf, batch_size=32).argmax(axis=1)\n", "acc_y_pred = accuracy(y_test, y_pred)\n", "print('Accuracy: {:.4f}'.format(acc_y_pred))" ] @@ -353,8 +345,8 @@ "metadata": {}, "outputs": [], "source": [ - "y_pred_cw = predict_batch(clf, X_test_cw, batch_size=32, return_class=True)\n", - "y_pred_slide = predict_batch(clf, X_test_slide, batch_size=32, return_class=True)" + "y_pred_cw = predict_batch(X_test_cw, clf, batch_size=32).argmax(axis=1)\n", + "y_pred_slide = predict_batch(X_test_slide, clf, batch_size=32).argmax(axis=1)" ] }, { @@ -554,8 +546,8 @@ "metadata": {}, "outputs": [], "source": [ - "X_recon_cw = predict_batch(ad.ae, X_test_cw, batch_size=32)\n", - "X_recon_slide = predict_batch(ad.ae, X_test_slide, batch_size=32)" + "X_recon_cw = predict_batch(X_test_cw, ad.ae, batch_size=32)\n", + "X_recon_slide = predict_batch(X_test_slide, ad.ae, batch_size=32)" ] }, { @@ -564,8 +556,8 @@ "metadata": {}, "outputs": [], "source": [ - "y_recon_cw = predict_batch(clf, X_recon_cw, batch_size=32, return_class=True)\n", - "y_recon_slide = predict_batch(clf, X_recon_slide, batch_size=32, return_class=True)" + "y_recon_cw = predict_batch(X_recon_cw, clf, batch_size=32).argmax(axis=1)\n", + "y_recon_slide = predict_batch(X_recon_slide, clf, batch_size=32).argmax(axis=1)" ] }, { @@ -840,7 +832,7 @@ } ], "source": [ - "y_pred_mix = predict_batch(clf, X_mix, batch_size=32, return_class=True)\n", + "y_pred_mix = predict_batch(X_mix, clf, batch_size=32).argmax(axis=1)\n", "acc_y_pred_mix = accuracy(y_mix, y_pred_mix)\n", "print('Accuracy {:.4f}'.format(acc_y_pred_mix))" ] @@ -989,12 +981,12 @@ ], "source": [ "# reconstructed adversarial instances\n", - "X_recon_cw_t = predict_batch(ad_t.ae, X_test_cw, batch_size=32)\n", - "X_recon_slide_t = predict_batch(ad_t.ae, X_test_slide, batch_size=32)\n", + "X_recon_cw_t = predict_batch(X_test_cw, ad_t.ae, batch_size=32)\n", + "X_recon_slide_t = predict_batch(X_test_slide, ad_t.ae, batch_size=32)\n", "\n", "# make predictions on reconstructed instances and compute accuracy\n", - "y_recon_cw_t = predict_batch(clf, X_recon_cw_t, batch_size=32, return_class=True)\n", - "y_recon_slide_t = predict_batch(clf, X_recon_slide_t, batch_size=32, return_class=True)\n", + "y_recon_cw_t = predict_batch(X_recon_cw_t, clf, batch_size=32).argmax(axis=1)\n", + "y_recon_slide_t = predict_batch(X_recon_slide_t, clf, batch_size=32).argmax(axis=1)\n", "acc_y_recon_cw_t = accuracy(y_test, y_recon_cw_t)\n", "acc_y_recon_slide_t = accuracy(y_test, y_recon_slide_t)\n", "print('Accuracy after C&W attack {:.4f} -- reconstruction {:.4f}'.format(acc_y_pred_cw, acc_y_recon_cw_t))\n", @@ -1284,7 +1276,7 @@ " X_corr, mean_test, std_test = scale_by_instance(X_corr)\n", " \n", " print('Make predictions on corrupted dataset...')\n", - " y_pred_corr = predict_batch(clf, X_corr, batch_size=32, return_class=True)\n", + " y_pred_corr = predict_batch(X_corr, clf, batch_size=32).argmax(axis=1)\n", " \n", " print('Compute adversarial scores on corrupted dataset...')\n", " score_corr = ad_t.score(X_corr, batch_size=32)\n", @@ -1401,7 +1393,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/cd_distillation_cifar10.ipynb b/examples/cd_distillation_cifar10.ipynb index 46a895743..5bb31e8f9 100644 --- a/examples/cd_distillation_cifar10.ipynb +++ b/examples/cd_distillation_cifar10.ipynb @@ -30,17 +30,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Importing plotly failed. Interactive plots will not work.\n" - ] - } - ], + "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -52,7 +44,7 @@ "\n", "from alibi_detect.models.tensorflow.resnet import scale_by_instance\n", "from alibi_detect.utils.fetching import fetch_tf_model, fetch_detector\n", - "from alibi_detect.utils.prediction import predict_batch\n", + "from alibi_detect.utils.tensorflow.prediction import predict_batch\n", "from alibi_detect.utils.saving import save_detector\n", "from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c" ] @@ -500,7 +492,7 @@ " 4: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},\n", " 5: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},\n", "}\n", - "y_pred = predict_batch(clf, X_test, batch_size=256, return_class=True)\n", + "y_pred = predict_batch(X_test, clf, batch_size=256).argmax(axis=1)\n", "score_x = ad.score(X_test, batch_size=256)\n", "\n", "for s in severities:\n", @@ -511,7 +503,7 @@ " X_corr = scale_by_instance(X_corr)\n", " \n", " print('Make predictions on corrupted dataset...')\n", - " y_pred_corr = predict_batch(clf, X_corr, batch_size=1000, return_class=True)\n", + " y_pred_corr = predict_batch(X_corr, clf, batch_size=1000).argmax(axis=1)\n", " \n", " print('Compute adversarial scores on corrupted dataset...')\n", " score_corr = ad.score(X_corr, batch_size=256)\n", @@ -780,7 +772,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/cd_ks_cifar10.ipynb b/examples/cd_ks_cifar10.ipynb index 94a67b1ed..a5f606c25 100644 --- a/examples/cd_ks_cifar10.ipynb +++ b/examples/cd_ks_cifar10.ipynb @@ -36,7 +36,7 @@ "import tensorflow as tf\n", "\n", "from alibi_detect.cd import KSDrift\n", - "from alibi_detect.models.resnet import scale_by_instance\n", + "from alibi_detect.models.tensorflow.resnet import scale_by_instance\n", "from alibi_detect.utils.fetching import fetch_tf_model, fetch_detector\n", "from alibi_detect.utils.saving import save_detector, load_detector\n", "from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c" @@ -1327,14 +1327,23 @@ "hash": "ffba93b5284319fb7a107c8eacae647f441487dcc7e0323a4c0d3feb66ea8c5e" }, "kernelspec": { - "display_name": "Python 3.8.5 64-bit ('.venv': venv)", + "display_name": "Python 3", + "language": "python", "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/examples/cd_mol.ipynb b/examples/cd_mol.ipynb index d3fad5672..9cd6db769 100644 --- a/examples/cd_mol.ipynb +++ b/examples/cd_mol.ipynb @@ -2,7 +2,6 @@ "cells": [ { "cell_type": "markdown", - "id": "8c5da571", "metadata": {}, "source": [ "# Drift detection on molecular graphs\n", @@ -38,7 +37,6 @@ { "cell_type": "code", "execution_count": 1, - "id": "d2b3d26f", "metadata": {}, "outputs": [], "source": [ @@ -56,7 +54,6 @@ }, { "cell_type": "markdown", - "id": "d8affe38", "metadata": {}, "source": [ "## Load and analyze data" @@ -65,7 +62,6 @@ { "cell_type": "code", "execution_count": 2, - "id": "21ff6e7c", "metadata": {}, "outputs": [], "source": [ @@ -76,7 +72,6 @@ { "cell_type": "code", "execution_count": 3, - "id": "ba205de0", "metadata": { "scrolled": true }, @@ -90,7 +85,6 @@ }, { "cell_type": "markdown", - "id": "73ea2395", "metadata": {}, "source": [ "We set some samples apart to serve as the reference data for our drift detectors. Note that the allowed format of the reference data is very flexible and can be `np.ndarray` or `List[Any]`:" @@ -99,7 +93,6 @@ { "cell_type": "code", "execution_count": 4, - "id": "1a5c0793", "metadata": {}, "outputs": [ { @@ -128,7 +121,6 @@ { "cell_type": "code", "execution_count": 5, - "id": "9f171e20", "metadata": {}, "outputs": [ { @@ -149,7 +141,6 @@ { "cell_type": "code", "execution_count": 6, - "id": "72702782", "metadata": {}, "outputs": [ { @@ -201,7 +192,6 @@ }, { "cell_type": "markdown", - "id": "d40f62a0", "metadata": {}, "source": [ "Let's plot some graph summary statistics such as the distribution of the node degrees, number of nodes and edges as well as the clustering coefficients:" @@ -210,7 +200,6 @@ { "cell_type": "code", "execution_count": 7, - "id": "334f9994", "metadata": { "scrolled": true }, @@ -261,7 +250,6 @@ { "cell_type": "code", "execution_count": 8, - "id": "cf5be599", "metadata": {}, "outputs": [ { @@ -283,7 +271,6 @@ { "cell_type": "code", "execution_count": 9, - "id": "9f15a27a", "metadata": {}, "outputs": [ { @@ -336,7 +323,6 @@ { "cell_type": "code", "execution_count": 10, - "id": "19ca6a44", "metadata": {}, "outputs": [ { @@ -397,7 +383,6 @@ }, { "cell_type": "markdown", - "id": "5669c433", "metadata": {}, "source": [ "While the average number of nodes and edges are similar across the splits, the histograms show that the tails are slightly heavier for the training graphs." @@ -405,7 +390,6 @@ }, { "cell_type": "markdown", - "id": "280f094b", "metadata": {}, "source": [ "## Plot molecules\n", @@ -416,7 +400,6 @@ { "cell_type": "code", "execution_count": 11, - "id": "c13691e8", "metadata": {}, "outputs": [], "source": [ @@ -458,7 +441,6 @@ { "cell_type": "code", "execution_count": 12, - "id": "1337f924", "metadata": {}, "outputs": [ { @@ -481,7 +463,6 @@ }, { "cell_type": "markdown", - "id": "49dd1e14", "metadata": {}, "source": [ "## Train and evaluate a GNN classification model\n", @@ -492,7 +473,6 @@ { "cell_type": "code", "execution_count": 13, - "id": "caa2ae1e", "metadata": {}, "outputs": [], "source": [ @@ -571,7 +551,6 @@ { "cell_type": "code", "execution_count": 14, - "id": "4a34b43e", "metadata": { "scrolled": true }, @@ -602,7 +581,6 @@ }, { "cell_type": "markdown", - "id": "2783ba1f", "metadata": {}, "source": [ "Train and evaluate the model. Evaluation is done using [ROC-AUC](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html). If you already have a trained model saved, you can directly load it by specifying the `load_path`:" @@ -611,7 +589,6 @@ { "cell_type": "code", "execution_count": 15, - "id": "1b828be6", "metadata": {}, "outputs": [], "source": [ @@ -621,7 +598,6 @@ { "cell_type": "code", "execution_count": 16, - "id": "86cdaf3d", "metadata": { "scrolled": true }, @@ -705,7 +681,6 @@ }, { "cell_type": "markdown", - "id": "a9338346", "metadata": {}, "source": [ "## Detect drift\n", @@ -718,7 +693,6 @@ { "cell_type": "code", "execution_count": 17, - "id": "c7a42195", "metadata": {}, "outputs": [], "source": [ @@ -747,7 +721,6 @@ }, { "cell_type": "markdown", - "id": "653a4878", "metadata": {}, "source": [ "Because we pass lists with `torch_geometric.data.Data` objects to the detector, we need to preprocess the data using the `batch_fn` into `torch_geometric.data.Batch` objects which can be fed to the model. Then we detect drift on the model prediction distribution." @@ -756,7 +729,6 @@ { "cell_type": "code", "execution_count": 18, - "id": "82af4b5f", "metadata": {}, "outputs": [ { @@ -786,7 +758,6 @@ }, { "cell_type": "markdown", - "id": "23035daf", "metadata": {}, "source": [ "Since the dataset is heavily imbalanced, we will test the detectors on a sample which oversamples from the minority class (molecules which inhibit HIV virus replication):" @@ -795,7 +766,6 @@ { "cell_type": "code", "execution_count": 19, - "id": "9dfac4c7", "metadata": {}, "outputs": [ { @@ -820,7 +790,6 @@ { "cell_type": "code", "execution_count": 20, - "id": "fb709d47", "metadata": {}, "outputs": [ { @@ -849,7 +818,6 @@ }, { "cell_type": "markdown", - "id": "8979bdfa", "metadata": {}, "source": [ "As expected, prediction distribution shift is detected for the imbalanced sample but not for the random test sample with similar label distribution as the reference data.\n", @@ -862,7 +830,6 @@ { "cell_type": "code", "execution_count": 21, - "id": "6302f0f4", "metadata": { "scrolled": false }, @@ -885,7 +852,6 @@ { "cell_type": "code", "execution_count": 22, - "id": "034c3581", "metadata": {}, "outputs": [ { @@ -913,7 +879,6 @@ }, { "cell_type": "markdown", - "id": "515be9b4", "metadata": {}, "source": [ "Although we didn't pick up drift in the GIN model prediction distribution for the test sample, we can see that the model is less certain about the predictions on the test set, illustrated by the lower ROC-AUC.\n", @@ -926,7 +891,6 @@ { "cell_type": "code", "execution_count": 23, - "id": "50adc300", "metadata": {}, "outputs": [], "source": [ @@ -965,7 +929,6 @@ { "cell_type": "code", "execution_count": 24, - "id": "ebad8153", "metadata": {}, "outputs": [ { @@ -987,7 +950,6 @@ { "cell_type": "code", "execution_count": 25, - "id": "27be77c0", "metadata": {}, "outputs": [ { @@ -1015,7 +977,6 @@ }, { "cell_type": "markdown", - "id": "45b48bbb", "metadata": {}, "source": [ "### Input data drift using a learned kernel\n", @@ -1026,7 +987,6 @@ { "cell_type": "code", "execution_count": 26, - "id": "5809031b", "metadata": {}, "outputs": [], "source": [ @@ -1041,7 +1001,6 @@ { "cell_type": "code", "execution_count": 27, - "id": "8c53917d", "metadata": {}, "outputs": [ { @@ -1069,7 +1028,6 @@ }, { "cell_type": "markdown", - "id": "50fd4b05", "metadata": {}, "source": [ "Since the molecular scaffolds are different across the train, validation and test sets, we expect that this type of data shift is picked up in the input data (technically not the input but the graph embedding).\n", @@ -1082,7 +1040,6 @@ { "cell_type": "code", "execution_count": 28, - "id": "0e6ff107", "metadata": {}, "outputs": [], "source": [ @@ -1097,7 +1054,6 @@ { "cell_type": "code", "execution_count": 29, - "id": "f47ee016", "metadata": {}, "outputs": [ { @@ -1115,7 +1071,6 @@ { "cell_type": "code", "execution_count": 30, - "id": "954e3904", "metadata": {}, "outputs": [ { @@ -1143,7 +1098,6 @@ }, { "cell_type": "markdown", - "id": "fd8f4643", "metadata": {}, "source": [ "The 3 returned p-values correspond to respectively the p-values for the number of nodes, edges and clustering coefficient. We already saw in the EDA that the distributions of the node, edge and clustering coefficients look similar across the train, validation and test sets except for the tails. This is confirmed by running the drift detector on the graph statistics which cannot seem to pick up on the differences in molecular scaffolds between the datasets, unless we heavily oversample from the minority class where the number of nodes and edges but not the clustering coefficient significantly differ." @@ -1152,9 +1106,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:detect]", + "display_name": "Python 3", "language": "python", - "name": "conda-env-detect-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1166,7 +1120,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/cd_online_camelyon.ipynb b/examples/cd_online_camelyon.ipynb index f06a224c5..7aebbe90c 100644 --- a/examples/cd_online_camelyon.ipynb +++ b/examples/cd_online_camelyon.ipynb @@ -11,7 +11,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This notebook demonstrates a typical workflow for applying online drift detectors to streams of image data. For those unfamiliar with how the online drift detectors operate in `alibi_detect` we recommend first checking out the more introductory example [Online Drift Detection on the Wine Quality Dataset](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_online_wine.html) where online drift detection is performed for the wine quality dataset." + "This notebook demonstrates a typical workflow for applying online drift detectors to streams of image data. For those unfamiliar with how the online drift detectors operate in `alibi_detect` we recommend first checking out the more introductory example [Online Drift Detection on the Wine Quality Dataset](https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_online_wine.html) where online drift detection is performed for the wine quality dataset.\n", + "\n", + "Install the `wilds` library to fetch the dataset used in the example:\n", + "\n", + "```bash\n", + "pip install wilds\n", + "```" ] }, { @@ -529,7 +535,7 @@ "hash": "26d4efd8bf86ae199e0cff801fa58ff781ca69d267a2f4141eff4295422fc53d" }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -543,7 +549,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/cd_spot_the_diff_mnist_wine.ipynb b/examples/cd_spot_the_diff_mnist_wine.ipynb index ebec41ee5..cf514439c 100644 --- a/examples/cd_spot_the_diff_mnist_wine.ipynb +++ b/examples/cd_spot_the_diff_mnist_wine.ipynb @@ -51,13 +51,13 @@ "outputs": [], "source": [ "MNIST_PATH = 'my_path'\n", - "DOWNLOAD = False\n", + "DOWNLOAD = True\n", "MISSING_NUMBER = 0\n", "N = 10000\n", "\n", "# Load and shuffle data\n", "mnist_train_ds = torchvision.datasets.MNIST(MNIST_PATH, train=True, download=DOWNLOAD)\n", - "all_x, all_y = mnist_train_ds.train_data, mnist_train_ds.train_labels\n", + "all_x, all_y = mnist_train_ds.data, mnist_train_ds.targets\n", "perm = np.random.permutation(len(all_x))\n", "all_x, all_y = all_x[perm], all_y[perm]\n", "all_x = all_x[:, None, : , :].numpy().astype(np.float32)/255.\n", @@ -804,7 +804,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/cd_text_amazon.ipynb b/examples/cd_text_amazon.ipynb index 74821f646..fc9f3e0e6 100644 --- a/examples/cd_text_amazon.ipynb +++ b/examples/cd_text_amazon.ipynb @@ -531,9 +531,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:detect] *", + "display_name": "Python 3", "language": "python", - "name": "conda-env-detect-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -545,7 +545,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/od_aegmm_kddcup.ipynb b/examples/od_aegmm_kddcup.ipynb index be0d5680b..5149038fc 100644 --- a/examples/od_aegmm_kddcup.ipynb +++ b/examples/od_aegmm_kddcup.ipynb @@ -48,7 +48,7 @@ "from tensorflow.keras.layers import Dense, InputLayer\n", "\n", "from alibi_detect.datasets import fetch_kdd\n", - "from alibi_detect.models.autoencoder import eucl_cosim_features\n", + "from alibi_detect.models.tensorflow.autoencoder import eucl_cosim_features\n", "from alibi_detect.od import OutlierAEGMM, OutlierVAEGMM\n", "from alibi_detect.utils.data import create_outlier_batch\n", "from alibi_detect.utils.fetching import fetch_detector\n", @@ -820,7 +820,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/od_llr_genome.ipynb b/examples/od_llr_genome.ipynb index e8165b4dc..0c5804ec6 100644 --- a/examples/od_llr_genome.ipynb +++ b/examples/od_llr_genome.ipynb @@ -39,7 +39,6 @@ "from alibi_detect.datasets import fetch_genome\n", "from alibi_detect.utils.fetching import fetch_detector\n", "from alibi_detect.utils.saving import save_detector, load_detector\n", - "from alibi_detect.utils.prediction import predict_batch\n", "from alibi_detect.utils.visualize import plot_roc" ] }, @@ -562,7 +561,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/od_llr_mnist.ipynb b/examples/od_llr_mnist.ipynb index 29a027e26..77b701027 100644 --- a/examples/od_llr_mnist.ipynb +++ b/examples/od_llr_mnist.ipynb @@ -40,7 +40,7 @@ "from alibi_detect.models.tensorflow import PixelCNN\n", "from alibi_detect.utils.fetching import fetch_detector\n", "from alibi_detect.utils.saving import save_detector, load_detector\n", - "from alibi_detect.utils.prediction import predict_batch\n", + "from alibi_detect.utils.tensorflow.prediction import predict_batch\n", "from alibi_detect.utils.visualize import plot_roc" ] }, @@ -425,12 +425,12 @@ "outputs": [], "source": [ "# semantic model\n", - "logp_s_in = predict_batch(od.dist_s.log_prob, X_test_in, batch_size=32, shape=shape_in)\n", - "logp_s_ood = predict_batch(od.dist_s.log_prob, X_test_ood, batch_size=32, shape=shape_ood)\n", + "logp_s_in = predict_batch(X_test_in, od.dist_s.log_prob, batch_size=32, shape=shape_in)\n", + "logp_s_ood = predict_batch(X_test_ood, od.dist_s.log_prob, batch_size=32, shape=shape_ood)\n", "logp_s = np.concatenate([logp_s_in, logp_s_ood])\n", "# background model\n", - "logp_b_in = predict_batch(od.dist_b.log_prob, X_test_in, batch_size=32, shape=shape_in)\n", - "logp_b_ood = predict_batch(od.dist_b.log_prob, X_test_ood, batch_size=32, shape=shape_ood)" + "logp_b_in = predict_batch(X_test_in, od.dist_b.log_prob, batch_size=32, shape=shape_in)\n", + "logp_b_ood = predict_batch(X_test_ood, od.dist_b.log_prob, batch_size=32, shape=shape_ood)" ] }, { @@ -727,13 +727,13 @@ "source": [ "# semantic model\n", "logp_fn_s = partial(od.dist_s.log_prob, return_per_feature=True)\n", - "logp_s_pixel_in = predict_batch(logp_fn_s, X_test_in[:n_plot], batch_size=32)\n", - "logp_s_pixel_ood = predict_batch(logp_fn_s, X_test_ood[:n_plot], batch_size=32)\n", + "logp_s_pixel_in = predict_batch(X_test_in[:n_plot], logp_fn_s, batch_size=32)\n", + "logp_s_pixel_ood = predict_batch(X_test_ood[:n_plot], logp_fn_s, batch_size=32)\n", "\n", "# background model\n", "logp_fn_b = partial(od.dist_b.log_prob, return_per_feature=True)\n", - "logp_b_pixel_in = predict_batch(logp_fn_b, X_test_in[:n_plot], batch_size=32)\n", - "logp_b_pixel_ood = predict_batch(logp_fn_b, X_test_ood[:n_plot], batch_size=32)\n", + "logp_b_pixel_in = predict_batch(X_test_in[:n_plot], logp_fn_b, batch_size=32)\n", + "logp_b_pixel_ood = predict_batch(X_test_ood[:n_plot], logp_fn_b, batch_size=32)\n", "\n", "# pixel-wise likelihood ratios\n", "llr_pixel_in = logp_s_pixel_in - logp_b_pixel_in\n", @@ -817,7 +817,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/od_seq2seq_synth.ipynb b/examples/od_seq2seq_synth.ipynb index 4852abed0..d61bc7286 100644 --- a/examples/od_seq2seq_synth.ipynb +++ b/examples/od_seq2seq_synth.ipynb @@ -513,9 +513,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:detect] *", + "display_name": "Python 3", "language": "python", - "name": "conda-env-detect-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -527,7 +527,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/od_vae_cifar10.ipynb b/examples/od_vae_cifar10.ipynb index 78f8be91c..3cd2576b1 100644 --- a/examples/od_vae_cifar10.ipynb +++ b/examples/od_vae_cifar10.ipynb @@ -21,6 +21,7 @@ "metadata": {}, "outputs": [], "source": [ + "import os\n", "import logging\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -706,7 +707,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/examples/od_vae_kddcup.ipynb b/examples/od_vae_kddcup.ipynb index 2b79dbeff..62a147214 100644 --- a/examples/od_vae_kddcup.ipynb +++ b/examples/od_vae_kddcup.ipynb @@ -34,16 +34,9 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ERROR:fbprophet:Importing plotly failed. Interactive plots will not work.\n" - ] - } - ], + "outputs": [], "source": [ + "import os\n", "import logging\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", @@ -535,7 +528,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" } }, "nbformat": 4,