diff --git a/keras/src/trainers/data_adapters/array_data_adapter_test.py b/keras/src/trainers/data_adapters/array_data_adapter_test.py index 46eb4fcc194..80b4462e407 100644 --- a/keras/src/trainers/data_adapters/array_data_adapter_test.py +++ b/keras/src/trainers/data_adapters/array_data_adapter_test.py @@ -52,11 +52,10 @@ def make_array(self, array_type, shape, dtype): "scipy_sparse", ], array_dtype=["float32", "float64"], - iterator_type=["np", "tf", "jax", "torch"], shuffle=[False, "batch", True], ) ) - def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle): + def test_basic_flow(self, array_type, array_dtype, shuffle): x = self.make_array(array_type, (34, 4), array_dtype) y = self.make_array(array_type, (34, 2), "int32") xdim1 = 1 if array_type == "pandas_series" else 4 @@ -75,10 +74,10 @@ def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle): self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() if array_type == "tf_ragged": expected_class = tf.RaggedTensor @@ -88,13 +87,13 @@ def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle): expected_class = tf.SparseTensor else: expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"): expected_class = jax_sparse.JAXSparse else: expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor diff --git a/keras/src/trainers/data_adapters/generator_data_adapter_test.py b/keras/src/trainers/data_adapters/generator_data_adapter_test.py index 4d6ebdc5597..76839308c7f 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter_test.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter_test.py @@ -3,12 +3,14 @@ import jax import jax.experimental.sparse as jax_sparse import numpy as np +import pytest import scipy import tensorflow as tf import torch from absl.testing import parameterized from jax import numpy as jnp +from keras.src import backend from keras.src import testing from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import generator_data_adapter @@ -37,10 +39,9 @@ class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase): {"testcase_name": "no_weight", "use_sample_weight": False}, ], generator_type=["np", "tf", "jax", "torch"], - iterator_type=["np", "tf", "jax", "torch"], ) ) - def test_basic_flow(self, use_sample_weight, generator_type, iterator_type): + def test_basic_flow(self, use_sample_weight, generator_type): x = np.random.random((34, 4)).astype("float32") y = np.array([[i, i] for i in range(34)], dtype="float32") sw = np.random.random((34,)).astype("float32") @@ -64,16 +65,16 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type): ) adapter = generator_data_adapter.GeneratorDataAdapter(make_generator()) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -101,10 +102,7 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type): sample_order.append(by[i, 0]) self.assertAllClose(sample_order, list(range(34))) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_with_different_shapes(self, iterator_type): + def test_with_different_shapes(self): def generator(): yield np.ones([16, 4], "float32"), np.ones([16, 2], "float32") yield np.ones([16, 5], "float32"), np.ones([16, 2], "float32") @@ -112,13 +110,13 @@ def generator(): adapter = generator_data_adapter.GeneratorDataAdapter(generator()) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() for i, batch in enumerate(it): @@ -137,11 +135,13 @@ def generator(): self.assertEqual(by.shape, (2, 2)) @parameterized.named_parameters( - named_product( - generator_type=["tf", "jax", "scipy"], iterator_type=["tf", "jax"] - ) + named_product(generator_type=["tf", "jax", "scipy"]) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors", ) - def test_scipy_sparse_tensors(self, generator_type, iterator_type): + def test_scipy_sparse_tensors(self, generator_type): if generator_type == "tf": x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 4)) y = tf.SparseTensor([[0, 0], [1, 1]], [3.0, 4.0], (2, 2)) @@ -158,10 +158,10 @@ def generate(): adapter = generator_data_adapter.GeneratorDataAdapter(generate()) - if iterator_type == "tf": + if backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.SparseTensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax_sparse.BCOO diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index daa56a1313f..19b48570577 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -342,11 +342,6 @@ def get_worker_id_queue(): return _WORKER_ID_QUEUE -def init_pool(seqs): - global _SHARED_SEQUENCES - _SHARED_SEQUENCES = seqs - - def get_index(uid, i): """Get the value from the PyDataset `uid` at index `i`. diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py index 7c41971db56..ac661c2047a 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py @@ -3,10 +3,12 @@ import jax import numpy as np +import pytest import tensorflow as tf import torch from absl.testing import parameterized +from keras.src import backend from keras.src import testing from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import py_dataset_adapter @@ -77,6 +79,21 @@ def __getitem__(self, idx): return batch +class ExceptionPyDataset(py_dataset_adapter.PyDataset): + + @property + def num_batches(self): + return 4 + + def __getitem__(self, index): + if index < 2: + return ( + np.random.random((64, 4)).astype("float32"), + np.random.random((64, 2)).astype("float32"), + ) + raise ValueError("Expected exception") + + class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( named_product( @@ -113,7 +130,6 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase): }, ], infinite=[True, False], - iterator_type=["np", "tf", "jax", "torch"], shuffle=[True, False], ) ) @@ -122,11 +138,13 @@ def test_basic_flow( shuffle, dataset_type, infinite, - iterator_type, workers=0, use_multiprocessing=False, max_queue_size=0, ): + if use_multiprocessing and (infinite or shuffle): + pytest.skip("Starting processes is slow, only test one variant") + set_random_seed(1337) x = np.random.random((64, 4)).astype("float32") y = np.array([[i, i] for i in range(64)], dtype="float32") @@ -149,16 +167,16 @@ def test_basic_flow( py_dataset, shuffle=shuffle ) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -257,10 +275,7 @@ def test_dict_inputs(self): self.assertEqual(tuple(bx.shape), (4, 4)) self.assertEqual(tuple(by.shape), (4, 2)) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_with_different_shapes(self, iterator_type): + def test_with_different_shapes(self): class TestPyDataset(py_dataset_adapter.PyDataset): @property @@ -285,13 +300,13 @@ def __getitem__(self, idx): TestPyDataset(), shuffle=False ) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() for i, batch in enumerate(it): @@ -310,67 +325,52 @@ def __getitem__(self, idx): self.assertEqual(by.shape, (2, 2)) @parameterized.named_parameters( - named_product( - [ - { - "testcase_name": "multiprocessing", - "workers": 2, - "use_multiprocessing": True, - "max_queue_size": 10, - }, - { - "testcase_name": "multithreading", - "workers": 2, - "max_queue_size": 10, - }, - { - "testcase_name": "single", - }, - ], - iterator_type=["np", "tf", "jax", "torch"], - ) + [ + { + "testcase_name": "multiprocessing", + "workers": 2, + "use_multiprocessing": True, + "max_queue_size": 10, + }, + { + "testcase_name": "multithreading", + "workers": 2, + "max_queue_size": 10, + }, + { + "testcase_name": "single", + }, + ] ) def test_exception_reported( self, - iterator_type, workers=0, use_multiprocessing=False, max_queue_size=0, ): - class ExceptionPyDataset(py_dataset_adapter.PyDataset): - - @property - def num_batches(self): - return 4 - - def __getitem__(self, index): - if index < 2: - return ( - np.random.random((64, 4)).astype("float32"), - np.random.random((64, 2)).astype("float32"), - ) - raise ValueError("Excepted exception") - - adapter = py_dataset_adapter.PyDatasetAdapter( - ExceptionPyDataset(), shuffle=False + dataset = ExceptionPyDataset( + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=max_queue_size, ) + adapter = py_dataset_adapter.PyDatasetAdapter(dataset, shuffle=False) expected_exception_class = ValueError - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() # tf.data wraps the exception expected_exception_class = tf.errors.InvalidArgumentError - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() it = iter(it) next(it) next(it) with self.assertRaisesRegex( - expected_exception_class, "Excepted exception" + expected_exception_class, "Expected exception" ): next(it) diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py index ad48c2d3c24..2535e505d61 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py @@ -2,20 +2,18 @@ import jax import numpy as np +import pytest import tensorflow as tf import torch from absl.testing import parameterized +from keras.src import backend from keras.src import testing -from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import tf_dataset_adapter class TestTFDatasetAdapter(testing.TestCase, parameterized.TestCase): - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_basic_flow(self, iterator_type): + def test_basic_flow(self): x = tf.random.normal((34, 4)) y = tf.random.normal((34, 2)) base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16) @@ -26,16 +24,16 @@ def test_basic_flow(self, iterator_type): self.assertEqual(adapter.has_partial_batch, None) self.assertEqual(adapter.partial_batch_size, None) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -258,10 +256,11 @@ def test_distribute_dataset(self): self.assertEqual(tuple(bx.shape), (2, 4)) self.assertEqual(tuple(by.shape), (2, 2)) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax"]) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS and backend.backend() != "numpy", + reason="Backend does not support sparse tensors", ) - def test_tf_sparse_tensors(self, iterator_type): + def test_tf_sparse_tensors(self): x = tf.SparseTensor( indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 4) ) @@ -271,13 +270,13 @@ def test_tf_sparse_tensors(self, iterator_type): base_ds = tf.data.Dataset.from_tensors((x, y)) adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.SparseTensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.experimental.sparse.BCOO diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py index e86f570d692..4d02f5592f6 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py @@ -6,6 +6,7 @@ import torch from absl.testing import parameterized +from keras.src import backend from keras.src import testing from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters.torch_data_loader_adapter import ( @@ -14,10 +15,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase): - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_basic_dataloader(self, iterator_type): + def test_basic_dataloader(self): x = torch.normal(2, 3, size=(34, 4)) y = torch.normal(1, 3, size=(34, 2)) ds = torch.utils.data.TensorDataset(x, y) @@ -29,16 +27,16 @@ def test_basic_dataloader(self, iterator_type): self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -57,15 +55,9 @@ def test_basic_dataloader(self, iterator_type): self.assertEqual(by.shape, (2, 2)) @parameterized.named_parameters( - named_product( - batch_size=[None, 3], - implements_len=[True, False], - iterator_type=["np", "tf", "jax", "torch"], - ) + named_product(batch_size=[None, 3], implements_len=[True, False]) ) - def test_dataloader_iterable_dataset( - self, batch_size, implements_len, iterator_type - ): + def test_dataloader_iterable_dataset(self, batch_size, implements_len): class TestIterableDataset(torch.utils.data.IterableDataset): def __init__(self): @@ -104,16 +96,16 @@ def __len__(self): self.assertIsNone(adapter.has_partial_batch) self.assertIsNone(adapter.partial_batch_size) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -142,10 +134,7 @@ def __len__(self): else: self.assertEqual(batch_count, 10) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_with_different_shapes(self, iterator_type): + def test_with_different_shapes(self): x = ( [np.ones([4], "float32")] * 16 + [np.ones([5], "float32")] * 16 @@ -161,13 +150,13 @@ def test_with_different_shapes(self, iterator_type): self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() for i, batch in enumerate(it):