From 8e141c043db7a551875791c2c76db89cc140038f Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 24 Jul 2021 21:05:50 +0900 Subject: [PATCH] Fix d4rl conversion --- d3rlpy/datasets.py | 68 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/d3rlpy/datasets.py b/d3rlpy/datasets.py index b9977803..179a7f27 100644 --- a/d3rlpy/datasets.py +++ b/d3rlpy/datasets.py @@ -230,23 +230,61 @@ def get_d4rl( env = gym.make(env_name) dataset = env.get_dataset() - observations = dataset["observations"][1:] - actions = dataset["actions"][1:] - rewards = dataset["rewards"][:-1] - terminals = np.logical_and( - dataset["terminals"][:-1], np.logical_not(dataset["timeouts"][:-1]) - ) - episode_terminals = np.logical_or( - dataset["terminals"][:-1], dataset["timeouts"][:-1] - ) - episode_terminals[-1] = 1.0 + observations = [] + actions = [] + rewards = [] + terminals = [] + episode_terminals = [] + episode_step = 0 + cursor = 0 + dataset_size = dataset["observations"].shape[0] + while cursor < dataset_size: + # collect data for step=t + observation = dataset["observations"][cursor] + action = dataset["actions"][cursor] + if episode_step == 0: + reward = 0.0 + else: + reward = dataset["rewards"][cursor - 1] + + observations.append(observation) + actions.append(action) + rewards.append(reward) + terminals.append(0.0) + + episode_step += 1 + + if dataset["timeouts"][cursor]: + # skip adding the last step + episode_terminals.append(1.0) + episode_step = 0 + cursor += 1 + continue + else: + episode_terminals.append(0.0) + + if dataset["terminals"][cursor]: + # collect data for step=t+1 + dummy_observation = observation.copy() + dummy_action = action.copy() + next_reward = dataset["rewards"][cursor] + + # the last observation is rarely used + observations.append(dummy_observation) + actions.append(dummy_action) + rewards.append(next_reward) + terminals.append(1.0) + episode_terminals.append(1.0) + episode_step = 0 + + cursor += 1 mdp_dataset = MDPDataset( - observations=observations, - actions=actions, - rewards=rewards, - terminals=terminals, - episode_terminals=episode_terminals, + observations=np.array(observations, dtype=np.float32), + actions=np.array(actions, dtype=np.float32), + rewards=np.array(rewards, dtype=np.float32), + terminals=np.array(terminals, dtype=np.float32), + episode_terminals=np.array(episode_terminals, dtype=np.float32), create_mask=create_mask, mask_size=mask_size, )