Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update configs folder #1101

Merged
merged 9 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mava/advanced_usage/ff_ippo_store_experience.py
WiemKhlifi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -670,11 +670,12 @@ def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Arr
logger.stop()


@hydra.main(config_path="../configs", config_name="default_ff_ippo.yaml", version_base="1.2")
@hydra.main(config_path="../configs/default", config_name="ff_ippo.yaml", version_base="1.2")
OmaymaMahjoub marked this conversation as resolved.
Show resolved Hide resolved
def hydra_entry_point(cfg: DictConfig) -> None:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "ff_ippo"

# Run experiment.
run_experiment(cfg)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
defaults:
- logger: ff_ippo
- logger: logger
- arch: anakin
- system: ppo/ff_ippo
- network: mlp # [mlp, continuous_mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

hydra:
searchpath:
- file://mava/configs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
defaults:
- _self_
- logger: ff_isac
- logger: logger
- arch: anakin
- system: sac/ff_isac
- network: continuous_mlp # [continuous_mlp]
- env: mabrax # [mabrax]

hydra:
searchpath:
- file://mava/configs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
defaults:
- logger: ff_mappo
- logger: logger
- arch: anakin
- system: ppo/ff_mappo
- network: mlp # [mlp, continuous_mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

hydra:
searchpath:
- file://mava/configs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
defaults:
- _self_
- logger: ff_masac
- logger: logger
- arch: anakin
- system: sac/ff_masac
- network: continuous_mlp # [continuous_mlp]
- env: mabrax # [mabrax]

hydra:
searchpath:
- file://mava/configs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
defaults:
- logger: rec_ippo
- logger: logger
- arch: anakin
- system: ppo/rec_ippo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

hydra:
searchpath:
- file://mava/configs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
defaults:
- _self_
- logger: rec_iql
- logger: logger
- arch: anakin
- system: q_learning/rec_iql
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, gigastep, lbf, matrax, rware, smax]

hydra:
searchpath:
- file://mava/configs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
defaults:
- logger: rec_mappo
- logger: logger
- arch: anakin
- system: ppo/rec_mappo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax]
- _self_

hydra:
searchpath:
- file://mava/configs
4 changes: 0 additions & 4 deletions mava/configs/logger/ff_ippo.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions mava/configs/logger/ff_isac.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions mava/configs/logger/ff_mappo.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions mava/configs/logger/ff_masac.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions mava/configs/logger/rec_ippo.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions mava/configs/logger/rec_iql.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions mava/configs/logger/rec_mappo.yaml

This file was deleted.

7 changes: 6 additions & 1 deletion mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,11 +570,16 @@ def run_experiment(_config: DictConfig) -> float:
return eval_performance


@hydra.main(config_path="../../../configs", config_name="default_ff_ippo.yaml", version_base="1.2")
@hydra.main(
config_path="../../../configs/default",
config_name="ff_ippo.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "ff_ippo"

# Run experiment.
eval_performance = run_experiment(cfg)
Expand Down
7 changes: 6 additions & 1 deletion mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,16 @@ def run_experiment(_config: DictConfig) -> float:
return eval_performance


@hydra.main(config_path="../../../configs", config_name="default_ff_mappo.yaml", version_base="1.2")
@hydra.main(
config_path="../../../configs/default",
config_name="ff_mappo.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "ff_mappo"

# Run experiment.
eval_performance = run_experiment(cfg)
Expand Down
7 changes: 6 additions & 1 deletion mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,11 +721,16 @@ def run_experiment(_config: DictConfig) -> float:
return eval_performance


@hydra.main(config_path="../../../configs", config_name="default_rec_ippo.yaml", version_base="1.2")
@hydra.main(
config_path="../../../configs/default",
config_name="rec_ippo.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "rec_ippo"

# Run experiment.
eval_performance = run_experiment(cfg)
Expand Down
5 changes: 4 additions & 1 deletion mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,12 +716,15 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
config_path="../../../configs", config_name="default_rec_mappo.yaml", version_base="1.2"
config_path="../../../configs/default",
config_name="rec_mappo.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "rec_mappo"

# Run experiment.
eval_performance = run_experiment(cfg)
Expand Down
7 changes: 6 additions & 1 deletion mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,11 +658,16 @@ def eval_act_fn(
return float(eval_performance)


@hydra.main(config_path="../../../configs", config_name="default_rec_iql.yaml", version_base="1.2")
@hydra.main(
config_path="../../../configs/default",
config_name="rec_iql.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "rec_iql"

# Run experiment.
final_return = run_experiment(cfg)
Expand Down
7 changes: 6 additions & 1 deletion mava/systems/sac/anakin/ff_isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,16 @@ def run_experiment(cfg: DictConfig) -> float:
return eval_performance


@hydra.main(config_path="../../../configs", config_name="default_ff_isac.yaml", version_base="1.2")
@hydra.main(
config_path="../../../configs/default",
config_name="ff_isac.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "ff_isac"

# Run experiment.
final_return = run_experiment(cfg)
Expand Down
7 changes: 6 additions & 1 deletion mava/systems/sac/anakin/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,16 @@ def run_experiment(cfg: DictConfig) -> float:
return eval_performance


@hydra.main(config_path="../../../configs", config_name="default_ff_masac.yaml", version_base="1.2")
@hydra.main(
config_path="../../../configs/default",
config_name="ff_masac.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)
cfg.logger.system_name = "ff_masac"

# Run experiment.
final_return = run_experiment(cfg)
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ id-marl-eval @ git+https://github.com/instadeepai/marl-eval
jax==0.4.30
jaxlib==0.4.30
jaxmarl
jumanji @ git+https://github.com/sash-a/jumanji # Includes a few extra MARL envs
jumanji @ git+https://github.com/sash-a/jumanji@old_jumanji # Includes a few extra MARL envs
matrax @ git+https://github.com/instadeepai/matrax
mujoco==3.1.3
mujoco-mjx==3.1.3
Expand Down
26 changes: 14 additions & 12 deletions test/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
# system run all envs, but each env and each system is run at least once.
# For each system we select a random environment to run.
# Then for each environment we select a random system to run.
config_path = "../mava/configs/default"

ppo_systems = [
"ppo.anakin.ff_ippo",
"ppo.anakin.ff_mappo",
Expand Down Expand Up @@ -67,8 +69,8 @@ def test_ppo_system(fast_config: dict, system_path: str) -> None:
_, _, system_name = system_path.split(".")
env = random.choice(discrete_envs)

with initialize(version_base=None, config_path="../mava/configs/"):
cfg = compose(config_name=f"default_{system_name}", overrides=[f"env={env}"])
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=[f"env={env}"])
cfg = _get_fast_config(cfg, fast_config)

_run_system(system_path, cfg)
Expand All @@ -80,8 +82,8 @@ def test_q_learning_system(fast_config: dict, system_path: str) -> None:
_, _, system_name = system_path.split(".")
env = random.choice(discrete_envs)

with initialize(version_base=None, config_path="../mava/configs/"):
cfg = compose(config_name=f"default_{system_name}", overrides=[f"env={env}"])
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=[f"env={env}"])
cfg = _get_fast_config(cfg, fast_config)

_run_system(system_path, cfg)
Expand All @@ -93,8 +95,8 @@ def test_sac_system(fast_config: dict, system_path: str) -> None:
_, _, system_name = system_path.split(".")
env = random.choice(continuous_envs)

with initialize(version_base=None, config_path="../mava/configs/"):
cfg = compose(config_name=f"default_{system_name}", overrides=[f"env={env}"])
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=[f"env={env}"])
cfg = _get_fast_config(cfg, fast_config)

_run_system(system_path, cfg)
Expand All @@ -106,8 +108,8 @@ def test_discrete_env(fast_config: dict, env_name: str) -> None:
system_path = random.choice(ppo_systems + q_learning_systems)
_, _, system_name = system_path.split(".")

with initialize(version_base=None, config_path="../mava/configs/"):
cfg = compose(config_name=f"default_{system_name}", overrides=[f"env={env_name}"])
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=[f"env={env_name}"])
cfg = _get_fast_config(cfg, fast_config)

_run_system(system_path, cfg)
Expand All @@ -122,8 +124,8 @@ def test_discrete_cnn_env(fast_config: dict, env_name: str) -> None:
network = "cnn" if "ff" in system_name else "rcnn"

overrides = [f"env={env_name}", f"network={network}"]
with initialize(version_base=None, config_path="../mava/configs/"):
cfg = compose(config_name=f"default_{system_name}", overrides=overrides)
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=overrides)
cfg = _get_fast_config(cfg, fast_config)

_run_system(system_path, cfg)
Expand All @@ -138,8 +140,8 @@ def test_continuous_env(fast_config: dict, env_name: str) -> None:
_, _, system_name = system_path.split(".")

overrides = [f"env={env_name}", "network=continuous_mlp"]
with initialize(version_base=None, config_path="../mava/configs/"):
cfg = compose(config_name=f"default_{system_name}", overrides=overrides)
with initialize(version_base=None, config_path=config_path):
cfg = compose(config_name=f"{system_name}", overrides=overrides)
cfg = _get_fast_config(cfg, fast_config)

_run_system(system_path, cfg)
Loading