From f89c0593b7268fa904190083eddd15cc44b29045 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Wed, 21 Feb 2024 22:47:43 -0500 Subject: [PATCH] Fix `GumbelSoftmaxTransformation.name`; add more tests for set_transformation --- nbs/01_data.utils.ipynb | 105 +++++++++++++++++----------------------- relax/_modidx.py | 2 + relax/data_utils.py | 9 ++-- 3 files changed, 52 insertions(+), 64 deletions(-) diff --git a/nbs/01_data.utils.ipynb b/nbs/01_data.utils.ipynb index 20e330e..7cb56e3 100644 --- a/nbs/01_data.utils.ipynb +++ b/nbs/01_data.utils.ipynb @@ -35,7 +35,15 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + } + ], "source": [ "#| export\n", "from __future__ import annotations\n", @@ -493,8 +501,8 @@ "source": [ "#| export\n", "class _OneHotTransformation(Transformation):\n", - " def __init__(self):\n", - " super().__init__(\"ohe\", OneHotEncoder())\n", + " def __init__(self, name: str = None):\n", + " super().__init__(name, OneHotEncoder())\n", "\n", " @property\n", " def num_categories(self) -> int:\n", @@ -528,6 +536,9 @@ "source": [ "#| export\n", "class SoftmaxTransformation(_OneHotTransformation):\n", + " def __init__(self): \n", + " super().__init__(\"ohe\")\n", + "\n", " def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n", " x, rng_key, kwargs = operand\n", " return jax.nn.softmax(x, axis=-1)\n", @@ -536,7 +547,7 @@ " \"\"\"Apply Gumbel softmax tricks for categorical transformation.\"\"\"\n", "\n", " def __init__(self, tau: float = 1.):\n", - " super().__init__()\n", + " super().__init__(\"gumbel\")\n", " self.tau = tau\n", " \n", " def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n", @@ -1155,71 +1166,43 @@ "metadata": {}, "outputs": [], "source": [ - "# Test set_transformations\n", - "feats_list_2 = deepcopy(feats_list)\n", - "feats_list_2.set_transformations({\n", - " feat: 'ordinal' for feat in cat_feats\n", - "})\n", - "assert feats_list_2.transformed_data.shape == (32561, 8)\n", + "def test_set_transformations(transformation, correct_shape):\n", + " T = transformation\n", + " feats_list_2 = deepcopy(feats_list)\n", + " feats_list_2.set_transformations({\n", + " feat: T for feat in cat_feats\n", + " })\n", + " assert feats_list_2.transformed_data.shape == correct_shape\n", + " name = T.name if isinstance(T, Transformation) else T\n", "\n", - "for feat in feats_list_2:\n", - " if feat.name in cat_feats: \n", - " assert feat.transformation.name == 'ordinal'\n", - " assert feat.is_categorical\n", - " else:\n", - " assert feat.transformation.name == 'minmax' \n", - " assert feat.is_categorical is False\n", - " assert feat.is_immutable is False\n", - "del feats_list_2" + " for feat in feats_list_2:\n", + " if feat.name in cat_feats: \n", + " assert feat.transformation.name == name\n", + " assert feat.is_categorical\n", + " else:\n", + " assert feat.transformation.name == 'minmax' \n", + " assert feat.is_categorical is False\n", + " assert feat.is_immutable is False\n", + "\n", + " x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, correct_shape[-1]))\n", + " _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=False)\n", + " _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[0.22846544, 0.1524936 , 0. , ..., 1. , 1. ,\n", - " 0. ],\n", - " [0.17293715, 0.5796174 , 0. , ..., 0. , 0. ,\n", - " 1. ],\n", - " [0.17434704, 0.8137592 , 0. , ..., 1. , 1. ,\n", - " 0. ],\n", - " ...,\n", - " [0.68356454, 0.65396845, 0. , ..., 0. , 1. ,\n", - " 0. ],\n", - " [0.73027587, 0.4722154 , 1. , ..., 1. , 0. ,\n", - " 1. ],\n", - " [0.8495003 , 0.04826355, 1. , ..., 1. , 0. ,\n", - " 1. ]], dtype=float32)" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "# Test set_transformations\n", - "feats_list_2 = FeaturesList.from_dict(feats_list.to_dict())\n", - "feats_list_2.set_transformations({\n", - " feat: OneHotTransformation() for feat in cat_feats\n", - "})\n", - "assert feats_list_2.transformed_data.shape == (32561, 29)\n", - "\n", - "for feat in feats_list_2:\n", - " if feat.name in cat_feats: \n", - " assert feat.transformation.name == 'ohe'\n", - " assert feat.is_categorical\n", - " else:\n", - " assert feat.transformation.name == 'minmax' \n", - " assert feat.is_categorical is False\n", - " assert feat.is_immutable is False\n", - "\n", - "x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n", - "feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)" + "test_set_transformations('ordinal', (32561, 8))\n", + "test_set_transformations('ohe', (32561, 29))\n", + "test_set_transformations('gumbel', (32561, 29))\n", + "# TODO: [bug] raise error when set_transformations is called with \n", + "# SoftmaxTransformation() or GumbelSoftmaxTransformation(),\n", + "# instead of \"ohe\" or \"gumbel\".\n", + "# test_set_transformations(SoftmaxTransformation(), (32561, 29))\n", + "# test_set_transformations(GumbelSoftmaxTransformation(), (32561, 29))" ] }, { diff --git a/relax/_modidx.py b/relax/_modidx.py index 703dad4..d45668f 100644 --- a/relax/_modidx.py +++ b/relax/_modidx.py @@ -244,6 +244,8 @@ 'relax/data_utils.py'), 'relax.data_utils.SoftmaxTransformation': ( 'data.utils.html#softmaxtransformation', 'relax/data_utils.py'), + 'relax.data_utils.SoftmaxTransformation.__init__': ( 'data.utils.html#softmaxtransformation.__init__', + 'relax/data_utils.py'), 'relax.data_utils.SoftmaxTransformation.soft_constraints': ( 'data.utils.html#softmaxtransformation.soft_constraints', 'relax/data_utils.py'), 'relax.data_utils.Transformation': ('data.utils.html#transformation', 'relax/data_utils.py'), diff --git a/relax/data_utils.py b/relax/data_utils.py index 260d885..aa87e95 100644 --- a/relax/data_utils.py +++ b/relax/data_utils.py @@ -215,8 +215,8 @@ def apply_constraints(self, xs, cfs, **kwargs): # %% ../nbs/01_data.utils.ipynb 23 class _OneHotTransformation(Transformation): - def __init__(self): - super().__init__("ohe", OneHotEncoder()) + def __init__(self, name: str = None): + super().__init__(name, OneHotEncoder()) @property def num_categories(self) -> int: @@ -243,6 +243,9 @@ def compute_reg_loss(self, xs, cfs, hard: bool = False): # %% ../nbs/01_data.utils.ipynb 24 class SoftmaxTransformation(_OneHotTransformation): + def __init__(self): + super().__init__("ohe") + def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]): x, rng_key, kwargs = operand return jax.nn.softmax(x, axis=-1) @@ -251,7 +254,7 @@ class GumbelSoftmaxTransformation(_OneHotTransformation): """Apply Gumbel softmax tricks for categorical transformation.""" def __init__(self, tau: float = 1.): - super().__init__() + super().__init__("gumbel") self.tau = tau def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):