Skip to content

Commit

Permalink
Merge pull request #40 from BirkhoffG/gumbel_softmax
Browse files Browse the repository at this point in the history
Support gumbel softmax transformation
  • Loading branch information
BirkhoffG authored Feb 20, 2024
2 parents 6171081 + 0ffc980 commit aaa9346
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 81 deletions.
27 changes: 27 additions & 0 deletions nbs/00_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,33 @@
" return upt_params, opt_state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Functional Utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def gumbel_softmax(\n",
" key: jrand.PRNGKey, # Random key\n",
" logits: Array, # Logits for each class. Shape (batch_size, num_classes)\n",
" tau: float, # Temperature for the Gumbel softmax\n",
" axis: int | tuple[int, ...] = -1, # The axis or axes along which the gumbel softmax should be computed\n",
"):\n",
" \"\"\"The Gumbel softmax function.\"\"\"\n",
"\n",
" gumbel_noise = jrand.gumbel(key, shape=logits.shape)\n",
" y = logits + gumbel_noise\n",
" return jax.nn.softmax(y / tau, axis=axis)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
153 changes: 107 additions & 46 deletions nbs/01_data.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"from ipynb_path import *\n",
"import warnings\n",
"\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand All @@ -40,8 +47,10 @@
"import einops\n",
"import os, sys, json, pickle\n",
"import shutil\n",
"from relax.utils import *\n",
"import chex"
"from relax.utils import gumbel_softmax, load_pytree, save_pytree, get_config\n",
"import chex\n",
"import functools as ft\n",
"import warnings"
]
},
{
Expand Down Expand Up @@ -467,8 +476,8 @@
"transformed_xs = scaler.fit_transform(xs)\n",
"assert scaler.is_categorical is False\n",
"\n",
"cfs = np.random.randn(100, 1)\n",
"cf_constrained = scaler.apply_constraints(xs, cfs)\n",
"x = np.random.randn(100, 1)\n",
"cf_constrained = scaler.apply_constraints(xs, x)\n",
"assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)\n",
"\n",
"# Test from_dict and to_dict\n",
Expand All @@ -483,20 +492,27 @@
"outputs": [],
"source": [
"#| export\n",
"class OneHotTransformation(Transformation):\n",
"class _OneHotTransformation(Transformation):\n",
" def __init__(self):\n",
" super().__init__(\"ohe\", OneHotEncoder())\n",
"\n",
" @property\n",
" def num_categories(self) -> int:\n",
" return len(self.transformer.categories_)\n",
" \n",
" def hard_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): \n",
" x, rng_key, kwargs = operand\n",
" return jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories)\n",
" \n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" raise NotImplementedError\n",
"\n",
" def apply_constraints(self, xs, cfs, hard: bool = False, **kwargs):\n",
" def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):\n",
" return jax.lax.cond(\n",
" hard,\n",
" true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories),\n",
" false_fun=lambda x: jax.nn.softmax(x, axis=-1),\n",
" operand=cfs,\n",
" true_fun=self.hard_constraints,\n",
" false_fun=self.soft_constraints,\n",
" operand=(cfs, rng_key, kwargs),\n",
" )\n",
" \n",
" def compute_reg_loss(self, xs, cfs, hard: bool = False):\n",
Expand All @@ -510,31 +526,74 @@
"metadata": {},
"outputs": [],
"source": [
"xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))\n",
"ohe_t = OneHotTransformation().fit(xs)\n",
"transformed_xs = ohe_t.transform(xs)\n",
"#| export\n",
"class SoftmaxTransformation(_OneHotTransformation):\n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" x, rng_key, kwargs = operand\n",
" return jax.nn.softmax(x, axis=-1)\n",
" \n",
"class GumbelSoftmaxTransformation(_OneHotTransformation):\n",
" \"\"\"Apply Gumbel softmax tricks for categorical transformation.\"\"\"\n",
"\n",
"cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 3))\n",
"# Test hard=True which applies softmax function.\n",
"soft = ohe_t.apply_constraints(transformed_xs, cfs, hard=False)\n",
"assert jnp.allclose(soft.sum(axis=-1), 1)\n",
"assert jnp.all(soft >= 0)\n",
"assert jnp.all(soft <= 1)\n",
"assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))\n",
" def __init__(self, tau: float = 1.):\n",
" super().__init__()\n",
" self.tau = tau\n",
" \n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" x, rng_key, _ = operand\n",
" if rng_key is None: # No randomness\n",
" rng_key = jax.random.PRNGKey(get_config().global_seed)\n",
" return gumbel_softmax(rng_key, x, self.tau)\n",
" \n",
" def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):\n",
" \"\"\"Apply constraints to the counterfactuals. If `rng_key` is None, no randomness is used.\"\"\"\n",
" return super().apply_constraints(xs, cfs, hard, rng_key, **kwargs)\n",
" \n",
"def OneHotTransformation():\n",
" warnings.warn(\"OneHotTransformation is deprecated since v0.2.5. \"\n",
" \"Use `SoftmaxTransformation` (same functionality) \"\n",
" \"or GumbelSoftmaxTransformation instead.\", DeprecationWarning)\n",
" return SoftmaxTransformation()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_ohe_t(ohe_cls):\n",
" xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))\n",
" ohe_t = ohe_cls().fit(xs)\n",
" transformed_xs = ohe_t.transform(xs)\n",
" rng_key = jax.random.PRNGKey(get_config().global_seed)\n",
"\n",
"# Test hard=True which enforce one-hot constraint.\n",
"hard = ohe_t.apply_constraints(transformed_xs, cfs, hard=True)\n",
"assert np.all([1 in x for x in hard])\n",
"assert np.all([0 in x for x in hard])\n",
"assert jnp.allclose(hard.sum(axis=-1), 1)\n",
"assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))\n",
" x = jax.random.uniform(rng_key, shape=(100, 3))\n",
" # Test hard=True which applies softmax function.\n",
" soft = ohe_t.apply_constraints(transformed_xs, x, hard=False, rng_key=rng_key)\n",
" assert jnp.allclose(soft.sum(axis=-1), 1)\n",
" assert jnp.all(soft >= 0)\n",
" assert jnp.all(soft <= 1)\n",
" assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))\n",
" assert jnp.allclose(soft, ohe_t.apply_constraints(transformed_xs, x, hard=False))\n",
"\n",
"# Test compute_reg_loss\n",
"assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0\n",
" # Test hard=True which enforce one-hot constraint.\n",
" hard = ohe_t.apply_constraints(transformed_xs, x, hard=True, rng_key=rng_key)\n",
" assert np.all([1 in x for x in hard])\n",
" assert np.all([0 in x for x in hard])\n",
" assert jnp.allclose(hard.sum(axis=-1), 1)\n",
" assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))\n",
"\n",
"# Test from_dict and to_dict\n",
"ohe_t_1 = OneHotTransformation().from_dict(ohe_t.to_dict())\n",
"assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))"
" # Test compute_reg_loss\n",
" assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0\n",
"\n",
" # Test from_dict and to_dict\n",
" ohe_t_1 = ohe_cls().from_dict(ohe_t.to_dict())\n",
" assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))\n",
"\n",
"\n",
"test_ohe_t(SoftmaxTransformation)\n",
"test_ohe_t(GumbelSoftmaxTransformation)"
]
},
{
Expand Down Expand Up @@ -616,7 +675,9 @@
"source": [
"#| export\n",
"PREPROCESSING_TRANSFORMATIONS = {\n",
" 'ohe': OneHotTransformation,\n",
" 'ohe': SoftmaxTransformation,\n",
" 'softmax': SoftmaxTransformation,\n",
" 'gumbel': GumbelSoftmaxTransformation,\n",
" 'minmax': MinMaxTransformation,\n",
" 'ordinal': OrdinalTransformation,\n",
" 'identity': IdentityTransformation,\n",
Expand Down Expand Up @@ -1157,8 +1218,8 @@
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"\n",
"cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n",
"feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], cfs, hard=True)"
"x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n",
"feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)"
]
},
{
Expand Down Expand Up @@ -1189,17 +1250,17 @@
"outputs": [],
"source": [
"# Test apply_constraints and compute_reg_loss\n",
"cfs = np.random.randn(10, 29)\n",
"constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False)\n",
"x = np.random.randn(10, 29)\n",
"constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=False)\n",
"assert constraint_cfs.shape == (10, 29)\n",
"assert np.allclose(\n",
" constraint_cfs[:, 2:].sum(axis=-1),\n",
" np.ones((10,)) * 6\n",
")\n",
"assert constraint_cfs[: :2].min() >= 0 and constraint_cfs[: :2].max() <= 1\n",
"assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True).shape == (10, 29)\n",
"assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=True).shape == (10, 29)\n",
"\n",
"reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, cfs)\n",
"reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, x)\n",
"assert jnp.ndim(reg_loss) == 0\n",
"assert np.all(reg_loss > 0)\n",
"assert np.allclose(feats_list.compute_reg_loss(xs, constraint_cfs), 0)"
Expand Down
35 changes: 27 additions & 8 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@
'relax/data_utils.py'),
'relax.data_utils.FeaturesList.with_transformed_data': ( 'data.utils.html#featureslist.with_transformed_data',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation': ( 'data.utils.html#gumbelsoftmaxtransformation',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.__init__': ( 'data.utils.html#gumbelsoftmaxtransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.apply_constraints': ( 'data.utils.html#gumbelsoftmaxtransformation.apply_constraints',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.soft_constraints': ( 'data.utils.html#gumbelsoftmaxtransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils.IdentityTransformation': ( 'data.utils.html#identitytransformation',
'relax/data_utils.py'),
'relax.data_utils.IdentityTransformation.__init__': ( 'data.utils.html#identitytransformation.__init__',
Expand Down Expand Up @@ -221,14 +229,6 @@
'relax.data_utils.OneHotEncoder.transform': ( 'data.utils.html#onehotencoder.transform',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation': ('data.utils.html#onehottransformation', 'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.__init__': ( 'data.utils.html#onehottransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.apply_constraints': ( 'data.utils.html#onehottransformation.apply_constraints',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.compute_reg_loss': ( 'data.utils.html#onehottransformation.compute_reg_loss',
'relax/data_utils.py'),
'relax.data_utils.OneHotTransformation.num_categories': ( 'data.utils.html#onehottransformation.num_categories',
'relax/data_utils.py'),
'relax.data_utils.OrdinalPreprocessor': ('data.utils.html#ordinalpreprocessor', 'relax/data_utils.py'),
'relax.data_utils.OrdinalPreprocessor.fit': ( 'data.utils.html#ordinalpreprocessor.fit',
'relax/data_utils.py'),
Expand All @@ -242,6 +242,10 @@
'relax/data_utils.py'),
'relax.data_utils.OrdinalTransformation.num_categories': ( 'data.utils.html#ordinaltransformation.num_categories',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation': ( 'data.utils.html#softmaxtransformation',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation.soft_constraints': ( 'data.utils.html#softmaxtransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils.Transformation': ('data.utils.html#transformation', 'relax/data_utils.py'),
'relax.data_utils.Transformation.__init__': ( 'data.utils.html#transformation.__init__',
'relax/data_utils.py'),
Expand All @@ -262,6 +266,20 @@
'relax/data_utils.py'),
'relax.data_utils.Transformation.transform': ( 'data.utils.html#transformation.transform',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation': ( 'data.utils.html#_onehottransformation',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.__init__': ( 'data.utils.html#_onehottransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.apply_constraints': ( 'data.utils.html#_onehottransformation.apply_constraints',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.compute_reg_loss': ( 'data.utils.html#_onehottransformation.compute_reg_loss',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.hard_constraints': ( 'data.utils.html#_onehottransformation.hard_constraints',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.num_categories': ( 'data.utils.html#_onehottransformation.num_categories',
'relax/data_utils.py'),
'relax.data_utils._OneHotTransformation.soft_constraints': ( 'data.utils.html#_onehottransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils._check_xs': ('data.utils.html#_check_xs', 'relax/data_utils.py'),
'relax.data_utils._unique': ('data.utils.html#_unique', 'relax/data_utils.py')},
'relax.docs': { 'relax.docs.CustomizedMarkdownRenderer': ('docs.html#customizedmarkdownrenderer', 'relax/docs.py'),
Expand Down Expand Up @@ -797,6 +815,7 @@
'relax.utils.auto_reshaping': ('utils.html#auto_reshaping', 'relax/utils.py'),
'relax.utils.get_config': ('utils.html#get_config', 'relax/utils.py'),
'relax.utils.grad_update': ('utils.html#grad_update', 'relax/utils.py'),
'relax.utils.gumbel_softmax': ('utils.html#gumbel_softmax', 'relax/utils.py'),
'relax.utils.load_json': ('utils.html#load_json', 'relax/utils.py'),
'relax.utils.load_pytree': ('utils.html#load_pytree', 'relax/utils.py'),
'relax.utils.save_pytree': ('utils.html#save_pytree', 'relax/utils.py'),
Expand Down
Loading

0 comments on commit aaa9346

Please sign in to comment.