Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] FlattenObservation transform with OneHotDiscreteTensorSpec #1904

Closed
svnv-svsv-jm opened this issue Feb 12, 2024 · 8 comments
Closed

[BUG] FlattenObservation transform with OneHotDiscreteTensorSpec #1904

svnv-svsv-jm opened this issue Feb 12, 2024 · 8 comments
Assignees
Labels
bug Something isn't working Good first issue A good way to start hacking torchrl!

Comments

@svnv-svsv-jm
Copy link

svnv-svsv-jm commented Feb 12, 2024

Describe the bug

Not sure this is a bug, but I am unable to use the FlattenObservation transform when the observation_spec = CompositeSpec(observation=OneHotDiscreteTensorSpec(...))

To Reproduce

Create this env:

__all__ = ["_CustomEnv"]

from loguru import logger
import typing as ty

import torch

from tensordict import TensorDict
from torchrl.data import (
    CompositeSpec,
    UnboundedContinuousTensorSpec,
    BinaryDiscreteTensorSpec,
    OneHotDiscreteTensorSpec,
)
from torchrl.envs import EnvBase


WORST_REWARD = -1e6


class _CustomEnv(EnvBase):
    """Custom dummy environment."""

    def __init__(
        self,
        **kwargs: ty.Any,
    ) -> None:
        super().__init__(**kwargs)  # call the constructor of the base class
        # Action is a one-hot tensor
        self.action_spec = OneHotDiscreteTensorSpec(
            n=10,
            shape=(10,),
            device=self.device,
            dtype=torch.float32,
        )
        # Observation space
        observation_spec = OneHotDiscreteTensorSpec(
            n=13,
            shape=(8, 8, 13),
            device=self.device,
            dtype=torch.float32,
        )
        self.observation_spec = CompositeSpec(observation=observation_spec)
        # Unlimited reward space
        self.reward_spec = UnboundedContinuousTensorSpec(
            shape=torch.Size([1]),
            device=self.device,
            dtype=torch.float32,
        )
        # Done
        self.done_spec = BinaryDiscreteTensorSpec(
            n=1,
            shape=torch.Size([1]),
            device=self.device,
            dtype=torch.bool,
        )
        logger.debug(f"action_spec: {self.action_spec}")
        logger.debug(f"observation_spec: {self.observation_spec}")
        logger.debug(f"reward_spec: {self.reward_spec}")

    def _reset(self, tensordict: TensorDict = None, **kwargs: ty.Any) -> TensorDict:
        """The `_reset()` method potentialy takes in a `TensorDict` and some kwargs which may contain data used in the resetting of the environment and returns a new `TensorDict` with an initial observation of the environment.

        The output `TensorDict` has to be new because the input tensordict is immutable.

        Args:
            tensordict (TensorDict):
                Immutable input.

        Returns:
            TensorDict:
                Initial state.
        """
        logger.debug("Resetting environment.")
        # Return new TensorDict
        return TensorDict(
            {
                "observation": torch.zeros(
                    (8, 8, 13), dtype=self.observation_spec.dtype, device=self.device
                ),
                "reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(self.device),
                "done": False,
            },
            batch_size=torch.Size(),
            device=self.device,
        )

    def _step(self, tensordict: TensorDict) -> TensorDict:
        """The `_step()` method takes in a `TensorDict` from which it reads an action, applies the action and returns a new `TensorDict` containing the observation, reward and done signal for that timestep.

        Args:
            tensordict (TensorDict): _description_

        Returns:
            TensorDict: _description_
        """
        # Return new TensorDict
        td = TensorDict(
            {
                "observation": torch.zeros(
                    (8, 8, 13), dtype=self.observation_spec.dtype, device=self.device
                ),
                "reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(self.device),
                "done": True,
            },
            batch_size=torch.Size(),
            device=self.device,
        )
        logger.trace(f"Returning new TensorDict: {td}")
        return td

Then, transform it:

# return base_env
env = TransformedEnv(
            _CustomEnv(),
            transform=Compose(
                FlattenObservation(
                    first_dim=0,
                    last_dim=-1,
                    in_keys=self.in_keys,
                    allow_positive_dim=True,
                ),
                StepCounter(),
            ),
        )

Then use it anywhere. Even just printing will raise an error:

logger.debug(f"observation_spec: {env.observation_spec}")

Will raise:

logger.debug(f"observation_spec: {env.observation_spec}")
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/common.py:1239: in observation_spec
    observation_spec = self.output_spec["full_observation_spec"]
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:725: in output_spec
    output_spec = self.transform.transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:1025: in transform_output_spec
    output_spec = t.transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:5274: in transform_output_spec
    return super().transform_output_spec(output_spec)
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/envs/transforms/transforms.py:380: in transform_output_spec
    output_spec = output_spec.clone()
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3720: in clone
    {
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3721: in <dictcomp>
    key: item.clone() if item is not None else None
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3720: in clone
    {
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:3721: in <dictcomp>
    key: item.clone() if item is not None else None
../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:1250: in clone
    return self.__class__(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <[AttributeError("'OneHotDiscreteTensorSpec' object has no attribute 'shape'") raised in repr()] OneHotDiscreteTensorSpec object at 0x2867d5000>
n = 13, shape = torch.Size([832]), device = device(type='cpu'), dtype = torch.float32, use_register = False, mask = None

    def __init__(
        self,
        n: int,
        shape: Optional[torch.Size] = None,
        device: Optional[DEVICE_TYPING] = None,
        dtype: Optional[Union[str, torch.dtype]] = torch.bool,
        use_register: bool = False,
        mask: torch.Tensor | None = None,
    ):
        dtype, device = _default_dtype_and_device(dtype, device)
        self.use_register = use_register
        space = DiscreteBox(n)
        if shape is None:
            shape = torch.Size((space.n,))
        else:
            shape = torch.Size(shape)
            if not len(shape) or shape[-1] != space.n:
>               raise ValueError(
                    f"The last value of the shape must match n for transform of type {self.__class__}. "
                    f"Got n={space.n} and shape={shape}."
                )
E               ValueError: The last value of the shape must match n for transform of type <class 'torchrl.data.tensor_specs.OneHotDiscreteTensorSpec'>. Got n=13 and shape=torch.Size([832]).

../../../.pyenv/versions/3.10.10/envs/chess-rl/lib/python3.10/site-packages/torchrl/data/tensor_specs.py:1206: ValueError

Calling env.reset() will raise the same error, too.

Expected behavior

TorchRL should not complain that the observation has wrong size. We want it to be of the "wrong" size as we want to flatten it.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
# > 0.3.0 1.26.4 3.10.10 (main, Sep 14 2023, 16:59:47) [Clang 14.0.3 (clang-1403.0.22.14.1)] darwin

Reason and Possible fixes

Not sure, but I'd be happy to work on one.

Checklist

  • [v] I have checked that there is no similar issue in the repo (required)
  • [v] I have read the documentation (required)
  • [v] I have provided a minimal working example to reproduce the bug (required)
@svnv-svsv-jm svnv-svsv-jm added the bug Something isn't working label Feb 12, 2024
@vmoens
Copy link
Contributor

vmoens commented Feb 13, 2024

Thanks for reporting this. May I ask what self.in_keys is in FlattenObservation?

@vmoens
Copy link
Contributor

vmoens commented Feb 13, 2024

Got it, i guess it was observation

The issue is that you're trying to flatten a one-hot spec with 13 possible values, which does not make sense anymore once it has shape 8 * 8 * 13 (it won't be one-hot anymore).
The error could be more explicit although it could be hard to capture.

It works if you flatten with

                FlattenObservation(
                    first_dim=0,
                    last_dim=-2,
                    in_keys=["observation"],
                    allow_positive_dim=True,
                ),

in which case your observation has shape 64, 13.

You could also use a categorical encoding:

__all__ = ["_CustomEnv"]

import typing as ty

import torch
from tensordict import TensorDict
from torchrl.data import (
    CompositeSpec,
    UnboundedContinuousTensorSpec,
    BinaryDiscreteTensorSpec,
    DiscreteTensorSpec,
)
from torchrl.envs import EnvBase, TransformedEnv, Compose, FlattenObservation, \
    StepCounter

WORST_REWARD = -1e6


class _CustomEnv(EnvBase):
    """Custom dummy environment."""

    def __init__(
        self,
        **kwargs: ty.Any,
    ) -> None:
        super().__init__(**kwargs)  # call the constructor of the base class
        # Action is a one-hot tensor
        self.action_spec = DiscreteTensorSpec(
            n=10,
            shape=(),
            device=self.device,
            dtype=torch.float32,
        )
        # Observation space
        observation_spec = DiscreteTensorSpec(
            n=13,
            shape=(8, 8,),
            device=self.device,
            dtype=torch.float32,
        )
        self.observation_spec = CompositeSpec(observation=observation_spec)
        # Unlimited reward space
        self.reward_spec = UnboundedContinuousTensorSpec(
            shape=torch.Size([1]),
            device=self.device,
            dtype=torch.float32,
        )
        # Done
        self.done_spec = BinaryDiscreteTensorSpec(
            n=1,
            shape=torch.Size([1]),
            device=self.device,
            dtype=torch.bool,
        )

    def _reset(
        self,
        tensordict: TensorDict = None,
        **kwargs: ty.Any
        ) -> TensorDict:
        """The `_reset()` method potentialy takes in a `TensorDict` and some kwargs which may contain data used in the resetting of the environment and returns a new `TensorDict` with an initial observation of the environment.

        The output `TensorDict` has to be new because the input tensordict is immutable.

        Args:
            tensordict (TensorDict):
                Immutable input.

        Returns:
            TensorDict:
                Initial state.
        """
        # Return new TensorDict
        return TensorDict(
            {
                "observation": torch.randint(
                    13,
                    (8, 8),
                    dtype=self.observation_spec.dtype,
                    device=self.device
                    ),
                "reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(
                    self.device
                    ),
                "done": False,
            },
            batch_size=torch.Size(),
            device=self.device,
        )

    def _step(self, tensordict: TensorDict) -> TensorDict:
        """The `_step()` method takes in a `TensorDict` from which it reads an action, applies the action and returns a new `TensorDict` containing the observation, reward and done signal for that timestep.

        Args:
            tensordict (TensorDict): _description_

        Returns:
            TensorDict: _description_
        """
        # Return new TensorDict
        td = TensorDict(
            {
                "observation": torch.randint(
                    13,
                    (8, 8),
                    dtype=self.observation_spec.dtype,
                    device=self.device
                    ),
                "reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(
                    self.device
                    ),
                "done": True,
            },
            batch_size=torch.Size(),
            device=self.device,
        )
        return td

    def _set_seed(self, seed: int):
        return seed


# return base_env
env = TransformedEnv(
    _CustomEnv(),
    transform=Compose(
        FlattenObservation(
            first_dim=0,
            last_dim=-1,
            in_keys=["observation"],
            allow_positive_dim=True,
        ),
        StepCounter(),
    ),
)
print(env.rollout(3))

@svnv-svsv-jm
Copy link
Author

Nice! Great answer. This worked, thanks. Yes perhaps not much to do here then, besides having a better error message, so that people in the future won't need to open an issue here :)

@vmoens vmoens added the Good first issue A good way to start hacking torchrl! label Feb 15, 2024
@vmoens vmoens closed this as not planned Won't fix, can't repro, duplicate, stale Feb 20, 2024
@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

I'm closing this as non-planned because the exception is hard to catch unfortunately :/

@svnv-svsv-jm
Copy link
Author

Sad! Well, I produced it by flattening a set of one-hot tensors, for example from (64, 13) to (64*13,). Basically producing a concatenation of 64 one-hot tensors of length 13. While the final 832-sized tensor is not a one-hot tensor anymore (it will have 64 ones), this "concatenation"/flatten operation is needed if one wants to avoid using CNNs and just a MLP.

But I believe the indedend way to do this is to prepend a torch.nn.Flatten() to a MLP?

@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

Right I can see why you'd want to do that.
Maybe flatten should change the one hot to smth else (eg discrete binary tensor) then?
Would that work?

@svnv-svsv-jm
Copy link
Author

Oh I see what you are thinking about, converting to BinaryDiscreteTensorSpec because each element in the flattened tensor can be either 0 or 1, and due to my flatten op being a concat of more than one-hot tensor, there can be more than one 1 value anyway...

I think this would fix it, let's say it does, now it is perhaps just a matter of whether torchrl wants to allow this (maybe it can get messy?) or perhaps dedicate a doc section to this use case and offer the currently supported way for doing this.

What do you think?

@denizetkar
Copy link

A bit late into the discussion but I had a similar problem, and I think the best solution would be to recognize the input spec as OneHotDiscreteTensorSpec in FlattenTransform and convert it to the corresponding MultiOneHotDiscreteTensorSpec.

What do you guys think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Good first issue A good way to start hacking torchrl!
Projects
None yet
Development

No branches or pull requests

3 participants