Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix-snapshot-nightly
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 19, 2024
2 parents 072d91f + c3ffb5a commit 469a0eb
Show file tree
Hide file tree
Showing 24 changed files with 1,975 additions and 143 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ to be able to create this other composition:
RewardScaling
RewardSum
Reward2GoTransform
RemoveEmptySpecs
SelectTransform
SignTransform
SqueezeTransform
Expand Down
13 changes: 11 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,18 @@ def _main(argv):
url="https://github.com/pytorch/rl",
long_description=long_description,
long_description_content_type="text/markdown",
license="BSD",
license="MIT",
# Package info
packages=find_packages(exclude=("test", "tutorials")),
packages=find_packages(
exclude=(
"test",
"tutorials",
"docs",
"examples",
"knowledge_base",
"packaging",
)
),
ext_modules=get_extensions(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
Expand Down
2 changes: 1 addition & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _set_gym_environments(): # noqa: F811
PONG_VERSIONED = "ALE/Pong-v5"


@implement_for("gymnasium", "0.27.0", None)
@implement_for("gymnasium")
def _set_gym_environments(): # noqa: F811
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED

Expand Down
4 changes: 3 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def _step(
assert (a.sum(-1) == 1).all()

obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
tensordict = tensordict.empty() # empty tensordict
tensordict = tensordict.empty()

tensordict.set(self.out_key, self._get_out_obs(obs))
tensordict.set(self._out_key, self._get_out_obs(obs))
Expand Down Expand Up @@ -603,6 +603,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
# state = torch.zeros(self.size) + self.counter
if tensordict is None:
tensordict = TensorDict({}, self.batch_size, device=self.device)

tensordict = tensordict.empty()
tensordict.update(self.observation_spec.rand())
# tensordict.set("next_" + self.out_key, self._get_out_obs(state))
Expand All @@ -622,6 +623,7 @@ def _step(
a = tensordict.get("action")

obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)

tensordict = tensordict.empty() # empty tensordict

tensordict.set(self.out_key, self._get_out_obs(obs))
Expand Down
Loading

0 comments on commit 469a0eb

Please sign in to comment.