Skip to content

Commit

Permalink
⚡️ Simplify H5Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jun 14, 2023
1 parent e31d9d2 commit 3b52b13
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 62 deletions.
47 changes: 14 additions & 33 deletions docs/tutorials/simulators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"import torch\n",
"import zuko\n",
"\n",
"from itertools import islice\n",
"from itertools import chain, islice\n",
"from tqdm import tqdm"
]
},
Expand Down Expand Up @@ -382,7 +382,7 @@
"source": [
"## Loading from disk\n",
"\n",
"Now that the pairs are stored, we need to load them. The `H5Dataset` class creates an [`IterableDataset`](torch.utils.data.IterableDataset) of pairs $(\\theta, x)$ from HDF5 files. The pairs are dynamically loaded, meaning they are read from disk on demand instead of being cached in memory. This allows for very large datasets that do not even fit in memory."
"Now that the pairs are stored, we need to load them. The `H5Dataset` class creates an [`IterableDataset`](torch.utils.data.IterableDataset) of pairs $(\\theta, x)$ from an HDF5 file. The pairs are dynamically loaded, meaning they are read from disk on demand instead of being cached in memory. This allows for very large datasets that do not even fit in memory."
]
},
{
Expand Down Expand Up @@ -522,7 +522,7 @@
"source": [
"## Merging datasets\n",
"\n",
"Another feature of `H5Dataset` is that it can load data from any number of HDF5 files."
"`H5Dataset` can only load data from a single HDF5 file, but it is easy to aggregate data from multiple sources, such as data generated on several machines or at different time."
]
},
{
Expand All @@ -539,17 +539,17 @@
}
],
"source": [
"dataset = lampe.data.H5Dataset('data_0.h5', 'data_1.h5', 'data_2.h5')\n",
"datasets = [\n",
" lampe.data.H5Dataset('data_0.h5', batch_size=512),\n",
" lampe.data.H5Dataset('data_1.h5', batch_size=256),\n",
" lampe.data.H5Dataset('data_2.h5', batch_size=128),\n",
"]\n",
"\n",
"for theta, x in tqdm(dataset):\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This feature is very useful when aggregating data from multiple sources, such as data generated on several machines or at different times."
"lampe.data.H5Dataset.store(\n",
" pairs=chain(*datasets),\n",
" file='data_all.h5',\n",
" size=sum(map(len, datasets)),\n",
")"
]
},
{
Expand All @@ -561,26 +561,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 196608/196608 [00:00<00:00, 423420.86sample/s]\n"
]
}
],
"source": [
"dataset = lampe.data.H5Dataset('data_0.h5', 'data_1.h5', 'data_2.h5', batch_size=256)\n",
"\n",
"lampe.data.H5Dataset.store(dataset, 'data_all.h5', size=len(dataset))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 196608/196608 [00:00<00:00, 226463.55it/s]\n"
"100%|██████████| 196608/196608 [00:00<00:00, 242486.34it/s]\n"
]
}
],
Expand Down
53 changes: 24 additions & 29 deletions lampe/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self,
prior: Distribution,
simulator: Callable,
batch_size: int = 2**10, # 1024
batch_size: int = 2**8, # 256
vectorized: bool = False,
numpy: bool = False,
**kwargs,
Expand Down Expand Up @@ -158,12 +158,11 @@ def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:


class H5Dataset(IterableDataset):
r"""Creates an iterable dataset of pairs :math:`(\theta, x)` from HDF5 files.
r"""Creates an iterable dataset of pairs :math:`(\theta, x)` from an HDF5 file.
As it can be slow to load pairs from disk one by one, :class:`H5Dataset` implements
a custom :meth:`__iter__` method that loads several contiguous chunks of pairs at
once and shuffles their concatenation before yielding the pairs one by one,
unless a batch size is provided.
once and shuffles their concatenation before yielding the pairs.
:class:`H5Dataset` also implements the :meth:`__len__` and :meth:`__getitem__`
methods for convenience.
Expand All @@ -176,7 +175,7 @@ class H5Dataset(IterableDataset):
necessary to wrap the dataset in a :class:`torch.utils.data.DataLoader`.
Arguments:
files: HDF5 files containing pairs :math:`(\theta, x)`.
file: An HDF5 file containing pairs :math:`(\theta, x)`.
batch_size: The size of the batches.
chunk_size: The size of the contiguous chunks.
chunk_step: The number of chunks loaded at once.
Expand All @@ -194,54 +193,53 @@ class H5Dataset(IterableDataset):

def __init__(
self,
*files: Union[str, Path],
file: Union[str, Path],
batch_size: int = None,
chunk_size: int = 2**10, # 1024
chunk_size: int = 2**8, # 256
chunk_step: str = 2**8, # 256
shuffle: bool = False,
):
super().__init__()

self.files = [h5py.File(f, mode='r') for f in files]
self.sizes = [f['theta'].shape[0] for f in self.files]
self.cumsizes = np.cumsum(self.sizes)
self.file = h5py.File(file, mode='r')

self.batch_size = batch_size
self.chunk_size = chunk_size
self.chunk_step = chunk_step
self.shuffle = shuffle

def __len__(self) -> int:
return self.cumsizes[-1]

def __getitem__(self, i: int) -> Tuple[Tensor, Tensor]:
i = i % len(self)
j = bisect(self.cumsizes, i)
if j > 0:
i = i - self.cumsizes[j - 1]
return len(self.file['theta'])

f = self.files[j]
theta, x = f['theta'][i], f['x'][i]
def __getitem__(self, i: Union[int, slice]) -> Tuple[Tensor, Tensor]:
theta, x = self.file['theta'][i], self.file['x'][i]

return torch.from_numpy(theta), torch.from_numpy(x)

def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
chunks = torch.tensor([
(i, j, j + self.chunk_size)
for i, size in enumerate(self.sizes)
for j in range(0, size, self.chunk_size)
(i, i + self.chunk_size)
for i in range(0, len(self), self.chunk_size)
])

if self.shuffle:
order = torch.randperm(len(chunks))
chunks = chunks[order]

for slices in chunks.split(self.chunk_step):
# Merge contiguous slices
slices = sorted(slices.tolist())
stack = []

for s in slices:
if stack and stack[-1][-1] == s[0]:
stack[-1][-1] = s[-1]
else:
stack.append(s)

# Load
theta = np.concatenate([self.files[i]['theta'][j:k] for i, j, k in slices])
x = np.concatenate([self.files[i]['x'][j:k] for i, j, k in slices])
theta = np.concatenate([self.file['theta'][i:j] for i, j in stack])
x = np.concatenate([self.file['x'][i:j] for i, j in stack])

theta, x = torch.from_numpy(theta), torch.from_numpy(x)

Expand All @@ -266,12 +264,9 @@ def to_memory(self) -> JointDataset:
>>> dataset = H5Dataset('data.h5').to_memory()
"""

theta = np.concatenate([f['theta'][:] for f in self.files])
x = np.concatenate([f['x'][:] for f in self.files])

return JointDataset(
theta,
x,
self.file['theta'][:],
self.file['x'][:],
batch_size=self.batch_size,
shuffle=self.shuffle,
)
Expand Down

0 comments on commit 3b52b13

Please sign in to comment.