Skip to content

Commit

Permalink
Merge pull request #15 from BirkhoffG/dm-from-np
Browse files Browse the repository at this point in the history
Add tutorials on methods; Implement DataModule.from_numpy
  • Loading branch information
BirkhoffG authored Oct 12, 2023
2 parents f182040 + 746838f commit 7099840
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 89 deletions.
53 changes: 43 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,24 @@ right version for GPU or TPU.
## Dive into `ReLax`

`ReLax` is a recourse explanation library for explaining (any) JAX-based
ML models.
ML models. We believe that it is important to give users flexibility to
choose how to use `ReLax`. You can

- only use methods implemeted in `ReLax` (as a recourse methods
library);
- build a pipeline using `ReLax` to define data module, training ML
models, and generating CF explanation (for constructing recourse
benchmarking pipeline).

### `ReLax` as a Recourse Explanation Library

We introduce basic use cases of using methods in `ReLax` to generate
recourse explanations. For more advanced usages of methods in `ReLax`,
See this [tutorials](tutorials/methods.ipynb).

``` python
from relax.ml_model import MLModule
from relax.methods import VanillaCF
from relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import functools as ft
Expand All @@ -72,22 +85,22 @@ xs, ys = make_classification(n_samples=1000, n_features=10, random_state=42)
train_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)
```

Next, we fit an MLP model for this data. This model can be any model
implmented in JAX. We will use the
Next, we fit an MLP model for this data. Note that this model can be any
model implmented in JAX. We will use the
[`MLModule`](https://birkhoffg.github.io/jax-relax/ml_model.html#mlmodule)
in `ReLax` as an example.

``` python
model = MLModule()
model.train((train_xs, train_ys), epochs=5, batch_size=64)
model.train((train_xs, train_ys), epochs=10, batch_size=64)
```

Generating recourse explanations are straightforward. We can simply call
`generate_cf` of an implemented recourse method to generate *one*
recourse explanation:

``` python
vcf = VanillaCF()
vcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05})
cf = vcf.generate_cf(test_xs[0], model.pred_fn)
assert cf.shape == test_xs[0].shape
```
Expand All @@ -100,13 +113,33 @@ cfs = jax.vmap(generate_fn)(test_xs)
assert cfs.shape == test_xs.shape
```

## An End-to-End Example of using `ReLax`
### `ReLax` for Building Recourse Explanation Pipelines

The above example illustrates the usage of the decoupled `relax.methods`
to generate recourse explanations. However, users are required to write
boilerplate code for tasks such as data preprocessing, model training,
and generating recourse explanations with feature constraints.

`ReLax` additionally offers a one-liner framework, streamlining the
process and helping users in building a standardized pipeline for
generating recourse explanations. You can write three lines of code to
benchmark recourse explanations:

``` python
data_module = DataModule.from_numpy(xs, ys)
exps = generate_cf_explanations(vcf, data_module, model.pred_fn)
benchmark_cfs([exps])
```

See [Getting Started with
ReLax](https://birkhoffg.github.io/ReLax/tutorials/getting_started.html).
ReLax](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)
for an end-to-end example of using `ReLax`.

## Supported Recourse Methods

`ReLax` currently provides implementations of 8 recourse explanation
methods.

| Method | Type | Paper Title | Ref |
|--------------------------------------------------------------------------------------------|-----------------|------------------------------------------------------------------------------------------------|-------------------------------------------|
| [`VanillaCF`](https://birkhoffg.github.io/jax-relax/methods/vanilla.html#vanillacf) | Non-Parametric | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. | [\[1\]](https://arxiv.org/abs/1711.00399) |
Expand All @@ -126,8 +159,8 @@ To cite this repository:
@software{relax2023github,
author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
title = {{R}e{L}ax: Recourse Explanation Library in Jax},
url = {http://github.com/birkhoffg/ReLax},
version = {0.1.0},
url = {http://github.com/birkhoffg/jax-relax},
version = {0.2.0},
year = {2023},
}
```
46 changes: 36 additions & 10 deletions nbs/01_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@
"text": [
"Using JAX backend.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"source": [
Expand Down Expand Up @@ -74,7 +81,8 @@
"source": [
"#| hide\n",
"from fastcore.test import *\n",
"from copy import deepcopy"
"from copy import deepcopy\n",
"from sklearn.datasets import make_classification"
]
},
{
Expand Down Expand Up @@ -203,15 +211,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"outputs": [],
"source": [
"#| hide\n",
"config = DataModuleConfig(data_name=\"TabularDataModule\", data_dir=\"data\", continous_cols=[], discret_cols=[], imutable_cols=[])\n",
Expand Down Expand Up @@ -502,6 +502,20 @@
" return cls(features=features, label=label, config=config, data=data)\n",
" \n",
" @classmethod\n",
" def from_numpy(\n",
" cls,\n",
" xs: np.ndarray, # Input data\n",
" ys: np.ndarray, # Labels\n",
" name: str = None, # Name of `DataModule`\n",
" transformation='minmax'\n",
" ) -> DataModule: # Initialized `DataModule` from numpy arrays\n",
" \"\"\"Create `DataModule` from numpy arrays. Note that the `xs` are treated as continuous features.\"\"\"\n",
" \n",
" features = FeaturesList([Feature(f\"feature_{i}\", xs[:, i].reshape(-1, 1), transformation=transformation) for i in range(xs.shape[1])])\n",
" labels = FeaturesList([Feature(f\"label\", ys.reshape(-1, 1), transformation='identity')])\n",
" return cls(features=features, label=labels, name=name)\n",
" \n",
" @classmethod\n",
" def from_features(\n",
" cls, \n",
" features: FeaturesList, # Features of `DataModule`\n",
Expand Down Expand Up @@ -613,6 +627,18 @@
"# chex.assert_trees_all_equal(dm['train'], dm_1['train'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xs, ys = make_classification(1000, n_features=10)\n",
"dm = DataModule.from_numpy(xs, ys)\n",
"assert np.array_equal(dm.data.iloc[:, :10].to_numpy(), xs)\n",
"assert np.array_equal(dm.data.iloc[:, -1].to_numpy(), ys)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading

0 comments on commit 7099840

Please sign in to comment.