From d3402c050d5b53336d2c138649ccc8fc92760d54 Mon Sep 17 00:00:00 2001 From: nhuet Date: Tue, 10 Dec 2024 15:10:53 +0100 Subject: [PATCH] Use config files included in pyRDDLGym-jax install directory (#448) --- notebooks/16_rddl_tuto.ipynb | 8 +++++--- tests/solvers/python/test_pyrddlgym_solvers.py | 12 ++++-------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/notebooks/16_rddl_tuto.ipynb b/notebooks/16_rddl_tuto.ipynb index 4d10a10fc5..77656a9adf 100644 --- a/notebooks/16_rddl_tuto.ipynb +++ b/notebooks/16_rddl_tuto.ipynb @@ -85,6 +85,7 @@ "import os\n", "import shutil\n", "\n", + "import pyRDDLGym_jax.examples.configs\n", "from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator\n", "from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv\n", "from ray.rllib.algorithms.ppo import PPO as RLLIB_PPO\n", @@ -472,8 +473,9 @@ "problem_info = manager.get_problem(problem_name)\n", "problem_visualizer = QuadcopterVisualizer\n", "\n", - "if not os.path.exists(\"Quadcopter_slp.cfg\"):\n", - " !wget https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg\n", + "config_name = \"Quadcopter_slp.cfg\"\n", + "config_dir = pyRDDLGym_jax.examples.configs.__path__[0]\n", + "config_path = f\"{config_dir}/{config_name}\"\n", "\n", "domain_factory_jax_agent = lambda alg_name=None: RDDLDomain(\n", " rddl_domain=problem_info.get_domain(),\n", @@ -491,7 +493,7 @@ "\n", "logging.getLogger(\"matplotlib.font_manager\").disabled = True\n", "with RDDLJaxSolver(\n", - " domain_factory=domain_factory_jax_agent, config=\"Quadcopter_slp.cfg\"\n", + " domain_factory=domain_factory_jax_agent, config=config_path\n", ") as solver:\n", " solver.solve()\n", " rollout(\n", diff --git a/tests/solvers/python/test_pyrddlgym_solvers.py b/tests/solvers/python/test_pyrddlgym_solvers.py index 09123bc595..af8a43c9c7 100644 --- a/tests/solvers/python/test_pyrddlgym_solvers.py +++ b/tests/solvers/python/test_pyrddlgym_solvers.py @@ -2,6 +2,7 @@ import shutil from urllib.request import urlcleanup, urlretrieve +import pyRDDLGym_jax.examples.configs from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator from skdecide.hub.domain.rddl import RDDLDomain @@ -12,13 +13,8 @@ def test_pyrddlgymdomain_jax(): # get solver config config_name = "Cartpole_Continuous_gym_drp.cfg" - if not os.path.exists(config_name): - url = f"https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/{config_name}" - try: - local_file_path, headers = urlretrieve(url) - shutil.move(local_file_path, config_name) - finally: - urlcleanup() + config_dir = pyRDDLGym_jax.examples.configs.__path__[0] + config_path = f"{config_dir}/{config_name}" # domain factory (with proper backend and vectorized flag) domain_factory = lambda: RDDLDomain( @@ -30,7 +26,7 @@ def test_pyrddlgymdomain_jax(): vectorized=True, ) solver_factory = lambda: RDDLJaxSolver( - domain_factory=domain_factory, config=config_name + domain_factory=domain_factory, config=config_path ) # solve