Skip to content

Commit

Permalink
ArrayDataAdapter no longer converts to NumPy and supports sparse tens…
Browse files Browse the repository at this point in the history
…ors. (#19298)

Instead, the passed arrays can be sliced or indexed in their native format.
- This addresses #18408 and improves performance, especially with Tensorflow and Torch. It improves TF -> TF and Torch -> Torch, but also TF -> Torch and Torch -> TF.
- This allows the support of sparse tensors (`tf.SparseTensor`, `jax.experimental.sparse.BCOO` and `scipy.sparse`). These sparse tensors are sliced as sparse and the iterators yield sparse tensors in the requested format (either TF or JAX).
- The `validation_split` argument of `Model.fit()` can now be used with anything supported by `ArrayDataAdapter`, in particular, sparse tensors are now supported.

In summary, `ArrayDataAdapter` now supports:
- native Python arrays
- NumPy arrays
- Tensorflow tensors, ragged tensors, sparse tensors (new)
- JAX arrays and BCOO sparse tensors (new)
- pandas DataFrames
- pandas Series
- scipy sparse matrices (new)

Also:
- Fixed bug where batch level shuffling would shuffle inconsistently the different arrays (in particular inputs and labels) when using a TF dataset or a NumPy iterator.
- Fixed bug where `tf.RaggedTensor`s would only work when using a TF dataset.
- Fixed bug where `tf.RaggedTensor`s would not work when doing batch level shuffling.
- Added a workaround for a bug where `tf.cast`ing a `tf.SparseTensor` would lose the static shape.
- Added test coverage for `tf.RaggedTensor`s and `pandas.Series`.
- Added verification in tests that inputs and labels are shuffled consistently.
  • Loading branch information
hertschuh authored Mar 14, 2024
1 parent 2cb0521 commit 818c9fa
Show file tree
Hide file tree
Showing 12 changed files with 800 additions and 396 deletions.
8 changes: 5 additions & 3 deletions keras/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,23 @@ def __jax_array__(self):
def convert_to_tensor(x, dtype=None, sparse=True):
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, (jnp.ndarray, jax.Array)) and dtype == x.dtype:
if isinstance(x, (jnp.ndarray, jax.Array)) and (
dtype is None or x.dtype == dtype
):
# Skip the conversion early if the instance is already a JAX array.
# This is important in the multi-process context since jax.array(x) for
# an existing distributed jax array will raise error.
return x

if isinstance(x, Variable):
if dtype and dtype != x.dtype:
if dtype is not None and x.dtype != dtype:
return x.value.astype(dtype)
return x.value

if isinstance(x, jax_sparse.JAXSparse):
if sparse is not None and not sparse:
x = x.todense()
elif dtype and dtype != x.dtype:
elif dtype is not None and x.dtype != dtype:
return x.astype(dtype)
else:
return x
Expand Down
3 changes: 2 additions & 1 deletion keras/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.backend import distribution_lib as jax_distribution_lib
from keras.distribution import distribution_lib
from keras.trainers import trainer as base_trainer
from keras.trainers.data_adapters import array_slicing
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.epoch_iterator import EpochIterator
from keras.utils import traceback_utils
Expand Down Expand Up @@ -330,7 +331,7 @@ def fit(
x,
y,
sample_weight,
), validation_data = data_adapter_utils.train_validation_split(
), validation_data = array_slicing.train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)

Expand Down
10 changes: 9 additions & 1 deletion keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def convert_to_numpy(x):
x.set_shape(x_shape)
elif isinstance(x, tf.IndexedSlices):
x = tf.convert_to_tensor(x)
elif isinstance(x, tf.RaggedTensor):
x = x.to_tensor()
return np.asarray(x)


Expand Down Expand Up @@ -161,7 +163,13 @@ def shape(x):

def cast(x, dtype):
dtype = standardize_dtype(dtype)
return tf.cast(x, dtype=dtype)
if isinstance(x, tf.SparseTensor):
x_shape = x.shape
x = tf.cast(x, dtype)
x.set_shape(x_shape)
return x
else:
return tf.cast(x, dtype=dtype)


def compute_output_spec(fn, *args, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion keras/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras import metrics as metrics_module
from keras import optimizers as optimizers_module
from keras.trainers import trainer as base_trainer
from keras.trainers.data_adapters import array_slicing
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.epoch_iterator import EpochIterator
from keras.utils import traceback_utils
Expand Down Expand Up @@ -273,7 +274,7 @@ def fit(
x,
y,
sample_weight,
), validation_data = data_adapter_utils.train_validation_split(
), validation_data = array_slicing.train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)

Expand Down
3 changes: 2 additions & 1 deletion keras/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras import callbacks as callbacks_module
from keras import optimizers as optimizers_module
from keras.trainers import trainer as base_trainer
from keras.trainers.data_adapters import array_slicing
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.epoch_iterator import EpochIterator
from keras.utils import traceback_utils
Expand Down Expand Up @@ -193,7 +194,7 @@ def fit(
x,
y,
sample_weight,
), validation_data = data_adapter_utils.train_validation_split(
), validation_data = array_slicing.train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)

Expand Down
Loading

0 comments on commit 818c9fa

Please sign in to comment.