Skip to content

Commit

Permalink
testing out base games
Browse files Browse the repository at this point in the history
no success
  • Loading branch information
arunim1 committed Aug 13, 2023
1 parent f25c227 commit bd989fc
Show file tree
Hide file tree
Showing 4 changed files with 664 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pax/conf/experiment/ipd/inf_mfos_v_nl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ naive:

# Logging setup
wandb:
entity: "ucl-dark"
entity: "arunim1"
project: ipd
group: '${agent1}-vs-${agent2}-parity'
name: run-seed-${seed}
Expand Down
2 changes: 1 addition & 1 deletion pax/conf/experiment/ipd/shaper_v_ppo_mem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ runner: evo
top_k: 5
popsize: 1000
num_envs: 2
num_opps: 1
num_opps: 10
num_outer_steps: 100
num_inner_steps: 100
num_iters: 5000
Expand Down
25 changes: 16 additions & 9 deletions pax/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

# NOTE: THIS MUST BE DONE BEFORE IMPORTING JAX
# uncomment to debug multi-devices on CPU
# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
# from jax.config import config
# config.update('jax_disable_jit', True)
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
from jax.config import config
config.update('jax_disable_jit', True)

import hydra
import jax.numpy as jnp
Expand Down Expand Up @@ -49,6 +49,7 @@
)
from pax.runners.runner_eval import EvalRunner
from pax.runners.runner_evo import EvoRunner
from pax.runners.runner_evodiff import DiffEvoRunner
from pax.runners.runner_marl import RLRunner
from pax.runners.runner_sarl import SARLRunner
from pax.runners.runner_ipditm_eval import IPDITMEvalRunner
Expand Down Expand Up @@ -169,7 +170,7 @@ def runner_setup(args, env, agents, save_dir, logger):
logger.info("Evaluating with ipditmEvalRunner")
return IPDITMEvalRunner(agents, env, save_dir, args)

if args.runner == "evo":
if args.runner in ["evo", "evodiff"]:
agent1, _ = agents
algo = args.es.algo
strategies = {"CMA_ES", "OpenES", "PGPE", "SimpleGA"}
Expand Down Expand Up @@ -254,10 +255,16 @@ def get_pgpe_strategy(agent):
strategy, es_params, param_reshaper = get_ga_strategy(agent1)

logger.info(f"Evolution Strategy: {algo}")

return EvoRunner(
agents, env, strategy, es_params, param_reshaper, save_dir, args
)
if args.runner == "evo":
return EvoRunner(
agents, env, strategy, es_params, param_reshaper, save_dir, args
)
elif args.runner == "evodiff":
return DiffEvoRunner(
agents, env, strategy, es_params, param_reshaper, save_dir, args
)
else:
raise ValueError(f"Unknown runner type {args.runner}")

elif args.runner == "rl":
logger.info("Training with RL Runner")
Expand Down Expand Up @@ -575,7 +582,7 @@ def main(args):

print(f"Number of Training Iterations: {args.num_iters}")

if args.runner == "evo":
if args.runner in ["evo", "evodiff"]:
runner.run_loop(env_params, agent_pair, args.num_iters, watchers)

elif args.runner == "rl":
Expand Down
Loading

0 comments on commit bd989fc

Please sign in to comment.