-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] FlattenObservation
transform with OneHotDiscreteTensorSpec
#1904
Comments
Thanks for reporting this. May I ask what |
Got it, i guess it was observation The issue is that you're trying to flatten a one-hot spec with 13 possible values, which does not make sense anymore once it has shape 8 * 8 * 13 (it won't be one-hot anymore). It works if you flatten with FlattenObservation(
first_dim=0,
last_dim=-2,
in_keys=["observation"],
allow_positive_dim=True,
), in which case your observation has shape You could also use a categorical encoding: __all__ = ["_CustomEnv"]
import typing as ty
import torch
from tensordict import TensorDict
from torchrl.data import (
CompositeSpec,
UnboundedContinuousTensorSpec,
BinaryDiscreteTensorSpec,
DiscreteTensorSpec,
)
from torchrl.envs import EnvBase, TransformedEnv, Compose, FlattenObservation, \
StepCounter
WORST_REWARD = -1e6
class _CustomEnv(EnvBase):
"""Custom dummy environment."""
def __init__(
self,
**kwargs: ty.Any,
) -> None:
super().__init__(**kwargs) # call the constructor of the base class
# Action is a one-hot tensor
self.action_spec = DiscreteTensorSpec(
n=10,
shape=(),
device=self.device,
dtype=torch.float32,
)
# Observation space
observation_spec = DiscreteTensorSpec(
n=13,
shape=(8, 8,),
device=self.device,
dtype=torch.float32,
)
self.observation_spec = CompositeSpec(observation=observation_spec)
# Unlimited reward space
self.reward_spec = UnboundedContinuousTensorSpec(
shape=torch.Size([1]),
device=self.device,
dtype=torch.float32,
)
# Done
self.done_spec = BinaryDiscreteTensorSpec(
n=1,
shape=torch.Size([1]),
device=self.device,
dtype=torch.bool,
)
def _reset(
self,
tensordict: TensorDict = None,
**kwargs: ty.Any
) -> TensorDict:
"""The `_reset()` method potentialy takes in a `TensorDict` and some kwargs which may contain data used in the resetting of the environment and returns a new `TensorDict` with an initial observation of the environment.
The output `TensorDict` has to be new because the input tensordict is immutable.
Args:
tensordict (TensorDict):
Immutable input.
Returns:
TensorDict:
Initial state.
"""
# Return new TensorDict
return TensorDict(
{
"observation": torch.randint(
13,
(8, 8),
dtype=self.observation_spec.dtype,
device=self.device
),
"reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(
self.device
),
"done": False,
},
batch_size=torch.Size(),
device=self.device,
)
def _step(self, tensordict: TensorDict) -> TensorDict:
"""The `_step()` method takes in a `TensorDict` from which it reads an action, applies the action and returns a new `TensorDict` containing the observation, reward and done signal for that timestep.
Args:
tensordict (TensorDict): _description_
Returns:
TensorDict: _description_
"""
# Return new TensorDict
td = TensorDict(
{
"observation": torch.randint(
13,
(8, 8),
dtype=self.observation_spec.dtype,
device=self.device
),
"reward": torch.Tensor([0]).to(self.reward_spec.dtype).to(
self.device
),
"done": True,
},
batch_size=torch.Size(),
device=self.device,
)
return td
def _set_seed(self, seed: int):
return seed
# return base_env
env = TransformedEnv(
_CustomEnv(),
transform=Compose(
FlattenObservation(
first_dim=0,
last_dim=-1,
in_keys=["observation"],
allow_positive_dim=True,
),
StepCounter(),
),
)
print(env.rollout(3)) |
Nice! Great answer. This worked, thanks. Yes perhaps not much to do here then, besides having a better error message, so that people in the future won't need to open an issue here :) |
I'm closing this as non-planned because the exception is hard to catch unfortunately :/ |
Sad! Well, I produced it by flattening a set of one-hot tensors, for example from But I believe the indedend way to do this is to prepend a |
Right I can see why you'd want to do that. |
Oh I see what you are thinking about, converting to I think this would fix it, let's say it does, now it is perhaps just a matter of whether What do you think? |
A bit late into the discussion but I had a similar problem, and I think the best solution would be to recognize the input spec as What do you guys think? |
Describe the bug
Not sure this is a bug, but I am unable to use the
FlattenObservation
transform when theobservation_spec = CompositeSpec(observation=OneHotDiscreteTensorSpec(...))
To Reproduce
Create this env:
Then, transform it:
Then use it anywhere. Even just printing will raise an error:
Will raise:
Calling
env.reset()
will raise the same error, too.Expected behavior
TorchRL should not complain that the observation has wrong size. We want it to be of the "wrong" size as we want to flatten it.
System info
Describe the characteristic of your environment:
Reason and Possible fixes
Not sure, but I'd be happy to work on one.
Checklist
The text was updated successfully, but these errors were encountered: