Skip to content

Commit

Permalink
used flags.level_name
Browse files Browse the repository at this point in the history
Signed-off-by: Shakti Kumar <[email protected]>
  • Loading branch information
shaktikshri committed Mar 25, 2020
1 parent f83660b commit b4c8536
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions torchbeast/monobeast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import traceback
import typing
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

os.environ["OMP_NUM_THREADS"] = "1" # Necessary for multithreading.
Expand All @@ -37,7 +38,6 @@
from torchbeast.core import prof
from torchbeast.core import vtrace


# yapf: disable
parser = argparse.ArgumentParser(description="PyTorch Scalable Agent")

Expand Down Expand Up @@ -70,6 +70,8 @@
help="Disable CUDA.")
parser.add_argument("--use_lstm", action="store_true",
help="Use LSTM in agent model.")
parser.add_argument("--level_name", type=str, default="rooms_collect_good_objects_train",
help="dmlab30 level name")

# Loss settings.
parser.add_argument("--entropy_cost", default=0.0006,
Expand Down Expand Up @@ -126,23 +128,24 @@ def compute_policy_gradient_loss(logits, actions, advantages):
cross_entropy = cross_entropy.view_as(advantages)
return torch.sum(cross_entropy * advantages.detach())


######changes function input with level name
def act(
flags,
actor_index: int,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
model: torch.nn.Module,
buffers: Buffers,
initial_agent_state_buffers,
level_name
flags,
actor_index: int,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
model: torch.nn.Module,
buffers: Buffers,
initial_agent_state_buffers,
level_name
):
try:
logging.info("Actor %i started.", actor_index)
timings = prof.Timings() # Keep track of how fast things are.
seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
######changed next line
gym_env = create_env(flags,level_name,seed)
gym_env = create_env(flags, level_name, seed)
env = environment.Environment(gym_env)
env_output = env.initial()
agent_state = model.initial_state(batch_size=1)
Expand All @@ -169,8 +172,6 @@ def act(

timings.time("model")



env_output = env.step(agent_output["action"])

timings.time("step")
Expand All @@ -196,13 +197,13 @@ def act(


def get_batch(
flags,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
buffers: Buffers,
initial_agent_state_buffers,
timings,
lock=threading.Lock(),
flags,
free_queue: mp.SimpleQueue,
full_queue: mp.SimpleQueue,
buffers: Buffers,
initial_agent_state_buffers,
timings,
lock=threading.Lock(),
):
with lock:
timings.time("lock")
Expand All @@ -228,14 +229,14 @@ def get_batch(


def learn(
flags,
actor_model,
model,
batch,
initial_agent_state,
optimizer,
scheduler,
lock=threading.Lock(), # noqa: B008
flags,
actor_model,
model,
batch,
initial_agent_state,
optimizer,
scheduler,
lock=threading.Lock(), # noqa: B008
):
"""Performs a learning (optimization) step."""
with lock:
Expand Down Expand Up @@ -319,8 +320,9 @@ def create_buffers(flags, obs_shape, num_actions) -> Buffers:
buffers[key].append(torch.empty(**specs[key]).share_memory_())
return buffers


####changed function input
def train(flags,level_names): # pylint: disable=too-many-branches, too-many-statements
def train(flags): # pylint: disable=too-many-branches, too-many-statements
if flags.xpid is None:
flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
plogger = file_writer.FileWriter(
Expand Down Expand Up @@ -348,7 +350,7 @@ def train(flags,level_names): # pylint: disable=too-many-branches, too-many-sta
logging.info("Not using CUDA.")
flags.device = torch.device("cpu")
######I changed the next two line with level name and env.initial
env = create_env(flags,level_names[0],1)
env = create_env(flags, flags.level_name, 1)
model = Net(env.initial().shape, len(environment.DEFAULT_ACTION_SET), flags.use_lstm)
buffers = create_buffers(flags, env._observation().shape, model.num_actions)
print(env._observation().shape)
Expand All @@ -368,7 +370,6 @@ def train(flags,level_names): # pylint: disable=too-many-branches, too-many-sta
full_queue = ctx.SimpleQueue()
##########I changed this part
for i in range(flags.num_actors):
level_name= level_names[i%len(level_names)]
actor = ctx.Process(
target=act,
args=(
Expand All @@ -379,14 +380,14 @@ def train(flags,level_names): # pylint: disable=too-many-branches, too-many-sta
model,
buffers,
initial_agent_state_buffers,
level_name
flags.level_name
),
)
actor.start()
actor_processes.append(actor)

learner_model = Net(
env._observation().shape,len(environment.DEFAULT_ACTION_SET), flags.use_lstm
env._observation().shape, len(environment.DEFAULT_ACTION_SET), flags.use_lstm
).to(device=flags.device)

optimizer = torch.optim.RMSprop(
Expand Down Expand Up @@ -481,7 +482,7 @@ def checkpoint():
sps = (step - start_step) / (timer() - start_time)
if stats.get("episode_returns", None):
mean_return = (
"Return per episode: %.1f. " % stats["mean_episode_return"]
"Return per episode: %.1f. " % stats["mean_episode_return"]
)
else:
mean_return = ""
Expand Down Expand Up @@ -510,15 +511,15 @@ def checkpoint():
plogger.close()


def test(flags, level_names, num_episodes: int = 10):
def test(flags, num_episodes: int = 10):
if flags.xpid is None:
checkpointpath = "./latest/model.tar"
else:
checkpointpath = os.path.expandvars(
os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
)

gym_env = create_env(flags,level_names[0],1)
gym_env = create_env(flags, flags.level_name, 1)
env = environment.Environment(gym_env)
model = Net(gym_env._observation().shape, len(environment.DEFAULT_ACTION_SET), flags.use_lstm)
model.eval()
Expand All @@ -529,8 +530,8 @@ def test(flags, level_names, num_episodes: int = 10):
returns = []

while len(returns) < num_episodes:
#if flags.mode == "test_render":
#env.gym_env.render()
# if flags.mode == "test_render":
# env.gym_env.render()
agent_outputs = model(observation)
policy_outputs, _ = agent_outputs
observation = env.step(policy_outputs["action"])
Expand Down Expand Up @@ -639,8 +640,9 @@ def forward(self, inputs, core_state=()):

Net = AtariNet


#######I changed the create_env function to match dmlab
def create_env(flags,level_name,seed=1):
def create_env(flags, level_name, seed=1):
level_name = 'contributed/dmlab30/' + level_name
config = {
'width': 96,
Expand All @@ -649,15 +651,13 @@ def create_env(flags,level_name,seed=1):
}
return dmlab_wrappers.createDmLab(level_name, config, seed)


####changed main to incorporate level_names and action_Set
def main(flags):

if flags.mode == "train":
level_names = list(dmlab30.LEVEL_MAPPING.keys())
train(flags,level_names)
train(flags)
else:
level_names = list(dmlab30.LEVEL_MAPPING.values())
test(flags,level_names)
test(flags)


if __name__ == "__main__":
Expand Down

0 comments on commit b4c8536

Please sign in to comment.