Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of sparse matrices #239

Closed
carlogeertse opened this issue Aug 6, 2021 · 11 comments · Fixed by #240
Closed

Use of sparse matrices #239

carlogeertse opened this issue Aug 6, 2021 · 11 comments · Fixed by #240

Comments

@carlogeertse
Copy link

carlogeertse commented Aug 6, 2021

Hi, nice to see someone continue development on a wrapper like this after the people at tensorflow decided to discontinue development on their wrapper.

I have run into an issue with the use of sparse matrices.

In the API documentation it is mentioned that the fit and predict functions from the KerasClassifier wrapper should work with array-like, sparse matrix and dataframe. However, when I use a sparse matrix, I get the following exception:

TypeError: A sparse matrix was passed, but dense data is required. Use X.toarray() to convert to a dense numpy array.

I used the quickstart guide to get a simple reproducable issue. I simply converted the ndarrays in the example into a scipy.sparse coo_matrix:

import numpy as np
from sklearn.datasets import make_classification
from tensorflow import keras
from scipy.sparse import coo_matrix

from scikeras.wrappers import KerasClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

X = coo_matrix(X)
y = coo_matrix(y)

def get_model(hidden_layer_dim, meta):
    # note that meta is a special argument that will be
    # handed a dict containing input metadata
    n_features_in_ = meta["n_features_in_"]
    X_shape_ = meta["X_shape_"]
    n_classes_ = meta["n_classes_"]

    model = keras.models.Sequential()
    model.add(keras.layers.Dense(n_features_in_, input_shape=X_shape_[1:]))
    model.add(keras.layers.Activation("relu"))
    model.add(keras.layers.Dense(hidden_layer_dim))
    model.add(keras.layers.Activation("relu"))
    model.add(keras.layers.Dense(n_classes_))
    model.add(keras.layers.Activation("softmax"))
    return model

clf = KerasClassifier(
    get_model,
    loss="sparse_categorical_crossentropy",
    hidden_layer_dim=100,
)

clf.fit(X, y)
y_proba = clf.predict_proba(X)

A potential reason for the issue could be that when validating the inputs via sklearn.utils.check_X_y, the default parameter for accept_sparse is False. See also here

Setting this parameter to true might solve the issue (I will go and test that soon).
I am running this on python=3.7.10, scikit-learn=0.24.2 and tensorflow=2.5.0

@adriangb
Copy link
Owner

Hi 👋 , sorry for the delay, busy week.

I really appreciate the clear example, it made things easy to reproduce.

So firstly, I'm sorry that the docstring says we accept sparse matrices but, as you point out, we obviously don't because of the way the parameters to sklearn.utils.check_X_y are set. That must have been confusing.

You are right that we could probably just add the parameter and some things would start working (and in fact I opened #240 to do so).

Sparse values for y seems to be less supported (at least my trivial attempt failed), and arguably are less useful in the first place, so I don't think we can support that. I don't even think Sklearn generally supports it.

This said, at a high level, I'm not sure that working with sparse matrices in this way is a great idea: it's highly likely that somewhere along the pipeline the entire data will have to be copied in memory or worse, cast to a dense format. This could be SciKeras converting from DOK -> LIL (TF does not support DOK), or it could be TF converting LIL -> dense tensor (which depends on the ops/layers used). I fear that this could lead to unexpected and tricky to debug results. Maybe you can give #240 a whirl and let me know how performance/memory usage goes with real world data?

Also, TensorFlow supports much more advanced data ingestion by means of tf.data.Datasets. Is that something you've looked into @carlogeertse ? Or would you mind elaborating on the actual use case / shape of the data a bit more and maybe we can brainstorm another way to go about this?

Thanks!

@carlogeertse
Copy link
Author

carlogeertse commented Aug 13, 2021

Hi,

I appreciate the fairly quick response. I would love to check and see if #240 solves the issue. I won't be able to do this with the data I tried it with before, but I will try to find or create a similar dataset in terms of size. I will be travelling for the coming 3 weeks however, so it will be a while untill I have time.

As for why I want to use sparse matrices: the first reason is that I have very sparse data, which becomes very large in memory when stored in a dense format.
Secondly, I am comparing different base classifiers for sklearn.semi_supervised.SelfTrainingClassifier. On top of that, I am calibrating the classifiers using sklearn.calibration.CalibratedClassifierCV. Lastly, all of this is done within a sklearn pipeline. All these methods support sparse matrices, but I don't think they support a tensorflow dataset. This pipeline also requires the base classifier to adhere to the sklearn api. So far I have successfully used it with an sklearn classifier: sklearn.ensemble.RandomForestClassifier and an external api implementation: xgboost.XGBClassifier.

I wanted to use a relatively simple fully connected network for comparison as well. I will admit that I don't know much about sparse inputs to these networks, so I am not sure if the process is memory efficiënt the whole way though. In the end I ran out of time to include it in my thesis anyway. Though I am interested in exploring this use case a bit further.

Lastly, I agree that support for a sparse y isnt needed. I have been using a dense array for that in practice too.

@mattalhonte-srm
Copy link

Any chance this could get merged in? I made a quick & dirty fork (just passing accept_sparse=True) but I'd like to be able to keep up with the main branch!

@adriangb
Copy link
Owner

@mattalhonte-srm does #240 work for you? As in not just run, but actually give you performance (runtime or memory) benefits?

@carlogeertse
Copy link
Author

Hi again, apologies for never coming back to this before.

I just performed some tests with some dummy data.
I tested it with generated random sparse data of 1 million x 200, using a GPU.
I had some trouble measuring memory usage, so just stuck to windows task manager for some rough estimates.

For a denisty of 5% (so only 5% of all values in the data are non-zero):
non-sparse memory usage: 3.6 GB
sparse memory usage: 3.6 GB

Density of 1%:
non-sparse: 3.6 GB
sparse: 3.0 GB

Density of 0.01 %:
non-sparse: 3.6 GB
sparse: 2.6 GB

So as you can see, using sparse matrices does help with reducing general memory usage, especially as datasets get more and more sparse.
Time performance wise, using sparse matrices is about twice as slow, which is in line with other machine learning models within sklearn that allow for the use of sparse matrices.

I do however still feel that the use of tf.data.Datasets would probably be better, but sparce matrices might help in some niche situations. I'm also not sure why I measure a better memory performance, while it seems quite likely that somewhere along the process the sparse matrix will indeed be converted into a dense one.

@mattalhonte-srm
Copy link

Yeah, I get better memory performance - and "not blowing up containers" is a concern for my use case, so this is kind of a dealbreaker feature for me.

@adriangb
Copy link
Owner

Would you mind posting an example of how you tested this?

I just want to make sure this actually works for real world use cases so that we don't add a feature that has no upside and this only serves to confuse users.

@mattalhonte-srm
Copy link

The main way I know is that casting .todense() made my container crash, while passing the Sparse matrix didn't.

@adriangb
Copy link
Owner

Awesome that's about as real world useful as it gets. I think I'll move forward with #240 tomorrow

@adriangb
Copy link
Owner

@carlogeertse and/or @mattalhonte-srm can you folks take a look at #240 (PR review is appreciated), especially the new tutorial (https://www.adriangb.com/scikeras/refs/pull/240/merge/notebooks/sparse.html) and see if it all looks good to you? If it does, we can merge and cut a release as soon as this weekend.

@mattalhonte-srm
Copy link

Heya! Just tested this and it doesn't work for me - converting it to lil makes the container blow up when I try to train, I need the csr matrix to stay a CSR matrix (TF can do much more efficient math on those!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants