Skip to content

Commit

Permalink
Speed up DataAdapter tests by testing only the current backend. (#1…
Browse files Browse the repository at this point in the history
…9625)

There is no use case for using an iterator for a different backend than the current backend.

Also:
- limit the number of tests using multiprocessing, the threading tests give us good coverage.
- fixed the `test_exception_reported` test, which was not actually exercising the multiprocessing / multithreading cases.
- removed unused `init_pool` method.
  • Loading branch information
hertschuh authored Apr 27, 2024
1 parent f6c4ac5 commit fe03ca5
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 129 deletions.
11 changes: 5 additions & 6 deletions keras/src/trainers/data_adapters/array_data_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,10 @@ def make_array(self, array_type, shape, dtype):
"scipy_sparse",
],
array_dtype=["float32", "float64"],
iterator_type=["np", "tf", "jax", "torch"],
shuffle=[False, "batch", True],
)
)
def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle):
def test_basic_flow(self, array_type, array_dtype, shuffle):
x = self.make_array(array_type, (34, 4), array_dtype)
y = self.make_array(array_type, (34, 2), "int32")
xdim1 = 1 if array_type == "pandas_series" else 4
Expand All @@ -75,10 +74,10 @@ def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle):
self.assertEqual(adapter.has_partial_batch, True)
self.assertEqual(adapter.partial_batch_size, 2)

if iterator_type == "np":
if backend.backend() == "numpy":
it = adapter.get_numpy_iterator()
expected_class = np.ndarray
elif iterator_type == "tf":
elif backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
if array_type == "tf_ragged":
expected_class = tf.RaggedTensor
Expand All @@ -88,13 +87,13 @@ def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle):
expected_class = tf.SparseTensor
else:
expected_class = tf.Tensor
elif iterator_type == "jax":
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"):
expected_class = jax_sparse.JAXSparse
else:
expected_class = jax.Array
elif iterator_type == "torch":
elif backend.backend() == "torch":
it = adapter.get_torch_dataloader()
expected_class = torch.Tensor

Expand Down
40 changes: 20 additions & 20 deletions keras/src/trainers/data_adapters/generator_data_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import jax
import jax.experimental.sparse as jax_sparse
import numpy as np
import pytest
import scipy
import tensorflow as tf
import torch
from absl.testing import parameterized
from jax import numpy as jnp

from keras.src import backend
from keras.src import testing
from keras.src.testing.test_utils import named_product
from keras.src.trainers.data_adapters import generator_data_adapter
Expand Down Expand Up @@ -37,10 +39,9 @@ class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase):
{"testcase_name": "no_weight", "use_sample_weight": False},
],
generator_type=["np", "tf", "jax", "torch"],
iterator_type=["np", "tf", "jax", "torch"],
)
)
def test_basic_flow(self, use_sample_weight, generator_type, iterator_type):
def test_basic_flow(self, use_sample_weight, generator_type):
x = np.random.random((34, 4)).astype("float32")
y = np.array([[i, i] for i in range(34)], dtype="float32")
sw = np.random.random((34,)).astype("float32")
Expand All @@ -64,16 +65,16 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type):
)

adapter = generator_data_adapter.GeneratorDataAdapter(make_generator())
if iterator_type == "np":
if backend.backend() == "numpy":
it = adapter.get_numpy_iterator()
expected_class = np.ndarray
elif iterator_type == "tf":
elif backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
expected_class = tf.Tensor
elif iterator_type == "jax":
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
expected_class = jax.Array
elif iterator_type == "torch":
elif backend.backend() == "torch":
it = adapter.get_torch_dataloader()
expected_class = torch.Tensor

Expand Down Expand Up @@ -101,24 +102,21 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type):
sample_order.append(by[i, 0])
self.assertAllClose(sample_order, list(range(34)))

@parameterized.named_parameters(
named_product(iterator_type=["np", "tf", "jax", "torch"])
)
def test_with_different_shapes(self, iterator_type):
def test_with_different_shapes(self):
def generator():
yield np.ones([16, 4], "float32"), np.ones([16, 2], "float32")
yield np.ones([16, 5], "float32"), np.ones([16, 2], "float32")
yield np.ones([2, 6], "float32"), np.ones([2, 2], "float32")

adapter = generator_data_adapter.GeneratorDataAdapter(generator())

if iterator_type == "np":
if backend.backend() == "numpy":
it = adapter.get_numpy_iterator()
elif iterator_type == "tf":
elif backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
elif iterator_type == "jax":
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
elif iterator_type == "torch":
elif backend.backend() == "torch":
it = adapter.get_torch_dataloader()

for i, batch in enumerate(it):
Expand All @@ -137,11 +135,13 @@ def generator():
self.assertEqual(by.shape, (2, 2))

@parameterized.named_parameters(
named_product(
generator_type=["tf", "jax", "scipy"], iterator_type=["tf", "jax"]
)
named_product(generator_type=["tf", "jax", "scipy"])
)
@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors",
)
def test_scipy_sparse_tensors(self, generator_type, iterator_type):
def test_scipy_sparse_tensors(self, generator_type):
if generator_type == "tf":
x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 4))
y = tf.SparseTensor([[0, 0], [1, 1]], [3.0, 4.0], (2, 2))
Expand All @@ -158,10 +158,10 @@ def generate():

adapter = generator_data_adapter.GeneratorDataAdapter(generate())

if iterator_type == "tf":
if backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
expected_class = tf.SparseTensor
elif iterator_type == "jax":
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
expected_class = jax_sparse.BCOO

Expand Down
5 changes: 0 additions & 5 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,6 @@ def get_worker_id_queue():
return _WORKER_ID_QUEUE


def init_pool(seqs):
global _SHARED_SEQUENCES
_SHARED_SEQUENCES = seqs


def get_index(uid, i):
"""Get the value from the PyDataset `uid` at index `i`.
Expand Down
110 changes: 55 additions & 55 deletions keras/src/trainers/data_adapters/py_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import jax
import numpy as np
import pytest
import tensorflow as tf
import torch
from absl.testing import parameterized

from keras.src import backend
from keras.src import testing
from keras.src.testing.test_utils import named_product
from keras.src.trainers.data_adapters import py_dataset_adapter
Expand Down Expand Up @@ -77,6 +79,21 @@ def __getitem__(self, idx):
return batch


class ExceptionPyDataset(py_dataset_adapter.PyDataset):

@property
def num_batches(self):
return 4

def __getitem__(self, index):
if index < 2:
return (
np.random.random((64, 4)).astype("float32"),
np.random.random((64, 2)).astype("float32"),
)
raise ValueError("Expected exception")


class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
named_product(
Expand Down Expand Up @@ -113,7 +130,6 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase):
},
],
infinite=[True, False],
iterator_type=["np", "tf", "jax", "torch"],
shuffle=[True, False],
)
)
Expand All @@ -122,11 +138,13 @@ def test_basic_flow(
shuffle,
dataset_type,
infinite,
iterator_type,
workers=0,
use_multiprocessing=False,
max_queue_size=0,
):
if use_multiprocessing and (infinite or shuffle):
pytest.skip("Starting processes is slow, only test one variant")

set_random_seed(1337)
x = np.random.random((64, 4)).astype("float32")
y = np.array([[i, i] for i in range(64)], dtype="float32")
Expand All @@ -149,16 +167,16 @@ def test_basic_flow(
py_dataset, shuffle=shuffle
)

if iterator_type == "np":
if backend.backend() == "numpy":
it = adapter.get_numpy_iterator()
expected_class = np.ndarray
elif iterator_type == "tf":
elif backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
expected_class = tf.Tensor
elif iterator_type == "jax":
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
expected_class = jax.Array
elif iterator_type == "torch":
elif backend.backend() == "torch":
it = adapter.get_torch_dataloader()
expected_class = torch.Tensor

Expand Down Expand Up @@ -257,10 +275,7 @@ def test_dict_inputs(self):
self.assertEqual(tuple(bx.shape), (4, 4))
self.assertEqual(tuple(by.shape), (4, 2))

@parameterized.named_parameters(
named_product(iterator_type=["np", "tf", "jax", "torch"])
)
def test_with_different_shapes(self, iterator_type):
def test_with_different_shapes(self):

class TestPyDataset(py_dataset_adapter.PyDataset):
@property
Expand All @@ -285,13 +300,13 @@ def __getitem__(self, idx):
TestPyDataset(), shuffle=False
)

if iterator_type == "np":
if backend.backend() == "numpy":
it = adapter.get_numpy_iterator()
elif iterator_type == "tf":
elif backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
elif iterator_type == "jax":
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
elif iterator_type == "torch":
elif backend.backend() == "torch":
it = adapter.get_torch_dataloader()

for i, batch in enumerate(it):
Expand All @@ -310,67 +325,52 @@ def __getitem__(self, idx):
self.assertEqual(by.shape, (2, 2))

@parameterized.named_parameters(
named_product(
[
{
"testcase_name": "multiprocessing",
"workers": 2,
"use_multiprocessing": True,
"max_queue_size": 10,
},
{
"testcase_name": "multithreading",
"workers": 2,
"max_queue_size": 10,
},
{
"testcase_name": "single",
},
],
iterator_type=["np", "tf", "jax", "torch"],
)
[
{
"testcase_name": "multiprocessing",
"workers": 2,
"use_multiprocessing": True,
"max_queue_size": 10,
},
{
"testcase_name": "multithreading",
"workers": 2,
"max_queue_size": 10,
},
{
"testcase_name": "single",
},
]
)
def test_exception_reported(
self,
iterator_type,
workers=0,
use_multiprocessing=False,
max_queue_size=0,
):
class ExceptionPyDataset(py_dataset_adapter.PyDataset):

@property
def num_batches(self):
return 4

def __getitem__(self, index):
if index < 2:
return (
np.random.random((64, 4)).astype("float32"),
np.random.random((64, 2)).astype("float32"),
)
raise ValueError("Excepted exception")

adapter = py_dataset_adapter.PyDatasetAdapter(
ExceptionPyDataset(), shuffle=False
dataset = ExceptionPyDataset(
workers=workers,
use_multiprocessing=use_multiprocessing,
max_queue_size=max_queue_size,
)
adapter = py_dataset_adapter.PyDatasetAdapter(dataset, shuffle=False)

expected_exception_class = ValueError
if iterator_type == "np":
if backend.backend() == "numpy":
it = adapter.get_numpy_iterator()
elif iterator_type == "tf":
elif backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
# tf.data wraps the exception
expected_exception_class = tf.errors.InvalidArgumentError
elif iterator_type == "jax":
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
elif iterator_type == "torch":
elif backend.backend() == "torch":
it = adapter.get_torch_dataloader()

it = iter(it)
next(it)
next(it)
with self.assertRaisesRegex(
expected_exception_class, "Excepted exception"
expected_exception_class, "Expected exception"
):
next(it)
Loading

0 comments on commit fe03ca5

Please sign in to comment.