diff --git a/calm/learning/hrl_agent.py b/calm/learning/hrl_agent.py index f0e9cde..fad2ebe 100644 --- a/calm/learning/hrl_agent.py +++ b/calm/learning/hrl_agent.py @@ -47,6 +47,7 @@ class HRLAgent(common_agent.CommonAgent): def __init__(self, base_name, config): with open(os.path.join(os.getcwd(), config['llc_config']), 'r') as f: llc_config = yaml.load(f, Loader=yaml.SafeLoader) + llc_config['params']['config']['device'] = config['device'] llc_config_params = llc_config['params'] self._latent_dim = llc_config_params['config']['latent_dim'] diff --git a/calm/run.py b/calm/run.py index d318520..621a3cb 100644 --- a/calm/run.py +++ b/calm/run.py @@ -247,6 +247,8 @@ def main(): # Create default directories for weights and statistics cfg_train['params']['config']['train_dir'] = args.output_path + cfg_train['params']['config']['device']=args.rl_device + if args.track: wandb.init(