Skip to content

Commit

Permalink
[BugFix] Fix device of container generated values in transforms (#1827)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 22, 2024
1 parent 3f04131 commit 55ec016
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5151,6 +5151,8 @@ def _reset(
step_count = tensordict.get(step_count_key, default=None)
if step_count is None:
step_count = self.container.observation_spec[step_count_key].zero()
if step_count.device != reset.device:
step_count = step_count.to(reset.device, non_blocking=True)

# zero the step count if reset is needed
step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0)
Expand Down Expand Up @@ -6413,7 +6415,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise ValueError(
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
)
action_spec.update_mask(mask)
action_spec.update_mask(mask.to(action_spec.device))
return tensordict

def _reset(
Expand All @@ -6424,7 +6426,10 @@ def _reset(
raise ValueError(
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
)
action_spec.update_mask(tensordict.get(self.in_keys[1], None))
mask = tensordict.get(self.in_keys[1], None)
if mask is not None:
mask = mask.to(action_spec.device)
action_spec.update_mask(mask)

# TODO: Check that this makes sense
with _set_missing_tolerance(self, True):
Expand Down

2 comments on commit 55ec016

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 55ec016 Previous: 7b21e52 Ratio
benchmarks/test_objectives_benchmarks.py::test_dqn_speed 133.11258439843198 iter/sec (stddev: 0.0009929794650010305) 535.9056817308965 iter/sec (stddev: 0.00038493854065690963) 4.03
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed 69.31635481741358 iter/sec (stddev: 0.0006998807802134125) 285.1020505805778 iter/sec (stddev: 0.0005619492767101129) 4.11
benchmarks/test_objectives_benchmarks.py::test_sac_speed 33.8450787841355 iter/sec (stddev: 0.0015474214657498202) 94.04144234196644 iter/sec (stddev: 0.0008278346101444233) 2.78
benchmarks/test_objectives_benchmarks.py::test_redq_speed 22.26250296065981 iter/sec (stddev: 0.0009597197888727151) 58.01109386044425 iter/sec (stddev: 0.0007185355348622557) 2.61
benchmarks/test_objectives_benchmarks.py::test_cql_speed 11.400037568823802 iter/sec (stddev: 0.0009102490007192575) 30.10759966629096 iter/sec (stddev: 0.0022295479867929066) 2.64
benchmarks/test_objectives_benchmarks.py::test_a2c_speed 37.62635043581897 iter/sec (stddev: 0.001458144806372114) 137.4074793786421 iter/sec (stddev: 0.0009561488809686526) 3.65
benchmarks/test_objectives_benchmarks.py::test_ppo_speed 37.05785567804892 iter/sec (stddev: 0.0007563101487669529) 133.16456956894947 iter/sec (stddev: 0.0005391927674600333) 3.59
benchmarks/test_objectives_benchmarks.py::test_reinforce_speed 38.78001081002825 iter/sec (stddev: 0.0006305962214068141) 184.30364872008184 iter/sec (stddev: 0.0007618179358653057) 4.75
benchmarks/test_objectives_benchmarks.py::test_iql_speed 15.814734731671571 iter/sec (stddev: 0.0021633193784848146) 37.378933948733426 iter/sec (stddev: 0.0017055008100570284) 2.36

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 55ec016 Previous: c96227a Ratio
benchmarks/test_objectives_benchmarks.py::test_values[td1_return_estimate-False-False] 18.64255690903953 iter/sec (stddev: 0.00013905361523456646) 74.75725850926985 iter/sec (stddev: 0.00013471599660878715) 4.01
benchmarks/test_objectives_benchmarks.py::test_values[td_lambda_return_estimate-True-False] 11.796465464704035 iter/sec (stddev: 0.0012667642584596774) 31.23680299738478 iter/sec (stddev: 0.00037885553338176436) 2.65
benchmarks/test_objectives_benchmarks.py::test_dqn_speed 135.62554338626444 iter/sec (stddev: 0.00007936252528534101) 613.6111725454206 iter/sec (stddev: 0.00027172557791339716) 4.52
benchmarks/test_objectives_benchmarks.py::test_ddpg_speed 69.28730556361555 iter/sec (stddev: 0.00019571181684099073) 362.13043770984643 iter/sec (stddev: 0.0002507883812106999) 5.23
benchmarks/test_objectives_benchmarks.py::test_sac_speed 34.32035628432647 iter/sec (stddev: 0.00023460000299078233) 119.2734990010058 iter/sec (stddev: 0.00032456047032026244) 3.48
benchmarks/test_objectives_benchmarks.py::test_redq_speed 20.995643425047042 iter/sec (stddev: 0.0005643656703533141) 65.1519127201971 iter/sec (stddev: 0.010589940721970557) 3.10
benchmarks/test_objectives_benchmarks.py::test_td3_speed 50.56676985264592 iter/sec (stddev: 0.000917379146741299) 103.50766122687905 iter/sec (stddev: 0.00020496404646321763) 2.05
benchmarks/test_objectives_benchmarks.py::test_cql_speed 12.10124030759098 iter/sec (stddev: 0.0004095658815692692) 37.45782352839336 iter/sec (stddev: 0.00043570247630634137) 3.10
benchmarks/test_objectives_benchmarks.py::test_a2c_speed 37.96887410430874 iter/sec (stddev: 0.0002527496907725071) 182.33459735190257 iter/sec (stddev: 0.0003460794900758151) 4.80
benchmarks/test_objectives_benchmarks.py::test_ppo_speed 37.52007790730031 iter/sec (stddev: 0.00029315157400287964) 170.58806038161623 iter/sec (stddev: 0.0002479578253287511) 4.55
benchmarks/test_objectives_benchmarks.py::test_reinforce_speed 38.90255081058404 iter/sec (stddev: 0.0002114853016190441) 231.59253813756877 iter/sec (stddev: 0.00020121137865884492) 5.95
benchmarks/test_objectives_benchmarks.py::test_iql_speed 17.508464496280975 iter/sec (stddev: 0.000469554301362209) 46.76456396119653 iter/sec (stddev: 0.000517310042232497) 2.67

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.