Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 9, 2024
1 parent 5351230 commit 784d024
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 41 deletions.
71 changes: 45 additions & 26 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5114,7 +5114,9 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
pass

@pytest.mark.parametrize("has_in_keys,", [True, False])
@pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3])
@pytest.mark.parametrize(
"reset_keys,", [[("some", "nested", "reset")], ["_reset"] * 3, None]
)
def test_trans_multi_key(
self, has_in_keys, reset_keys, n_workers=2, batch_size=(3, 2), max_steps=5
):
Expand All @@ -5136,9 +5138,9 @@ def test_trans_multi_key(
)
with pytest.raises(
ValueError, match="Could not match the env reset_keys"
) if reset_keys is None else contextlib.nullcontext():
) if reset_keys == [("some", "nested", "reset")] else contextlib.nullcontext():
check_env_specs(env)
if reset_keys is not None:
if reset_keys != [("some", "nested", "reset")]:
td = env.rollout(max_steps, policy=policy)
for reward_key in env.reward_keys:
reward_key = _unravel_key_to_tuple(reward_key)
Expand Down Expand Up @@ -9955,16 +9957,27 @@ def test_transform_inverse(self):


class TestDeviceCastTransformPart(TransformBase):
@pytest.fixture(scope="class")
def _cast_device(self):
if torch.cuda.is_available():
yield torch.device("cuda:0")
elif torch.backends.mps.is_available():
yield torch.device("mps:0")
else:
yield torch.device("cpu:1")

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_single_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -9978,12 +9991,14 @@ def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_i
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_serial_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_serial_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10000,13 +10015,13 @@ def make_env():
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_parallel_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10032,14 +10047,16 @@ def make_env():
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_serial_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_trans_serial_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(
SerialEnv(2, make_env),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10054,7 +10071,7 @@ def make_env():
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_parallel_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")
Expand All @@ -10066,7 +10083,7 @@ def make_env():
mp_start_method=mp_ctx if not torch.cuda.is_available() else "spawn",
),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10082,8 +10099,8 @@ def make_env():
except RuntimeError:
pass

def test_transform_no_env(self):
t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"])
def test_transform_no_env(self, _cast_device):
t = DeviceCastTransform(_cast_device, "cpu:0", in_keys=["a"], out_keys=["b"])
td = TensorDict({"a": torch.randn((), device="cpu:0")}, [], device="cpu:0")
tdt = t._call(td)
assert tdt.device is None
Expand All @@ -10092,26 +10109,28 @@ def test_transform_no_env(self):
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_transform_env(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
assert env.transform.device == torch.device("cpu:1")
assert env.transform.device == _cast_device
assert env.transform.orig_device == torch.device("cpu:0")

def test_transform_compose(self):
def test_transform_compose(self, _cast_device):
t = Compose(
DeviceCastTransform(
"cpu:1",
_cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
Expand All @@ -10123,7 +10142,7 @@ def test_transform_compose(self):
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
"c": torch.randn((), device=_cast_device),
},
[],
device="cpu:0",
Expand All @@ -10134,11 +10153,11 @@ def test_transform_compose(self):
assert tdt.device is None
assert tdit.device is None

def test_transform_model(self):
def test_transform_model(self, _cast_device):
t = nn.Sequential(
Compose(
DeviceCastTransform(
"cpu:1",
_cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
Expand All @@ -10161,11 +10180,11 @@ def test_transform_model(self):

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize("storage", [LazyTensorStorage])
def test_transform_rb(self, rbclass, storage):
def test_transform_rb(self, rbclass, storage, _cast_device):
# we don't test casting to cuda on Memmap tensor storage since it's discouraged
t = Compose(
DeviceCastTransform(
"cpu:1",
_cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
Expand All @@ -10178,7 +10197,7 @@ def test_transform_rb(self, rbclass, storage):
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
"c": torch.randn((), device=_cast_device),
},
[],
device="cpu:0",
Expand Down
51 changes: 36 additions & 15 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3893,6 +3893,19 @@ class DeviceCastTransform(Transform):
a parent environment exists, it it retrieved from it. In all other cases,
it remains unspecified.
Keyword Args:
in_keys (list of NestedKey): the list of entries to map to a different device.
Defaults to ``None``.
out_keys (list of NestedKey): the output names of the entries mapped onto a device.
Defaults to the values of ``in_keys``.
in_keys_inv (list of NestedKey): the list of entries to map to a different device.
``in_keys_inv`` are the names expected by the base environment.
Defaults to ``None``.
out_keys_inv (list of NestedKey): the output names of the entries mapped onto a device.
``out_keys_inv`` are the names of the keys as seen from outside the transformed env.
Defaults to the values of ``in_keys_inv``.
Examples:
>>> td = TensorDict(
... {'obs': torch.ones(1, dtype=torch.double),
Expand Down Expand Up @@ -3920,6 +3933,10 @@ def __init__(
self.orig_device = (
torch.device(orig_device) if orig_device is not None else orig_device
)
if out_keys is None:
out_keys = copy(in_keys)
if out_keys_inv is None:
out_keys_inv = copy(in_keys_inv)
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
Expand Down Expand Up @@ -4043,52 +4060,54 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
if self._map_env_device:
return input_spec.to(self.device)
else:
input_spec.clear_device_()
return super().transform_input_spec(input_spec)

def transform_action_spec(self, full_action_spec: Composite) -> Composite:
full_action_spec = full_action_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_action_spec.keys(True, True):
continue
full_action_spec[out_key] = full_action_spec[in_key].to(self.device)
local_action_spec = full_action_spec.get(in_key, None)
if local_action_spec is not None:
full_action_spec[out_key] = local_action_spec.to(self.device)
return full_action_spec

def transform_state_spec(self, full_state_spec: Composite) -> Composite:
full_state_spec = full_state_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_state_spec.keys(True, True):
continue
full_state_spec[out_key] = full_state_spec[in_key].to(self.device)
local_state_spec = full_state_spec.get(in_key, None)
if local_state_spec is not None:
full_state_spec[out_key] = local_state_spec.to(self.device)
return full_state_spec

def transform_output_spec(self, output_spec: Composite) -> Composite:
if self._map_env_device:
return output_spec.to(self.device)
else:
output_spec.clear_device_()
return super().transform_output_spec(output_spec)

def transform_observation_spec(self, observation_spec: Composite) -> Composite:
observation_spec = observation_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in observation_spec.keys(True, True):
continue
observation_spec[out_key] = observation_spec[in_key].to(self.device)
local_obs_spec = observation_spec.get(in_key, None)
if local_obs_spec is not None:
observation_spec[out_key] = local_obs_spec.to(self.device)
return observation_spec

def transform_done_spec(self, full_done_spec: Composite) -> Composite:
full_done_spec = full_done_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_done_spec.keys(True, True):
continue
full_done_spec[out_key] = full_done_spec[in_key].to(self.device)
local_done_spec = full_done_spec.get(in_key, None)
if local_done_spec is not None:
full_done_spec[out_key] = local_done_spec.to(self.device)
return full_done_spec

def transform_reward_spec(self, full_reward_spec: Composite) -> Composite:
full_reward_spec = full_reward_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_reward_spec.keys(True, True):
continue
full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device)
local_reward_spec = full_reward_spec.get(in_key, None)
if local_reward_spec is not None:
full_reward_spec[out_key] = local_reward_spec.to(self.device)
return full_reward_spec

def transform_env_device(self, device):
Expand Down Expand Up @@ -5494,6 +5513,8 @@ def reset_keys(self):
# We take the filtered reset keys, which are the only keys that really
# matter when calling reset, and check that they match the in_keys root.
reset_keys = parent._filtered_reset_keys
if len(reset_keys) == 1:
reset_keys = list(reset_keys) * len(self.in_keys)

def _check_match(reset_keys, in_keys):
# if this is called, the length of reset_keys and in_keys must match
Expand Down

0 comments on commit 784d024

Please sign in to comment.