From 89746a7b4a070245c799c24ce3a93dd0b9103f37 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Wed, 4 Jan 2023 11:51:59 -0500 Subject: [PATCH] Iterate through batch generator in benchmarks (#140) * Fix TorchLoader benchmarks * Iterate through batch generator in benchmarks --- asv_bench/benchmarks/benchmarks.py | 38 ++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/asv_bench/benchmarks/benchmarks.py b/asv_bench/benchmarks/benchmarks.py index 0b21247..3a9173d 100644 --- a/asv_bench/benchmarks/benchmarks.py +++ b/asv_bench/benchmarks/benchmarks.py @@ -24,16 +24,16 @@ def setup(self, *args, **kwargs): shape_4d = (10, 50, 100, 3) self.ds_4d = xr.Dataset( { - "foo": (["time", "y", "x", "b"], np.random.rand(*shape_4d)), + "foo": (["time", "y", "x", "z"], np.random.rand(*shape_4d)), }, { "x": (["x"], np.arange(shape_4d[-2])), "y": (["y"], np.arange(shape_4d[-3])), - "b": (["b"], np.arange(shape_4d[-1])), + "z": (["z"], np.arange(shape_4d[-1])), }, ) - self.ds_xy = xr.Dataset( + self.ds_2d = xr.Dataset( { "x": ( ["sample", "feature"], @@ -51,8 +51,12 @@ def time_batch_preload(self, preload_batch): Construct a generator on a chunked DataSet with and without preloading batches. """ - ds_dask = self.ds_xy.chunk({"sample": 2}) - BatchGenerator(ds_dask, input_dims={"sample": 2}, preload_batch=preload_batch) + ds_dask = self.ds_2d.chunk({"sample": 2}) + bg = BatchGenerator( + ds_dask, input_dims={"sample": 2}, preload_batch=preload_batch + ) + for batch in bg: + pass @parameterized( ["input_dims", "batch_dims", "input_overlap"], @@ -66,12 +70,14 @@ def time_batch_input(self, input_dims, batch_dims, input_overlap): """ Benchmark simple batch generation case. """ - BatchGenerator( + bg = BatchGenerator( self.ds_3d, input_dims=input_dims, batch_dims=batch_dims, input_overlap=input_overlap, ) + for batch in bg: + pass @parameterized( ["input_dims", "concat_input_dims"], @@ -82,11 +88,13 @@ def time_batch_concat(self, input_dims, concat_input_dims): Construct a generator on a DataSet with and without concatenating chunks specified by ``input_dims`` into the batch dimension. """ - BatchGenerator( + bg = BatchGenerator( self.ds_3d, input_dims=input_dims, concat_input_dims=concat_input_dims, ) + for batch in bg: + pass @parameterized( ["input_dims", "batch_dims", "concat_input_dims"], @@ -101,12 +109,14 @@ def time_batch_concat_4d(self, input_dims, batch_dims, concat_input_dims): Construct a generator on a DataSet with and without concatenating chunks specified by ``input_dims`` into the batch dimension. """ - BatchGenerator( + bg = BatchGenerator( self.ds_4d, input_dims=input_dims, batch_dims=batch_dims, concat_input_dims=concat_input_dims, ) + for batch in bg: + pass class Accessor(Base): @@ -119,14 +129,16 @@ def time_accessor_input_dim(self, input_dims): Benchmark simple batch generation case using xarray accessor Equivalent to subset of ``time_batch_input()``. """ - self.ds_3d.batch.generator(input_dims=input_dims) + bg = self.ds_3d.batch.generator(input_dims=input_dims) + for batch in bg: + pass class TorchLoader(Base): def setup(self, *args, **kwargs): super().setup(**kwargs) - self.x_gen = BatchGenerator(self.ds_xy["x"], {"sample": 10}) - self.y_gen = BatchGenerator(self.ds_xy["y"], {"sample": 10}) + self.x_gen = BatchGenerator(self.ds_2d["x"], {"sample": 10}) + self.y_gen = BatchGenerator(self.ds_2d["y"], {"sample": 10}) def time_map_dataset(self): """ @@ -134,7 +146,7 @@ def time_map_dataset(self): """ dataset = MapDataset(self.x_gen, self.y_gen) loader = torch.utils.data.DataLoader(dataset) - iter(loader).next() + next(iter(loader)) def time_iterable_dataset(self): """ @@ -142,4 +154,4 @@ def time_iterable_dataset(self): """ dataset = IterableDataset(self.x_gen, self.y_gen) loader = torch.utils.data.DataLoader(dataset) - iter(loader).next() + next(iter(loader))