Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SparseTensor check on convert_to_np_if_not_ragged. #20151

Merged
merged 2 commits into from
Aug 26, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,8 @@ def is_tpu_strat(k):
def convert_to_np_if_not_ragged(x):
if isinstance(x, tf.RaggedTensor):
return x
elif isinstance(x, tf.SparseTensor):
return x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to instead return a scipy sparse matrix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my case, this was a test to see if my layer would behave properly inside a model. So I am not using the output directly. The output will be sent to another layer.

I see no need for scipy in my use case, but it all depends on what people expect from the Model.predict() function.

What would have been really useful is that the error would say that, when using predict(), the output of the model needs to be a RaggedTensor, SparceTensor or something with a "numpy" attribute. I've spent quite some time trying to find out what I was doing wrong. However, I have no clue on where to place that check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to return a TF SparseTensor in this case. The underlying issue is basically a TF limitation -- .numpy() is a standard API and it's not normal that it isn't available on a SparseTensor.

Can you add a unit test for this change?

return x.numpy()


Expand Down