Accepted to ICLR2023 (notable-top-25%, Spotlight) [arxiv] [Website]
If you use this codebase for your research, please cite the paper:
@inproceedings{furuta2023asystem,
title={A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation},
author={Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo and Shixiang Shane Gu},
booktitle={International Conference on Learning Representations},
year={2023},
}
pip install -r requirements.txt
- Train single-task single-morphology PPO policy on the environment:
CUDA_VISIBLE_DEVICES=0 python train_ppo_mlp.py --logdir ../results --seed 0 --env ant_reach_4
- Pick trained policy weight, and collect expert
brax.QP
:
CUDA_VISIBLE_DEVICES=0,1 python generate_behavior_and_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --params_path ../results/ao_ppo_mlp_single_pro_ant_reach_4_20220707_174507/ppo_mlp_98304000.pkl
-
Register
qp_path
(path to savedbrax.QP
) in dataset_config.py. -
Convert
brax.QP
to morphlogy-task graph representation (e.g.mtg_v2_base_m
):
CUDA_VISIBLE_DEVICES=0 python generate_behavior_from_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --data_name ant_reach_4_mtg_v2_base_m --obs_config2 mtg_v2_base_m
-
Register
dataset_path
(path to saved observations) in dataset_config.py and task_config.py. -
Train Transformer policy via multi-task behavior cloning:
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer.py --task_name example --seed 0
# zero-shot evaluation
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_zs.py --task_name example --seed 0
# fine-tuning on multi-task imitation learning
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_fs.py --task_name example --seed 0 --params_path ../results/bc_transformer_zs/policy.pkl
- Register a blueprint of new morphology in mxt_bench/procedural_envs/components (e.g. missing ant).
- If you are interested in custom agents, please follow unimal.py.
- See mxt_bench/procedural_envs/tasks. You need to prepare a dictionary of components (e.g. ant_reach), and register your task in register.py.
ENV_DESCS = dict()
# add environments
for i in range(2, 7, 1):
ENV_DESCS[f'ant_reach_{i}'] = functools.partial(load_desc, num_legs=i)
ENV_DESCS[f'ant_reach_hard_{i}'] = functools.partial(load_desc, num_legs=i, r_min=10.5, r_max=11.5)
# missing
for i in range(3, 7, 1):
for j in range(i):
ENV_DESCS[f'ant_reach_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j)
ENV_DESCS[f'ant_reach_hard_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j, r_min=10.5, r_max=11.5)
- If you would like to avoid immidiate termination after the agents reach to the goal, please set
min_dist=0
in each reward function dict.
- mxt_bench/algo: Algorithms for policy learning (PPO, BC), which supports MLP and Transformer.
- mxt_bench/models: NN architectures.
- mxt_bench/procedural_envs/components: Morphology.
- mxt_bench/procedural_envs/misc: Utility functions.
- mxt_bench/procedural_envs/tasks: Task.
- mxt_bench/procedural_envs/tasks/observation_config: Config files for morphlogy-task graph observations.