From 13762fcedaf89d165284c1fbec62f6cddc25ba8f Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Sun, 29 Oct 2023 17:49:45 +0000 Subject: [PATCH 1/4] Tutorials on `ReLax` as a Recourse Library --- nbs/tutorials/methods.ipynb | 50 ++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/nbs/tutorials/methods.ipynb b/nbs/tutorials/methods.ipynb index a59149c..8090ffc 100644 --- a/nbs/tutorials/methods.ipynb +++ b/nbs/tutorials/methods.ipynb @@ -11,10 +11,36 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`ReLax` is a recourse explanation library which provides implementations of various recourse methods.\n", - "In other words, you can use implemented methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n", + "`ReLax` contains implementations of various recourse methods, which are decoupled from the rest of `ReLax` library.\n", + "We give users flexibility on how to use `ReLax`: \n", "\n", - "At a high level, you can use the implemented methods in `ReLax` to generate a recourse explanation via three lines of code:\n", + "* You can use the recourse pipeline in `ReLax` (\"one-liner\" for easy benchmarking recourse methods; see this [tutorial](getting_started.ipynb)).\n", + "* You can use all of the recourse methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n", + "\n", + "In this tutorial, we uncover the possibility of the second option by using recourse methods under `relax.methods` \n", + "for debugging, diagnoising, interpreting your JAX models.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Types of Recourse Methods\n", + "\n", + "TODO: Describe the difference between non-parametric, semi-parametric, and parametric methods. \n", + "What it means conceptually (text and formula), and what it means in terms of code \n", + "(e.g., parametric methods inherites `ParametricCFModule`). \n", + "\n", + "TODO: Include a table to describe the difference." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basis \n", + "\n", + "At a high level, you can use the implemented methods in `ReLax` to generate *one* recourse explanation via three lines of code:\n", "\n", "```python\n", "from relax.methods import VanillaCF\n", @@ -30,10 +56,22 @@ "...\n", "import functools as ft\n", "\n", - "generate_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n", + "vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n", "# xs is a batched data. Shape: `(N, K)`\n", - "cfs = jax.vmap(generate_fn)(xs)\n", - "```" + "cfs = jax.vmap(vcf_gen_fn)(xs)\n", + "```\n", + "\n", + "TODO: Also show examples of using parametric methods (using `CCHVAE` as an example)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Config Recourse Methods\n", + "\n", + "TODO: Refer to [this link](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html#load-dataset-with-datamodule)\n", + "on how config works in `ReLax`. It is similar on how to config recourse methods here." ] } ], From a2c426bdfcef44a82d4c0ccae84c51d998a0cdac Mon Sep 17 00:00:00 2001 From: Praneyg Date: Wed, 8 Nov 2023 15:08:42 -0500 Subject: [PATCH 2/4] explained types of recourse methods --- nbs/tutorials/methods.ipynb | 116 +++++++++++++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 9 deletions(-) diff --git a/nbs/tutorials/methods.ipynb b/nbs/tutorials/methods.ipynb index 8090ffc..bbc55b8 100644 --- a/nbs/tutorials/methods.ipynb +++ b/nbs/tutorials/methods.ipynb @@ -18,7 +18,7 @@ "* You can use all of the recourse methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n", "\n", "In this tutorial, we uncover the possibility of the second option by using recourse methods under `relax.methods` \n", - "for debugging, diagnoising, interpreting your JAX models.\n" + "for debugging, diagnosing, interpreting your JAX models.\n" ] }, { @@ -27,11 +27,18 @@ "source": [ "## Types of Recourse Methods\n", "\n", - "TODO: Describe the difference between non-parametric, semi-parametric, and parametric methods. \n", - "What it means conceptually (text and formula), and what it means in terms of code \n", - "(e.g., parametric methods inherites `ParametricCFModule`). \n", + "1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include VanillaCF and GrowingSpheres. These methods inherit from NonParametricCFModule.\n", "\n", - "TODO: Include a table to describe the difference." + "2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include DiverseCF and ProtoCF. These methods inherit from SemiParametricCFModule.\n", + "\n", + "3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include CounterNet, CCHVAE, VAECF and CLUE. These methods inherit from ParametricCFModule.\n", + "\n", + "\n", + "|Method Type | Learned Parameters | Training Required | Example Methods | \n", + "|-----|:-----|:---:|:-----:|\n", + "|Non-parametric |None |No |VanillaCF, GrowingSpheres |\n", + "|Semi-parametric|Some (θ)|Modest amount |DiverseCF, ProtoCF |\n", + "|Parametric|Full generator model (φ)|Substantial amount|CounterNet, CCHVAE, VAECF, CLUE|" ] }, { @@ -59,9 +66,39 @@ "vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n", "# xs is a batched data. Shape: `(N, K)`\n", "cfs = jax.vmap(vcf_gen_fn)(xs)\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example of using ReLax for parametric methods (using CCHVAE)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```python\n", + "from relax.methods import CCHVAE\n", + "\n", + "cchvae = CCHVAE()\n", + "# x is one data point. Shape: `(K)` or `(1, K)`\n", + "cf = vcf.generate_cf(x, pred_fn=pred_fn)\n", "```\n", "\n", - "TODO: Also show examples of using parametric methods (using `CCHVAE` as an example)." + "Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n", + "\n", + "```python\n", + "...\n", + "import functools as ft\n", + "\n", + "cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)\n", + "cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals\n", + "\n", + "```" ] }, { @@ -70,8 +107,69 @@ "source": [ "## Config Recourse Methods\n", "\n", - "TODO: Refer to [this link](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html#load-dataset-with-datamodule)\n", - "on how config works in `ReLax`. It is similar on how to config recourse methods here." + "Each recourse method in ReLax has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n", + "\n", + "For example, to configure VanillaCF:\n", + "\n", + "```Python\n", + "from relax.methods import VanillaCF \n", + "from relax.methods.vanilla import VanillaCFConfig\n", + "\n", + "config = VanillaCFConfig(\n", + " n_steps=100,\n", + " lr=0.1,\n", + " lambda_=0.1\n", + ")\n", + "\n", + "vcf = VanillaCF(config)\n", + "\n", + "```\n", + "Each Config class inherits from a BaseConfig that defines common options like num_cfs. Method-specific parameters are defined on the individual Config classes.\n", + "\n", + "See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name]." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, we can also specify this config via a dictionary.\n", + "\n", + "```Python\n", + "from relax.methods import VanillaCF\n", + "\n", + "config = {\n", + " \"num_cfs\": 10, \n", + " \"epsilon\": 0.01,\n", + " \"lr\": 0.1 \n", + "}\n", + "\n", + "vcf = VanillaCF(config)\n", + "```\n", + "\n", + "This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our VanillaCF instance is configured to:\n", + "\n", + " * Generate 10 counterfactuals per input (num_cfs=10)\n", + " * Use a maximum perturbation of 0.01 (epsilon=0.01)\n", + " * Use a learning rate of 0.1 for optimization (lr=0.1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Modifying at Runtime:\n", + "\n", + "The configuration can also be updated after constructing the recourse method:\n", + "\n", + "```Python\n", + "vcf = VanillaCF() \n", + "\n", + "# Later, modify config\n", + "vcf.config[\"lr\"] = 0.5\n", + "\n", + "```\n", + "This allows dynamically adjusting the configuration as needed." ] } ], @@ -83,5 +181,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } From c1523606404e19b3402bd5e4752b6ae257f07604 Mon Sep 17 00:00:00 2001 From: Praneyg Date: Wed, 8 Nov 2023 17:39:03 -0500 Subject: [PATCH 3/4] resolved errors --- nbs/tutorials/methods.ipynb | 51 ++++++++++++------------------------- 1 file changed, 16 insertions(+), 35 deletions(-) diff --git a/nbs/tutorials/methods.ipynb b/nbs/tutorials/methods.ipynb index bbc55b8..b5e821c 100644 --- a/nbs/tutorials/methods.ipynb +++ b/nbs/tutorials/methods.ipynb @@ -27,18 +27,18 @@ "source": [ "## Types of Recourse Methods\n", "\n", - "1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include VanillaCF and GrowingSpheres. These methods inherit from NonParametricCFModule.\n", + "1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include `VanillaCF`, `DiverseCF` and `GrowingSpheres` . These methods inherit from NonParametricCFModule.\n", "\n", - "2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include DiverseCF and ProtoCF. These methods inherit from SemiParametricCFModule.\n", + "2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include `ProtoCF`, `CCHVAE` and `CLUE`. These methods inherit from SemiParametricCFModule.\n", "\n", - "3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include CounterNet, CCHVAE, VAECF and CLUE. These methods inherit from ParametricCFModule.\n", + "3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include `CounterNet` and `VAECF`. These methods inherit from ParametricCFModule.\n", "\n", "\n", "|Method Type | Learned Parameters | Training Required | Example Methods | \n", "|-----|:-----|:---:|:-----:|\n", - "|Non-parametric |None |No |VanillaCF, GrowingSpheres |\n", - "|Semi-parametric|Some (θ)|Modest amount |DiverseCF, ProtoCF |\n", - "|Parametric|Full generator model (φ)|Substantial amount|CounterNet, CCHVAE, VAECF, CLUE|" + "|Non-parametric | None |No |`VanillaCF`, `DiverseCF`, `GrowingSpheres` |\n", + "|Semi-parametric| Some (θ) |Modest amount |`ProtoCF`, `CCHVAE`, `CLUE` |\n", + "|Parametric|Full generator model (φ)|Substantial amount|`CounterNet`, `VAECF` |" ] }, { @@ -73,7 +73,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Example of using ReLax for parametric methods (using CCHVAE)\n", + "Example of using `ReLax` for parametric methods (using `CCHVAE`)\n", "\n" ] }, @@ -85,8 +85,7 @@ "from relax.methods import CCHVAE\n", "\n", "cchvae = CCHVAE()\n", - "# x is one data point. Shape: `(K)` or `(1, K)`\n", - "cf = vcf.generate_cf(x, pred_fn=pred_fn)\n", + "cchvae.train(train_data) # Train counterfactual generator \n", "```\n", "\n", "Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n", @@ -107,9 +106,9 @@ "source": [ "## Config Recourse Methods\n", "\n", - "Each recourse method in ReLax has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n", + "Each recourse method in `ReLax` has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n", "\n", - "For example, to configure VanillaCF:\n", + "For example, to configure `VanillaCF`:\n", "\n", "```Python\n", "from relax.methods import VanillaCF \n", @@ -124,7 +123,7 @@ "vcf = VanillaCF(config)\n", "\n", "```\n", - "Each Config class inherits from a BaseConfig that defines common options like num_cfs. Method-specific parameters are defined on the individual Config classes.\n", + "Each Config class inherits from a BaseConfig that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.\n", "\n", "See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name]." ] @@ -139,38 +138,20 @@ "from relax.methods import VanillaCF\n", "\n", "config = {\n", - " \"num_cfs\": 10, \n", - " \"epsilon\": 0.01,\n", + " \"n_steps\": 10, \n", + " \"lambda_\": 0.1,\n", " \"lr\": 0.1 \n", "}\n", "\n", "vcf = VanillaCF(config)\n", "```\n", "\n", - "This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our VanillaCF instance is configured to:\n", + "This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our `VanillaCF` instance is configured to:\n", "\n", - " * Generate 10 counterfactuals per input (num_cfs=10)\n", - " * Use a maximum perturbation of 0.01 (epsilon=0.01)\n", + " * Number 10 optimization steps (n_steps=100)\n", + " * Use 0.1 validity regularization for counterfactuals (lambda_=0.1)\n", " * Use a learning rate of 0.1 for optimization (lr=0.1)" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Modifying at Runtime:\n", - "\n", - "The configuration can also be updated after constructing the recourse method:\n", - "\n", - "```Python\n", - "vcf = VanillaCF() \n", - "\n", - "# Later, modify config\n", - "vcf.config[\"lr\"] = 0.5\n", - "\n", - "```\n", - "This allows dynamically adjusting the configuration as needed." - ] } ], "metadata": { From fa8a532b11d46d2e46b273cf61f3f9af38cde7b5 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:35:56 -0500 Subject: [PATCH 4/4] Fix typos and update documentation for ReLax methods --- nbs/tutorials/methods.ipynb | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/nbs/tutorials/methods.ipynb b/nbs/tutorials/methods.ipynb index b5e821c..db2b0b5 100644 --- a/nbs/tutorials/methods.ipynb +++ b/nbs/tutorials/methods.ipynb @@ -27,16 +27,16 @@ "source": [ "## Types of Recourse Methods\n", "\n", - "1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include `VanillaCF`, `DiverseCF` and `GrowingSpheres` . These methods inherit from NonParametricCFModule.\n", + "1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include `VanillaCF`, `DiverseCF` and `GrowingSphere` . These methods inherit from `CFModule`.\n", "\n", - "2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include `ProtoCF`, `CCHVAE` and `CLUE`. These methods inherit from SemiParametricCFModule.\n", + "2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include `ProtoCF`, `CCHVAE` and `CLUE`. These methods inherit from `ParametricCFModule `.\n", "\n", - "3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include `CounterNet` and `VAECF`. These methods inherit from ParametricCFModule.\n", + "3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include `CounterNet` and `VAECF`. These methods inherit from `ParametricCFModule`.\n", "\n", "\n", "|Method Type | Learned Parameters | Training Required | Example Methods | \n", "|-----|:-----|:---:|:-----:|\n", - "|Non-parametric | None |No |`VanillaCF`, `DiverseCF`, `GrowingSpheres` |\n", + "|Non-parametric | None |No |`VanillaCF`, `DiverseCF`, `GrowingSphere` |\n", "|Semi-parametric| Some (θ) |Modest amount |`ProtoCF`, `CCHVAE`, `CLUE` |\n", "|Parametric|Full generator model (φ)|Substantial amount|`CounterNet`, `VAECF` |" ] @@ -45,7 +45,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Basis \n", + "## Basic Usages\n", "\n", "At a high level, you can use the implemented methods in `ReLax` to generate *one* recourse explanation via three lines of code:\n", "\n", @@ -73,8 +73,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Example of using `ReLax` for parametric methods (using `CCHVAE`)\n", - "\n" + "To use parametric and semi-parametric methods, you can first train the model\n", + "by calling `ParametricCF.train`, and then generate recourse explanations.\n", + "Here is an example of using `ReLax` for `CCHVAE`.\n" ] }, { @@ -85,7 +86,8 @@ "from relax.methods import CCHVAE\n", "\n", "cchvae = CCHVAE()\n", - "cchvae.train(train_data) # Train counterfactual generator \n", + "cchvae.train(train_data) # Train CVAE before generation\n", + "cf = cchvae.generate_cf(x, pred_fn=pred_fn) \n", "```\n", "\n", "Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n", @@ -123,9 +125,9 @@ "vcf = VanillaCF(config)\n", "\n", "```\n", - "Each Config class inherits from a BaseConfig that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.\n", + "Each Config class inherits from a `BaseConfig` that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.\n", "\n", - "See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from relax.methods.[method_name]." + "See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from `relax.methods.[method_name]`." ] }, {