Skip to content

Commit

Permalink
polish style
Browse files Browse the repository at this point in the history
  • Loading branch information
karroyan committed Feb 8, 2023
1 parent d9c6566 commit d1df66c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 4 additions & 2 deletions ding/model/template/procedure_cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

class Block(nn.Module):

def __init__(self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, \
feedforward_hidden: int, n_feedforward: int) -> None:
def __init__(
self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int,
feedforward_hidden: int, n_feedforward: int
) -> None:
super().__init__()
self.n_att = n_att
self.n_feedforward = n_feedforward
Expand Down
7 changes: 5 additions & 2 deletions ding/model/template/tests/test_procedure_cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
class TestProcedureCloning:

def test_procedure_cloning(self, obs_shape, action_dim):
inputs = {'states': torch.randn(B, *obs_shape), 'goals': torch.randn(B, *obs_shape),\
'actions': torch.randn(B, T, action_dim)}
inputs = {
'states': torch.randn(B, *obs_shape),
'goals': torch.randn(B, *obs_shape),
'actions': torch.randn(B, T, action_dim)
}
model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim)

print(model)
Expand Down

0 comments on commit d1df66c

Please sign in to comment.