Skip to content

Commit

Permalink
[BugFix] Fix listing of updated keys in collectors (#2460)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 1, 2024
1 parent 1858bea commit 97ccbb7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
10 changes: 7 additions & 3 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2980,7 +2980,7 @@ def make_and_test_policy(
envs,
policy=policy,
total_frames=1000,
frames_per_batch=100,
frames_per_batch=10,
policy_device=policy_device,
env_device=env_device,
device=device,
Expand Down Expand Up @@ -3044,11 +3044,15 @@ def make_and_test_policy(
make_and_test_policy(policy, env_device=original_device)

# If the policy is a CudaGraphModule, we know it's on cuda - no need to warn
if torch.cuda.is_available():
if torch.cuda.is_available() and collector_type is SyncDataCollector:
with pytest.warns(UserWarning, match="Tensordict is registered in PyTree"):
policy = make_policy(original_device)
cudagraph_policy = CudaGraphModule(policy)
make_and_test_policy(cudagraph_policy, policy_device=original_device)
make_and_test_policy(
cudagraph_policy,
policy_device=original_device,
env_device=shared_device,
)


if __name__ == "__main__":
Expand Down
13 changes: 8 additions & 5 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,11 +832,13 @@ def check_exclusive(val):
# changed them here).
# This will cause a failure to update entries when policy and env device mismatch and
# casting is necessary.
def filter_policy(value_output, value_input, value_input_clone):
if (
(value_input is None)
or (value_output is not value_input)
or ~torch.isclose(value_output, value_input_clone).any()
def filter_policy(name, value_output, value_input, value_input_clone):
if (value_input is None) or (
(value_output is not value_input)
and (
value_output.device != value_input_clone.device
or ~torch.isclose(value_output, value_input_clone).any()
)
):
return value_output

Expand All @@ -846,6 +848,7 @@ def filter_policy(value_output, value_input, value_input_clone):
policy_input_clone,
default=None,
filter_empty=True,
named=True,
)
self._policy_output_keys = list(
self._policy_output_keys.union(
Expand Down

0 comments on commit 97ccbb7

Please sign in to comment.