Skip to content

Commit

Permalink
sequential stage 1
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Feb 2, 2021
1 parent d4b4476 commit 4c0ea52
Showing 1 changed file with 134 additions and 61 deletions.
195 changes: 134 additions & 61 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,95 @@ def _shift_right(self, input_ids):
return shifted_input_ids


class T5StackPipeSegment(nn.Module):
def __init__(
self,
idx,
layer_module,
is_decoder,
):
super().__init__()
self.idx = idx
self.layer_module = layer_module
self.is_decoder = is_decoder

def forward(self, input):
#print(f"!!!!!!!!!!!!!!!!!!! {self.idx}")

# unpack
(
hidden_states,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
head_mask,
encoder_head_mask,
past_key_values,
use_cache,
output_attentions,
all_hidden_states,
output_hidden_states,
present_key_value_states,
) = input

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = self.layer_module(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=head_mask[self.idx],
encoder_layer_head_mask=encoder_head_mask[self.idx],
past_key_value=past_key_values[self.idx],
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]

# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights),
# (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)

if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)

# pack
outputs = (
hidden_states,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
head_mask,
encoder_head_mask,
past_key_values,
use_cache,
output_attentions,
all_hidden_states,
output_hidden_states,
present_key_value_states,
)

return outputs


class T5Stack(T5PreTrainedModel):
def __init__(self, config, embed_tokens=None):
super().__init__(config)
Expand Down Expand Up @@ -919,67 +1008,52 @@ def forward(

hidden_states = self.dropout(inputs_embeds)

for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i]
encoder_layer_head_mask = encoder_head_mask[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if position_bias is not None:
position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
if encoder_extended_attention_mask is not None:
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if layer_head_mask is not None:
layer_head_mask = layer_head_mask.to(hidden_states.device)
if encoder_layer_head_mask is not None:
encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
# rewrite the model after pre-trained weights were loaded
layers = [
T5StackPipeSegment(
idx,
layer_module,
self.is_decoder,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]

# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights),
# (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)

if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
for idx, layer_module in enumerate(self.block)
]
net = nn.Sequential(*layers)

input = (
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
head_mask,
encoder_head_mask,
past_key_values,
use_cache,
output_attentions,
all_hidden_states,
output_hidden_states,
present_key_value_states,
)
output = net(input)

# unpack
(
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
head_mask,
encoder_head_mask,
past_key_values,
use_cache,
output_attentions,
all_hidden_states,
output_hidden_states,
present_key_value_states,
) = output

hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
Expand Down Expand Up @@ -1416,7 +1490,6 @@ def deparallelize(self):
self.device_map = None
torch.cuda.empty_cache()


def pipeline_enable(self, chunks, device_map, mpu=None):
logger.info(f"enabling pipeline with chunks={chunks}")

Expand Down

0 comments on commit 4c0ea52

Please sign in to comment.