Skip to content

Commit

Permalink
Fix d4rl conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jul 24, 2021
1 parent 6ad6506 commit 8e141c0
Showing 1 changed file with 53 additions and 15 deletions.
68 changes: 53 additions & 15 deletions d3rlpy/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,23 +230,61 @@ def get_d4rl(
env = gym.make(env_name)
dataset = env.get_dataset()

observations = dataset["observations"][1:]
actions = dataset["actions"][1:]
rewards = dataset["rewards"][:-1]
terminals = np.logical_and(
dataset["terminals"][:-1], np.logical_not(dataset["timeouts"][:-1])
)
episode_terminals = np.logical_or(
dataset["terminals"][:-1], dataset["timeouts"][:-1]
)
episode_terminals[-1] = 1.0
observations = []
actions = []
rewards = []
terminals = []
episode_terminals = []
episode_step = 0
cursor = 0
dataset_size = dataset["observations"].shape[0]
while cursor < dataset_size:
# collect data for step=t
observation = dataset["observations"][cursor]
action = dataset["actions"][cursor]
if episode_step == 0:
reward = 0.0
else:
reward = dataset["rewards"][cursor - 1]

observations.append(observation)
actions.append(action)
rewards.append(reward)
terminals.append(0.0)

episode_step += 1

if dataset["timeouts"][cursor]:
# skip adding the last step
episode_terminals.append(1.0)
episode_step = 0
cursor += 1
continue
else:
episode_terminals.append(0.0)

if dataset["terminals"][cursor]:
# collect data for step=t+1
dummy_observation = observation.copy()
dummy_action = action.copy()
next_reward = dataset["rewards"][cursor]

# the last observation is rarely used
observations.append(dummy_observation)
actions.append(dummy_action)
rewards.append(next_reward)
terminals.append(1.0)
episode_terminals.append(1.0)
episode_step = 0

cursor += 1

mdp_dataset = MDPDataset(
observations=observations,
actions=actions,
rewards=rewards,
terminals=terminals,
episode_terminals=episode_terminals,
observations=np.array(observations, dtype=np.float32),
actions=np.array(actions, dtype=np.float32),
rewards=np.array(rewards, dtype=np.float32),
terminals=np.array(terminals, dtype=np.float32),
episode_terminals=np.array(episode_terminals, dtype=np.float32),
create_mask=create_mask,
mask_size=mask_size,
)
Expand Down

0 comments on commit 8e141c0

Please sign in to comment.