Skip to content

Commit

Permalink
feat/492-sparse-compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
CesarLeblanc authored and Optimox committed Jul 19, 2023
1 parent 9ba8991 commit 5c000c2
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 122 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ loaded_clf.load_model(saved_filepath)

## Fit parameters

- `X_train` : np.array
- `X_train` : np.array or scipy.sparse.csr_matrix

Training features

Expand Down
114 changes: 70 additions & 44 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"import numpy as np\n",
"np.random.seed(0)\n",
"\n",
"import scipy\n",
"\n",
"import os\n",
"import wget\n",
Expand Down Expand Up @@ -265,57 +266,82 @@
"aug = ClassificationSMOTE(p=0.2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This illustrates the behaviour of the model's fit method using Compressed Sparse Row matrices\n",
"sparse_X_train = scipy.sparse.csr_matrix(X_train) # Create a CSR matrix from X_train\n",
"sparse_X_valid = scipy.sparse.csr_matrix(X_valid) # Create a CSR matrix from X_valid\n",
"\n",
"# Fitting the model\n",
"clf.fit(\n",
" X_train=sparse_X_train, y_train=y_train,\n",
" eval_set=[(sparse_X_train, y_train), (sparse_X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# This illustrates the warm_start=False behaviour\n",
"save_history = []\n",
"\n",
"# Fitting the model without starting from a warm start nor computing the feature importance\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" compute_importance=False\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))\n",
"\n",
"save_history = [] # Resetting the list to show that it also works when computing feature importance\n",
"\n",
"# Fitting the model without starting from a warm start but with the computing of the feature importance activated\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" compute_importance=True # True by default so not needed\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
]
"source": [
"# This illustrates the warm_start=False behaviour\n",
"save_history = []\n",
"\n",
"# Fitting the model without starting from a warm start nor computing the feature importance\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" compute_importance=False\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))\n",
"\n",
"save_history = [] # Resetting the list to show that it also works when computing feature importance\n",
"\n",
"# Fitting the model without starting from a warm start but with the computing of the feature importance activated\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" compute_importance=True # True by default so not needed\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
]
},
{
"cell_type": "code",
Expand Down
41 changes: 29 additions & 12 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import abstractmethod
from pytorch_tabnet import tab_network
from pytorch_tabnet.utils import (
SparsePredictDataset,
PredictDataset,
create_explain_matrix,
validate_eval_set,
Expand Down Expand Up @@ -35,6 +36,7 @@
import zipfile
import warnings
import copy
import scipy


@dataclass
Expand Down Expand Up @@ -281,7 +283,7 @@ def predict(self, X):
Parameters
----------
X : a :tensor: `torch.Tensor`
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
Expand All @@ -290,11 +292,19 @@ def predict(self, X):
Predictions of the regression problem
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

results = []
for batch_nb, data in enumerate(dataloader):
Expand All @@ -311,7 +321,7 @@ def explain(self, X, normalize=False):
Parameters
----------
X : tensor: `torch.Tensor`
X : tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
normalize : bool (default False)
Wheter to normalize so that sum of features are equal to 1
Expand All @@ -325,11 +335,18 @@ def explain(self, X, normalize=False):
"""
self.network.eval()

dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

res_explain = []

Expand Down
42 changes: 29 additions & 13 deletions pytorch_tabnet/multitask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
import numpy as np
from scipy.special import softmax
from pytorch_tabnet.utils import PredictDataset, filter_weights
from pytorch_tabnet.utils import SparsePredictDataset, PredictDataset, filter_weights
from pytorch_tabnet.abstract_model import TabModel
from pytorch_tabnet.multiclass_utils import infer_multitask_output, check_output_dim
from torch.utils.data import DataLoader
import scipy


class TabNetMultiTaskClassifier(TabModel):
Expand Down Expand Up @@ -87,7 +88,7 @@ def predict(self, X):
Parameters
----------
X : a :tensor: `torch.Tensor`
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
Expand All @@ -96,11 +97,19 @@ def predict(self, X):
Predictions of the most probable class
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

results = {}
for data in dataloader:
Expand Down Expand Up @@ -132,7 +141,7 @@ def predict_proba(self, X):
Parameters
----------
X : a :tensor: `torch.Tensor`
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
Expand All @@ -142,11 +151,18 @@ def predict_proba(self, X):
"""
self.network.eval()

dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

results = {}
for data in dataloader:
Expand Down
22 changes: 16 additions & 6 deletions pytorch_tabnet/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytorch_tabnet.utils import (
create_explain_matrix,
filter_weights,
SparsePredictDataset,
PredictDataset,
check_input,
create_group_matrix,
Expand All @@ -20,6 +21,7 @@
UnsupervisedLoss,
)
from pytorch_tabnet.abstract_model import TabModel
import scipy


class TabNetPretrainer(TabModel):
Expand Down Expand Up @@ -390,7 +392,7 @@ def predict(self, X):
Parameters
----------
X : a :tensor: `torch.Tensor`
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
Expand All @@ -399,11 +401,19 @@ def predict(self, X):
Predictions of the regression problem
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

results = []
embedded_res = []
Expand Down
Loading

0 comments on commit 5c000c2

Please sign in to comment.