Skip to content

Commit

Permalink
Hope this is better
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 committed Jul 4, 2022
1 parent 8ddfba9 commit cc465ca
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def merge_layers(layers, num_heads: int, hidden_size: int):
return torch.reshape(
torch.cat(
[
layer.view(num_heads, hidden_size // num_heads, 1)
layer.view(num_heads, 1, hidden_size // num_heads)
for layer in layers
],
dim=1
Expand All @@ -124,7 +124,7 @@ def merge_layers(layers, num_heads: int, hidden_size: int):
return torch.reshape(
torch.cat(
[
layer.view(num_heads, hidden_size // num_heads, 1, hidden_size)
layer.view(num_heads, 1, hidden_size // num_heads, hidden_size)
for layer in layers
],
dim=1
Expand Down

0 comments on commit cc465ca

Please sign in to comment.