-
Notifications
You must be signed in to change notification settings - Fork 493
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add augmentations inside the fit method
- Loading branch information
1 parent
5b19091
commit 6d0485f
Showing
6 changed files
with
310 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
from pytorch_tabnet.utils import define_device | ||
import numpy as np | ||
|
||
# TODO : change this so that p would be the proportion of rows that are changed | ||
# add a beta argument (beta distribution) | ||
class RegressionSMOTE(): | ||
""" | ||
Apply SMOTE | ||
This will average a percentage p of the elements in the batch with other elements. | ||
The target will be averaged as well (this might work with binary classification and certain loss), | ||
following a beta distribution. | ||
""" | ||
def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): | ||
"" | ||
self.seed = seed | ||
self._set_seed() | ||
self.device = define_device(device_name) | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.p = p | ||
if (p < 0.) or (p > 1.0): | ||
raise ValueError("Value of p should be between 0. and 1.") | ||
|
||
def _set_seed(self): | ||
torch.manual_seed(self.seed) | ||
np.random.seed(self.seed) | ||
return | ||
|
||
def __call__(self, X, y): | ||
batch_size = X.shape[0] | ||
random_values = torch.rand(batch_size, device=self.device) | ||
idx_to_change = random_values < self.p | ||
|
||
# ensure that first element to switch has probability > 0.5 | ||
np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5 | ||
random_betas = torch.from_numpy(np_betas).to(self.device).float() | ||
index_permute = torch.randperm(batch_size, device=self.device) | ||
|
||
X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \ | ||
(1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size()) | ||
|
||
y[idx_to_change] = random_betas[idx_to_change, None]*y[idx_to_change] + \ | ||
(1 - random_betas[idx_to_change, None])*y[index_permute][idx_to_change].view(y[idx_to_change].size()) | ||
|
||
return X, y | ||
|
||
class ClassificationSMOTE(): | ||
""" | ||
Apply SMOTE for classification tasks. | ||
This will average a percentage p of the elements in the batch with other elements. | ||
The target will stay unchanged and keep the value of the most important row in the mix. | ||
""" | ||
def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): | ||
"" | ||
self.seed = seed | ||
self._set_seed() | ||
self.device = define_device(device_name) | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.p = p | ||
if (p < 0.) or (p > 1.0): | ||
raise ValueError("Value of p should be between 0. and 1.") | ||
|
||
def _set_seed(self): | ||
torch.manual_seed(self.seed) | ||
np.random.seed(self.seed) | ||
return | ||
|
||
def __call__(self, X, y): | ||
batch_size = X.shape[0] | ||
random_values = torch.rand(batch_size, device=self.device) | ||
idx_to_change = random_values < self.p | ||
|
||
# ensure that first element to switch has probability > 0.5 | ||
np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5 | ||
random_betas = torch.from_numpy(np_betas).to(self.device).float() | ||
index_permute = torch.randperm(batch_size, device=self.device) | ||
|
||
X[idx_to_change] = random_betas[idx_to_change, None]*X[idx_to_change] + \ | ||
(1 - random_betas[idx_to_change, None])*X[index_permute][idx_to_change].view(X[idx_to_change].size()) | ||
|
||
|
||
return X, y | ||
|
||
|
||
|
Oops, something went wrong.