diff --git a/nbs/tutorials/methods.ipynb b/nbs/tutorials/methods.ipynb index a59149c..db2b0b5 100644 --- a/nbs/tutorials/methods.ipynb +++ b/nbs/tutorials/methods.ipynb @@ -11,10 +11,43 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`ReLax` is a recourse explanation library which provides implementations of various recourse methods.\n", - "In other words, you can use implemented methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n", + "`ReLax` contains implementations of various recourse methods, which are decoupled from the rest of `ReLax` library.\n", + "We give users flexibility on how to use `ReLax`: \n", "\n", - "At a high level, you can use the implemented methods in `ReLax` to generate a recourse explanation via three lines of code:\n", + "* You can use the recourse pipeline in `ReLax` (\"one-liner\" for easy benchmarking recourse methods; see this [tutorial](getting_started.ipynb)).\n", + "* You can use all of the recourse methods in `ReLax` without relying on the entire pipeline of `ReLax`.\n", + "\n", + "In this tutorial, we uncover the possibility of the second option by using recourse methods under `relax.methods` \n", + "for debugging, diagnosing, interpreting your JAX models.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Types of Recourse Methods\n", + "\n", + "1. Non-parametric methods: These methods do not rely on any learned parameters. They generate counterfactuals solely based on the model's predictions and gradients. Examples in ReLax include `VanillaCF`, `DiverseCF` and `GrowingSphere` . These methods inherit from `CFModule`.\n", + "\n", + "2. Semi-parametric methods: These methods learn some parameters to aid in counterfactual generation, but do not learn a full counterfactual generation model. Examples in ReLax include `ProtoCF`, `CCHVAE` and `CLUE`. These methods inherit from `ParametricCFModule `.\n", + "\n", + "3. Parametric methods: These methods learn a full parametric model for counterfactual generation. The model is trained to generate counterfactuals that fool the model. Examples in ReLax include `CounterNet` and `VAECF`. These methods inherit from `ParametricCFModule`.\n", + "\n", + "\n", + "|Method Type | Learned Parameters | Training Required | Example Methods | \n", + "|-----|:-----|:---:|:-----:|\n", + "|Non-parametric | None |No |`VanillaCF`, `DiverseCF`, `GrowingSphere` |\n", + "|Semi-parametric| Some (θ) |Modest amount |`ProtoCF`, `CCHVAE`, `CLUE` |\n", + "|Parametric|Full generator model (φ)|Substantial amount|`CounterNet`, `VAECF` |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usages\n", + "\n", + "At a high level, you can use the implemented methods in `ReLax` to generate *one* recourse explanation via three lines of code:\n", "\n", "```python\n", "from relax.methods import VanillaCF\n", @@ -30,11 +63,97 @@ "...\n", "import functools as ft\n", "\n", - "generate_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n", + "vcf_gen_fn = ft.partial(vcf.generate_cf, pred_fn=pred_fn)\n", "# xs is a batched data. Shape: `(N, K)`\n", - "cfs = jax.vmap(generate_fn)(xs)\n", + "cfs = jax.vmap(vcf_gen_fn)(xs)\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use parametric and semi-parametric methods, you can first train the model\n", + "by calling `ParametricCF.train`, and then generate recourse explanations.\n", + "Here is an example of using `ReLax` for `CCHVAE`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```python\n", + "from relax.methods import CCHVAE\n", + "\n", + "cchvae = CCHVAE()\n", + "cchvae.train(train_data) # Train CVAE before generation\n", + "cf = cchvae.generate_cf(x, pred_fn=pred_fn) \n", + "```\n", + "\n", + "Or generate a batch of recourse explanation via the `jax.vmap` primitive:\n", + "\n", + "```python\n", + "...\n", + "import functools as ft\n", + "\n", + "cchvae_gen_fn = ft.partial(cchvae.generate_cf, pred_fn=pred_fn)\n", + "cfs = jax.vmap(cchvae_gen_fn)(xs) # Generate counterfactuals\n", + "\n", "```" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Config Recourse Methods\n", + "\n", + "Each recourse method in `ReLax` has an associated Config class that defines the set of supported configuration parameters. To configure a method, import and instantiate its Config class and pass it as the config parameter.\n", + "\n", + "For example, to configure `VanillaCF`:\n", + "\n", + "```Python\n", + "from relax.methods import VanillaCF \n", + "from relax.methods.vanilla import VanillaCFConfig\n", + "\n", + "config = VanillaCFConfig(\n", + " n_steps=100,\n", + " lr=0.1,\n", + " lambda_=0.1\n", + ")\n", + "\n", + "vcf = VanillaCF(config)\n", + "\n", + "```\n", + "Each Config class inherits from a `BaseConfig` that defines common options like n_steps. Method-specific parameters are defined on the individual Config classes.\n", + "\n", + "See the documentation for each recourse method for details on its supported configuration parameters. The Config class for a method can be imported from `relax.methods.[method_name]`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, we can also specify this config via a dictionary.\n", + "\n", + "```Python\n", + "from relax.methods import VanillaCF\n", + "\n", + "config = {\n", + " \"n_steps\": 10, \n", + " \"lambda_\": 0.1,\n", + " \"lr\": 0.1 \n", + "}\n", + "\n", + "vcf = VanillaCF(config)\n", + "```\n", + "\n", + "This config dictionary is passed to VanillaCF's __init__ method, which will set the specified parameters. Now our `VanillaCF` instance is configured to:\n", + "\n", + " * Number 10 optimization steps (n_steps=100)\n", + " * Use 0.1 validity regularization for counterfactuals (lambda_=0.1)\n", + " * Use a learning rate of 0.1 for optimization (lr=0.1)" + ] } ], "metadata": { @@ -45,5 +164,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }