Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 20, 2023
1 parent c7d4764 commit 9e4053b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
16 changes: 12 additions & 4 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,23 @@

_init_extension()

DEFAULT_START_METHOD = os.environ.get("DEFAULT_START_METHOD", None)
if DEFAULT_START_METHOD is None:
if torch.cuda.device_count() == 0:
DEFAULT_START_METHOD = "fork"
else:
DEFAULT_START_METHOD = "spawn"
try:
mp.set_start_method("spawn")
mp.set_start_method(DEFAULT_START_METHOD)
except RuntimeError as err:
if str(err).startswith("context has already been set"):
mp_start_method = mp.get_start_method()
if mp_start_method != "spawn":
if mp_start_method != DEFAULT_START_METHOD:
warn(
f"failed to set start method to spawn, "
f"and current start method for mp is {mp_start_method}."
f"failed to set start method to {DEFAULT_START_METHOD}, which is the default on this node. "
f"The current start method for mp is {mp_start_method}. "
f"To change the default mp start method, set the 'DEFAULT_START_METHOD' environment variable to "
f"'fork' or 'spawn'. If the mp start method is set before importing torchrl, it cannot be changed. "
)


Expand Down
4 changes: 3 additions & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
clear_mpi_env_vars,
)

from .. import DEFAULT_START_METHOD

# legacy
from .libs.envpool import MultiThreadedEnv, MultiThreadedEnvWrapper # noqa: F401

Expand Down Expand Up @@ -715,7 +717,7 @@ def _start_workers(self) -> None:

torch.set_num_threads(self.num_threads)

ctx = mp.get_context("spawn")
ctx = mp.get_context(DEFAULT_START_METHOD)

_num_workers = self.num_workers

Expand Down

0 comments on commit 9e4053b

Please sign in to comment.