Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add function for generating TrajectoryContextDataset from a list of trajectories #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions kooplearn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,59 @@ def _contexts_from_traj_np(trajectory, context_length, time_lag):
data = np.moveaxis(data, -1, 1)[:, ::time_lag, ...]
return data, TensorContextDataset(idx_map)

def get_multi_trajectorycontextdataset(
trajectories: ArrayLike,
context_length: int = 2,
time_lag: int = 1,
backend: str = "auto",
**backend_kw,
):
"""
Generates a TrajectoryContextDataset from multiple trajectories.
It takes as input a list of trajectories and returns a sequence of context windows.

Args:
trajectories (ArrayLike): A trajectory of shape ``(n_trajs, n_frames, *features_shape)``.
context_length (int, optional): Length of the context window. Default to ``2``.
time_lag (int, optional): Time lag, i.e. stride, between successive context windows. Default to ``1``.
backend (str, optional): Specifies the backend to be used (``'numpy'``, ``'torch'``). If set to ``'auto'``, will use the same backend of the trajectory. Default to ``'auto'``.
backend_kw (dict, optional): Keyword arguments to pass to the backend. For example, if ``'torch'``, it is possible to specify the device of the tensor.

class MultiTrajectoryContextDataset(TrajectoryContextDataset):
def __init__(self):
raise NotImplementedError
Returns:
data (TrajectoryContextDataset): A sequence of context windows aggregating all trajectories
"""

if trajectories.ndim < 3:
raise ShapeError(
f"Invalid trajectories shape {trajectories.shape}. The trajectories should be at least 3D."
)
torch, backend = parse_backend(backend)

#initialize TrajectoryContextDataset with the first trajectory
data = traj_to_contexts(trajectories[0],context_window_len=context_length,time_lag=time_lag,backend=backend,**backend_kw)

if torch is not None and torch.is_tensor(data):
concat = lambda x,y,axis: torch.concatenate((x,y),dim=axis)
dt = torch
else:
concat = lambda x,y,axis: np.concatenate((x,y),axis=axis)
dt = np

idx = dt.zeros(data.idx_map.shape)
data.idx_map = concat(idx,data.idx_map,axis=-1)
#idx_map will now have the index of the trajectory as well
for i in range(1,len(trajectories)):

new_traj = traj_to_contexts(trajectories[i],context_window_len=context_length,time_lag=time_lag)
data.data = concat(data.data,new_traj.data,axis=0)

#update index map for consistency
idx = np.zeros(new_traj.idx_map.shape) + i
new_traj.idx_map = concat(idx,new_traj.idx_map,axis=-1)

data.idx_map = concat(data.idx_map,new_traj.idx_map,axis=0)

return data


def traj_to_contexts(
Expand Down