Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] step_and_maybe_reset in env #1611

Merged
merged 122 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
1f539dd
init
vmoens Oct 6, 2023
d4c16e1
amend
vmoens Oct 6, 2023
d2321aa
amend
vmoens Oct 6, 2023
565115a
amend
vmoens Oct 6, 2023
3c46136
amend
vmoens Oct 6, 2023
78cfa41
amend
vmoens Oct 6, 2023
a6bd8eb
amend
vmoens Oct 6, 2023
04d4ae7
amend
vmoens Oct 6, 2023
2d0b4c6
init
vmoens Oct 10, 2023
528609a
amend
vmoens Oct 10, 2023
51cd5af
fix
vmoens Oct 10, 2023
3e31963
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 10, 2023
b11e73a
amend
vmoens Oct 10, 2023
1540407
remove `pip3 install -e .`
vmoens Oct 10, 2023
f1b0ea4
tensordict_
vmoens Oct 10, 2023
16b3538
amend rollout logic
vmoens Oct 10, 2023
7ad0864
amend
vmoens Oct 10, 2023
bcac398
amend
vmoens Oct 10, 2023
428f8ee
inference
vmoens Oct 10, 2023
6fbd0bd
cpu -> cuda
vmoens Oct 10, 2023
02db623
checks
vmoens Oct 10, 2023
08a8f47
using pipe instead of event
vmoens Oct 10, 2023
45e64f7
amend
vmoens Oct 10, 2023
7dd4821
amend
vmoens Oct 10, 2023
e1a2206
rm cuda event
vmoens Oct 10, 2023
dc2caab
amend
vmoens Oct 10, 2023
01ffbf9
amend
vmoens Oct 10, 2023
ac76ec3
amend
vmoens Oct 10, 2023
ceab010
amend
vmoens Oct 10, 2023
f0327c9
amend
vmoens Oct 10, 2023
518b3d1
amend
vmoens Oct 10, 2023
47dd93b
amend
vmoens Oct 10, 2023
354fb6f
amend
vmoens Oct 10, 2023
53d5f9a
amend
vmoens Oct 10, 2023
78c00e8
amend
vmoens Oct 10, 2023
f63480e
amend
vmoens Oct 10, 2023
9a3631f
amend
vmoens Oct 10, 2023
44336ed
Merge remote-tracking branch 'origin/main' into fix_ci
vmoens Oct 10, 2023
512f9f7
specs fix
vmoens Oct 10, 2023
2ceb438
amend
vmoens Oct 10, 2023
c666d5a
amend
vmoens Oct 10, 2023
6ecebda
amend
vmoens Oct 10, 2023
ce6e9bd
amend
vmoens Oct 10, 2023
5c613c3
amend
vmoens Oct 10, 2023
9f97e58
amend
vmoens Oct 10, 2023
9cbcbb0
amend
vmoens Oct 10, 2023
ae2748d
amend
vmoens Oct 10, 2023
72c4163
amend
vmoens Oct 10, 2023
6f4c374
amend
vmoens Oct 10, 2023
bf36bec
amend
vmoens Oct 10, 2023
4095766
amend
vmoens Oct 11, 2023
9f2a9ad
amend
vmoens Oct 11, 2023
5b34961
amend
vmoens Oct 11, 2023
deba78d
amend
vmoens Oct 11, 2023
b638461
amend
vmoens Oct 11, 2023
f8b2f60
amend
vmoens Oct 11, 2023
dfc868a
amend
vmoens Oct 11, 2023
1221c48
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 11, 2023
22d3a27
amend
vmoens Oct 11, 2023
38e0e90
amend
vmoens Oct 11, 2023
5828a84
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 11, 2023
004d381
amend
vmoens Oct 11, 2023
6ef720d
amend
vmoens Oct 11, 2023
d2167ea
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 11, 2023
f66c6c9
amend
vmoens Oct 11, 2023
4fd2768
amend
vmoens Oct 11, 2023
ded1883
amend
vmoens Oct 11, 2023
cb1a83f
amend
vmoens Oct 11, 2023
6af6a45
amend
vmoens Oct 11, 2023
eaf5ebd
amend
vmoens Oct 12, 2023
f426322
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 12, 2023
83672c6
amend
vmoens Oct 12, 2023
f0a6134
amend
vmoens Oct 12, 2023
c589a4a
amend
vmoens Oct 12, 2023
a85b321
amend
vmoens Oct 12, 2023
ed5e96f
amend
vmoens Oct 13, 2023
e705719
amend
vmoens Oct 16, 2023
4c3fa33
fix
vmoens Oct 17, 2023
e354bc8
fix
vmoens Oct 18, 2023
faa3b41
amend
vmoens Oct 18, 2023
80120b4
fix
vmoens Oct 18, 2023
09205e5
empty
vmoens Oct 18, 2023
3c41dab
amend
vmoens Oct 18, 2023
4db8110
amend
vmoens Oct 18, 2023
38cac91
empty
vmoens Oct 18, 2023
a336f0e
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 18, 2023
a9f4678
amend
vmoens Oct 18, 2023
9ee6348
amend
vmoens Oct 19, 2023
62c848b
amend
vmoens Oct 19, 2023
684d527
amend
vmoens Oct 19, 2023
89ddfd2
amend
vmoens Oct 19, 2023
a0376e4
amend
vmoens Oct 19, 2023
44dfe86
amend
vmoens Oct 19, 2023
d332597
amend
vmoens Oct 19, 2023
85cf664
amend
vmoens Oct 19, 2023
b55fae1
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 19, 2023
5dcbf15
amend
vmoens Oct 19, 2023
0236818
amend
vmoens Oct 20, 2023
36f151d
amend
vmoens Oct 20, 2023
afecc65
amend
vmoens Oct 20, 2023
f4a0beb
amend
vmoens Oct 20, 2023
5a331e5
amend
vmoens Oct 20, 2023
1ee4ba7
amend
vmoens Oct 20, 2023
9d14427
amend
vmoens Oct 20, 2023
8c15741
amend
vmoens Oct 20, 2023
1fe2cd3
amend
vmoens Oct 20, 2023
c378acc
amend
vmoens Oct 20, 2023
551c9eb
amend
vmoens Oct 20, 2023
e177c77
amend
vmoens Oct 20, 2023
28b0059
amend
vmoens Oct 20, 2023
5a21f2a
amend
vmoens Oct 20, 2023
b7b2081
lint
vmoens Oct 20, 2023
de1ecf2
amend
vmoens Oct 21, 2023
6fec99e
amend
vmoens Oct 21, 2023
6bfe517
amend
vmoens Oct 22, 2023
8f78ac0
cache keys
vmoens Oct 23, 2023
d2b734b
fix empty cache
vmoens Oct 23, 2023
a0c12d5
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 23, 2023
83eb52f
amend
vmoens Oct 24, 2023
5943c51
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 24, 2023
3e86847
addressing comments
vmoens Oct 24, 2023
ca1dd78
amend
vmoens Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pytest-mock
- pytest-instafail
- pytest-rerunfailures
- pytest-timeout
- expecttest
- pyyaml
- scipy
Expand Down
6 changes: 4 additions & 2 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ pytest test/smoke_test.py -v --durations 200
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
if [ "${CU_VERSION:-}" != cpu ] ; then
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
--timeout=120
else
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \
--timeout=120
fi

coverage combine
Expand Down
8 changes: 6 additions & 2 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend=
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4
python .github/unittest/helpers/coverage_run_parallel.py examples/decision_transformer/online_dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend=
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4

# ==================================================================================== #
# ================================ Gymnasium ========================================= #
Expand Down
82 changes: 38 additions & 44 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""
import time

import myosuite # noqa: F401
# import myosuite # noqa: F401
import torch
import tqdm
from torchrl._utils import timeit
from torchrl.collectors import (
Expand All @@ -29,6 +30,10 @@
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend

if __name__ == "__main__":
avail_devices = ("cpu",)
if torch.cuda.device_count():
avail_devices = avail_devices + ("cuda:0",)

for envname in [
"CartPole-v1",
"HalfCheetah-v4",
Expand Down Expand Up @@ -69,24 +74,25 @@ def make(envname=envname, gym_backend=gym_backend):
log.flush()

# regular parallel env
for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)

env_make = EnvCreator(make)
penv = ParallelEnv(num_workers, env_make)
# warmup
penv.rollout(2)
pbar = tqdm.tqdm(total=num_workers * 10_000)
t0 = time.time()
for _ in range(100):
data = penv.rollout(100, break_when_any_done=False)
pbar.update(100 * num_workers)
# env_make = EnvCreator(make)
penv = ParallelEnv(num_workers, EnvCreator(make))
with torch.inference_mode():
# warmup
penv.rollout(2)
pbar = tqdm.tqdm(total=num_workers * 10_000)
t0 = time.time()
data = None
for _ in range(100):
data = penv.rollout(
100, break_when_any_done=False, out=data
)
pbar.update(100 * num_workers)
log.write(
f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
)
Expand All @@ -95,7 +101,7 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
timeit.print()
del penv

for device in ("cuda:0", "cpu"):
for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
with set_gym_backend(gym_backend):
Expand All @@ -109,29 +115,26 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
RandomPolicy(penv.action_spec),
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
for i, data in enumerate(collector):
if i == num_collectors:
t0 = time.time()
if i >= num_collectors:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
t0 = time.time()
for data in collector:
total_frames += data.numel()
pbar.update(data.numel())
pbar.set_description(
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
)
log.write(
f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
)
log.flush()
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# gym parallel env
def make_env(
envname=envname,
Expand All @@ -158,10 +161,7 @@ def make_env(
penv.close()
del penv

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# async collector
# + torchrl parallel env
def make_env(
Expand All @@ -179,6 +179,7 @@ def make_env(
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand All @@ -198,10 +199,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# async collector
# + gym async env
def make_env(
Expand All @@ -226,6 +224,7 @@ def make_env(
total_frames=num_workers * 10_000,
num_sub_threads=num_workers // num_collectors,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand All @@ -245,10 +244,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# sync collector
# + torchrl parallel env
def make_env(
Expand All @@ -266,6 +262,7 @@ def make_env(
frames_per_batch=1024,
total_frames=num_workers * 10_000,
device=device,
storing_device=device,
)
pbar = tqdm.tqdm(total=num_workers * 10_000)
total_frames = 0
Expand All @@ -285,10 +282,7 @@ def make_env(
collector.shutdown()
del collector

for device in (
"cuda:0",
"cpu",
):
for device in avail_devices:
# sync collector
# + gym async env
def make_env(
Expand Down
107 changes: 106 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ With these, the following methods are implemented:
- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input
containing an input action as well as other inputs (for model-based or stateless
environments, for instance).
- :meth:`env.step_and_maybe_reset`: executes a step, and (partially) resets the
environments if it needs to. It returns the updated input with a ``"next"``
key containing the data of the next step, as well as a tensordict containing
the input data for the next step (ie, reset or result or
:func:`~torchrl.envs.utils.step_mdp`)
This is done by reading the ``done_keys`` and
assigning a ``"_reset"`` signal to each done state. This method allows
to code non-stopping rollout functions with little effort:

>>> data_ = env.reset()
>>> result = []
>>> for i in range(N):
... data, data_ = env.step_and_maybe_reset(data_)
... result.append(data)
...
>>> result = torch.stack(result)

- :meth:`env.set_seed`: a seeding method that will return the next seed
to be used in a multi-env setting. This next seed is deterministically computed
from the preceding one, such that one can seed multiple environments with a different
Expand Down Expand Up @@ -169,7 +186,95 @@ one can simply call:
>>> print(a)
9.81

It is also possible to reset some but not all of the environments:
TorchRL uses a private ``"_reset"`` key to indicate to the environment which
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using the word environment? I would use dimensions as this could be tasks or agents

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, fair point. In the context of batched envs it seemed clearer to refer to it as envs
we could say "envs, tasks or agent"

component (sub-environments or agents) should be reset.
This allows to reset some but not all of the components.

The ``"_reset"`` key has two distinct functionalities:
1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may
not be present in the input tensordict. TorchRL's convention is that the
absence of the ``"_reset"`` key at a given ``"done"`` level indicates
a total reset of that level (unless a ``"_reset"`` key was found at a level
above, see details below).
If it is present, it is expected that those entries and only those components
where the ``"_reset"`` entry is ``True`` (along key and shape dimension) will be reset.

The way an environment deals with the ``"_reset"`` keys in its :meth:`~.EnvBase._reset`
method is proper to its class.
Designing an environment that behaves according to ``"_reset"`` inputs is the
developer's responsibility, as TorchRL has no control over the inner logic
of :meth:`~.EnvBase._reset`. Nevertheless, the following point should be
kept in mind when desiging that method.

2. After a call to :meth:`~.EnvBase._reset`, the output will be masked with the
``"_reset"`` entries and the output of the previous :meth:`~.EnvBase.step`
will be written wherever the ``"_reset"`` was ``False``. In practice, this
means that if a ``"_reset"`` modifies data that isn't exposed by it, this
modification will be lost. After this masking operation, the ``"_reset"``
entries will be erased from the :meth:`~.EnvBase.reset` outputs.

It must be pointed that ``"_reset"`` is a private key, and it should only be
used when coding specific environment features that are internal facing.
In other words, this should NOT be used outside of the library, and developers
will keep the right to modify the logic of partial resets through ``"_reset"``
setting without preliminary warranty, as long as they don't affect TorchRL
internal tests.

Finally, the following assumptions are made and should be kept in mind when
designing reset functionalities:

- Each ``"_reset"`` is paired with a ``"done"`` entry (+ ``"terminated"`` and,
possibly, ``"truncated"``). This means that the following structure is not
allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, [])``, as
the ``"_reset"`` lives at a different nesting level than the ``"done"``.
- A reset at one level does not preclude the presence of a ``"_reset"`` at lower
vmoens marked this conversation as resolved.
Show resolved Hide resolved
levels, but it annihilates its effects. The reason is simply that
whether the ``"_reset"`` at the root level corresponds to an ``all()``, ``any()``
or custom call to the nested ``"done"`` entries cannot be known in advance,
and it is explicitly assumed that the ``"_reset"`` at the root was placed
there to superseed the nested values (for an example, have a look at
:class:`~.PettingZooWrapper` implementation where each group has one or more
``"done"`` entries associated which is aggregated at the root level with a
``any`` or ``all`` logic depending on the task).
- When calling :meth:`env.reset(tensordict)` with a partial ``"_reset"`` entry
that will reset some but not all the done sub-environments, the input data
should contain the data of the sub-environemtns that are __not__ being reset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can the input contain only the data "of the sub-environments that are not being reset." It has to contain all data no? otherwise it is sparse

Copy link
Contributor Author

@vmoens vmoens Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who said only? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if we specify that it should contain the data for those that are not being reset it might be read as implying that it contains only that. I think it is easier if we say it contains all data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we want to be precise, we don't really care that the data of the envs being reset is there actually. As long as torch.where will work.
I don't really see how it can contain one and not the other (without recurring to lazy stacks or such -- and if it does contain lazy stacks with just the non-reset data it should still work).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example:

import torch
from tensordict import TensorDict
reset = torch.tensor([True, False])
td = torch.stack([
    TensorDict({}, []),
    TensorDict({"a": 1}, []),
])
td_reset = torch.stack([
    TensorDict({"a": 0}, []),
    TensorDict({}, []),
])
td_reset.where(~reset, other=td, pad=0)

which returns a non-sparse td

LazyStackedTensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)

The reason for this constrain lies in the fact that the output of the
``env._reset(data)`` can only be predicted for the entries that are reset.
For the others, TorchRL cannot know in advance if they will be meaningful or
not. For instance, one could perfectly just pad the values of the non-reset
components, in which case the non-reset data will be meaningless and should
be discarded.

Below, we give some examples of the expected effect that ``"_reset"`` keys will
have on an environment returning zeros after reset:

>>> # single reset at the root
>>> data = TensorDict({"val": [1, 1], "_reset": [False, True]}, [])
>>> env.reset(data)
>>> print(data.get("val")) # only the second value is 0
tensor([1, 0])
>>> # nested resets
>>> data = TensorDict({
... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True],
... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False],
... }, [])
>>> env.reset(data)
>>> print(data.get(("agent0", "val"))) # only the second value is 0
tensor([1, 0])
>>> print(data.get(("agent1", "val"))) # only the second value is 0
tensor([0, 2])
>>> # nested resets are overridden by a "_reset" at the root
>>> data = TensorDict({
... "_reset": [True, True],
... ("agent0", "val"): [1, 1], ("agent0", "_reset"): [False, True],
... ("agent1", "val"): [2, 2], ("agent1", "_reset"): [True, False],
... }, [])
>>> env.reset(data)
>>> print(data.get(("agent0", "val"))) # reset at the root overrides nested
tensor([0, 0])
>>> print(data.get(("agent1", "val"))) # reset at the root overrides nested
tensor([0, 0])

.. code-block::
:caption: Parallel environment reset
Expand Down
12 changes: 9 additions & 3 deletions examples/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
)


@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden
@hydra.main(config_path=".", config_name="dt_config")
@hydra.main(config_path=".", config_name="dt_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()

model_device = cfg.optim.device

# Set seeds
Expand Down Expand Up @@ -63,6 +64,11 @@ def main(cfg: "DictConfig"): # noqa: F821
policy=policy,
inference_context=cfg.env.inference_context,
).to(model_device)
inference_policy.set_tensor_keys(
observation="observation_cat",
action="action_cat",
return_to_go="return_to_go_cat",
)

pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)

Expand All @@ -76,7 +82,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
pbar.update(i)
pbar.update(1)

# Sample data
data = offline_buffer.sample()
Expand Down
1 change: 1 addition & 0 deletions examples/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ env:
target_return_mode: reduce
eval_target_return: 6000
collect_target_return: 12000
backend: gym # D4RL uses gym so we make sure gymnasium is hidden

# logger
logger:
Expand Down
1 change: 1 addition & 0 deletions examples/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ env:
target_return_mode: reduce
eval_target_return: 6000
collect_target_return: 12000
backend: gym # D4RL uses gym so we make sure gymnasium is hidden


# logger
Expand Down
Loading
Loading