Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnecesary Memory usage of the TrajectoryContextDataset #13

Open
Danfoa opened this issue Jun 29, 2024 · 2 comments
Open

Unnecesary Memory usage of the TrajectoryContextDataset #13

Danfoa opened this issue Jun 29, 2024 · 2 comments

Comments

@Danfoa
Copy link
Contributor

Danfoa commented Jun 29, 2024

Hi @pietronvll,

There seems to be a very bad memory management within the TrajectoryContextDataset class.

When I do:

train_trajs = np.float32(train_trajs)
val_trajs = np.float32(val_trajs)
test_trajs = np.float32(test_trajs)

log.info(f"Training dataset memory footprint {train_trajs.nbytes / 1e6:.2f} [MB]")
# Split data into context windows using kooplean. Load data to device if available, and store reference of the
# raw dataset to do required regression tasks.
self.train_dataset = multi_traj_to_context(train_trajs,
                                           context_window_len=self.pred_horizon + self.lookback_len,
                                           backend="torch",
                                           device=self.device)

log.info(f"Torch dataset memory footprint {self.train_dataset.data.element_size() * self.train_dataset.data.nelement() / 1e6:.2f} [MB]")

We go from 93MB to 4GB of memory consumption, which becomes quite problematic with the default behaviour of this class which loads the data tensor to GPU for fast training.

[2024-06-29 13:15:48,471][data.DynamicsDataModule][INFO] - Training dataset memory footprint 93.00 [MB]
[2024-06-29 13:15:48,564][data.DynamicsDataModule][INFO] - Torch dataset memory footprint 4404.44 [MB]

Will update soon on details.

@Danfoa
Copy link
Contributor Author

Danfoa commented Jun 29, 2024

Since the hierarchy of classes of Datasets of ContextWindow is highly complex and difficult to edit/modify/justify. Here I propose an alternative setup which:

  1. Does not use any class hierarchy. That is the only class we inherit from is the abstract class torch.utils.data.Dataset.
  2. Handles both numpy and torch backends.
  3. Mantains the memory footprint from the original trajectories arrays/data
  4. Reduce the number of code lines used.

I prefer to do this since the current data managing pipeline's complex hierarchical structure and length are overly complex and complicated to understand and modify. As discussed also with @prolearner, we belive it is best to keep the OOP to a bare minimum and provide a reduced number of concise, well-documented, and simple classes/functions.

class TrajectoryContextDataset(torch.utils.data.Dataset):
    """Class for a collection of context windows with tensor features."""

    def __init__(
            self,
            trajectories: list[ArrayLike, ...],
            context_length: int = 2,
            time_lag: int = 1,
            backend: str = "numpy",
            shuffle: bool = False,
            seed: int = 1234,
            **backend_kw,
            ):
        """ Initialize a Dataset instance that can be passed to a torch.data.DataLoader.

        Args:
            trajectories: (list[ArrayLike, ...]) A list of trajectories (or potentially different time-length) of
                shape (time, *features).
            context_length: (int) Number of time-frames per context window. Default to 2.
            time_lag: (int) Time lag between successive context windows. Default to 1.
            backend: (str) Specifies the backend to be used (``'numpy'``, ``'torch'``). Default to ``'numpy'``.
            shuffle: (bool) If True, shuffles the context windows. Default to False.
            seed: (int) Seed for the random number generator. Default to 1234.
            **backend_kw: (dict) Keyword arguments to pass to the backend.
                If backend='torch', for instance it is possible to specify the device and type of the data samples.
                If backend='numpy', it is possible to specify the dtype of the data samples
        """
        if context_length < 1 and not isinstance(context_length, int):
            raise ValueError(f"context_length must be an interger >= 1, got {context_length}")

        if time_lag < 1:
            raise ValueError(f"time_lag must be >= 1, got {time_lag}")

        if isinstance(trajectories, list):
            raise ValueError(f"Expected list of trajectories of shape (time, *features), got {type(trajectories)}.")

        torch, backend = parse_backend(backend)

        self._backend = backend
        self._context_length = context_length
        self._time_lag = time_lag
        self._indices = []
        self._raw_data = []  # Variable containing the trajectories in the desired backed.

        # Convert trajectories to the desired backend. We copy data only once, and keep the original memory footprint.
        if backend == "numpy":  # If backend is numpy, we convert the data to numpy.
            self._raw_data = [np.array(traj, **backend_kw) for traj in trajectories]
        elif backend == "torch":  # Load raw data ONCE to GPU if specified in backend_kw.
            self._raw_data = [torch.tensor(traj, **backend_kw) for traj in trajectories]

        # Compute the list of indices (traj_idx, slice(start, end)) for each ContextWindow.
        for traj_idx, traj_data in enumerate(self._raw_data):
            if traj_data.ndim < 2:
                raise ShapeError(
                    f"Shape of trajectory {traj_idx} is {traj_data.shape}. Expected a 2D array of (time, *features)."
                    )
            context_window_slices = _slices_from_traj_len(time_horizon=traj_data.shape[0],
                                                          context_length=context_length,
                                                          time_lag=time_lag)
            # Store a tuple of (traj_idx, context window slice) for each context window.
            self._indices.extend([(traj_idx, s) for s in context_window_slices])

        self._memory_footprint = None
        self._shuffled = False

        if shuffle:
            self.shuffle(seed=seed)
        self.shuffle()

        log.info(f"TrajectoryContextDataset initialized with {len(self)} context windows.")

    def shuffle(self, seed: int = None):
        """Shuffles the context windows."""
        if seed is not None:
            np.random.seed(seed)
        np.random.shuffle(self._indices)
        self._shuffled = True

    @property
    def backend(self):
        return str(self._backend)

    @property
    def context_length(self):
        return int(self._context_length)

    @property
    def time_lag(self):
        return int(self._time_lag)

    @property
    def is_shuffled(self):
        return self._shuffled

    @property
    def memory_footprint(self):
        """Returns the memory footprint of the dataset in bytes."""
        if self._memory_footprint is None:
            if self._backend == "numpy":
                self._memory_footprint = sum(traj.nbytes for traj in self._raw_data)
            elif self._backend == "torch":
                self._memory_footprint = sum(traj.element_size() * traj.nelement() for traj in self._raw_data)
        return self._memory_footprint

    def __len__(self):
        return len(self._indices)

    def __getitem__(self, idx):
        traj_idx, slice_idx = self._indices[idx]
        sample = self._raw_data[traj_idx][slice_idx]
        return sample

    def __repr__(self):
        device = "cpu"
        if self._backend == "torch":
            if len(self._raw_data) > 0:
                device = self._raw_data[0].device
        return f"Memory use: {self.memory_footprint / 1e6:.2f} MB on {device}"

def _slices_from_traj_len(time_horizon: int, context_length: int, time_lag: int) -> list[slice]:
    """ Returns the list of slices (start_time_idx, end_time_idx) for each context window in the trajectory.
    Args:
        time_horizon: (int) Number time-frames of the trajectory.
        context_length: (int) Number of time-frames per context window
        time_lag: (int) Time lag between successive context windows.
    Returns:
        list[slice]: List of slices for each context window.

    Examples
    --------
    >>> time_horizon, context_length, time_lag = 10, 4, 2
    >>> slices = _slices_from_traj_len(time_horizon, context_length, time_lag)
    >>> for s in slices:
    ...     print(f"start: {s.start}, end: {s.stop}")
    start: 0, end: 4
    start: 2, end: 6
    start: 4, end: 8
    start: 6, end: 10
    """
    slices = []
    for start in range(0, time_horizon - context_length + 1, time_lag):
        end = start + context_length
        slices.append(slice(start, end))

    return slices



def traj_to_contexts(
        trajectory: np.ndarray,
        context_window_len: int = 2,
        time_lag: int = 1,
        backend: str = "auto",
        **backend_kwargs,
        ):
    """Transforms a single trajectory to a sequence of context windows.

    Args:
    ----
        trajectory (np.ndarray): A trajectory of shape ``(n_frames, *features_shape)``.
        context_window_len (int, optional): Length of the context window. Default to ``2``.
        time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``.
        backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``,
        will use the same backend of the trajectory. Default to ``'auto'``.
        backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``,
        it is possible to specify the device of the tensor.

    Returns:
    -------
        TrajectoryContextDataset: A sequence of context windows.
    """
    return TrajectoryContextDataset(
        trajectories=[trajectory],
        context_length=context_window_len,
        time_lag=time_lag,
        backend=backend,
        **backend_kwargs,
        )


def multi_traj_to_context(
        trajectories: list[ArrayLike, ...],
        context_window_len: int = 2,
        time_lag: int = 1,
        backend: str = "auto",
        **backend_kwargs,
        ):
    """Transforms a collection of trajectories to a sequence of context windows.

    Args:
    ----
        trajectories (np.ndarray): A trajectory of shape ``(n_trajs, n_frames, *features_shape)``.
        context_window_len (int, optional): Length of the context window. Default to ``2``.
        time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``.
        backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``,
        will use the same backend of the trajectory. Default to ``'auto'``.
        backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``,
        it is possible to specify the device of the tensor.

    Returns:
    -------
        TrajectoryContextDataset: A sequence of context windows.
    """
    return TrajectoryContextDataset(
        trajectories=trajectories,
        context_length=context_window_len,
        time_lag=time_lag,
        backend=backend,
        **backend_kwargs,
        )

@Danfoa
Copy link
Contributor Author

Danfoa commented Nov 26, 2024

@pietronvll @g-turri

Please have look at this again guys.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant