-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unnecesary Memory usage of the TrajectoryContextDataset
#13
Comments
Since the hierarchy of classes of Datasets of ContextWindow is highly complex and difficult to edit/modify/justify. Here I propose an alternative setup which:
I prefer to do this since the current data managing pipeline's complex hierarchical structure and length are overly complex and complicated to understand and modify. As discussed also with @prolearner, we belive it is best to keep the OOP to a bare minimum and provide a reduced number of concise, well-documented, and simple classes/functions. class TrajectoryContextDataset(torch.utils.data.Dataset):
"""Class for a collection of context windows with tensor features."""
def __init__(
self,
trajectories: list[ArrayLike, ...],
context_length: int = 2,
time_lag: int = 1,
backend: str = "numpy",
shuffle: bool = False,
seed: int = 1234,
**backend_kw,
):
""" Initialize a Dataset instance that can be passed to a torch.data.DataLoader.
Args:
trajectories: (list[ArrayLike, ...]) A list of trajectories (or potentially different time-length) of
shape (time, *features).
context_length: (int) Number of time-frames per context window. Default to 2.
time_lag: (int) Time lag between successive context windows. Default to 1.
backend: (str) Specifies the backend to be used (``'numpy'``, ``'torch'``). Default to ``'numpy'``.
shuffle: (bool) If True, shuffles the context windows. Default to False.
seed: (int) Seed for the random number generator. Default to 1234.
**backend_kw: (dict) Keyword arguments to pass to the backend.
If backend='torch', for instance it is possible to specify the device and type of the data samples.
If backend='numpy', it is possible to specify the dtype of the data samples
"""
if context_length < 1 and not isinstance(context_length, int):
raise ValueError(f"context_length must be an interger >= 1, got {context_length}")
if time_lag < 1:
raise ValueError(f"time_lag must be >= 1, got {time_lag}")
if isinstance(trajectories, list):
raise ValueError(f"Expected list of trajectories of shape (time, *features), got {type(trajectories)}.")
torch, backend = parse_backend(backend)
self._backend = backend
self._context_length = context_length
self._time_lag = time_lag
self._indices = []
self._raw_data = [] # Variable containing the trajectories in the desired backed.
# Convert trajectories to the desired backend. We copy data only once, and keep the original memory footprint.
if backend == "numpy": # If backend is numpy, we convert the data to numpy.
self._raw_data = [np.array(traj, **backend_kw) for traj in trajectories]
elif backend == "torch": # Load raw data ONCE to GPU if specified in backend_kw.
self._raw_data = [torch.tensor(traj, **backend_kw) for traj in trajectories]
# Compute the list of indices (traj_idx, slice(start, end)) for each ContextWindow.
for traj_idx, traj_data in enumerate(self._raw_data):
if traj_data.ndim < 2:
raise ShapeError(
f"Shape of trajectory {traj_idx} is {traj_data.shape}. Expected a 2D array of (time, *features)."
)
context_window_slices = _slices_from_traj_len(time_horizon=traj_data.shape[0],
context_length=context_length,
time_lag=time_lag)
# Store a tuple of (traj_idx, context window slice) for each context window.
self._indices.extend([(traj_idx, s) for s in context_window_slices])
self._memory_footprint = None
self._shuffled = False
if shuffle:
self.shuffle(seed=seed)
self.shuffle()
log.info(f"TrajectoryContextDataset initialized with {len(self)} context windows.")
def shuffle(self, seed: int = None):
"""Shuffles the context windows."""
if seed is not None:
np.random.seed(seed)
np.random.shuffle(self._indices)
self._shuffled = True
@property
def backend(self):
return str(self._backend)
@property
def context_length(self):
return int(self._context_length)
@property
def time_lag(self):
return int(self._time_lag)
@property
def is_shuffled(self):
return self._shuffled
@property
def memory_footprint(self):
"""Returns the memory footprint of the dataset in bytes."""
if self._memory_footprint is None:
if self._backend == "numpy":
self._memory_footprint = sum(traj.nbytes for traj in self._raw_data)
elif self._backend == "torch":
self._memory_footprint = sum(traj.element_size() * traj.nelement() for traj in self._raw_data)
return self._memory_footprint
def __len__(self):
return len(self._indices)
def __getitem__(self, idx):
traj_idx, slice_idx = self._indices[idx]
sample = self._raw_data[traj_idx][slice_idx]
return sample
def __repr__(self):
device = "cpu"
if self._backend == "torch":
if len(self._raw_data) > 0:
device = self._raw_data[0].device
return f"Memory use: {self.memory_footprint / 1e6:.2f} MB on {device}"
def _slices_from_traj_len(time_horizon: int, context_length: int, time_lag: int) -> list[slice]:
""" Returns the list of slices (start_time_idx, end_time_idx) for each context window in the trajectory.
Args:
time_horizon: (int) Number time-frames of the trajectory.
context_length: (int) Number of time-frames per context window
time_lag: (int) Time lag between successive context windows.
Returns:
list[slice]: List of slices for each context window.
Examples
--------
>>> time_horizon, context_length, time_lag = 10, 4, 2
>>> slices = _slices_from_traj_len(time_horizon, context_length, time_lag)
>>> for s in slices:
... print(f"start: {s.start}, end: {s.stop}")
start: 0, end: 4
start: 2, end: 6
start: 4, end: 8
start: 6, end: 10
"""
slices = []
for start in range(0, time_horizon - context_length + 1, time_lag):
end = start + context_length
slices.append(slice(start, end))
return slices
def traj_to_contexts(
trajectory: np.ndarray,
context_window_len: int = 2,
time_lag: int = 1,
backend: str = "auto",
**backend_kwargs,
):
"""Transforms a single trajectory to a sequence of context windows.
Args:
----
trajectory (np.ndarray): A trajectory of shape ``(n_frames, *features_shape)``.
context_window_len (int, optional): Length of the context window. Default to ``2``.
time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``.
backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``,
will use the same backend of the trajectory. Default to ``'auto'``.
backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``,
it is possible to specify the device of the tensor.
Returns:
-------
TrajectoryContextDataset: A sequence of context windows.
"""
return TrajectoryContextDataset(
trajectories=[trajectory],
context_length=context_window_len,
time_lag=time_lag,
backend=backend,
**backend_kwargs,
)
def multi_traj_to_context(
trajectories: list[ArrayLike, ...],
context_window_len: int = 2,
time_lag: int = 1,
backend: str = "auto",
**backend_kwargs,
):
"""Transforms a collection of trajectories to a sequence of context windows.
Args:
----
trajectories (np.ndarray): A trajectory of shape ``(n_trajs, n_frames, *features_shape)``.
context_window_len (int, optional): Length of the context window. Default to ``2``.
time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``.
backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``,
will use the same backend of the trajectory. Default to ``'auto'``.
backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``,
it is possible to specify the device of the tensor.
Returns:
-------
TrajectoryContextDataset: A sequence of context windows.
"""
return TrajectoryContextDataset(
trajectories=trajectories,
context_length=context_window_len,
time_lag=time_lag,
backend=backend,
**backend_kwargs,
) |
Please have look at this again guys. |
Hi @pietronvll,
There seems to be a very bad memory management within the
TrajectoryContextDataset
class.When I do:
We go from 93MB to 4GB of memory consumption, which becomes quite problematic with the default behaviour of this class which loads the data tensor to GPU for fast training.
Will update soon on details.
The text was updated successfully, but these errors were encountered: