Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 1d96ae4 commit e10b734
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _main(argv):
"dm_control": ["dm_control"],
"gym_continuous": ["gymnasium<1.0", "mujoco"],
"rendering": ["moviepy<2.0.0"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy", "pytest-mock", "pytest-cov", "pytest-benchmark", "pytest-rerunfailures", "pytest-error-for-skips", ""],
"utils": [
"tensorboard",
"wandb",
Expand Down
16 changes: 6 additions & 10 deletions test/test_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,10 @@ def test_tensordict_tokenizer(
"Lettuce in, it's cold out here!",
]
}
if not truncation and return_tensordict and max_length == 10:
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
out = process(example)
return
out = process(example)
if return_tensordict:
if not truncation and return_tensordict and max_length == 10:
assert out.get("input_ids").shape[-1] == -1
elif return_tensordict:
assert out.get("input_ids").shape[-1] == max_length
else:
obj = out.get("input_ids")
Expand Down Expand Up @@ -346,12 +344,10 @@ def test_prompt_tensordict_tokenizer(
],
"label": ["right", "wrong", "right", "wrong", "right"],
}
if not truncation and return_tensordict and max_length == 10:
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
out = process(example)
return
out = process(example)
if return_tensordict:
if not truncation and return_tensordict and max_length == 10:
assert out.get("input_ids").shape[-1] == -1
elif return_tensordict:
assert out.get("input_ids").shape[-1] == max_length
else:
obj = out.get("input_ids")
Expand Down

0 comments on commit e10b734

Please sign in to comment.