Skip to content

Commit

Permalink
Prevent using fork if possible for async environments
Browse files Browse the repository at this point in the history
Prevents possible interference with JAX, see: jax-ml/jax#18852
  • Loading branch information
nico-bohlinger committed Feb 25, 2024
1 parent 0a0d383 commit 7f64cae
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@

class AsyncVectorEnvWithSkipping(gym.vector.AsyncVectorEnv):
def __init__(self, env_fns, async_skip_percentage=0.0,
observation_space=None, action_space=None, shared_memory=True, copy=True, context=None, daemon=True, worker=None):
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, context, daemon, worker)
observation_space=None, action_space=None, shared_memory=True, copy=True, start_method=None, daemon=True, worker=None):

if start_method is None:
all_start_methods = mp.get_all_start_methods()
start_method = mp.get_start_method()
# Only use fork if it is the only available start method to prevent interference with JAX
if start_method == "fork":
if "forkserver" in all_start_methods:
start_method = "forkserver"
elif "spawn" in all_start_methods:
start_method = "spawn"
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, start_method, daemon, worker)

if not shared_memory:
raise NotImplementedError("AsyncVectorEnvWithSkipping only supports shared_memory=True.")
Expand Down
14 changes: 12 additions & 2 deletions rl_x/environments/custom_mujoco/ant/async_vectorized_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@

class AsyncVectorEnvWithSkipping(gym.vector.AsyncVectorEnv):
def __init__(self, env_fns, async_skip_percentage=0.0,
observation_space=None, action_space=None, shared_memory=True, copy=True, context=None, daemon=True, worker=None):
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, context, daemon, worker)
observation_space=None, action_space=None, shared_memory=True, copy=True, start_method=None, daemon=True, worker=None):

if start_method is None:
all_start_methods = mp.get_all_start_methods()
start_method = mp.get_start_method()
# Only use fork if it is the only available start method to prevent interference with JAX
if start_method == "fork":
if "forkserver" in all_start_methods:
start_method = "forkserver"
elif "spawn" in all_start_methods:
start_method = "spawn"
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, start_method, daemon, worker)

if not shared_memory:
raise NotImplementedError("AsyncVectorEnvWithSkipping only supports shared_memory=True.")
Expand Down
14 changes: 12 additions & 2 deletions rl_x/environments/gym/atari/pong_v5/async_vectorized_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@

class AsyncVectorEnvWithSkipping(gym.vector.AsyncVectorEnv):
def __init__(self, env_fns, async_skip_percentage=0.0,
observation_space=None, action_space=None, shared_memory=True, copy=True, context=None, daemon=True, worker=None):
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, context, daemon, worker)
observation_space=None, action_space=None, shared_memory=True, copy=True, start_method=None, daemon=True, worker=None):

if start_method is None:
all_start_methods = mp.get_all_start_methods()
start_method = mp.get_start_method()
# Only use fork if it is the only available start method to prevent interference with JAX
if start_method == "fork":
if "forkserver" in all_start_methods:
start_method = "forkserver"
elif "spawn" in all_start_methods:
start_method = "spawn"
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, start_method, daemon, worker)

if not shared_memory:
raise NotImplementedError("AsyncVectorEnvWithSkipping only supports shared_memory=True.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@

class AsyncVectorEnvWithSkipping(gym.vector.AsyncVectorEnv):
def __init__(self, env_fns, async_skip_percentage=0.0,
observation_space=None, action_space=None, shared_memory=True, copy=True, context=None, daemon=True, worker=None):
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, context, daemon, worker)
observation_space=None, action_space=None, shared_memory=True, copy=True, start_method=None, daemon=True, worker=None):

if start_method is None:
all_start_methods = mp.get_all_start_methods()
start_method = mp.get_start_method()
# Only use fork if it is the only available start method to prevent interference with JAX
if start_method == "fork":
if "forkserver" in all_start_methods:
start_method = "forkserver"
elif "spawn" in all_start_methods:
start_method = "spawn"
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, start_method, daemon, worker)

if not shared_memory:
raise NotImplementedError("AsyncVectorEnvWithSkipping only supports shared_memory=True.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@

class AsyncVectorEnvWithSkipping(gym.vector.AsyncVectorEnv):
def __init__(self, env_fns, async_skip_percentage=0.0,
observation_space=None, action_space=None, shared_memory=True, copy=True, context=None, daemon=True, worker=None):
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, context, daemon, worker)
observation_space=None, action_space=None, shared_memory=True, copy=True, start_method=None, daemon=True, worker=None):

if start_method is None:
all_start_methods = mp.get_all_start_methods()
start_method = mp.get_start_method()
# Only use fork if it is the only available start method to prevent interference with JAX
if start_method == "fork":
if "forkserver" in all_start_methods:
start_method = "forkserver"
elif "spawn" in all_start_methods:
start_method = "spawn"
super().__init__(env_fns, observation_space, action_space, shared_memory, copy, start_method, daemon, worker)

if not shared_memory:
raise NotImplementedError("AsyncVectorEnvWithSkipping only supports shared_memory=True.")
Expand Down

0 comments on commit 7f64cae

Please sign in to comment.