diff --git a/nbs/00_utils.ipynb b/nbs/00_utils.ipynb index 1b0614e..c3ea2ac 100644 --- a/nbs/00_utils.ipynb +++ b/nbs/00_utils.ipynb @@ -493,6 +493,33 @@ " return upt_params, opt_state" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Functional Utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def gumbel_softmax(\n", + " key: jrand.PRNGKey, # Random key\n", + " logits: Array, # Logits for each class. Shape (batch_size, num_classes)\n", + " tau: float, # Temperature for the Gumbel softmax\n", + " axis: int | tuple[int, ...] = -1, # The axis or axes along which the gumbel softmax should be computed\n", + "):\n", + " \"\"\"The Gumbel softmax function.\"\"\"\n", + "\n", + " gumbel_noise = jrand.gumbel(key, shape=logits.shape)\n", + " y = logits + gumbel_noise\n", + " return jax.nn.softmax(y / tau, axis=axis)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/nbs/01_data.utils.ipynb b/nbs/01_data.utils.ipynb index 31264d2..20e330e 100644 --- a/nbs/01_data.utils.ipynb +++ b/nbs/01_data.utils.ipynb @@ -20,15 +20,22 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], + "outputs": [], + "source": [ + "#| hide\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "from ipynb_path import *\n", + "import warnings\n", + "\n", + "warnings.simplefilter(action='ignore', category=FutureWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "#| export\n", "from __future__ import annotations\n", @@ -40,8 +47,10 @@ "import einops\n", "import os, sys, json, pickle\n", "import shutil\n", - "from relax.utils import *\n", - "import chex" + "from relax.utils import gumbel_softmax, load_pytree, save_pytree, get_config\n", + "import chex\n", + "import functools as ft\n", + "import warnings" ] }, { @@ -467,8 +476,8 @@ "transformed_xs = scaler.fit_transform(xs)\n", "assert scaler.is_categorical is False\n", "\n", - "cfs = np.random.randn(100, 1)\n", - "cf_constrained = scaler.apply_constraints(xs, cfs)\n", + "x = np.random.randn(100, 1)\n", + "cf_constrained = scaler.apply_constraints(xs, x)\n", "assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)\n", "\n", "# Test from_dict and to_dict\n", @@ -483,20 +492,27 @@ "outputs": [], "source": [ "#| export\n", - "class OneHotTransformation(Transformation):\n", + "class _OneHotTransformation(Transformation):\n", " def __init__(self):\n", " super().__init__(\"ohe\", OneHotEncoder())\n", "\n", " @property\n", " def num_categories(self) -> int:\n", " return len(self.transformer.categories_)\n", + " \n", + " def hard_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): \n", + " x, rng_key, kwargs = operand\n", + " return jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories)\n", + " \n", + " def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n", + " raise NotImplementedError\n", "\n", - " def apply_constraints(self, xs, cfs, hard: bool = False, **kwargs):\n", + " def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):\n", " return jax.lax.cond(\n", " hard,\n", - " true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories),\n", - " false_fun=lambda x: jax.nn.softmax(x, axis=-1),\n", - " operand=cfs,\n", + " true_fun=self.hard_constraints,\n", + " false_fun=self.soft_constraints,\n", + " operand=(cfs, rng_key, kwargs),\n", " )\n", " \n", " def compute_reg_loss(self, xs, cfs, hard: bool = False):\n", @@ -510,31 +526,74 @@ "metadata": {}, "outputs": [], "source": [ - "xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))\n", - "ohe_t = OneHotTransformation().fit(xs)\n", - "transformed_xs = ohe_t.transform(xs)\n", + "#| export\n", + "class SoftmaxTransformation(_OneHotTransformation):\n", + " def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n", + " x, rng_key, kwargs = operand\n", + " return jax.nn.softmax(x, axis=-1)\n", + " \n", + "class GumbelSoftmaxTransformation(_OneHotTransformation):\n", + " \"\"\"Apply Gumbel softmax tricks for categorical transformation.\"\"\"\n", "\n", - "cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 3))\n", - "# Test hard=True which applies softmax function.\n", - "soft = ohe_t.apply_constraints(transformed_xs, cfs, hard=False)\n", - "assert jnp.allclose(soft.sum(axis=-1), 1)\n", - "assert jnp.all(soft >= 0)\n", - "assert jnp.all(soft <= 1)\n", - "assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))\n", + " def __init__(self, tau: float = 1.):\n", + " super().__init__()\n", + " self.tau = tau\n", + " \n", + " def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n", + " x, rng_key, _ = operand\n", + " if rng_key is None: # No randomness\n", + " rng_key = jax.random.PRNGKey(get_config().global_seed)\n", + " return gumbel_softmax(rng_key, x, self.tau)\n", + " \n", + " def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):\n", + " \"\"\"Apply constraints to the counterfactuals. If `rng_key` is None, no randomness is used.\"\"\"\n", + " return super().apply_constraints(xs, cfs, hard, rng_key, **kwargs)\n", + " \n", + "def OneHotTransformation():\n", + " warnings.warn(\"OneHotTransformation is deprecated since v0.2.5. \"\n", + " \"Use `SoftmaxTransformation` (same functionality) \"\n", + " \"or GumbelSoftmaxTransformation instead.\", DeprecationWarning)\n", + " return SoftmaxTransformation()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_ohe_t(ohe_cls):\n", + " xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))\n", + " ohe_t = ohe_cls().fit(xs)\n", + " transformed_xs = ohe_t.transform(xs)\n", + " rng_key = jax.random.PRNGKey(get_config().global_seed)\n", "\n", - "# Test hard=True which enforce one-hot constraint.\n", - "hard = ohe_t.apply_constraints(transformed_xs, cfs, hard=True)\n", - "assert np.all([1 in x for x in hard])\n", - "assert np.all([0 in x for x in hard])\n", - "assert jnp.allclose(hard.sum(axis=-1), 1)\n", - "assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))\n", + " x = jax.random.uniform(rng_key, shape=(100, 3))\n", + " # Test hard=True which applies softmax function.\n", + " soft = ohe_t.apply_constraints(transformed_xs, x, hard=False, rng_key=rng_key)\n", + " assert jnp.allclose(soft.sum(axis=-1), 1)\n", + " assert jnp.all(soft >= 0)\n", + " assert jnp.all(soft <= 1)\n", + " assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))\n", + " assert jnp.allclose(soft, ohe_t.apply_constraints(transformed_xs, x, hard=False))\n", "\n", - "# Test compute_reg_loss\n", - "assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0\n", + " # Test hard=True which enforce one-hot constraint.\n", + " hard = ohe_t.apply_constraints(transformed_xs, x, hard=True, rng_key=rng_key)\n", + " assert np.all([1 in x for x in hard])\n", + " assert np.all([0 in x for x in hard])\n", + " assert jnp.allclose(hard.sum(axis=-1), 1)\n", + " assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))\n", "\n", - "# Test from_dict and to_dict\n", - "ohe_t_1 = OneHotTransformation().from_dict(ohe_t.to_dict())\n", - "assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))" + " # Test compute_reg_loss\n", + " assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0\n", + "\n", + " # Test from_dict and to_dict\n", + " ohe_t_1 = ohe_cls().from_dict(ohe_t.to_dict())\n", + " assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))\n", + "\n", + "\n", + "test_ohe_t(SoftmaxTransformation)\n", + "test_ohe_t(GumbelSoftmaxTransformation)" ] }, { @@ -616,7 +675,9 @@ "source": [ "#| export\n", "PREPROCESSING_TRANSFORMATIONS = {\n", - " 'ohe': OneHotTransformation,\n", + " 'ohe': SoftmaxTransformation,\n", + " 'softmax': SoftmaxTransformation,\n", + " 'gumbel': GumbelSoftmaxTransformation,\n", " 'minmax': MinMaxTransformation,\n", " 'ordinal': OrdinalTransformation,\n", " 'identity': IdentityTransformation,\n", @@ -1157,8 +1218,8 @@ " assert feat.is_categorical is False\n", " assert feat.is_immutable is False\n", "\n", - "cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n", - "feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], cfs, hard=True)" + "x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n", + "feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)" ] }, { @@ -1189,17 +1250,17 @@ "outputs": [], "source": [ "# Test apply_constraints and compute_reg_loss\n", - "cfs = np.random.randn(10, 29)\n", - "constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False)\n", + "x = np.random.randn(10, 29)\n", + "constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=False)\n", "assert constraint_cfs.shape == (10, 29)\n", "assert np.allclose(\n", " constraint_cfs[:, 2:].sum(axis=-1),\n", " np.ones((10,)) * 6\n", ")\n", "assert constraint_cfs[: :2].min() >= 0 and constraint_cfs[: :2].max() <= 1\n", - "assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True).shape == (10, 29)\n", + "assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], x, hard=True).shape == (10, 29)\n", "\n", - "reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, cfs)\n", + "reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, x)\n", "assert jnp.ndim(reg_loss) == 0\n", "assert np.all(reg_loss > 0)\n", "assert np.allclose(feats_list.compute_reg_loss(xs, constraint_cfs), 0)" diff --git a/relax/_modidx.py b/relax/_modidx.py index 7e5eeae..703dad4 100644 --- a/relax/_modidx.py +++ b/relax/_modidx.py @@ -182,6 +182,14 @@ 'relax/data_utils.py'), 'relax.data_utils.FeaturesList.with_transformed_data': ( 'data.utils.html#featureslist.with_transformed_data', 'relax/data_utils.py'), + 'relax.data_utils.GumbelSoftmaxTransformation': ( 'data.utils.html#gumbelsoftmaxtransformation', + 'relax/data_utils.py'), + 'relax.data_utils.GumbelSoftmaxTransformation.__init__': ( 'data.utils.html#gumbelsoftmaxtransformation.__init__', + 'relax/data_utils.py'), + 'relax.data_utils.GumbelSoftmaxTransformation.apply_constraints': ( 'data.utils.html#gumbelsoftmaxtransformation.apply_constraints', + 'relax/data_utils.py'), + 'relax.data_utils.GumbelSoftmaxTransformation.soft_constraints': ( 'data.utils.html#gumbelsoftmaxtransformation.soft_constraints', + 'relax/data_utils.py'), 'relax.data_utils.IdentityTransformation': ( 'data.utils.html#identitytransformation', 'relax/data_utils.py'), 'relax.data_utils.IdentityTransformation.__init__': ( 'data.utils.html#identitytransformation.__init__', @@ -221,14 +229,6 @@ 'relax.data_utils.OneHotEncoder.transform': ( 'data.utils.html#onehotencoder.transform', 'relax/data_utils.py'), 'relax.data_utils.OneHotTransformation': ('data.utils.html#onehottransformation', 'relax/data_utils.py'), - 'relax.data_utils.OneHotTransformation.__init__': ( 'data.utils.html#onehottransformation.__init__', - 'relax/data_utils.py'), - 'relax.data_utils.OneHotTransformation.apply_constraints': ( 'data.utils.html#onehottransformation.apply_constraints', - 'relax/data_utils.py'), - 'relax.data_utils.OneHotTransformation.compute_reg_loss': ( 'data.utils.html#onehottransformation.compute_reg_loss', - 'relax/data_utils.py'), - 'relax.data_utils.OneHotTransformation.num_categories': ( 'data.utils.html#onehottransformation.num_categories', - 'relax/data_utils.py'), 'relax.data_utils.OrdinalPreprocessor': ('data.utils.html#ordinalpreprocessor', 'relax/data_utils.py'), 'relax.data_utils.OrdinalPreprocessor.fit': ( 'data.utils.html#ordinalpreprocessor.fit', 'relax/data_utils.py'), @@ -242,6 +242,10 @@ 'relax/data_utils.py'), 'relax.data_utils.OrdinalTransformation.num_categories': ( 'data.utils.html#ordinaltransformation.num_categories', 'relax/data_utils.py'), + 'relax.data_utils.SoftmaxTransformation': ( 'data.utils.html#softmaxtransformation', + 'relax/data_utils.py'), + 'relax.data_utils.SoftmaxTransformation.soft_constraints': ( 'data.utils.html#softmaxtransformation.soft_constraints', + 'relax/data_utils.py'), 'relax.data_utils.Transformation': ('data.utils.html#transformation', 'relax/data_utils.py'), 'relax.data_utils.Transformation.__init__': ( 'data.utils.html#transformation.__init__', 'relax/data_utils.py'), @@ -262,6 +266,20 @@ 'relax/data_utils.py'), 'relax.data_utils.Transformation.transform': ( 'data.utils.html#transformation.transform', 'relax/data_utils.py'), + 'relax.data_utils._OneHotTransformation': ( 'data.utils.html#_onehottransformation', + 'relax/data_utils.py'), + 'relax.data_utils._OneHotTransformation.__init__': ( 'data.utils.html#_onehottransformation.__init__', + 'relax/data_utils.py'), + 'relax.data_utils._OneHotTransformation.apply_constraints': ( 'data.utils.html#_onehottransformation.apply_constraints', + 'relax/data_utils.py'), + 'relax.data_utils._OneHotTransformation.compute_reg_loss': ( 'data.utils.html#_onehottransformation.compute_reg_loss', + 'relax/data_utils.py'), + 'relax.data_utils._OneHotTransformation.hard_constraints': ( 'data.utils.html#_onehottransformation.hard_constraints', + 'relax/data_utils.py'), + 'relax.data_utils._OneHotTransformation.num_categories': ( 'data.utils.html#_onehottransformation.num_categories', + 'relax/data_utils.py'), + 'relax.data_utils._OneHotTransformation.soft_constraints': ( 'data.utils.html#_onehottransformation.soft_constraints', + 'relax/data_utils.py'), 'relax.data_utils._check_xs': ('data.utils.html#_check_xs', 'relax/data_utils.py'), 'relax.data_utils._unique': ('data.utils.html#_unique', 'relax/data_utils.py')}, 'relax.docs': { 'relax.docs.CustomizedMarkdownRenderer': ('docs.html#customizedmarkdownrenderer', 'relax/docs.py'), @@ -797,6 +815,7 @@ 'relax.utils.auto_reshaping': ('utils.html#auto_reshaping', 'relax/utils.py'), 'relax.utils.get_config': ('utils.html#get_config', 'relax/utils.py'), 'relax.utils.grad_update': ('utils.html#grad_update', 'relax/utils.py'), + 'relax.utils.gumbel_softmax': ('utils.html#gumbel_softmax', 'relax/utils.py'), 'relax.utils.load_json': ('utils.html#load_json', 'relax/utils.py'), 'relax.utils.load_pytree': ('utils.html#load_pytree', 'relax/utils.py'), 'relax.utils.save_pytree': ('utils.html#save_pytree', 'relax/utils.py'), diff --git a/relax/data_utils.py b/relax/data_utils.py index 3dddb7d..260d885 100644 --- a/relax/data_utils.py +++ b/relax/data_utils.py @@ -1,6 +1,6 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_data.utils.ipynb. -# %% ../nbs/01_data.utils.ipynb 2 +# %% ../nbs/01_data.utils.ipynb 3 from __future__ import annotations from fastcore.test import * import pandas as pd @@ -10,15 +10,18 @@ import einops import os, sys, json, pickle import shutil -from .utils import * +from .utils import gumbel_softmax, load_pytree, save_pytree, get_config import chex +import functools as ft +import warnings # %% auto 0 __all__ = ['PREPROCESSING_TRANSFORMATIONS', 'DataPreprocessor', 'MinMaxScaler', 'EncoderPreprocessor', 'OrdinalPreprocessor', - 'OneHotEncoder', 'Transformation', 'MinMaxTransformation', 'OneHotTransformation', 'OrdinalTransformation', - 'IdentityTransformation', 'Feature', 'FeaturesList'] + 'OneHotEncoder', 'Transformation', 'MinMaxTransformation', 'SoftmaxTransformation', + 'GumbelSoftmaxTransformation', 'OneHotTransformation', 'OrdinalTransformation', 'IdentityTransformation', + 'Feature', 'FeaturesList'] -# %% ../nbs/01_data.utils.ipynb 5 +# %% ../nbs/01_data.utils.ipynb 6 def _check_xs(xs: np.ndarray, name: str): if xs.ndim > 2 or (xs.ndim == 2 and xs.shape[1] != 1): raise ValueError(f"`{name}` only supports array with a single feature, but got shape={xs.shape}.") @@ -60,7 +63,7 @@ def from_dict(self, params: dict): __ALL__ = ["fit", "transform", "fit_transform", "inverse_transform", "to_dict", "from_dict"] -# %% ../nbs/01_data.utils.ipynb 6 +# %% ../nbs/01_data.utils.ipynb 7 class MinMaxScaler(DataPreprocessor): def __init__(self): super().__init__(name="minmax") @@ -85,7 +88,7 @@ def from_dict(self, params: dict): def to_dict(self) -> dict: return {"min_": self.min_, "max_": self.max_} -# %% ../nbs/01_data.utils.ipynb 12 +# %% ../nbs/01_data.utils.ipynb 13 def _unique(xs): if xs.dtype == object: # Note: np.unique does not work with object dtype @@ -95,7 +98,7 @@ def _unique(xs): return np.unique(xs.astype(str)) return np.unique(xs) -# %% ../nbs/01_data.utils.ipynb 13 +# %% ../nbs/01_data.utils.ipynb 14 class EncoderPreprocessor(DataPreprocessor): """Encode categorical features as an integer array.""" def _fit(self, xs, y=None): @@ -121,7 +124,7 @@ def from_dict(self, params: dict): def to_dict(self) -> dict: return {"categories_": self.categories_} -# %% ../nbs/01_data.utils.ipynb 14 +# %% ../nbs/01_data.utils.ipynb 15 class OrdinalPreprocessor(EncoderPreprocessor): """Ordinal encoder for a single feature.""" @@ -138,7 +141,7 @@ def transform(self, xs): def inverse_transform(self, xs): return self._inverse_transform(xs) -# %% ../nbs/01_data.utils.ipynb 16 +# %% ../nbs/01_data.utils.ipynb 17 class OneHotEncoder(EncoderPreprocessor): """One-hot encoder for a single categorical feature.""" @@ -158,7 +161,7 @@ def inverse_transform(self, xs): xs_int = np.argmax(xs, axis=-1) return self._inverse_transform(xs_int).reshape(-1, 1) -# %% ../nbs/01_data.utils.ipynb 19 +# %% ../nbs/01_data.utils.ipynb 20 class Transformation: def __init__(self, name, transformer): self.name = name @@ -202,7 +205,7 @@ def from_dict(self, params: dict): def to_dict(self) -> dict: return {"name": self.name, "transformer": self.transformer.to_dict()} -# %% ../nbs/01_data.utils.ipynb 20 +# %% ../nbs/01_data.utils.ipynb 21 class MinMaxTransformation(Transformation): def __init__(self): super().__init__("minmax", MinMaxScaler()) @@ -210,21 +213,28 @@ def __init__(self): def apply_constraints(self, xs, cfs, **kwargs): return jnp.clip(cfs, 0., 1.) -# %% ../nbs/01_data.utils.ipynb 22 -class OneHotTransformation(Transformation): +# %% ../nbs/01_data.utils.ipynb 23 +class _OneHotTransformation(Transformation): def __init__(self): super().__init__("ohe", OneHotEncoder()) @property def num_categories(self) -> int: return len(self.transformer.categories_) + + def hard_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): + x, rng_key, kwargs = operand + return jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories) + + def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): + raise NotImplementedError - def apply_constraints(self, xs, cfs, hard: bool = False, **kwargs): + def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs): return jax.lax.cond( hard, - true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories), - false_fun=lambda x: jax.nn.softmax(x, axis=-1), - operand=cfs, + true_fun=self.hard_constraints, + false_fun=self.soft_constraints, + operand=(cfs, rng_key, kwargs), ) def compute_reg_loss(self, xs, cfs, hard: bool = False): @@ -232,6 +242,35 @@ def compute_reg_loss(self, xs, cfs, hard: bool = False): return reg_loss_per_xs.mean() # %% ../nbs/01_data.utils.ipynb 24 +class SoftmaxTransformation(_OneHotTransformation): + def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): + x, rng_key, kwargs = operand + return jax.nn.softmax(x, axis=-1) + +class GumbelSoftmaxTransformation(_OneHotTransformation): + """Apply Gumbel softmax tricks for categorical transformation.""" + + def __init__(self, tau: float = 1.): + super().__init__() + self.tau = tau + + def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): + x, rng_key, _ = operand + if rng_key is None: # No randomness + rng_key = jax.random.PRNGKey(get_config().global_seed) + return gumbel_softmax(rng_key, x, self.tau) + + def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs): + """Apply constraints to the counterfactuals. If `rng_key` is None, no randomness is used.""" + return super().apply_constraints(xs, cfs, hard, rng_key, **kwargs) + +def OneHotTransformation(): + warnings.warn("OneHotTransformation is deprecated since v0.2.5. " + "Use `SoftmaxTransformation` (same functionality) " + "or GumbelSoftmaxTransformation instead.", DeprecationWarning) + return SoftmaxTransformation() + +# %% ../nbs/01_data.utils.ipynb 26 class OrdinalTransformation(Transformation): def __init__(self): super().__init__("ordinal", OrdinalPreprocessor()) @@ -263,15 +302,17 @@ def from_dict(self, params: dict): self.name = params["name"] return self -# %% ../nbs/01_data.utils.ipynb 27 +# %% ../nbs/01_data.utils.ipynb 29 PREPROCESSING_TRANSFORMATIONS = { - 'ohe': OneHotTransformation, + 'ohe': SoftmaxTransformation, + 'softmax': SoftmaxTransformation, + 'gumbel': GumbelSoftmaxTransformation, 'minmax': MinMaxTransformation, 'ordinal': OrdinalTransformation, 'identity': IdentityTransformation, } -# %% ../nbs/01_data.utils.ipynb 28 +# %% ../nbs/01_data.utils.ipynb 30 class Feature: """THe feature class which represents a column in the dataset.""" @@ -413,7 +454,7 @@ def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs) def compute_reg_loss(self, xs, cfs, hard: bool = False): return self.transformation.compute_reg_loss(xs, cfs, hard) -# %% ../nbs/01_data.utils.ipynb 31 +# %% ../nbs/01_data.utils.ipynb 33 class FeaturesList: def __init__( self, diff --git a/relax/utils.py b/relax/utils.py index ce9733c..74a91af 100644 --- a/relax/utils.py +++ b/relax/utils.py @@ -11,8 +11,8 @@ from jax.core import InconclusiveDimensionOperation # %% auto 0 -__all__ = ['validate_configs', 'save_pytree', 'load_pytree', 'auto_reshaping', 'grad_update', 'load_json', 'get_config', - 'set_config'] +__all__ = ['validate_configs', 'save_pytree', 'load_pytree', 'auto_reshaping', 'grad_update', 'gumbel_softmax', 'load_json', + 'get_config', 'set_config'] # %% ../nbs/00_utils.ipynb 5 def validate_configs( @@ -119,12 +119,25 @@ def grad_update( return upt_params, opt_state # %% ../nbs/00_utils.ipynb 32 +def gumbel_softmax( + key: jrand.PRNGKey, # Random key + logits: Array, # Logits for each class. Shape (batch_size, num_classes) + tau: float, # Temperature for the Gumbel softmax + axis: int | tuple[int, ...] = -1, # The axis or axes along which the gumbel softmax should be computed +): + """The Gumbel softmax function.""" + + gumbel_noise = jrand.gumbel(key, shape=logits.shape) + y = logits + gumbel_noise + return jax.nn.softmax(y / tau, axis=axis) + +# %% ../nbs/00_utils.ipynb 34 def load_json(f_name: str) -> Dict[str, Any]: # file name with open(f_name) as f: return json.load(f) -# %% ../nbs/00_utils.ipynb 34 +# %% ../nbs/00_utils.ipynb 36 @dataclass class Config: rng_reserve_size: int @@ -136,11 +149,11 @@ def default(cls) -> Config: main_config = Config.default() -# %% ../nbs/00_utils.ipynb 35 +# %% ../nbs/00_utils.ipynb 37 def get_config() -> Config: return main_config -# %% ../nbs/00_utils.ipynb 36 +# %% ../nbs/00_utils.ipynb 38 def set_config( *, rng_reserve_size: int = None, # The number of random number generators to reserve.