Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GumbelSoftmaxTransformation.name; add more tests for set_transformation #42

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 44 additions & 61 deletions nbs/01_data.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand Down Expand Up @@ -493,8 +501,8 @@
"source": [
"#| export\n",
"class _OneHotTransformation(Transformation):\n",
" def __init__(self):\n",
" super().__init__(\"ohe\", OneHotEncoder())\n",
" def __init__(self, name: str = None):\n",
" super().__init__(name, OneHotEncoder())\n",
"\n",
" @property\n",
" def num_categories(self) -> int:\n",
Expand Down Expand Up @@ -528,6 +536,9 @@
"source": [
"#| export\n",
"class SoftmaxTransformation(_OneHotTransformation):\n",
" def __init__(self): \n",
" super().__init__(\"ohe\")\n",
"\n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" x, rng_key, kwargs = operand\n",
" return jax.nn.softmax(x, axis=-1)\n",
Expand All @@ -536,7 +547,7 @@
" \"\"\"Apply Gumbel softmax tricks for categorical transformation.\"\"\"\n",
"\n",
" def __init__(self, tau: float = 1.):\n",
" super().__init__()\n",
" super().__init__(\"gumbel\")\n",
" self.tau = tau\n",
" \n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
Expand Down Expand Up @@ -1155,71 +1166,43 @@
"metadata": {},
"outputs": [],
"source": [
"# Test set_transformations\n",
"feats_list_2 = deepcopy(feats_list)\n",
"feats_list_2.set_transformations({\n",
" feat: 'ordinal' for feat in cat_feats\n",
"})\n",
"assert feats_list_2.transformed_data.shape == (32561, 8)\n",
"def test_set_transformations(transformation, correct_shape):\n",
" T = transformation\n",
" feats_list_2 = deepcopy(feats_list)\n",
" feats_list_2.set_transformations({\n",
" feat: T for feat in cat_feats\n",
" })\n",
" assert feats_list_2.transformed_data.shape == correct_shape\n",
" name = T.name if isinstance(T, Transformation) else T\n",
"\n",
"for feat in feats_list_2:\n",
" if feat.name in cat_feats: \n",
" assert feat.transformation.name == 'ordinal'\n",
" assert feat.is_categorical\n",
" else:\n",
" assert feat.transformation.name == 'minmax' \n",
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"del feats_list_2"
" for feat in feats_list_2:\n",
" if feat.name in cat_feats: \n",
" assert feat.transformation.name == name\n",
" assert feat.is_categorical\n",
" else:\n",
" assert feat.transformation.name == 'minmax' \n",
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"\n",
" x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, correct_shape[-1]))\n",
" _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=False)\n",
" _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([[0.22846544, 0.1524936 , 0. , ..., 1. , 1. ,\n",
" 0. ],\n",
" [0.17293715, 0.5796174 , 0. , ..., 0. , 0. ,\n",
" 1. ],\n",
" [0.17434704, 0.8137592 , 0. , ..., 1. , 1. ,\n",
" 0. ],\n",
" ...,\n",
" [0.68356454, 0.65396845, 0. , ..., 0. , 1. ,\n",
" 0. ],\n",
" [0.73027587, 0.4722154 , 1. , ..., 1. , 0. ,\n",
" 1. ],\n",
" [0.8495003 , 0.04826355, 1. , ..., 1. , 0. ,\n",
" 1. ]], dtype=float32)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Test set_transformations\n",
"feats_list_2 = FeaturesList.from_dict(feats_list.to_dict())\n",
"feats_list_2.set_transformations({\n",
" feat: OneHotTransformation() for feat in cat_feats\n",
"})\n",
"assert feats_list_2.transformed_data.shape == (32561, 29)\n",
"\n",
"for feat in feats_list_2:\n",
" if feat.name in cat_feats: \n",
" assert feat.transformation.name == 'ohe'\n",
" assert feat.is_categorical\n",
" else:\n",
" assert feat.transformation.name == 'minmax' \n",
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"\n",
"x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n",
"feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)"
"test_set_transformations('ordinal', (32561, 8))\n",
"test_set_transformations('ohe', (32561, 29))\n",
"test_set_transformations('gumbel', (32561, 29))\n",
"# TODO: [bug] raise error when set_transformations is called with \n",
"# SoftmaxTransformation() or GumbelSoftmaxTransformation(),\n",
"# instead of \"ohe\" or \"gumbel\".\n",
"# test_set_transformations(SoftmaxTransformation(), (32561, 29))\n",
"# test_set_transformations(GumbelSoftmaxTransformation(), (32561, 29))"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation': ( 'data.utils.html#softmaxtransformation',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation.__init__': ( 'data.utils.html#softmaxtransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation.soft_constraints': ( 'data.utils.html#softmaxtransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils.Transformation': ('data.utils.html#transformation', 'relax/data_utils.py'),
Expand Down
9 changes: 6 additions & 3 deletions relax/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def apply_constraints(self, xs, cfs, **kwargs):

# %% ../nbs/01_data.utils.ipynb 23
class _OneHotTransformation(Transformation):
def __init__(self):
super().__init__("ohe", OneHotEncoder())
def __init__(self, name: str = None):
super().__init__(name, OneHotEncoder())

@property
def num_categories(self) -> int:
Expand All @@ -243,6 +243,9 @@ def compute_reg_loss(self, xs, cfs, hard: bool = False):

# %% ../nbs/01_data.utils.ipynb 24
class SoftmaxTransformation(_OneHotTransformation):
def __init__(self):
super().__init__("ohe")

def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):
x, rng_key, kwargs = operand
return jax.nn.softmax(x, axis=-1)
Expand All @@ -251,7 +254,7 @@ class GumbelSoftmaxTransformation(_OneHotTransformation):
"""Apply Gumbel softmax tricks for categorical transformation."""

def __init__(self, tau: float = 1.):
super().__init__()
super().__init__("gumbel")
self.tau = tau

def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):
Expand Down
Loading