From 47975de2cb8f47a471634c5361242a03be33d8d2 Mon Sep 17 00:00:00 2001 From: Mitchell Ostrow Date: Thu, 13 Jun 2024 19:39:53 -0400 Subject: [PATCH] add function for generating trajectorycontextdataset from a list of trajectories --- kooplearn/data.py | 55 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/kooplearn/data.py b/kooplearn/data.py index 60a86a1..20ed18a 100644 --- a/kooplearn/data.py +++ b/kooplearn/data.py @@ -206,10 +206,59 @@ def _contexts_from_traj_np(trajectory, context_length, time_lag): data = np.moveaxis(data, -1, 1)[:, ::time_lag, ...] return data, TensorContextDataset(idx_map) +def get_multi_trajectorycontextdataset( + trajectories: ArrayLike, + context_length: int = 2, + time_lag: int = 1, + backend: str = "auto", + **backend_kw, + ): + """ + Generates a TrajectoryContextDataset from multiple trajectories. + It takes as input a list of trajectories and returns a sequence of context windows. + + Args: + trajectories (ArrayLike): A trajectory of shape ``(n_trajs, n_frames, *features_shape)``. + context_length (int, optional): Length of the context window. Default to ``2``. + time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``. + backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``, will use the same backend of the trajectory. Default to ``'auto'``. + backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``, it is possible to specify the device of the tensor. -class MultiTrajectoryContextDataset(TrajectoryContextDataset): - def __init__(self): - raise NotImplementedError + Returns: + data (TrajectoryContextDataset): A sequence of context windows aggregating all trajectories + """ + + if trajectories.ndim < 3: + raise ShapeError( + f"Invalid trajectories shape {trajectories.shape}. The trajectories should be at least 3D." + ) + torch, backend = parse_backend(backend) + + #initialize TrajectoryContextDataset with the first trajectory + data = traj_to_contexts(trajectories[0],context_window_len=context_length,time_lag=time_lag,backend=backend,**backend_kw) + + if torch is not None and torch.is_tensor(data): + concat = lambda x,y,axis: torch.concatenate((x,y),dim=axis) + dt = torch + else: + concat = lambda x,y,axis: np.concatenate((x,y),axis=axis) + dt = np + + idx = dt.zeros(data.idx_map.shape) + data.idx_map = concat(idx,data.idx_map,axis=-1) + #idx_map will now have the index of the trajectory as well + for i in range(1,len(trajectories)): + + new_traj = traj_to_contexts(trajectories[i],context_window_len=context_length,time_lag=time_lag) + data.data = concat(data.data,new_traj.data,axis=0) + + #update index map for consistency + idx = np.zeros(new_traj.idx_map.shape) + i + new_traj.idx_map = concat(idx,new_traj.idx_map,axis=-1) + + data.idx_map = concat(data.idx_map,new_traj.idx_map,axis=0) + + return data def traj_to_contexts(