From 3b52b13af56d2e189b1d435fcc5c2f04305e0d17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Tue, 13 Jun 2023 10:42:29 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Simplify=20H5Dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorials/simulators.ipynb | 47 +++++++++-------------------- lampe/data.py | 53 +++++++++++++++------------------ 2 files changed, 38 insertions(+), 62 deletions(-) diff --git a/docs/tutorials/simulators.ipynb b/docs/tutorials/simulators.ipynb index e600a2d..abc25cc 100644 --- a/docs/tutorials/simulators.ipynb +++ b/docs/tutorials/simulators.ipynb @@ -23,7 +23,7 @@ "import torch\n", "import zuko\n", "\n", - "from itertools import islice\n", + "from itertools import chain, islice\n", "from tqdm import tqdm" ] }, @@ -382,7 +382,7 @@ "source": [ "## Loading from disk\n", "\n", - "Now that the pairs are stored, we need to load them. The `H5Dataset` class creates an [`IterableDataset`](torch.utils.data.IterableDataset) of pairs $(\\theta, x)$ from HDF5 files. The pairs are dynamically loaded, meaning they are read from disk on demand instead of being cached in memory. This allows for very large datasets that do not even fit in memory." + "Now that the pairs are stored, we need to load them. The `H5Dataset` class creates an [`IterableDataset`](torch.utils.data.IterableDataset) of pairs $(\\theta, x)$ from an HDF5 file. The pairs are dynamically loaded, meaning they are read from disk on demand instead of being cached in memory. This allows for very large datasets that do not even fit in memory." ] }, { @@ -522,7 +522,7 @@ "source": [ "## Merging datasets\n", "\n", - "Another feature of `H5Dataset` is that it can load data from any number of HDF5 files." + "`H5Dataset` can only load data from a single HDF5 file, but it is easy to aggregate data from multiple sources, such as data generated on several machines or at different time." ] }, { @@ -539,17 +539,17 @@ } ], "source": [ - "dataset = lampe.data.H5Dataset('data_0.h5', 'data_1.h5', 'data_2.h5')\n", + "datasets = [\n", + " lampe.data.H5Dataset('data_0.h5', batch_size=512),\n", + " lampe.data.H5Dataset('data_1.h5', batch_size=256),\n", + " lampe.data.H5Dataset('data_2.h5', batch_size=128),\n", + "]\n", "\n", - "for theta, x in tqdm(dataset):\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This feature is very useful when aggregating data from multiple sources, such as data generated on several machines or at different times." + "lampe.data.H5Dataset.store(\n", + " pairs=chain(*datasets),\n", + " file='data_all.h5',\n", + " size=sum(map(len, datasets)),\n", + ")" ] }, { @@ -561,26 +561,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 196608/196608 [00:00<00:00, 423420.86sample/s]\n" - ] - } - ], - "source": [ - "dataset = lampe.data.H5Dataset('data_0.h5', 'data_1.h5', 'data_2.h5', batch_size=256)\n", - "\n", - "lampe.data.H5Dataset.store(dataset, 'data_all.h5', size=len(dataset))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 196608/196608 [00:00<00:00, 226463.55it/s]\n" + "100%|██████████| 196608/196608 [00:00<00:00, 242486.34it/s]\n" ] } ], diff --git a/lampe/data.py b/lampe/data.py index 7bbe322..d9f44a7 100644 --- a/lampe/data.py +++ b/lampe/data.py @@ -77,7 +77,7 @@ def __init__( self, prior: Distribution, simulator: Callable, - batch_size: int = 2**10, # 1024 + batch_size: int = 2**8, # 256 vectorized: bool = False, numpy: bool = False, **kwargs, @@ -158,12 +158,11 @@ def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]: class H5Dataset(IterableDataset): - r"""Creates an iterable dataset of pairs :math:`(\theta, x)` from HDF5 files. + r"""Creates an iterable dataset of pairs :math:`(\theta, x)` from an HDF5 file. As it can be slow to load pairs from disk one by one, :class:`H5Dataset` implements a custom :meth:`__iter__` method that loads several contiguous chunks of pairs at - once and shuffles their concatenation before yielding the pairs one by one, - unless a batch size is provided. + once and shuffles their concatenation before yielding the pairs. :class:`H5Dataset` also implements the :meth:`__len__` and :meth:`__getitem__` methods for convenience. @@ -176,7 +175,7 @@ class H5Dataset(IterableDataset): necessary to wrap the dataset in a :class:`torch.utils.data.DataLoader`. Arguments: - files: HDF5 files containing pairs :math:`(\theta, x)`. + file: An HDF5 file containing pairs :math:`(\theta, x)`. batch_size: The size of the batches. chunk_size: The size of the contiguous chunks. chunk_step: The number of chunks loaded at once. @@ -194,17 +193,15 @@ class H5Dataset(IterableDataset): def __init__( self, - *files: Union[str, Path], + file: Union[str, Path], batch_size: int = None, - chunk_size: int = 2**10, # 1024 + chunk_size: int = 2**8, # 256 chunk_step: str = 2**8, # 256 shuffle: bool = False, ): super().__init__() - self.files = [h5py.File(f, mode='r') for f in files] - self.sizes = [f['theta'].shape[0] for f in self.files] - self.cumsizes = np.cumsum(self.sizes) + self.file = h5py.File(file, mode='r') self.batch_size = batch_size self.chunk_size = chunk_size @@ -212,24 +209,17 @@ def __init__( self.shuffle = shuffle def __len__(self) -> int: - return self.cumsizes[-1] - - def __getitem__(self, i: int) -> Tuple[Tensor, Tensor]: - i = i % len(self) - j = bisect(self.cumsizes, i) - if j > 0: - i = i - self.cumsizes[j - 1] + return len(self.file['theta']) - f = self.files[j] - theta, x = f['theta'][i], f['x'][i] + def __getitem__(self, i: Union[int, slice]) -> Tuple[Tensor, Tensor]: + theta, x = self.file['theta'][i], self.file['x'][i] return torch.from_numpy(theta), torch.from_numpy(x) def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]: chunks = torch.tensor([ - (i, j, j + self.chunk_size) - for i, size in enumerate(self.sizes) - for j in range(0, size, self.chunk_size) + (i, i + self.chunk_size) + for i in range(0, len(self), self.chunk_size) ]) if self.shuffle: @@ -237,11 +227,19 @@ def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]: chunks = chunks[order] for slices in chunks.split(self.chunk_step): + # Merge contiguous slices slices = sorted(slices.tolist()) + stack = [] + + for s in slices: + if stack and stack[-1][-1] == s[0]: + stack[-1][-1] = s[-1] + else: + stack.append(s) # Load - theta = np.concatenate([self.files[i]['theta'][j:k] for i, j, k in slices]) - x = np.concatenate([self.files[i]['x'][j:k] for i, j, k in slices]) + theta = np.concatenate([self.file['theta'][i:j] for i, j in stack]) + x = np.concatenate([self.file['x'][i:j] for i, j in stack]) theta, x = torch.from_numpy(theta), torch.from_numpy(x) @@ -266,12 +264,9 @@ def to_memory(self) -> JointDataset: >>> dataset = H5Dataset('data.h5').to_memory() """ - theta = np.concatenate([f['theta'][:] for f in self.files]) - x = np.concatenate([f['x'][:] for f in self.files]) - return JointDataset( - theta, - x, + self.file['theta'][:], + self.file['x'][:], batch_size=self.batch_size, shuffle=self.shuffle, )