Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix register id problem #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion algos/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def format_obs(self, obs, instr_dropout_prob=0):
without_obs = [] if cutoff == 0 else [self.obs_preprocessor(obs[:cutoff], self.teacher, show_instrs=False)]
with_obs = [] if cutoff == len(obs) else [self.obs_preprocessor(obs[cutoff:], self.teacher, show_instrs=True)]
obs = without_obs + with_obs
obs = merge_dictlists(obs)
obs = merge_dictlists(obs, device=self.device)
if self.state_encoder is not None:
obs = self.state_encoder(obs)
if self.task_encoder is not None:
Expand Down
14 changes: 10 additions & 4 deletions algos/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class DataCollector(ABC):
"""The collection class."""

def __init__(self, collect_policy, envs, args, repeated_seed=None):
def __init__(self, collect_policy, envs, args, repeated_seed=None, device=None):

if not args.sequential:
self.env = ParallelEnv(envs, repeated_seed=repeated_seed)
Expand All @@ -22,7 +22,10 @@ def __init__(self, collect_policy, envs, args, repeated_seed=None):


# Store helpers values
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
self.num_procs = len(envs)
self.num_frames = self.args.frames_per_proc * self.num_procs

Expand Down Expand Up @@ -93,8 +96,9 @@ def collect_experiences(self, collect_with_oracle=False, collect_reward=True, tr
"""
policy = self.policy
policy.train(train)

ncount = 0
for i in range(self.args.frames_per_proc):
ncount += 1
with torch.no_grad():
action, agent_dict = policy.act(list(self.obs), sample=True,
instr_dropout_prob=self.args.collect_dropout_prob)
Expand Down Expand Up @@ -152,12 +156,14 @@ def collect_experiences(self, collect_with_oracle=False, collect_reward=True, tr
self.log_episode_success += torch.tensor([e['success'] for e in env_info], device=self.device, dtype=torch.float)
self.log_episode_reshaped_return += self.rewards[i]
self.log_episode_num_frames += torch.ones(self.num_procs, device=self.device)

# print(f"self log_episode_success: {self.log_episode_success}")
for i, done_ in enumerate(done):
if done_:
# print(f"done on frame {ncount} index {i}")
self.log_done_counter += 1
self.log_return.append(self.log_episode_return[i].item())
self.log_success.append(self.log_episode_success[i].item())
# print(f"self log_success: {self.log_success}")
if 'dist_to_goal' in env_info[i]:
self.log_dist_to_goal.append(env_info[i]['dist_to_goal'].item())
self.log_reshaped_return.append(self.log_episode_reshaped_return[i].item())
Expand Down
3 changes: 3 additions & 0 deletions algos/mf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,14 @@ def train(self):
self.should_collect = self.args.collect_teacher is not None
self.should_train_rl = self.args.rl_teacher is not None
if self.should_collect:
logger.log("collect data...")
# Collect if we are distilling OR if we're not skipping
samples_data, episode_logs = self.sampler.collect_experiences(
collect_with_oracle=self.args.collect_with_oracle,
collect_reward=self.should_train_rl,
train=self.should_train_rl)
if self.relabel_policy is not None:
logger.log("relabel samples ...")
samples_data = self.relabel(samples_data)
buffer_start = time.time()
self.buffer.add_batch(samples_data, save=self.itr % 200 == 0)
Expand All @@ -199,6 +201,7 @@ def train(self):
logger.log("RL Training...")
for _ in range(self.args.epochs):
if self.args.on_policy:
logger.log("on_policy data...")
sampled_batch = samples_data
else:
sampled_batch = self.buffer.sample(total_num_samples=self.args.batch_size, split='train')
Expand Down
6 changes: 6 additions & 0 deletions envs/d4rl/d4rl_content/gym_bullet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from gym.envs.registration import register
from envs.d4rl.d4rl_content.gym_bullet import gym_envs
from envs.d4rl.d4rl_content import infos
import gym

env_dict = gym.envs.registration.registry.env_specs.copy()
for env in env_dict:
if 'bullet' in env:
print('Remove {} from registry'.format(env))
del gym.envs.registration.registry.env_specs[env]

for agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']:
register(
Expand Down
6 changes: 6 additions & 0 deletions envs/d4rl/d4rl_content/locomotion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gym
from gym.envs.registration import register
from envs.d4rl.d4rl_content.locomotion import ant
from envs.d4rl.d4rl_content.locomotion import maze_env
Expand All @@ -19,6 +20,11 @@
}
)
"""
env_dict = gym.envs.registration.registry.env_specs.copy()
for env in env_dict:
if 'antmaze' in env:
print('Remove {} from registry'.format(env))
del gym.envs.registration.registry.env_specs[env]

register(
id='antmaze-open-v0',
Expand Down
9 changes: 8 additions & 1 deletion envs/d4rl/d4rl_content/pointmaze/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .maze_model import MazeEnv, OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL, TWELVE, FIFTEEN
from gym.envs.registration import register

import gym

env_dict = gym.envs.registration.registry.env_specs.copy()
for env in env_dict:
if 'maze2d' in env:
print('Remove {} from registry'.format(env))
del gym.envs.registration.registry.env_specs[env]

register(
id='maze2d-open-v0',
entry_point='envs.d4rl.d4rl_content.pointmaze:MazeEnv',
Expand Down
Loading