Skip to content

Commit

Permalink
Fix torch.export issue in dpt based models
Browse files Browse the repository at this point in the history
Signed-off-by: Phillip Kuznetsov <[email protected]>
  • Loading branch information
philkuz committed Oct 12, 2024
1 parent 617b212 commit aa7d562
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,16 @@ def forward(self, hidden_states):
hidden_states = hidden_states[::-1]

fused_hidden_states = []
# first layer only uses the last hidden_state
fused_hidden_state = self.layers[0](hidden_states[0])
fused_hidden_states.append(fused_hidden_state)
fused_hidden_state = None
# looping from the last layer to the second
for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
for idx, (hidden_state, layer) in enumerate(
zip(hidden_states, self.layers)
):
if idx == 0:
# first layer only uses the last hidden_state
fused_hidden_state = layer(hidden_state)
fused_hidden_states.append(fused_hidden_state)
continue
fused_hidden_state = layer(fused_hidden_state, hidden_state)
fused_hidden_states.append(fused_hidden_state)

Expand Down

0 comments on commit aa7d562

Please sign in to comment.