From 452540c8fd7d22c23f48d481fc16019e1c7241c6 Mon Sep 17 00:00:00 2001 From: lgd21356 <59635384+lgd21356@users.noreply.github.com> Date: Sun, 22 Sep 2024 00:49:55 +0900 Subject: [PATCH 1/2] Update run.py To fix the issue that can not select the training device correctly --- calm/run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/calm/run.py b/calm/run.py index d318520..621a3cb 100644 --- a/calm/run.py +++ b/calm/run.py @@ -247,6 +247,8 @@ def main(): # Create default directories for weights and statistics cfg_train['params']['config']['train_dir'] = args.output_path + cfg_train['params']['config']['device']=args.rl_device + if args.track: wandb.init( From ea4f0191e1a026c5ec963fcd331d9ec63cedbb51 Mon Sep 17 00:00:00 2001 From: lgd21356 <59635384+lgd21356@users.noreply.github.com> Date: Sun, 22 Sep 2024 00:54:29 +0900 Subject: [PATCH 2/2] Update hrl_agent.py --- calm/learning/hrl_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/calm/learning/hrl_agent.py b/calm/learning/hrl_agent.py index f0e9cde..fad2ebe 100644 --- a/calm/learning/hrl_agent.py +++ b/calm/learning/hrl_agent.py @@ -47,6 +47,7 @@ class HRLAgent(common_agent.CommonAgent): def __init__(self, base_name, config): with open(os.path.join(os.getcwd(), config['llc_config']), 'r') as f: llc_config = yaml.load(f, Loader=yaml.SafeLoader) + llc_config['params']['config']['device'] = config['device'] llc_config_params = llc_config['params'] self._latent_dim = llc_config_params['config']['latent_dim']