Skip to content

Commit

Permalink
feat: add augmentations inside the fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Mar 23, 2022
1 parent 5b19091 commit 6d0485f
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ NO_COLOR=\\e[39m
OK_COLOR=\\e[32m
ERROR_COLOR=\\e[31m
WARN_COLOR=\\e[33m
PORT=8889
PORT=8887
.SILENT: ;
default: help; # default target

Expand Down
15 changes: 13 additions & 2 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@
"max_epochs = 100 if not os.getenv(\"CI\", False) else 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pytorch_tabnet.augmentations import ClassificationSMOTE\n",
"aug = ClassificationSMOTE(p=0.2)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -225,10 +235,11 @@
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
" \n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
]
},
Expand Down
8 changes: 6 additions & 2 deletions forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 5 if not os.getenv(\"CI\", False) else 2"
"max_epochs = 50 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand All @@ -248,12 +248,16 @@
},
"outputs": [],
"source": [
"from pytorch_tabnet.augmentations import ClassificationSMOTE\n",
"aug = ClassificationSMOTE(p=0.2)\n",
"\n",
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" max_epochs=max_epochs, patience=100,\n",
" batch_size=16384, virtual_batch_size=256\n",
" batch_size=16384, virtual_batch_size=256,\n",
" augmentations=aug\n",
") "
]
},
Expand Down
13 changes: 11 additions & 2 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def fit(
callbacks=None,
pin_memory=True,
from_unsupervised=None,
warm_start=False
warm_start=False,
augmentations=None,
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
Expand Down Expand Up @@ -183,6 +184,11 @@ def fit(
self.input_dim = X_train.shape[1]
self._stop_training = False
self.pin_memory = pin_memory and (self.device.type != "cpu")
self.augmentations = augmentations

if self.augmentations is not None:
# This ensure reproducibility
self.augmentations._set_seed()

eval_set = eval_set if eval_set else []

Expand Down Expand Up @@ -477,9 +483,12 @@ def _train_batch(self, X, y):
"""
batch_logs = {"batch_size": X.shape[0]}

X = X.to(self.device).float()
X = X.to(self.device).float() # Is this .float() needed ?
y = y.to(self.device).float()

if self.augmentations is not None:
X, y = self.augmentations(X, y)

for param in self.network.parameters():
param.grad = None

Expand Down
90 changes: 90 additions & 0 deletions pytorch_tabnet/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from pytorch_tabnet.utils import define_device
import numpy as np

# TODO : change this so that p would be the proportion of rows that are changed
# add a beta argument (beta distribution)
class RegressionSMOTE():
"""
Apply SMOTE
This will average a percentage p of the elements in the batch with other elements.
The target will be averaged as well (this might work with binary classification and certain loss),
following a beta distribution.
"""
def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0):
""
self.seed = seed
self._set_seed()
self.device = define_device(device_name)
self.alpha = alpha
self.beta = beta
self.p = p
if (p < 0.) or (p > 1.0):
raise ValueError("Value of p should be between 0. and 1.")

def _set_seed(self):
torch.manual_seed(self.seed)
np.random.seed(self.seed)
return

def __call__(self, X, y):
batch_size = X.shape[0]
random_values = torch.rand(batch_size, device=self.device)
idx_to_change = random_values < self.p

# ensure that first element to switch has probability > 0.5
np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5
random_betas = torch.from_numpy(np_betas).to(self.device).float()
index_permute = torch.randperm(batch_size, device=self.device)

X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \
(1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size())

y[idx_to_change] = random_betas[idx_to_change, None]*y[idx_to_change] + \
(1 - random_betas[idx_to_change, None])*y[index_permute][idx_to_change].view(y[idx_to_change].size())

return X, y

class ClassificationSMOTE():
"""
Apply SMOTE for classification tasks.
This will average a percentage p of the elements in the batch with other elements.
The target will stay unchanged and keep the value of the most important row in the mix.
"""
def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0):
""
self.seed = seed
self._set_seed()
self.device = define_device(device_name)
self.alpha = alpha
self.beta = beta
self.p = p
if (p < 0.) or (p > 1.0):
raise ValueError("Value of p should be between 0. and 1.")

def _set_seed(self):
torch.manual_seed(self.seed)
np.random.seed(self.seed)
return

def __call__(self, X, y):
batch_size = X.shape[0]
random_values = torch.rand(batch_size, device=self.device)
idx_to_change = random_values < self.p

# ensure that first element to switch has probability > 0.5
np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5
random_betas = torch.from_numpy(np_betas).to(self.device).float()
index_permute = torch.randperm(batch_size, device=self.device)

X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \
(1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size())


return X, y



Loading

0 comments on commit 6d0485f

Please sign in to comment.