Skip to content

Commit

Permalink
Small change to Wav2Vec2 model to support Tensor-Parallelism with Dee…
Browse files Browse the repository at this point in the history
…pSpeed (#14298)

* minor modification to the wav2vec2 modeling file to support tensor-parallelism with DeepSpeed on this HuggingFace model

* refine the comments

* synch changes

* fix comments

* refine comments

* fix format
  • Loading branch information
RezaYazdaniAminabadi authored Nov 9, 2021
1 parent d0e96c6 commit a503012
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 30 deletions.
8 changes: 6 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -257,7 +258,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -1330,7 +1331,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -259,7 +260,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -257,7 +258,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -474,7 +475,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/m2m_100/modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -328,7 +329,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -274,7 +275,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -263,7 +264,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -274,7 +275,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -473,7 +474,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -341,7 +342,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -281,7 +282,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -542,7 +543,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -543,7 +544,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,8 @@ def forward(
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -579,7 +580,10 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

Expand Down

0 comments on commit a503012

Please sign in to comment.