Skip to content

Commit

Permalink
fix loss test case for batch size variation (Lightning-Universe#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidhantls authored and chris-clem committed Dec 9, 2020
1 parent 1bd9ae7 commit a4d7e01
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/models/rl/unit/test_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def setUp(self) -> None:
def test_loss(self):
"""Test the reinforce loss function"""

batch_states = torch.rand(32, 4)
batch_actions = torch.rand(32).long()
batch_qvals = torch.rand(32)
batch_states = torch.rand(16, 4)
batch_actions = torch.rand(16).long()
batch_qvals = torch.rand(16)

loss = self.model.loss(batch_states, batch_actions, batch_qvals)

Expand Down

0 comments on commit a4d7e01

Please sign in to comment.