Skip to content

Commit

Permalink
[Feature] step_and_maybe_reset in env (#1611)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 24, 2023
1 parent 8d2bc8b commit 3b355dd
Show file tree
Hide file tree
Showing 41 changed files with 2,595 additions and 1,110 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pytest-mock
- pytest-instafail
- pytest-rerunfailures
- pytest-timeout
- expecttest
- pyyaml
- scipy
Expand Down
6 changes: 4 additions & 2 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ pytest test/smoke_test.py -v --durations 200
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
if [ "${CU_VERSION:-}" != cpu ] ; then
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
--timeout=120
else
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \
--timeout=120
fi

coverage combine
Expand Down
8 changes: 6 additions & 2 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend=
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4
python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend=
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4

# ==================================================================================== #
# ================================ Gymnasium ========================================= #
Expand Down
82 changes: 38 additions & 44 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""
import time

import myosuite # noqa: F401
# import myosuite # noqa: F401
import torch
import tqdm
from torchrl._utils import timeit
from torchrl.collectors import (
Expand All @@ -29,6 +30,10 @@
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend

if __name__ == "__main__":
avail_devices = ("cpu",)
if torch.cuda.device_count():
avail_devices = avail_devices + ("cuda:0",)

for envname in [
"CartPole-v1",
"HalfCheetah-v4",
Expand Down Expand Up @@ -69,24 +74,25 @@ def make(envname=envname, gym_backend=gym_backend):
log.flush()

# regular parallel env
for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)

env_make = EnvCreator(make)
penv = ParallelEnv(num_workers, env_make)
# warmup
penv.rollout(2)
pbar = tqdm.tqdm(total=num_workers * 10_000)
t0 = time.time()
for _ in range(100):
data = penv.rollout(100, break_when_any_done=False)
pbar.update(100 * num_workers)
# env_make = EnvCreator(make)
penv = ParallelEnv(num_workers, EnvCreator(make))
with torch.inference_mode():
# warmup
penv.rollout(2)
pbar = tqdm.tqdm(total=num_workers * 10_000)
t0 = time.time()
data = None
for _ in range(100):
data = penv.rollout(
100, break_when_any_done=False, out=data
)
pbar.update(100 * num_workers)
log.write(
f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
)
Expand All @@ -95,7 +101,7 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
timeit.print()
del penv

for device in ("cuda:0", "cpu"):
for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
Expand All @@ -109,29 +115,26 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
RandomPolicy(penv.action_spec),
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
for i, data in enumerate(collector):
if i == num_collectors:
t0 = time.time()
if i >= num_collectors:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
t0 = time.time()
for data in collector:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
log.write(
f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# gym parallel env
def make_env(
envname=envname,
Expand All @@ -158,10 +161,7 @@ def make_env(
penv.close()
del penv

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# async collector
# + torchrl parallel env
def make_env(
Expand All @@ -179,6 +179,7 @@ def make_env(
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand All @@ -198,10 +199,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# async collector
# + gym async env
def make_env(
Expand All @@ -226,6 +224,7 @@ def make_env(
total_frames=num_workers * 10_000,
num_sub_threads=num_workers // num_collectors,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand All @@ -245,10 +244,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# sync collector
# + torchrl parallel env
def make_env(
Expand All @@ -266,6 +262,7 @@ def make_env(
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand All @@ -285,10 +282,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# sync collector
# + gym async env
def make_env(
Expand Down
107 changes: 106 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ With these, the following methods are implemented:
- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input
containing an input action as well as other inputs (for model-based or stateless
environments, for instance).
- :meth:`env.step_and_maybe_reset`: executes a step, and (partially) resets the
environments if it needs to. It returns the updated input with a ``"next"``
key containing the data of the next step, as well as a tensordict containing
the input data for the next step (ie, reset or result or
:func:`~torchrl.envs.utils.step_mdp`)
This is done by reading the ``done_keys`` and
assigning a ``"_reset"`` signal to each done state. This method allows
to code non-stopping rollout functions with little effort:

>>> data_ = env.reset()
>>> result = []
>>> for i in range(N):
... data, data_ = env.step_and_maybe_reset(data_)
... result.append(data)
...
>>> result = torch.stack(result)

- :meth:`env.set_seed`: a seeding method that will return the next seed
to be used in a multi-env setting. This next seed is deterministically computed
from the preceding one, such that one can seed multiple environments with a different
Expand Down Expand Up @@ -169,7 +186,95 @@ one can simply call:
>>> print(a)
9.81
It is also possible to reset some but not all of the environments:
TorchRL uses a private ``"_reset"`` key to indicate to the environment which
component (sub-environments or agents) should be reset.
This allows to reset some but not all of the components.

The ``"_reset"`` key has two distinct functionalities:
1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may
not be present in the input tensordict. TorchRL's convention is that the
absence of the ``"_reset"`` key at a given ``"done"`` level indicates
a total reset of that level (unless a ``"_reset"`` key was found at a level
above, see details below).
If it is present, it is expected that those entries and only those components
where the ``"_reset"`` entry is ``True`` (along key and shape dimension) will be reset.

The way an environment deals with the ``"_reset"`` keys in its :meth:`~.EnvBase._reset`
method is proper to its class.
Designing an environment that behaves according to ``"_reset"`` inputs is the
developer's responsibility, as TorchRL has no control over the inner logic
of :meth:`~.EnvBase._reset`. Nevertheless, the following point should be
kept in mind when desiging that method.

2. After a call to :meth:`~.EnvBase._reset`, the output will be masked with the
``"_reset"`` entries and the output of the previous :meth:`~.EnvBase.step`
will be written wherever the ``"_reset"`` was ``False``. In practice, this
means that if a ``"_reset"`` modifies data that isn't exposed by it, this
modification will be lost. After this masking operation, the ``"_reset"``
entries will be erased from the :meth:`~.EnvBase.reset` outputs.

It must be pointed that ``"_reset"`` is a private key, and it should only be
used when coding specific environment features that are internal facing.
In other words, this should NOT be used outside of the library, and developers
will keep the right to modify the logic of partial resets through ``"_reset"``
setting without preliminary warranty, as long as they don't affect TorchRL
internal tests.

Finally, the following assumptions are made and should be kept in mind when
designing reset functionalities:

- Each ``"_reset"`` is paired with a ``"done"`` entry (+ ``"terminated"`` and,
possibly, ``"truncated"``). This means that the following structure is not
allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, [])``, as
the ``"_reset"`` lives at a different nesting level than the ``"done"``.
- A reset at one level does not preclude the presence of a ``"_reset"`` at lower
levels, but it annihilates its effects. The reason is simply that
whether the ``"_reset"`` at the root level corresponds to an ``all()``, ``any()``
or custom call to the nested ``"done"`` entries cannot be known in advance,
and it is explicitly assumed that the ``"_reset"`` at the root was placed
there to superseed the nested values (for an example, have a look at
:class:`~.PettingZooWrapper` implementation where each group has one or more
``"done"`` entries associated which is aggregated at the root level with a
``any`` or ``all`` logic depending on the task).
- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry
that will reset some but not all the done sub-environments, the input data
should contain the data of the sub-environemtns that are __not__ being reset.
The reason for this constrain lies in the fact that the output of the
``env._reset(data)`` can only be predicted for the entries that are reset.
For the others, TorchRL cannot know in advance if they will be meaningful or
not. For instance, one could perfectly just pad the values of the non-reset
components, in which case the non-reset data will be meaningless and should
be discarded.

Below, we give some examples of the expected effect that ``"_reset"`` keys will
have on an environment returning zeros after reset:

>>> # single reset at the root
>>> data = TensorDict({"val": [1, 1], "_reset": [False, True]}, [])
>>> env.reset(data)
>>> print(data.get("val")) # only the second value is 0
tensor([1, 0])
>>> # nested resets
>>> data = TensorDict({
... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True],
... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False],
... }, [])
>>> env.reset(data)
>>> print(data.get(("agent0", "val"))) # only the second value is 0
tensor([1, 0])
>>> print(data.get(("agent1", "val"))) # only the second value is 0
tensor([0, 2])
>>> # nested resets are overridden by a "_reset" at the root
>>> data = TensorDict({
... "_reset": [True, True],
... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True],
... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False],
... }, [])
>>> env.reset(data)
>>> print(data.get(("agent0", "val"))) # reset at the root overrides nested
tensor([0, 0])
>>> print(data.get(("agent1", "val"))) # reset at the root overrides nested
tensor([0, 0])

.. code-block::
:caption: Parallel environment reset
Expand Down
12 changes: 9 additions & 3 deletions examples/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
)


@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden
@hydra.main(config_path=".", config_name="dt_config")
@hydra.main(config_path=".", config_name="dt_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()

model_device = cfg.optim.device

# Set seeds
Expand Down Expand Up @@ -63,6 +64,11 @@ def main(cfg: "DictConfig"): # noqa: F821
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
return_to_go="return_to_go_cat",
)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

Expand All @@ -76,7 +82,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(i)
pbar.update(1)

# Sample data
data = offline_buffer.sample()
Expand Down
1 change: 1 addition & 0 deletions examples/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ env:
target_return_mode: reduce
eval_target_return: 6000
collect_target_return: 12000
backend: gym # D4RL uses gym so we make sure gymnasium is hidden

# logger
logger:
Expand Down
1 change: 1 addition & 0 deletions examples/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ env:
target_return_mode: reduce
eval_target_return: 6000
collect_target_return: 12000
backend: gym # D4RL uses gym so we make sure gymnasium is hidden


# logger
Expand Down
Loading

0 comments on commit 3b355dd

Please sign in to comment.