Skip to content

Commit

Permalink
Support sparse matrices for inputs/X (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Jul 22, 2022
1 parent 71cabed commit 8d5e1a9
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 22 deletions.
164 changes: 164 additions & 0 deletions docs/source/notebooks/sparse.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
---
jupyter:
jupytext:
formats: ipynb,md
text_representation:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.14.0
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

<!-- #raw -->
<a href="https://colab.research.google.com/github/adriangb/scikeras/blob/docs-deploy/refs/heads/master/notebooks/AutoEncoders.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a>
<!-- #endraw -->

# Sparse Inputs


SciKeras supports sparse inputs (`X`/features).
You don't have to do anything special for this to work, you can just pass a sparse matrix to `fit()`.

In this notebook, we'll demonstrate how this works and compare memory consumption of sparse inputs to dense inputs.


## Setup

```python
!pip install memory_profiler
%load_ext memory_profiler
```

```python
import warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import get_logger
get_logger().setLevel('ERROR')
warnings.filterwarnings("ignore", message="Setting the random state for TF")
```

```python
try:
import scikeras
except ImportError:
!python -m pip install scikeras
```

```python
import scipy
import numpy as np
from scikeras.wrappers import KerasRegressor
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from tensorflow import keras
```

## Data

The dataset we'll be using is designed to demostrate a worst-case/best-case scenario for dense and sparse input features respectively.
It consists of a single categorical feature with equal number of categories as rows.
This means the one-hot encoded representation will require as many columns as it does rows, making it very ineffienct to store as a dense matrix but very efficient to store as a sparse matrix.

```python
N_SAMPLES = 20_000 # hand tuned to be ~4GB peak

X = np.arange(0, N_SAMPLES).reshape(-1, 1)
y = np.random.uniform(0, 1, size=(X.shape[0],))
```

## Model

The model here is nothing special, just a basic multilayer perceptron with one hidden layer.

```python
def get_clf(meta) -> keras.Model:
n_features_in_ = meta["n_features_in_"]
model = keras.models.Sequential()
model.add(keras.layers.Input(shape=(n_features_in_,)))
# a single hidden layer
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(1))
return model
```

## Pipelines

Here is where it gets interesting.
We make two Scikit-Learn pipelines that use `OneHotEncoder`: one that uses `sparse=False` to force a dense matrix as the output and another that uses `sparse=True` (the default).

```python
dense_pipeline = Pipeline(
[
("encoder", OneHotEncoder(sparse=False)),
("model", KerasRegressor(get_clf, loss="mse", epochs=5, verbose=False))
]
)

sparse_pipeline = Pipeline(
[
("encoder", OneHotEncoder(sparse=True)),
("model", KerasRegressor(get_clf, loss="mse", epochs=5, verbose=False))
]
)
```

## Benchmark

Our benchmark will be to just train each one of these pipelines and measure peak memory consumption.

```python
%memit dense_pipeline.fit(X, y)
```

```python
%memit sparse_pipeline.fit(X, y)
```

You should see at least 100x more memory consumption **increment** in the dense pipeline.


### Runtime

Using sparse inputs can have a drastic impact on memory usage, but it often (not always) hurts overall runtime.

```python
%timeit dense_pipeline.fit(X, y)
```

```python
%timeit sparse_pipeline.fit(X, y)
```

## Tensorflow Datasets

Tensorflow provides a whole suite of functionality around the [Dataset].
Datasets are lazily evaluated, can be sparse and minimize the transformations required to feed data into the model.
They are _a lot_ more performant and efficient at scale than using numpy datastructures, even sparse ones.

SciKeras does not (and cannot) support Datasets directly because Scikit-Learn itself does not support them and SciKeras' outwards API is Scikit-Learn's API.
You may want to explore breaking out of SciKeras and just using TensorFlow/Keras directly to see if Datasets can have a large impact for your use case.

[Dataset]: https://www.tensorflow.org/api_docs/python/tf/data/Dataset


## Bonus: dtypes

You might be able to save even more memory by changing the output dtype of `OneHotEncoder`.

```python
sparse_pipline_uint8 = Pipeline(
[
("encoder", OneHotEncoder(sparse=True, dtype=np.uint8)),
("model", KerasRegressor(get_clf, loss="mse", epochs=5, verbose=False))
]
)
```

```python
%memit sparse_pipline_uint8.fit(X, y)
```
3 changes: 2 additions & 1 deletion docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ Tutorials
notebooks/Meta_Estimators
notebooks/DataTransformers
notebooks/AutoEncoders
notebooks/Benchmarks
notebooks/Benchmarks
notebooks/sparse
52 changes: 31 additions & 21 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import tensorflow as tf

from scipy.sparse import isspmatrix, lil_matrix
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.exceptions import NotFittedError
from sklearn.metrics import accuracy_score as sklearn_accuracy_score
Expand Down Expand Up @@ -579,11 +580,11 @@ def _validate_data(
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape \
X : Union[array-like, sparse matrix, dataframe, of shape \
(n_samples, n_features)
The input samples. If None, ``check_array`` is called on y and
``check_X_y`` is called otherwise.
y : Union[array-like, sparse matrix, dataframe] of shape \
y : Union[array-like, dataframe,, of shape \
(n_samples,), default=None
The targets. If None, ``check_array`` is called on X and
``check_X_y`` is called otherwise.
Expand Down Expand Up @@ -617,6 +618,7 @@ def _check_array_dtype(arr, force_numeric):
X, y = check_X_y(
X,
y,
accept_sparse=True,
allow_nd=True, # allow X to have more than 2 dimensions
multi_output=True, # allow y to be 2D
dtype=None,
Expand Down Expand Up @@ -648,8 +650,16 @@ def _check_array_dtype(arr, force_numeric):
f" is expecting {self.y_ndim_} dimensions in y."
)
if X is not None:
if isspmatrix(X):
# TensorFlow does not support several of SciPy's sparse formats
# use SciPy to reformat here so at least the cost is known
X = lil_matrix(X) # no-copy reformat

X = check_array(
X, allow_nd=True, dtype=_check_array_dtype(X, force_numeric=True)
X,
allow_nd=True,
dtype=_check_array_dtype(X, force_numeric=True),
accept_sparse=True,
)
X_dtype_ = X.dtype
X_shape_ = X.shape
Expand Down Expand Up @@ -719,10 +729,10 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "BaseWrapper":
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples, where n_samples is the number of samples
and n_features is the number of features.
y : Union[array-like, sparse matrix, dataframe] of shape (n_samples,) or (n_samples, n_outputs)
y : Union[array-like, dataframe of shape (n_samples,) or (n_samples, n_outputs)
True labels for X.
sample_weight : array-like of shape (n_samples,), default=None
Array of weights that are assigned to individual samples.
Expand Down Expand Up @@ -856,10 +866,10 @@ def initialize(self, X, y=None) -> "BaseWrapper":
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples where n_samples is the number of samples
and `n_features` is the number of features.
y : Union[array-like, sparse matrix, dataframe] of shape \
y : Union[array-like, dataframe,, of shape \
(n_samples,) or (n_samples, n_outputs), default None
True labels for X.
Expand All @@ -885,10 +895,10 @@ def _fit(
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples where `n_samples` is the number of samples
and `n_features` is the number of features.
y :Union[array-like, sparse matrix, dataframe] of shape (n_samples,) or (n_samples, n_outputs)
y :Union[array-like, dataframe,, of shape (n_samples,) or (n_samples, n_outputs)
True labels for X.
sample_weight : array-like of shape (n_samples,), default=None
Array of weights that are assigned to individual samples.
Expand Down Expand Up @@ -934,10 +944,10 @@ def partial_fit(self, X, y, sample_weight=None, **kwargs) -> "BaseWrapper":
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples where n_samples is the number of samples
and n_features is the number of features.
y : Union[array-like, sparse matrix, dataframe] of shape \
y : Union[array-like, dataframe,, of shape \
(n_samples,) or (n_samples, n_outputs)
True labels for X.
sample_weight : array-like of shape (n_samples,), default=None
Expand Down Expand Up @@ -1018,7 +1028,7 @@ def predict(self, X, **kwargs):
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples where n_samples is the number of samples
and n_features is the number of features.
**kwargs : Dict[str, Any]
Expand Down Expand Up @@ -1076,10 +1086,10 @@ def score(self, X, y, sample_weight=None) -> float:
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Test input samples, where n_samples is the number of samples
and n_features is the number of features.
y : Union[array-like, sparse matrix, dataframe] of shape \
y : Union[array-like, dataframe,, of shape \
(n_samples,) or (n_samples, n_outputs)
True labels for X.
sample_weight : array-like of shape (n_samples,), default=None
Expand Down Expand Up @@ -1432,10 +1442,10 @@ def initialize(self, X, y) -> "KerasClassifier":
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples where n_samples is the number of samples
and `n_features` is the number of features.
y : Union[array-like, sparse matrix, dataframe] of shape \
y : Union[array-like, dataframe,, of shape \
(n_samples,) or (n_samples, n_outputs), default None
True labels for X.
Expand All @@ -1453,10 +1463,10 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "KerasClassifier":
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples, where n_samples is the number of samples
and n_features is the number of features.
y : Union[array-like, sparse matrix, dataframe] of shape (n_samples,) or (n_samples, n_outputs)
y : Union[array-like, dataframe,, of shape (n_samples,) or (n_samples, n_outputs)
True labels for X.
sample_weight : array-like of shape (n_samples,), default=None
Array of weights that are assigned to individual samples.
Expand Down Expand Up @@ -1492,10 +1502,10 @@ def partial_fit(
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples, where n_samples is the number of samples
and n_features is the number of features.
y : Union[array-like, sparse matrix, dataframe] of shape (n_samples,) or (n_samples, n_outputs)
y : Union[array-like, dataframe,, of shape (n_samples,) or (n_samples, n_outputs)
True labels for X.
classes: ndarray of shape (n_classes,), default=None
Classes across all calls to partial_fit. Can be obtained by via
Expand Down Expand Up @@ -1530,7 +1540,7 @@ def predict_proba(self, X, **kwargs):
Parameters
----------
X : Union[array-like, sparse matrix, dataframe] of shape (n_samples, n_features)
X : Union[array-like, sparse matrix, dataframe, of shape (n_samples, n_features)
Training samples, where n_samples is the number of samples
and n_features is the number of features.
**kwargs : Dict[str, Any]
Expand Down
1 change: 1 addition & 0 deletions tests/test_input_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import tensorflow as tf

from scipy.sparse import coo_matrix
from sklearn.base import BaseEstimator
from sklearn.metrics import accuracy_score, r2_score
from sklearn.model_selection import train_test_split
Expand Down

0 comments on commit 8d5e1a9

Please sign in to comment.