Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 25, 2024
2 parents 102003e + 3e99960 commit 8ec6725
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 51 deletions.
92 changes: 80 additions & 12 deletions benchmarks/test_objectives_benchmarks.py

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,18 @@ def _main(argv):
"dm_control": ["dm_control"],
"gym_continuous": ["gymnasium<1.0", "mujoco"],
"rendering": ["moviepy<2.0.0"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
"tests": [
"pytest",
"pyyaml",
"pytest-instafail",
"scipy",
"pytest-mock",
"pytest-cov",
"pytest-benchmark",
"pytest-rerunfailures",
"pytest-error-for-skips",
"",
],
"utils": [
"tensorboard",
"wandb",
Expand Down
119 changes: 119 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4459,6 +4459,69 @@ def test_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
def test_sac_terminating(
self, action_key, observation_key, reward_key, done_key, terminated_key, version
):
torch.manual_seed(self.seed)
td = self._create_mock_data_sac(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
)

actor = self._create_mock_actor(
observation_key=observation_key, action_key=action_key
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
action_key=action_key,
out_keys=["state_action_value"],
)
if version == 1:
value = self._create_mock_value(observation_key=observation_key)
else:
value = None

loss = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
)
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
)

torch.manual_seed(self.seed)

SoftUpdate(loss, eps=0.5)

done = td.get(("next", done_key))
while not (done.any() and not done.all()):
done.bernoulli_(0.1)
obs_nan = td.get(("next", terminated_key))
obs_nan[done.squeeze(-1)] = float("nan")

kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": done,
f"next_{terminated_key}": obs_nan,
f"next_{observation_key}": td.get(("next", observation_key)),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
assert loss(td).isfinite().all()

def test_state_dict(self, version):
if version == 1:
pytest.skip("Test not implemented for version 1.")
Expand Down Expand Up @@ -5112,6 +5175,62 @@ def test_discrete_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
def test_discrete_sac_terminating(
self, action_key, observation_key, reward_key, done_key, terminated_key
):
torch.manual_seed(self.seed)
td = self._create_mock_data_sac(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
)

actor = self._create_mock_actor(
observation_key=observation_key, action_key=action_key
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
)

loss = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
num_actions=actor.spec[action_key].space.n,
action_space="one-hot",
)
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
)

SoftUpdate(loss, eps=0.5)

torch.manual_seed(0)
done = td.get(("next", done_key))
while not (done.any() and not done.all()):
done = done.bernoulli_(0.1)
obs_none = td.get(("next", observation_key))
obs_none[done.squeeze(-1)] = float("nan")
kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": done,
f"next_{terminated_key}": td.get(("next", terminated_key)),
f"next_{observation_key}": obs_none,
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
assert loss(td).isfinite().all()

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_discrete_sac_reduction(self, reduction):
torch.manual_seed(self.seed)
Expand Down
16 changes: 6 additions & 10 deletions test/test_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,10 @@ def test_tensordict_tokenizer(
"Lettuce in, it's cold out here!",
]
}
if not truncation and return_tensordict and max_length == 10:
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
out = process(example)
return
out = process(example)
if return_tensordict:
if not truncation and return_tensordict and max_length == 10:
assert out.get("input_ids").shape[-1] == -1
elif return_tensordict:
assert out.get("input_ids").shape[-1] == max_length
else:
obj = out.get("input_ids")
Expand Down Expand Up @@ -346,12 +344,10 @@ def test_prompt_tensordict_tokenizer(
],
"label": ["right", "wrong", "right", "wrong", "right"],
}
if not truncation and return_tensordict and max_length == 10:
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
out = process(example)
return
out = process(example)
if return_tensordict:
if not truncation and return_tensordict and max_length == 10:
assert out.get("input_ids").shape[-1] == -1
elif return_tensordict:
assert out.get("input_ids").shape[-1] == max_length
else:
obj = out.get("input_ids")
Expand Down
81 changes: 79 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4265,6 +4265,14 @@ def __new__(cls, *args, **kwargs):
cls._locked = False
return super().__new__(cls)

@property
def batch_size(self):
return self._shape

@batch_size.setter
def batch_size(self, value: torch.Size):
self._shape = value

@property
def shape(self):
return self._shape
Expand All @@ -4286,8 +4294,22 @@ def shape(self, value: torch.Size):
)
self._shape = _size(value)

def is_empty(self):
"""Whether the composite spec contains specs or not."""
def is_empty(self, recurse: bool = False):
"""Whether the composite spec contains specs or not.
Args:
recurse (bool): whether to recursively assess if the spec is empty.
If ``True``, will return ``True`` if there are no leaves. If ``False``
(default) will return whether there is any spec defined at the root level.
"""
if recurse:
for spec in self._specs.values():
if spec is None:
continue
if isinstance(spec, Composite) and spec.is_empty(recurse=True):
continue
return False
return len(self._specs) == 0

@property
Expand All @@ -4297,6 +4319,61 @@ def ndim(self):
def ndimension(self):
return len(self.shape)

def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any:
"""Removes and returns the value associated with the specified key from the composite spec.
This method searches for the given key in the composite spec, removes it, and returns its associated value.
If the key is not found, it returns the provided default value if specified, otherwise raises a `KeyError`.
Args:
key (NestedKey):
The key to be removed from the composite spec. It can be a single key or a nested key.
default (Any, optional):
The value to return if the specified key is not found in the composite spec.
If not provided and the key is not found, a `KeyError` is raised.
Returns:
Any: The value associated with the specified key that was removed from the composite spec.
Raises:
KeyError: If the specified key is not found in the composite spec and no default value is provided.
"""
key = unravel_key(key)
if key in self.keys(True, True):
result = self[key]
del self[key]
return result
elif default is not NO_DEFAULT:
return default
raise KeyError(f"{key} not found in composite spec.")

def separates(self, *keys: NestedKey, default: Any = None) -> Composite:
"""Splits the composite spec by extracting specified keys and their associated values into a new composite spec.
This method iterates over the provided keys, removes them from the current composite spec, and adds them to a new
composite spec. If a key is not found, the specified default value is used. The new composite spec is returned.
Args:
*keys (NestedKey):
One or more keys to be extracted from the composite spec. Each key can be a single key or a nested key.
default (Any, optional):
The value to use if a specified key is not found in the composite spec. Defaults to `None`.
Returns:
Composite: A new composite spec containing the extracted keys and their associated values.
Note:
If none of the specified keys are found, the method returns `None`.
"""
out = None
for key in keys:
result = self.pop(key, default=default)
if result is not None:
if out is None:
out = Composite(batch_size=self.batch_size, device=self.device)
out[key] = result
return out

def set(self, name, spec):
if self.locked:
raise RuntimeError("Cannot modify a locked Composite.")
Expand Down
Loading

0 comments on commit 8ec6725

Please sign in to comment.