Skip to content

Commit

Permalink
Iterate through batch generator in benchmarks (#140)
Browse files Browse the repository at this point in the history
* Fix TorchLoader benchmarks

* Iterate through batch generator in benchmarks
  • Loading branch information
maxrjones authored Jan 4, 2023
1 parent f93af88 commit 89746a7
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions asv_bench/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ def setup(self, *args, **kwargs):
shape_4d = (10, 50, 100, 3)
self.ds_4d = xr.Dataset(
{
"foo": (["time", "y", "x", "b"], np.random.rand(*shape_4d)),
"foo": (["time", "y", "x", "z"], np.random.rand(*shape_4d)),
},
{
"x": (["x"], np.arange(shape_4d[-2])),
"y": (["y"], np.arange(shape_4d[-3])),
"b": (["b"], np.arange(shape_4d[-1])),
"z": (["z"], np.arange(shape_4d[-1])),
},
)

self.ds_xy = xr.Dataset(
self.ds_2d = xr.Dataset(
{
"x": (
["sample", "feature"],
Expand All @@ -51,8 +51,12 @@ def time_batch_preload(self, preload_batch):
Construct a generator on a chunked DataSet with and without preloading
batches.
"""
ds_dask = self.ds_xy.chunk({"sample": 2})
BatchGenerator(ds_dask, input_dims={"sample": 2}, preload_batch=preload_batch)
ds_dask = self.ds_2d.chunk({"sample": 2})
bg = BatchGenerator(
ds_dask, input_dims={"sample": 2}, preload_batch=preload_batch
)
for batch in bg:
pass

@parameterized(
["input_dims", "batch_dims", "input_overlap"],
Expand All @@ -66,12 +70,14 @@ def time_batch_input(self, input_dims, batch_dims, input_overlap):
"""
Benchmark simple batch generation case.
"""
BatchGenerator(
bg = BatchGenerator(
self.ds_3d,
input_dims=input_dims,
batch_dims=batch_dims,
input_overlap=input_overlap,
)
for batch in bg:
pass

@parameterized(
["input_dims", "concat_input_dims"],
Expand All @@ -82,11 +88,13 @@ def time_batch_concat(self, input_dims, concat_input_dims):
Construct a generator on a DataSet with and without concatenating
chunks specified by ``input_dims`` into the batch dimension.
"""
BatchGenerator(
bg = BatchGenerator(
self.ds_3d,
input_dims=input_dims,
concat_input_dims=concat_input_dims,
)
for batch in bg:
pass

@parameterized(
["input_dims", "batch_dims", "concat_input_dims"],
Expand All @@ -101,12 +109,14 @@ def time_batch_concat_4d(self, input_dims, batch_dims, concat_input_dims):
Construct a generator on a DataSet with and without concatenating
chunks specified by ``input_dims`` into the batch dimension.
"""
BatchGenerator(
bg = BatchGenerator(
self.ds_4d,
input_dims=input_dims,
batch_dims=batch_dims,
concat_input_dims=concat_input_dims,
)
for batch in bg:
pass


class Accessor(Base):
Expand All @@ -119,27 +129,29 @@ def time_accessor_input_dim(self, input_dims):
Benchmark simple batch generation case using xarray accessor
Equivalent to subset of ``time_batch_input()``.
"""
self.ds_3d.batch.generator(input_dims=input_dims)
bg = self.ds_3d.batch.generator(input_dims=input_dims)
for batch in bg:
pass


class TorchLoader(Base):
def setup(self, *args, **kwargs):
super().setup(**kwargs)
self.x_gen = BatchGenerator(self.ds_xy["x"], {"sample": 10})
self.y_gen = BatchGenerator(self.ds_xy["y"], {"sample": 10})
self.x_gen = BatchGenerator(self.ds_2d["x"], {"sample": 10})
self.y_gen = BatchGenerator(self.ds_2d["y"], {"sample": 10})

def time_map_dataset(self):
"""
Benchmark MapDataset integration with torch DataLoader.
"""
dataset = MapDataset(self.x_gen, self.y_gen)
loader = torch.utils.data.DataLoader(dataset)
iter(loader).next()
next(iter(loader))

def time_iterable_dataset(self):
"""
Benchmark IterableDataset integration with torch DataLoader.
"""
dataset = IterableDataset(self.x_gen, self.y_gen)
loader = torch.utils.data.DataLoader(dataset)
iter(loader).next()
next(iter(loader))

0 comments on commit 89746a7

Please sign in to comment.