From 1d07f937b549135a6d281edb44f87122f12a2b8f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 20 Jan 2021 20:44:48 -0800 Subject: [PATCH 01/11] pipe working with 1 chunk --- src/transformers/modeling_utils.py | 62 +++++ src/transformers/models/t5/modeling_t5.py | 271 ++++++++++++++-------- 2 files changed, 241 insertions(+), 92 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d0fc1ad0f4b297..6ef2825936ddc1 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1785,3 +1785,65 @@ def forward(self, hidden_states): return torch.cat(output_chunks, dim=chunk_dim) return forward_fn(*input_tensors) + +def recursive_to(device, item): + """ + Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, + tuple and dict structures. + """ + + if torch.is_tensor(item): + return item.to(device) + + elif isinstance(item, list): + for i, x in enumerate(item): + item[i] = recursive_to(device, x) + return item + + elif isinstance(item, tuple): + return tuple(recursive_to(device, list(item))) + + elif isinstance(item, dict): + for k, v in item.items(): + item[k] = recursive_to(device, v) + return item + + else: + return item + +#tnone = torch.tensor([float('nan')]*batch_size) +def pipe_none_or_empty_to_torch(x, batch_size, device): + tnone = torch.tensor([-100]*batch_size).to(device) + tempty = torch.empty(0).to(device) + if x is None: + return tnone.to(device) + if x == (): + return tempty.to(device) + return x + +def pipe_torch_to_none_or_empty(x, batch_size, device): + tnone = torch.tensor([-100]*batch_size).to(device) + #tempty = torch.empty(0).to(device) + # if torch.is_tensor(x): + # print(x.shape, x) + # else: + # print(x) + if torch.is_tensor(x) and x.shape[0] == batch_size: + if not x.numel(): + return () + # print(x.numel(), batch_size, x, tnone) + if x.shape == tnone.shape and all(x == tnone): + return None + return x + +def pipe_encode_all(input, batch_size, device): + input = list(input) + for i, x in enumerate(input): + input[i] = pipe_none_or_empty_to_torch(x, batch_size, device) + return tuple(input) + +def pipe_decode_all(input, batch_size, device): + input = list(input) + for i, x in enumerate(input): + input[i] = pipe_torch_to_none_or_empty(x, batch_size, device) + return tuple(input) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ea405ed6e7ae0..2b8dd63d2801ad 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -38,11 +38,15 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer, recursive_to from ...utils import logging from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_t5 import T5Config +from torch.distributed.pipeline.sync import Pipe + +from ...modeling_utils import pipe_none_or_empty_to_torch, pipe_torch_to_none_or_empty, pipe_encode_all, pipe_decode_all + logger = logging.get_logger(__name__) @@ -774,6 +778,103 @@ def _shift_right(self, input_ids): return shifted_input_ids +class T5StackPipeSegment(nn.Module): + def __init__(self, idx, layer_module, is_decoder, head_mask, output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add): + super().__init__() + #self.batch_id = 0 + self.idx = idx + self.layer_module = layer_module + self.is_decoder = is_decoder + self.head_mask = head_mask + #self.past_key_value = past_key_value + self.output_hidden_states = output_hidden_states + self.use_cache = use_cache + self.output_attentions = output_attentions + self.all_hidden_states_add = all_hidden_states_add + self.present_key_value_states_add = present_key_value_states_add + self.all_attentions_add = all_attentions_add + self.all_cross_attentions_add = all_cross_attentions_add + + def forward(self, inputs): + inputs = pipe_decode_all(inputs, inputs[0].shape[0], inputs[0].device) + hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = inputs + idx = self.idx + + #self.past_key_value = recursive_to(inputs[0].device, self.past_key_value) + self.head_mask = recursive_to(inputs[0].device, self.head_mask) + + # crazy restore: XXX: fix hardcoded numbers - 6 for number of blocks - 4 should always be there + if past_key_values_p1 is not None and past_key_values_p2 is not None: + past_key_value = (past_key_values_p1.chunk(6, 1)[self.idx].chunk(2, 1) + + past_key_values_p2.chunk(6, 1)[self.idx].chunk(2, 1)) + else: + past_key_value = None + + # if self.past_key_value is not None: + # past_key_value = tuple(self.past_key_value[i][self.batch_id] for i in self.past_key_value) + # else: + # past_key_value=None + # self.batch_id += 1 + + # # restore None's if any + # position_bias = None if len(position_bias.shape) == 1 else position_bias + # encoder_hidden_states = None if len(encoder_hidden_states.shape) == 1 else encoder_hidden_states + # encoder_attention_mask = None if len(encoder_attention_mask.shape) == 1 else encoder_attention_mask + # encoder_decoder_position_bias = None if len(encoder_decoder_position_bias.shape) == 1 else encoder_decoder_position_bias + + # all_hidden_states = () if torch.is_tensor(all_hidden_states) and len(all_hidden_states.shape) == 1 else all_hidden_states + # present_key_value_states = () if torch.is_tensor(present_key_value_states) and len(present_key_value_states.shape) == 1 else present_key_value_states + # all_attentions = () if len(all_attentions.shape) == 1 else all_attentions + # all_cross_attentions = () if len(all_cross_attentions.shape) == 1 else all_cross_attentions + + + if self.output_hidden_states: + self.all_hidden_states_add(hidden_states) + + layer_outputs = self.layer_module(hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + head_mask=self.head_mask, + past_key_value=past_key_value, + use_cache=self.use_cache, + output_attentions=self.output_attentions) + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention weights), + # (self-attention position bias), (cross-attention weights), (cross-attention position bias) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] + # append next layer key value states + if self.use_cache: + self.present_key_value_states_add(present_key_value_state) + + if self.output_attentions: + self.all_attentions_add(layer_outputs[3]) + if self.is_decoder: + self.all_cross_attentions_add(layer_outputs[5]) + + #tnone = torch.tensor([float('nan')]*out_shape) + # tnone = torch.tensor([-100]*out_shape).to(hidden_states.device) + # position_bias = tnone if position_bias is None else position_bias + # encoder_hidden_states = tnone if encoder_hidden_states is None else encoder_hidden_states + # encoder_attention_mask = tnone if encoder_attention_mask is None else encoder_attention_mask + # encoder_decoder_position_bias = tnone if encoder_decoder_position_bias is None else encoder_decoder_position_bias + # all_hidden_states = tnone if all_hidden_states is None else all_hidden_states + # present_key_value_states = tnone if present_key_value_states is None else present_key_value_states + # all_attentions = tnone if all_attentions is None else all_attentions + # all_cross_attentions = tnone if all_cross_attentions is None else all_cross_attentions + + outputs = (hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias) + outputs = pipe_encode_all(outputs, hidden_states.shape[0], hidden_states.device) + return outputs + class T5Stack(T5PreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) @@ -784,46 +885,15 @@ def __init__(self, config, embed_tokens=None): self.block = nn.ModuleList( [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] ) + + #self.is_pipeline = False + self.is_pipeline = True + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) self.init_weights() - # Model parallel - self.model_parallel = False - self.device_map = None - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - # Load onto devices - for k, v in self.device_map.items(): - for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - # Set embed_tokens to first layer - self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device - self.final_layer_norm = self.final_layer_norm.to(self.last_device) - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def deparallelize(self): self.model_parallel = False - self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") - torch.cuda.empty_cache() def get_input_embeddings(self): return self.embed_tokens @@ -893,8 +963,8 @@ def forward( ) # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + #if past_key_values is None: + # past_key_values = [None] * len(self.block) # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) @@ -915,64 +985,81 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - head_mask=head_mask[i], - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention weights), - # (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) + def all_hidden_states_add(x): + nonlocal all_hidden_states + all_hidden_states += (x,) + def present_key_value_states_add(x): + nonlocal present_key_value_states + present_key_value_states += (x,) + #print(x.shape for x in present_key_value_states) + #print(present_key_value_states) + def all_attentions_add(x): + nonlocal all_attentions + all_attentions += (x,) + def all_cross_attentions_add(x): + nonlocal all_cross_attentions + all_cross_attentions += (x,) + + # crazy flattening of 2 level tuples so that the batch dimension is first to be spliced upon and then restored on the other side + if past_key_values is not None: + x1 = tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in [0,1]) + #for i in x1: print(i.shape) + x2 = tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in [2,3]) + #for i in x2: print(i.shape) + past_key_values_p1 = torch.cat(x1, 1) + past_key_values_p2 = torch.cat(x2, 1) + #input = torch.cat(tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in range(len(past_key_values[i]))), 1) + else: + past_key_values_p1 = None + past_key_values_p2 = None + + # rewrite the model after pre-trained weights were loaded + layers = [T5StackPipeSegment(idx, layer_module, self.is_decoder, head_mask[idx], output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add) for idx, layer_module in enumerate(self.block)] + self.block_sequential = nn.Sequential(*layers) + + # for now don't enable the pipe + if self.is_pipeline: + # XXX: switch pipe segments to other devices + device_idx = 0 + segments = len(self.block_sequential) + #print(f"segments: {segments}") + #for i, layer in enumerate(self.block_sequential): + # layer.to(0) # XXX: change to multiple GPUs when things work on one gpu + layers0 = nn.Sequential(self.block_sequential[0:3]).to(0) + layers1 = nn.Sequential(self.block_sequential[3:6]).to(0) + self.block_sequential = nn.Sequential(layers0, layers1) + + self.block_pipe = Pipe(self.block_sequential, chunks=1, checkpoint="never") + + inputs = (hidden_states, extended_attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias) + #, all_hidden_states, present_key_value_states, all_attentions, all_cross_attentions) + inputs = pipe_encode_all(inputs, batch_size, input_ids.device) + + if self.is_pipeline: + outputs = self.block_pipe(inputs) + outputs = outputs.local_value() + outputs = recursive_to(input_ids.device, outputs) + else: + outputs = self.block_sequential(inputs) + + outputs = pipe_decode_all(outputs, batch_size, input_ids.device) + hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = outputs hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) + # if self.is_pipeline and present_key_value_states is not None: + # new_x = () + # x = list(present_key_value_states) + # for i in range(int(len(present_key_value_states)/2)): + # new_y = () + # a = x[i*2] + # b = x[i*2+1] + # for j in range(len(a)): + # new_y += (torch.cat((a[j].to(0), b[j].to(0)), 0),) + # new_x += (new_y,) + # present_key_value_states = new_x + # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From 0e4e71d8c94511def1bf712df174020a8d3b312d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 20 Jan 2021 22:39:12 -0800 Subject: [PATCH 02/11] pipe working with 2 chunks --- src/transformers/models/t5/modeling_t5.py | 78 ++++++++++++++++------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 2b8dd63d2801ad..9fd9f80d496b8b 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -781,7 +781,7 @@ def _shift_right(self, input_ids): class T5StackPipeSegment(nn.Module): def __init__(self, idx, layer_module, is_decoder, head_mask, output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add): super().__init__() - #self.batch_id = 0 + self.batch_id = -1 self.idx = idx self.layer_module = layer_module self.is_decoder = is_decoder @@ -796,8 +796,9 @@ def __init__(self, idx, layer_module, is_decoder, head_mask, output_hidden_state self.all_cross_attentions_add = all_cross_attentions_add def forward(self, inputs): + self.batch_id += 1 inputs = pipe_decode_all(inputs, inputs[0].shape[0], inputs[0].device) - hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = inputs + hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = inputs idx = self.idx #self.past_key_value = recursive_to(inputs[0].device, self.past_key_value) @@ -814,7 +815,6 @@ def forward(self, inputs): # past_key_value = tuple(self.past_key_value[i][self.batch_id] for i in self.past_key_value) # else: # past_key_value=None - # self.batch_id += 1 # # restore None's if any # position_bias = None if len(position_bias.shape) == 1 else position_bias @@ -853,7 +853,12 @@ def forward(self, inputs): encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] # append next layer key value states if self.use_cache: - self.present_key_value_states_add(present_key_value_state) + print(idx, self.batch_id) + # present_key_values_p1 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) + # present_key_values_p2 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) + present_key_values_p1 = torch.cat(present_key_value_state[0:2], 1) + present_key_values_p2 = torch.cat(present_key_value_state[2:4], 1) + self.present_key_value_states_add(present_key_value_state, idx, self.batch_id) if self.output_attentions: self.all_attentions_add(layer_outputs[3]) @@ -871,7 +876,7 @@ def forward(self, inputs): # all_attentions = tnone if all_attentions is None else all_attentions # all_cross_attentions = tnone if all_cross_attentions is None else all_cross_attentions - outputs = (hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias) + outputs = (hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias) outputs = pipe_encode_all(outputs, hidden_states.shape[0], hidden_states.device) return outputs @@ -988,9 +993,13 @@ def forward( def all_hidden_states_add(x): nonlocal all_hidden_states all_hidden_states += (x,) - def present_key_value_states_add(x): + + present_key_value_states = [[0 for x in range(2)] for y in range(6)] + def present_key_value_states_add(x, block_id, micro_batch_id): nonlocal present_key_value_states - present_key_value_states += (x,) + #present_key_value_states += (x,) + present_key_value_states[block_id][micro_batch_id] = x + #present_key_value_states += (x,) #print(x.shape for x in present_key_value_states) #print(present_key_value_states) def all_attentions_add(x): @@ -1002,9 +1011,9 @@ def all_cross_attentions_add(x): # crazy flattening of 2 level tuples so that the batch dimension is first to be spliced upon and then restored on the other side if past_key_values is not None: - x1 = tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in [0,1]) + x1 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [0,1]) #for i in x1: print(i.shape) - x2 = tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in [2,3]) + x2 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [2,3]) #for i in x2: print(i.shape) past_key_values_p1 = torch.cat(x1, 1) past_key_values_p2 = torch.cat(x2, 1) @@ -1013,6 +1022,10 @@ def all_cross_attentions_add(x): past_key_values_p1 = None past_key_values_p2 = None + # batch_size=2, blocks=6, fixed=2 + present_key_values_p1 = torch.empty(2, 6*2).to(0) + present_key_values_p2 = torch.empty(2, 6*2).to(0) + # rewrite the model after pre-trained weights were loaded layers = [T5StackPipeSegment(idx, layer_module, self.is_decoder, head_mask[idx], output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add) for idx, layer_module in enumerate(self.block)] self.block_sequential = nn.Sequential(*layers) @@ -1026,12 +1039,12 @@ def all_cross_attentions_add(x): #for i, layer in enumerate(self.block_sequential): # layer.to(0) # XXX: change to multiple GPUs when things work on one gpu layers0 = nn.Sequential(self.block_sequential[0:3]).to(0) - layers1 = nn.Sequential(self.block_sequential[3:6]).to(0) + layers1 = nn.Sequential(self.block_sequential[3:6]).to(1) self.block_sequential = nn.Sequential(layers0, layers1) - self.block_pipe = Pipe(self.block_sequential, chunks=1, checkpoint="never") + self.block_pipe = Pipe(self.block_sequential, chunks=2, checkpoint="never") - inputs = (hidden_states, extended_attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias) + inputs = (hidden_states, extended_attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias) #, all_hidden_states, present_key_value_states, all_attentions, all_cross_attentions) inputs = pipe_encode_all(inputs, batch_size, input_ids.device) @@ -1043,22 +1056,39 @@ def all_cross_attentions_add(x): outputs = self.block_sequential(inputs) outputs = pipe_decode_all(outputs, batch_size, input_ids.device) - hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = outputs + hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = outputs hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - # if self.is_pipeline and present_key_value_states is not None: - # new_x = () - # x = list(present_key_value_states) - # for i in range(int(len(present_key_value_states)/2)): - # new_y = () - # a = x[i*2] - # b = x[i*2+1] - # for j in range(len(a)): - # new_y += (torch.cat((a[j].to(0), b[j].to(0)), 0),) - # new_x += (new_y,) - # present_key_value_states = new_x + if present_key_values_p1 is not None: + x1 = present_key_values_p1.chunk(6, 1) + finalx1 = tuple(x.chunk(2, 1) for x in x1) + + #present_key_values_p1 = present_key_values_p1.chunk(6, 1).chunk(2, 1) + #present_key_values_p2 + + if self.is_pipeline and present_key_value_states is not None and present_key_value_states[0][0] != 0: + print() + new_x = () + x = present_key_value_states + for block in present_key_value_states: + new_y = () + for j in (0, 1, 2, 3): + new_y += (torch.cat((block[0][j].to(0), block[1][j].to(0)), 0),) + new_x += (new_y,) + present_key_value_states = new_x + + # new_x = () + # x = list(present_key_value_states) + # for i in range(int(len(present_key_value_states)/2)): + # new_y = () + # a = x[i*2] + # b = x[i*2+1] + # for j in range(len(a)): + # new_y += (torch.cat((a[j].to(0), b[j].to(0)), 0),) + # new_x += (new_y,) + # present_key_value_states = new_x # Add last layer if output_hidden_states: From f9fb99dd6da894f0efb6ee43e1906a1ae0f38fd8 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 21 Jan 2021 20:45:39 -0800 Subject: [PATCH 03/11] parametrized version, yay! --- src/transformers/models/t5/modeling_t5.py | 96 ++++++++++++++--------- 1 file changed, 59 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 9fd9f80d496b8b..ed72089e5363ba 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -779,10 +779,11 @@ def _shift_right(self, input_ids): class T5StackPipeSegment(nn.Module): - def __init__(self, idx, layer_module, is_decoder, head_mask, output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add): + def __init__(self, idx, n_layers, layer_module, is_decoder, head_mask, output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add): super().__init__() self.batch_id = -1 self.idx = idx + self.n_layers = n_layers self.layer_module = layer_module self.is_decoder = is_decoder self.head_mask = head_mask @@ -804,10 +805,10 @@ def forward(self, inputs): #self.past_key_value = recursive_to(inputs[0].device, self.past_key_value) self.head_mask = recursive_to(inputs[0].device, self.head_mask) - # crazy restore: XXX: fix hardcoded numbers - 6 for number of blocks - 4 should always be there + # crazy restore: XXX: fix hardcoded numbers - self.n_layers for number of blocks - 2+2 should always be there if past_key_values_p1 is not None and past_key_values_p2 is not None: - past_key_value = (past_key_values_p1.chunk(6, 1)[self.idx].chunk(2, 1) + - past_key_values_p2.chunk(6, 1)[self.idx].chunk(2, 1)) + past_key_value = (past_key_values_p1.chunk(self.n_layers, 1)[self.idx].chunk(2, 1) + + past_key_values_p2.chunk(self.n_layers, 1)[self.idx].chunk(2, 1)) else: past_key_value = None @@ -853,7 +854,7 @@ def forward(self, inputs): encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] # append next layer key value states if self.use_cache: - print(idx, self.batch_id) + #print(idx, self.batch_id) # present_key_values_p1 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) # present_key_values_p2 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) present_key_values_p1 = torch.cat(present_key_value_state[0:2], 1) @@ -990,11 +991,15 @@ def forward( hidden_states = self.dropout(inputs_embeds) + # PP + n_layers = len(self.block) + n_chunks = 4 + def all_hidden_states_add(x): nonlocal all_hidden_states all_hidden_states += (x,) - present_key_value_states = [[0 for x in range(2)] for y in range(6)] + present_key_value_states = [[0 for x in range(n_chunks)] for y in range(n_layers)] def present_key_value_states_add(x, block_id, micro_batch_id): nonlocal present_key_value_states #present_key_value_states += (x,) @@ -1002,6 +1007,7 @@ def present_key_value_states_add(x, block_id, micro_batch_id): #present_key_value_states += (x,) #print(x.shape for x in present_key_value_states) #print(present_key_value_states) + def all_attentions_add(x): nonlocal all_attentions all_attentions += (x,) @@ -1022,38 +1028,58 @@ def all_cross_attentions_add(x): past_key_values_p1 = None past_key_values_p2 = None - # batch_size=2, blocks=6, fixed=2 - present_key_values_p1 = torch.empty(2, 6*2).to(0) - present_key_values_p2 = torch.empty(2, 6*2).to(0) + # batch_size=2, blocks=self.n_layers, fixed=2 (2+2 keys) + present_key_values_p1 = torch.empty(batch_size, n_layers*2).to(0) + present_key_values_p2 = torch.empty(batch_size, n_layers*2).to(0) # rewrite the model after pre-trained weights were loaded - layers = [T5StackPipeSegment(idx, layer_module, self.is_decoder, head_mask[idx], output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add) for idx, layer_module in enumerate(self.block)] - self.block_sequential = nn.Sequential(*layers) + layers = [T5StackPipeSegment(idx, n_layers, layer_module, self.is_decoder, head_mask[idx], output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add) for idx, layer_module in enumerate(self.block)] + #block_sequential = nn.Sequential(*layers) # for now don't enable the pipe if self.is_pipeline: # XXX: switch pipe segments to other devices device_idx = 0 - segments = len(self.block_sequential) + #segments = len(layers) #print(f"segments: {segments}") - #for i, layer in enumerate(self.block_sequential): + #for i, layer in enumerate(block_sequential): # layer.to(0) # XXX: change to multiple GPUs when things work on one gpu - layers0 = nn.Sequential(self.block_sequential[0:3]).to(0) - layers1 = nn.Sequential(self.block_sequential[3:6]).to(1) - self.block_sequential = nn.Sequential(layers0, layers1) - self.block_pipe = Pipe(self.block_sequential, chunks=2, checkpoint="never") + # XXX: later will have a map - for now just roughly split + # XXX: Need to build it once outside the model + n_gpus = torch.cuda.device_count() + if n_gpus < 2: + assert "Need at least 2 gpus to use the pipeline" + + devices = list(range(n_gpus)) + layer_ids = list(range(n_layers)) + layer_splits = [layer_ids[i*n_layers // n_gpus: (i+1)*n_layers // n_gpus] for i in range(n_gpus)] + for device_id, layer_partition in zip(devices, layer_splits): + for layer_id in layer_partition: + #print(f"{layer_id} => {device_id}") + layers[layer_id].to(device_id) + + # XXX: fix this to match the number of GPUs + # layers0 = nn.Sequential(block_sequential[0:3]).to(0) + # layers1 = nn.Sequential(block_sequential[3:6]).to(1) + #block_sequential = nn.Sequential(layers0, layers1) + #block_sequential = nn.Sequential(*layer_stack) + + block_sequential = nn.Sequential(*layers) + block_pipe = Pipe(block_sequential, chunks=n_chunks, checkpoint="never") + else: + block_sequential = nn.Sequential(*layers) inputs = (hidden_states, extended_attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias) #, all_hidden_states, present_key_value_states, all_attentions, all_cross_attentions) inputs = pipe_encode_all(inputs, batch_size, input_ids.device) if self.is_pipeline: - outputs = self.block_pipe(inputs) + outputs = block_pipe(inputs) outputs = outputs.local_value() outputs = recursive_to(input_ids.device, outputs) else: - outputs = self.block_sequential(inputs) + outputs = block_sequential(inputs) outputs = pipe_decode_all(outputs, batch_size, input_ids.device) hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = outputs @@ -1061,35 +1087,31 @@ def all_cross_attentions_add(x): hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - if present_key_values_p1 is not None: - x1 = present_key_values_p1.chunk(6, 1) - finalx1 = tuple(x.chunk(2, 1) for x in x1) + # if present_key_values_p1 is not None: + # x1 = present_key_values_p1.chunk(n_layers, 1) + # finalx1 = tuple(x.chunk(2, 1) for x in x1) - #present_key_values_p1 = present_key_values_p1.chunk(6, 1).chunk(2, 1) - #present_key_values_p2 + # #present_key_values_p1 = present_key_values_p1.chunk(n_layers, 1).chunk(2, 1) + # #present_key_values_p2 if self.is_pipeline and present_key_value_states is not None and present_key_value_states[0][0] != 0: - print() + #print() + # reconstruct the flattened tensor to tuple of tuples of tensors new_x = () x = present_key_value_states + # XXX: check we aren't inserting random garbage for the uneven last batch for block in present_key_value_states: new_y = () for j in (0, 1, 2, 3): - new_y += (torch.cat((block[0][j].to(0), block[1][j].to(0)), 0),) + entries = tuple(block[i][j].to(0) for i in range(n_chunks)) + new_y += (torch.cat(entries, 0),) + # new_y += (torch.cat((block[0][j].to(0), + # block[1][j].to(0), + # block[2][j].to(0), + # ), 0),) new_x += (new_y,) present_key_value_states = new_x - # new_x = () - # x = list(present_key_value_states) - # for i in range(int(len(present_key_value_states)/2)): - # new_y = () - # a = x[i*2] - # b = x[i*2+1] - # for j in range(len(a)): - # new_y += (torch.cat((a[j].to(0), b[j].to(0)), 0),) - # new_x += (new_y,) - # present_key_value_states = new_x - # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From f935feb761bc3d60e39fa7ebe8c7e50906ff2173 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 23 Jan 2021 16:48:45 -0800 Subject: [PATCH 04/11] parametrize --- src/transformers/models/t5/modeling_t5.py | 129 +++++++++++++++++----- src/transformers/trainer.py | 29 +++++ src/transformers/training_args.py | 2 + 3 files changed, 130 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index ed72089e5363ba..6f933f9a05ab23 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -901,6 +901,15 @@ def __init__(self, config, embed_tokens=None): self.init_weights() self.model_parallel = False + self.pipeline_chunks = 0 + self.device_map = None + self.pipeline_is_enabled = False + + def pipeline_params(self, chunks, device_map): + self.pipeline_chunks = chunks + self.device_map = device_map + self.pipeline_is_enabled = True + def get_input_embeddings(self): return self.embed_tokens @@ -991,15 +1000,23 @@ def forward( hidden_states = self.dropout(inputs_embeds) + if not self.pipeline_is_enabled: + self.pipeline_chunks = 1 + # PP + # handle batches (usually last) that are shorter than pipeline_chunks + if batch_size < self.pipeline_chunks: + # XXX: Is it always last? so that we don't override user's chunks setting unless it's the last batch + self.pipeline_chunks = batch_size + n_layers = len(self.block) - n_chunks = 4 + #n_chunks = 4 def all_hidden_states_add(x): nonlocal all_hidden_states all_hidden_states += (x,) - present_key_value_states = [[0 for x in range(n_chunks)] for y in range(n_layers)] + present_key_value_states = [[0 for x in range(self.pipeline_chunks)] for y in range(n_layers)] def present_key_value_states_add(x, block_id, micro_batch_id): nonlocal present_key_value_states #present_key_value_states += (x,) @@ -1037,36 +1054,16 @@ def all_cross_attentions_add(x): #block_sequential = nn.Sequential(*layers) # for now don't enable the pipe - if self.is_pipeline: - # XXX: switch pipe segments to other devices - device_idx = 0 - #segments = len(layers) - #print(f"segments: {segments}") - #for i, layer in enumerate(block_sequential): - # layer.to(0) # XXX: change to multiple GPUs when things work on one gpu - - # XXX: later will have a map - for now just roughly split - # XXX: Need to build it once outside the model - n_gpus = torch.cuda.device_count() - if n_gpus < 2: - assert "Need at least 2 gpus to use the pipeline" - - devices = list(range(n_gpus)) - layer_ids = list(range(n_layers)) - layer_splits = [layer_ids[i*n_layers // n_gpus: (i+1)*n_layers // n_gpus] for i in range(n_gpus)] - for device_id, layer_partition in zip(devices, layer_splits): + if self.pipeline_is_enabled: + + # print("using partitioning: ", dict(zip(devices, layer_splits))) + for device_id, layer_partition in self.device_map.items(): for layer_id in layer_partition: #print(f"{layer_id} => {device_id}") layers[layer_id].to(device_id) - # XXX: fix this to match the number of GPUs - # layers0 = nn.Sequential(block_sequential[0:3]).to(0) - # layers1 = nn.Sequential(block_sequential[3:6]).to(1) - #block_sequential = nn.Sequential(layers0, layers1) - #block_sequential = nn.Sequential(*layer_stack) - block_sequential = nn.Sequential(*layers) - block_pipe = Pipe(block_sequential, chunks=n_chunks, checkpoint="never") + block_pipe = Pipe(block_sequential, chunks=self.pipeline_chunks, checkpoint="never") else: block_sequential = nn.Sequential(*layers) @@ -1074,7 +1071,7 @@ def all_cross_attentions_add(x): #, all_hidden_states, present_key_value_states, all_attentions, all_cross_attentions) inputs = pipe_encode_all(inputs, batch_size, input_ids.device) - if self.is_pipeline: + if self.pipeline_is_enabled: outputs = block_pipe(inputs) outputs = outputs.local_value() outputs = recursive_to(input_ids.device, outputs) @@ -1094,7 +1091,8 @@ def all_cross_attentions_add(x): # #present_key_values_p1 = present_key_values_p1.chunk(n_layers, 1).chunk(2, 1) # #present_key_values_p2 - if self.is_pipeline and present_key_value_states is not None and present_key_value_states[0][0] != 0: + #if self.pipeline_is_enabled and present_key_value_states is not None and present_key_value_states[0][0] != 0: + if present_key_value_states is not None and present_key_value_states[0][0] != 0: #print() # reconstruct the flattened tensor to tuple of tuples of tensors new_x = () @@ -1103,7 +1101,7 @@ def all_cross_attentions_add(x): for block in present_key_value_states: new_y = () for j in (0, 1, 2, 3): - entries = tuple(block[i][j].to(0) for i in range(n_chunks)) + entries = tuple(block[i][j].to(0) for i in range(self.pipeline_chunks)) new_y += (torch.cat(entries, 0),) # new_y += (torch.cat((block[0][j].to(0), # block[1][j].to(0), @@ -1496,6 +1494,77 @@ def __init__(self, config): self.model_parallel = False self.device_map = None + self.pipeline_is_enabled = False + + def pipeline_enable(self, chunks, device_map): + logger.info(f"enabling pipeline with chunks={chunks}") + + # XXX: should be a separate function + import torch + n_gpus = torch.cuda.device_count() + if n_gpus < 2: + raise ValueError("Need at least 2 gpus to use the pipeline") + + if device_map is not None: + logger.info(f"using user-provided device_map") + else: + def make_device_map(n_gpus, n_layers): + print(f"making default device map: n_gpus={n_gpus}, n_layers={n_layers}") + devices = list(range(n_gpus)) + layer_ids = list(range(n_layers)) + # XXX: later will have a map - for now just roughly split + # XXX: probably should balance more so that the 0th gpu has the least number of layers, rather than the last one, because 0th gpu is already very busy + layer_splits = [layer_ids[i*n_layers // n_gpus: (i+1)*n_layers // n_gpus] for i in range(n_gpus)] + return dict(zip(devices, layer_splits)) + + + # XXX: for now assume encode/decoder symmetry - later fix to build each one separately + n_layers = len(self.encoder.block) + device_map = make_device_map(n_gpus, n_layers) + + self.device_map = device_map + self.pipeline_is_enabled = True + logger.info(f"using pipeline partitioning: {device_map}") + + # XXX: validate chunks is a good arg + + self.encoder.pipeline_params(chunks=chunks, device_map=device_map) + self.decoder.pipeline_params(chunks=chunks, device_map=device_map) + + # XXX for now hardcoded the RPC setup here - but it should happen in the trainer instead + import os + import torch + from torch.distributed import rpc + os.environ.update({"MASTER_ADDR": "localhost"}) + os.environ.update({"MASTER_PORT": "10638"}) + rpc.init_rpc( + "worker", + rank=0, + world_size=1, + ) + num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 2 + device = torch.device("cuda") + + def pipeline_finalize(self): + # XXX: should reset the max counter + # reset_peak_stats() + import torch + import json + n_gpus = torch.cuda.device_count() + mem_map = {} + for id in range(n_gpus): + with torch.cuda.device(id): + # XXX: this doesn't seem to report the right thing - getting much lower numbers + mem_map[id] = torch.cuda.max_memory_allocated() >> 20 + + logger.info(f"peak memory usage per device in MBs:\n{json.dumps(mem_map, sort_keys=True, indent=4)}") + # reset for the next train/eval/predict stage + # XXX: probably should do in the trainer? + torch.cuda.reset_peak_memory_stats() + + # XXX: would be great to add gpu utilization stats as well + + @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): self.device_map = ( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a6ca42a9b105de..608d48a9051417 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -276,6 +276,25 @@ def __init__( else: self.is_model_parallel = False + # XXX: for now hack over naive MP to have the same behavior + if len(self.args.pipeline): + # using range() syntax for upper boundary (i.e. not inclusive) + # --pipeline "chunks=5; device_map=0:0-10,1:10-20" + self.is_model_parallel = True + + chunks_str, *device_map_str = self.args.pipeline.split() + chunks = int(chunks_str.split('=')[1]) + device_map = None + if len(device_map_str): + device_map_range = device_map_str[0].split('=')[1] + device_map_range_str = device_map_range.split(',') + device_map = {} + for x in device_map_range_str: + device_id, layers = x.split(':') + device_map[int(device_id)] = list(range(*map(int, layers.split('-')))) + + model.pipeline_enable(chunks=chunks, device_map=device_map) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -289,6 +308,10 @@ def __init__( # Force n_gpu to 1 to avoid DataParallel. self.args._n_gpu = 1 + if len(self.args.pipeline): + model = model.to(args.device) + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model self.model = model @@ -999,6 +1022,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D # add remaining tr_loss self._total_loss_scalar += tr_loss.item() + if len(self.args.pipeline): + model.pipeline_finalize() + return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): @@ -1631,6 +1657,9 @@ def prediction_loop( if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + if len(self.args.pipeline): + model.pipeline_finalize() + return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) def _gather_and_numpify(self, tensors, name): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index abef9e35c67db9..8e7a3b868c3e6a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -350,6 +350,8 @@ class TrainingArguments: ) debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"}) + pipeline: str = field(default="", metadata={"help": "Whether to enable Pipeline Parallelism and the value is pipeline params: 'chunks=5; device_map=0:1-10,1:11-20'"}) + dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) From 1c7becbc2528f734037acfd2bfe16a7ce6ccd51a Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 23 Jan 2021 18:21:41 -0800 Subject: [PATCH 05/11] fix edge case --- src/transformers/models/t5/modeling_t5.py | 32 ++++++++++++++--------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 6f933f9a05ab23..4aad6647fd914a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -798,6 +798,7 @@ def __init__(self, idx, n_layers, layer_module, is_decoder, head_mask, output_hi def forward(self, inputs): self.batch_id += 1 + #print(f"micro BS: {inputs[0].shape[0]}") inputs = pipe_decode_all(inputs, inputs[0].shape[0], inputs[0].device) hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = inputs idx = self.idx @@ -904,6 +905,7 @@ def __init__(self, config, embed_tokens=None): self.pipeline_chunks = 0 self.device_map = None self.pipeline_is_enabled = False + self.pipeline_batch_size = None def pipeline_params(self, chunks, device_map): self.pipeline_chunks = chunks @@ -930,6 +932,9 @@ def forward( output_hidden_states=None, return_dict=None, ): + + #print(f"mini BS: {input_ids.shape[0]}") + # Model parallel if self.model_parallel: torch.cuda.set_device(self.first_device) @@ -1000,14 +1005,16 @@ def forward( hidden_states = self.dropout(inputs_embeds) - if not self.pipeline_is_enabled: + if self.pipeline_is_enabled: + # handle batches (usually last) that are shorter than pipeline_chunks + if batch_size < self.pipeline_chunks: + # # XXX: Is it always last? so that we don't override user's chunks setting unless it's the last batch + self.pipeline_chunks = 1 + else: + # non-pipeline run is the same as chunks=1 batch size-wise self.pipeline_chunks = 1 # PP - # handle batches (usually last) that are shorter than pipeline_chunks - if batch_size < self.pipeline_chunks: - # XXX: Is it always last? so that we don't override user's chunks setting unless it's the last batch - self.pipeline_chunks = batch_size n_layers = len(self.block) #n_chunks = 4 @@ -1016,6 +1023,8 @@ def all_hidden_states_add(x): nonlocal all_hidden_states all_hidden_states += (x,) + # handle batches (usually last) that are can't be equally divided by pipeline_chunks + present_key_value_states = [[0 for x in range(self.pipeline_chunks)] for y in range(n_layers)] def present_key_value_states_add(x, block_id, micro_batch_id): nonlocal present_key_value_states @@ -1096,17 +1105,16 @@ def all_cross_attentions_add(x): #print() # reconstruct the flattened tensor to tuple of tuples of tensors new_x = () - x = present_key_value_states - # XXX: check we aren't inserting random garbage for the uneven last batch + # deal with unpredictable potential last short batch + real_chunks = 0 + for i in range(self.pipeline_chunks): + if not present_key_value_states[0][i] == 0: + real_chunks += 1 for block in present_key_value_states: new_y = () for j in (0, 1, 2, 3): - entries = tuple(block[i][j].to(0) for i in range(self.pipeline_chunks)) + entries = tuple(block[i][j].to(0) for i in range(real_chunks)) new_y += (torch.cat(entries, 0),) - # new_y += (torch.cat((block[0][j].to(0), - # block[1][j].to(0), - # block[2][j].to(0), - # ), 0),) new_x += (new_y,) present_key_value_states = new_x From fda68307c7a8ebecb46149787e088da40f3365e0 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 25 Jan 2021 11:03:00 -0800 Subject: [PATCH 06/11] missing commit --- src/transformers/models/t5/modeling_t5.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 58bf28c842abc8..e72551e613d5fc 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -781,14 +781,15 @@ def _shift_right(self, input_ids): class T5StackPipeSegment(nn.Module): - def __init__(self, idx, n_layers, layer_module, is_decoder, head_mask, output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add): + def __init__(self, idx, n_layers, layer_module, is_decoder, layer_head_mask, encoder_layer_head_mask, output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add): super().__init__() self.batch_id = -1 self.idx = idx self.n_layers = n_layers self.layer_module = layer_module self.is_decoder = is_decoder - self.head_mask = head_mask + self.layer_head_mask = layer_head_mask + self.encoder_layer_head_mask = encoder_layer_head_mask #self.past_key_value = past_key_value self.output_hidden_states = output_hidden_states self.use_cache = use_cache @@ -806,7 +807,8 @@ def forward(self, inputs): idx = self.idx #self.past_key_value = recursive_to(inputs[0].device, self.past_key_value) - self.head_mask = recursive_to(inputs[0].device, self.head_mask) + self.layer_head_mask = recursive_to(inputs[0].device, self.layer_head_mask) + self.encoder_layer_head_mask = recursive_to(inputs[0].device, self.encoder_layer_head_mask) # crazy restore: XXX: fix hardcoded numbers - self.n_layers for number of blocks - 2+2 should always be there if past_key_values_p1 is not None and past_key_values_p2 is not None: @@ -841,7 +843,8 @@ def forward(self, inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, - head_mask=self.head_mask, + layer_head_mask=self.layer_head_mask, + encoder_layer_head_mask=self.encoder_layer_head_mask, past_key_value=past_key_value, use_cache=self.use_cache, output_attentions=self.output_attentions) @@ -1063,7 +1066,7 @@ def all_cross_attentions_add(x): present_key_values_p2 = torch.empty(batch_size, n_layers*2).to(0) # rewrite the model after pre-trained weights were loaded - layers = [T5StackPipeSegment(idx, n_layers, layer_module, self.is_decoder, head_mask[idx], output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add) for idx, layer_module in enumerate(self.block)] + layers = [T5StackPipeSegment(idx, n_layers, layer_module, self.is_decoder, head_mask[idx], encoder_head_mask[idx], output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add) for idx, layer_module in enumerate(self.block)] #block_sequential = nn.Sequential(*layers) # for now don't enable the pipe From 1ef2ef211bf63315c434f2a8653dd4979117bbe2 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 28 Jan 2021 13:15:33 -0800 Subject: [PATCH 07/11] add mpu + style --- src/transformers/integrations.py | 122 ++++++++++++ src/transformers/modeling_utils.py | 15 +- src/transformers/models/t5/modeling_t5.py | 224 ++++++++++++++++------ src/transformers/trainer.py | 11 +- src/transformers/training_args.py | 7 +- 5 files changed, 304 insertions(+), 75 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index af7909fca23093..a27ee5280a8adc 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -252,6 +252,119 @@ def rewrite_logs(d): return new_d +# adjusted from Megatron-LM/mpu/ +import torch + + +# Model parallel group that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + +class MPU: + def __init__(self, n_gpus): + self.n_gpus = n_gpus + + def initialize_model_parallel(self): + """ + Initialize model data parallel groups. + + Arguments: + model_parallel_size: number of GPUs used to parallelize model. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model. The present function will + create 4 model parallel groups and 2 data parallel grous as: + 4 model parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 data parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + + model_parallel_size_ = self.n_gpus + + def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + + if torch.distributed.get_rank() == 0: + print("> initializing model parallel with size {}".format(model_parallel_size_)) + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + model_parallel_size = min(model_parallel_size_, world_size) + ensure_divisibility(world_size, model_parallel_size) + rank = torch.distributed.get_rank() + + # Build the data parallel groups. + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + for i in range(model_parallel_size): + ranks = range(i, world_size, model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank % model_parallel_size): + _DATA_PARALLEL_GROUP = group + + # Build the model parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + for i in range(world_size // model_parallel_size): + ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // model_parallel_size): + _MODEL_PARALLEL_GROUP = group + + def model_parallel_is_initialized(self): + """Check if model and data parallel groups are initialized.""" + if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + def get_model_parallel_group(self): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP + + def get_data_parallel_group(self): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP + + def get_model_parallel_world_size(self): + """Return world size for the model parallel group.""" + return torch.distributed.get_world_size(group=self.get_model_parallel_group()) + + def get_model_parallel_rank(self): + """Return my rank for the model parallel group.""" + return torch.distributed.get_rank(group=self.get_model_parallel_group()) + + def get_model_parallel_src_rank(self): + """Calculate the global rank corresponding to a local rank zeor + in the model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + def get_data_parallel_world_size(self): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=self.get_data_parallel_group()) + + def get_data_parallel_rank(self): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=self.get_data_parallel_group()) + + def destroy_model_parallel(self): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + + def init_deepspeed(trainer, num_training_steps): """ Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration @@ -268,6 +381,14 @@ def init_deepspeed(trainer, num_training_steps): ds_config_file = args.deepspeed model = trainer.model + # 2D Parallel + if len(args.pipeline): + n_gpus = torch.distributed.get_world_size() + mpu = MPU(n_gpus) + mpu.initialize_model_parallel() + else: + mpu = None + with io.open(ds_config_file, "r", encoding="utf-8") as f: config = json.load(f) @@ -398,6 +519,7 @@ def init_deepspeed(trainer, num_training_steps): model=model, model_parameters=model_parameters, config_params=config, + mpu = mpu, ) return model, optimizer, lr_scheduler diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6ef2825936ddc1..7a2b53773543b9 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1786,6 +1786,7 @@ def forward(self, hidden_states): return forward_fn(*input_tensors) + def recursive_to(device, item): """ Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list, @@ -1811,9 +1812,10 @@ def recursive_to(device, item): else: return item -#tnone = torch.tensor([float('nan')]*batch_size) + +# tnone = torch.tensor([float('nan')]*batch_size) def pipe_none_or_empty_to_torch(x, batch_size, device): - tnone = torch.tensor([-100]*batch_size).to(device) + tnone = torch.tensor([-100] * batch_size).to(device) tempty = torch.empty(0).to(device) if x is None: return tnone.to(device) @@ -1821,9 +1823,10 @@ def pipe_none_or_empty_to_torch(x, batch_size, device): return tempty.to(device) return x + def pipe_torch_to_none_or_empty(x, batch_size, device): - tnone = torch.tensor([-100]*batch_size).to(device) - #tempty = torch.empty(0).to(device) + tnone = torch.tensor([-100] * batch_size).to(device) + # tempty = torch.empty(0).to(device) # if torch.is_tensor(x): # print(x.shape, x) # else: @@ -1836,14 +1839,16 @@ def pipe_torch_to_none_or_empty(x, batch_size, device): return None return x + def pipe_encode_all(input, batch_size, device): input = list(input) for i, x in enumerate(input): input[i] = pipe_none_or_empty_to_torch(x, batch_size, device) return tuple(input) + def pipe_decode_all(input, batch_size, device): - input = list(input) + input = list(input) for i, x in enumerate(input): input[i] = pipe_torch_to_none_or_empty(x, batch_size, device) return tuple(input) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e72551e613d5fc..ea64940cb3fe93 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -23,6 +23,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch.distributed.pipeline.sync import Pipe from torch.nn import CrossEntropyLoss from ...activations import ACT2FN @@ -39,15 +40,20 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer, recursive_to +from ...modeling_utils import ( + PreTrainedModel, + find_pruneable_heads_and_indices, + pipe_decode_all, + pipe_encode_all, + pipe_none_or_empty_to_torch, + pipe_torch_to_none_or_empty, + prune_linear_layer, + recursive_to, +) from ...utils import logging from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_t5 import T5Config -from torch.distributed.pipeline.sync import Pipe - -from ...modeling_utils import pipe_none_or_empty_to_torch, pipe_torch_to_none_or_empty, pipe_encode_all, pipe_decode_all - logger = logging.get_logger(__name__) @@ -781,7 +787,22 @@ def _shift_right(self, input_ids): class T5StackPipeSegment(nn.Module): - def __init__(self, idx, n_layers, layer_module, is_decoder, layer_head_mask, encoder_layer_head_mask, output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add): + def __init__( + self, + idx, + n_layers, + layer_module, + is_decoder, + layer_head_mask, + encoder_layer_head_mask, + output_hidden_states, + use_cache, + output_attentions, + all_hidden_states_add, + present_key_value_states_add, + all_attentions_add, + all_cross_attentions_add, + ): super().__init__() self.batch_id = -1 self.idx = idx @@ -790,7 +811,7 @@ def __init__(self, idx, n_layers, layer_module, is_decoder, layer_head_mask, enc self.is_decoder = is_decoder self.layer_head_mask = layer_head_mask self.encoder_layer_head_mask = encoder_layer_head_mask - #self.past_key_value = past_key_value + # self.past_key_value = past_key_value self.output_hidden_states = output_hidden_states self.use_cache = use_cache self.output_attentions = output_attentions @@ -801,19 +822,31 @@ def __init__(self, idx, n_layers, layer_module, is_decoder, layer_head_mask, enc def forward(self, inputs): self.batch_id += 1 - #print(f"micro BS: {inputs[0].shape[0]}") + # print(f"micro BS: {inputs[0].shape[0]}") inputs = pipe_decode_all(inputs, inputs[0].shape[0], inputs[0].device) - hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = inputs + ( + hidden_states, + attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + ) = inputs idx = self.idx - #self.past_key_value = recursive_to(inputs[0].device, self.past_key_value) + # self.past_key_value = recursive_to(inputs[0].device, self.past_key_value) self.layer_head_mask = recursive_to(inputs[0].device, self.layer_head_mask) self.encoder_layer_head_mask = recursive_to(inputs[0].device, self.encoder_layer_head_mask) # crazy restore: XXX: fix hardcoded numbers - self.n_layers for number of blocks - 2+2 should always be there if past_key_values_p1 is not None and past_key_values_p2 is not None: - past_key_value = (past_key_values_p1.chunk(self.n_layers, 1)[self.idx].chunk(2, 1) + - past_key_values_p2.chunk(self.n_layers, 1)[self.idx].chunk(2, 1)) + past_key_value = past_key_values_p1.chunk(self.n_layers, 1)[self.idx].chunk( + 2, 1 + ) + past_key_values_p2.chunk(self.n_layers, 1)[self.idx].chunk(2, 1) else: past_key_value = None @@ -833,21 +866,22 @@ def forward(self, inputs): # all_attentions = () if len(all_attentions.shape) == 1 else all_attentions # all_cross_attentions = () if len(all_cross_attentions.shape) == 1 else all_cross_attentions - if self.output_hidden_states: self.all_hidden_states_add(hidden_states) - layer_outputs = self.layer_module(hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=self.layer_head_mask, - encoder_layer_head_mask=self.encoder_layer_head_mask, - past_key_value=past_key_value, - use_cache=self.use_cache, - output_attentions=self.output_attentions) + layer_outputs = self.layer_module( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=self.layer_head_mask, + encoder_layer_head_mask=self.encoder_layer_head_mask, + past_key_value=past_key_value, + use_cache=self.use_cache, + output_attentions=self.output_attentions, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] @@ -860,7 +894,7 @@ def forward(self, inputs): encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] # append next layer key value states if self.use_cache: - #print(idx, self.batch_id) + # print(idx, self.batch_id) # present_key_values_p1 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) # present_key_values_p2 = torch.tensor([[idx, self.batch_id], [idx, self.batch_id]]).to(hidden_states.device) present_key_values_p1 = torch.cat(present_key_value_state[0:2], 1) @@ -872,7 +906,7 @@ def forward(self, inputs): if self.is_decoder: self.all_cross_attentions_add(layer_outputs[5]) - #tnone = torch.tensor([float('nan')]*out_shape) + # tnone = torch.tensor([float('nan')]*out_shape) # tnone = torch.tensor([-100]*out_shape).to(hidden_states.device) # position_bias = tnone if position_bias is None else position_bias # encoder_hidden_states = tnone if encoder_hidden_states is None else encoder_hidden_states @@ -883,10 +917,22 @@ def forward(self, inputs): # all_attentions = tnone if all_attentions is None else all_attentions # all_cross_attentions = tnone if all_cross_attentions is None else all_cross_attentions - outputs = (hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias) + outputs = ( + hidden_states, + attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + ) outputs = pipe_encode_all(outputs, hidden_states.shape[0], hidden_states.device) return outputs + class T5Stack(T5PreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) @@ -898,7 +944,7 @@ def __init__(self, config, embed_tokens=None): [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] ) - #self.is_pipeline = False + # self.is_pipeline = False self.is_pipeline = True self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -939,7 +985,7 @@ def forward( return_dict=None, ): - #print(f"mini BS: {input_ids.shape[0]}") + # print(f"mini BS: {input_ids.shape[0]}") # Model parallel if self.model_parallel: @@ -989,7 +1035,7 @@ def forward( ) # initialize past_key_values with `None` if past does not exist - #if past_key_values is None: + # if past_key_values is None: # past_key_values = [None] * len(self.block) # ourselves in which case we just need to make it broadcastable to all heads. @@ -1015,8 +1061,8 @@ def forward( if self.pipeline_is_enabled: # handle batches (usually last) that are shorter than pipeline_chunks if batch_size < self.pipeline_chunks: - # # XXX: Is it always last? so that we don't override user's chunks setting unless it's the last batch - self.pipeline_chunks = 1 + # # XXX: Is it always last? so that we don't override user's chunks setting unless it's the last batch + self.pipeline_chunks = 1 else: # non-pipeline run is the same as chunks=1 batch size-wise self.pipeline_chunks = 1 @@ -1024,7 +1070,7 @@ def forward( # PP n_layers = len(self.block) - #n_chunks = 4 + # n_chunks = 4 def all_hidden_states_add(x): nonlocal all_hidden_states @@ -1033,41 +1079,60 @@ def all_hidden_states_add(x): # handle batches (usually last) that are can't be equally divided by pipeline_chunks present_key_value_states = [[0 for x in range(self.pipeline_chunks)] for y in range(n_layers)] + def present_key_value_states_add(x, block_id, micro_batch_id): nonlocal present_key_value_states - #present_key_value_states += (x,) + # present_key_value_states += (x,) present_key_value_states[block_id][micro_batch_id] = x - #present_key_value_states += (x,) - #print(x.shape for x in present_key_value_states) - #print(present_key_value_states) + # present_key_value_states += (x,) + # print(x.shape for x in present_key_value_states) + # print(present_key_value_states) def all_attentions_add(x): nonlocal all_attentions all_attentions += (x,) + def all_cross_attentions_add(x): nonlocal all_cross_attentions all_cross_attentions += (x,) # crazy flattening of 2 level tuples so that the batch dimension is first to be spliced upon and then restored on the other side if past_key_values is not None: - x1 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [0,1]) - #for i in x1: print(i.shape) - x2 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [2,3]) - #for i in x2: print(i.shape) + x1 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [0, 1]) + # for i in x1: print(i.shape) + x2 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [2, 3]) + # for i in x2: print(i.shape) past_key_values_p1 = torch.cat(x1, 1) past_key_values_p2 = torch.cat(x2, 1) - #input = torch.cat(tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in range(len(past_key_values[i]))), 1) + # input = torch.cat(tuple(past_key_values[i][j] for i in range(len(past_key_values)) for j in range(len(past_key_values[i]))), 1) else: past_key_values_p1 = None past_key_values_p2 = None # batch_size=2, blocks=self.n_layers, fixed=2 (2+2 keys) - present_key_values_p1 = torch.empty(batch_size, n_layers*2).to(0) - present_key_values_p2 = torch.empty(batch_size, n_layers*2).to(0) + present_key_values_p1 = torch.empty(batch_size, n_layers * 2).to(0) + present_key_values_p2 = torch.empty(batch_size, n_layers * 2).to(0) # rewrite the model after pre-trained weights were loaded - layers = [T5StackPipeSegment(idx, n_layers, layer_module, self.is_decoder, head_mask[idx], encoder_head_mask[idx], output_hidden_states, use_cache, output_attentions, all_hidden_states_add, present_key_value_states_add, all_attentions_add, all_cross_attentions_add) for idx, layer_module in enumerate(self.block)] - #block_sequential = nn.Sequential(*layers) + layers = [ + T5StackPipeSegment( + idx, + n_layers, + layer_module, + self.is_decoder, + head_mask[idx], + encoder_head_mask[idx], + output_hidden_states, + use_cache, + output_attentions, + all_hidden_states_add, + present_key_value_states_add, + all_attentions_add, + all_cross_attentions_add, + ) + for idx, layer_module in enumerate(self.block) + ] + # block_sequential = nn.Sequential(*layers) # for now don't enable the pipe if self.pipeline_is_enabled: @@ -1075,7 +1140,7 @@ def all_cross_attentions_add(x): # print("using partitioning: ", dict(zip(devices, layer_splits))) for device_id, layer_partition in self.device_map.items(): for layer_id in layer_partition: - #print(f"{layer_id} => {device_id}") + # print(f"{layer_id} => {device_id}") layers[layer_id].to(device_id) block_sequential = nn.Sequential(*layers) @@ -1083,8 +1148,19 @@ def all_cross_attentions_add(x): else: block_sequential = nn.Sequential(*layers) - inputs = (hidden_states, extended_attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias) - #, all_hidden_states, present_key_value_states, all_attentions, all_cross_attentions) + inputs = ( + hidden_states, + extended_attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + ) + # , all_hidden_states, present_key_value_states, all_attentions, all_cross_attentions) inputs = pipe_encode_all(inputs, batch_size, input_ids.device) if self.pipeline_is_enabled: @@ -1095,7 +1171,18 @@ def all_cross_attentions_add(x): outputs = block_sequential(inputs) outputs = pipe_decode_all(outputs, batch_size, input_ids.device) - hidden_states, attention_mask, position_bias, past_key_values_p1, past_key_values_p2, present_key_values_p1, present_key_values_p2, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias = outputs + ( + hidden_states, + attention_mask, + position_bias, + past_key_values_p1, + past_key_values_p2, + present_key_values_p1, + present_key_values_p2, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + ) = outputs hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1107,9 +1194,9 @@ def all_cross_attentions_add(x): # #present_key_values_p1 = present_key_values_p1.chunk(n_layers, 1).chunk(2, 1) # #present_key_values_p2 - #if self.pipeline_is_enabled and present_key_value_states is not None and present_key_value_states[0][0] != 0: + # if self.pipeline_is_enabled and present_key_value_states is not None and present_key_value_states[0][0] != 0: if present_key_value_states is not None and present_key_value_states[0][0] != 0: - #print() + # print() # reconstruct the flattened tensor to tuple of tuples of tensors new_x = () # deal with unpredictable potential last short batch @@ -1540,6 +1627,7 @@ def pipeline_enable(self, chunks, device_map): # XXX: should be a separate function import torch + n_gpus = torch.cuda.device_count() if n_gpus < 2: raise ValueError("Need at least 2 gpus to use the pipeline") @@ -1547,16 +1635,18 @@ def pipeline_enable(self, chunks, device_map): if device_map is not None: logger.info(f"using user-provided device_map") else: + def make_device_map(n_gpus, n_layers): print(f"making default device map: n_gpus={n_gpus}, n_layers={n_layers}") devices = list(range(n_gpus)) layer_ids = list(range(n_layers)) # XXX: later will have a map - for now just roughly split # XXX: probably should balance more so that the 0th gpu has the least number of layers, rather than the last one, because 0th gpu is already very busy - layer_splits = [layer_ids[i*n_layers // n_gpus: (i+1)*n_layers // n_gpus] for i in range(n_gpus)] + layer_splits = [ + layer_ids[i * n_layers // n_gpus : (i + 1) * n_layers // n_gpus] for i in range(n_gpus) + ] return dict(zip(devices, layer_splits)) - # XXX: for now assume encode/decoder symmetry - later fix to build each one separately n_layers = len(self.encoder.block) device_map = make_device_map(n_gpus, n_layers) @@ -1572,23 +1662,32 @@ def make_device_map(n_gpus, n_layers): # XXX for now hardcoded the RPC setup here - but it should happen in the trainer instead import os + import torch from torch.distributed import rpc - os.environ.update({"MASTER_ADDR": "localhost"}) - os.environ.update({"MASTER_PORT": "10638"}) - rpc.init_rpc( - "worker", - rank=0, - world_size=1, - ) + + # dynamically check if rpc has been initialized already - i.e in case we have deepspeed as the launcher of 2D + try: + # will succeed if rpc has started + torch.distributed.get_world_size() + except: + os.environ.update({"MASTER_ADDR": "localhost"}) + os.environ.update({"MASTER_PORT": "10638"}) + rpc.init_rpc( + "worker", + rank=0, + world_size=1, + ) num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 2 device = torch.device("cuda") def pipeline_finalize(self): # XXX: should reset the max counter # reset_peak_stats() - import torch import json + + import torch + n_gpus = torch.cuda.device_count() mem_map = {} for id in range(n_gpus): @@ -1603,7 +1702,6 @@ def pipeline_finalize(self): # XXX: would be great to add gpu utilization stats as well - @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): self.device_map = ( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d50b4ae9abed3e..83fdf7a593bf14 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -253,15 +253,15 @@ def __init__( self.is_model_parallel = True chunks_str, *device_map_str = self.args.pipeline.split() - chunks = int(chunks_str.split('=')[1]) + chunks = int(chunks_str.split("=")[1]) device_map = None if len(device_map_str): - device_map_range = device_map_str[0].split('=')[1] - device_map_range_str = device_map_range.split(',') + device_map_range = device_map_str[0].split("=")[1] + device_map_range_str = device_map_range.split(",") device_map = {} for x in device_map_range_str: - device_id, layers = x.split(':') - device_map[int(device_id)] = list(range(*map(int, layers.split('-')))) + device_id, layers = x.split(":") + device_map[int(device_id)] = list(range(*map(int, layers.split("-")))) model.pipeline_enable(chunks=chunks, device_map=device_map) @@ -281,7 +281,6 @@ def __init__( if len(self.args.pipeline): model = model.to(args.device) - # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model self.model = model diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c64552ec3fc46b..458ae2e013348a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -354,7 +354,12 @@ class TrainingArguments: ) debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"}) - pipeline: str = field(default="", metadata={"help": "Whether to enable Pipeline Parallelism and the value is pipeline params: 'chunks=5; device_map=0:1-10,1:11-20'"}) + pipeline: str = field( + default="", + metadata={ + "help": "Whether to enable Pipeline Parallelism and the value is pipeline params: 'chunks=5; device_map=0:1-10,1:11-20'" + }, + ) dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} From 361f68caf686ef05b088f3410dd13b5a014b40b8 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 28 Jan 2021 19:28:25 -0800 Subject: [PATCH 08/11] mpu work --- src/transformers/integrations.py | 41 ++++++++++++++++------- src/transformers/models/t5/modeling_t5.py | 31 +++++++++++++---- src/transformers/trainer.py | 20 +++++++++-- 3 files changed, 71 insertions(+), 21 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index a27ee5280a8adc..53252299fab47b 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -258,6 +258,7 @@ def rewrite_logs(d): # Model parallel group that the current rank belongs to. _MODEL_PARALLEL_GROUP = None +_MODEL_PARALLEL_GROUP_DEVICE_IDS = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None @@ -271,18 +272,29 @@ def initialize_model_parallel(self): Arguments: model_parallel_size: number of GPUs used to parallelize model. + **Important**: not the total number of gpus! Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will - create 4 model parallel groups and 2 data parallel grous as: + create 4 model parallel groups and 2 data parallel groups as: 4 model parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 data parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. + + Let's say we have a total of 4 GPUs denoted by g0 ... g3 and we + use 2 GPUs to parallelize the model. The present function will + create 2 model parallel groups and 2 data parallel groups as: + 2 model parallel groups: + [g0, g1], [g2, g3] + 2 data parallel groups: + [g0, g2], [g1, g3] + """ model_parallel_size_ = self.n_gpus @@ -300,6 +312,8 @@ def ensure_divisibility(numerator, denominator): ensure_divisibility(world_size, model_parallel_size) rank = torch.distributed.get_rank() + #print(f"MP size: {model_parallel_size}") + # Build the data parallel groups. global _DATA_PARALLEL_GROUP assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" @@ -307,16 +321,20 @@ def ensure_divisibility(numerator, denominator): ranks = range(i, world_size, model_parallel_size) group = torch.distributed.new_group(ranks) if i == (rank % model_parallel_size): + #print(f"DP ranks: {list(ranks)}") _DATA_PARALLEL_GROUP = group # Build the model parallel groups. global _MODEL_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP_DEVICE_IDS assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" for i in range(world_size // model_parallel_size): ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) group = torch.distributed.new_group(ranks) if i == (rank // model_parallel_size): + #print(f"MP ranks: {list(ranks)}") _MODEL_PARALLEL_GROUP = group + _MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks) def model_parallel_is_initialized(self): """Check if model and data parallel groups are initialized.""" @@ -324,6 +342,11 @@ def model_parallel_is_initialized(self): return False return True + def get_model_parallel_group_device_ids(self): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" + return _MODEL_PARALLEL_GROUP_DEVICE_IDS + def get_model_parallel_group(self): """Get the model parallel group the caller rank belongs to.""" assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" @@ -343,7 +366,7 @@ def get_model_parallel_rank(self): return torch.distributed.get_rank(group=self.get_model_parallel_group()) def get_model_parallel_src_rank(self): - """Calculate the global rank corresponding to a local rank zeor + """Calculate the global rank corresponding to a local rank zero in the model parallel group.""" global_rank = torch.distributed.get_rank() local_world_size = get_model_parallel_world_size() @@ -360,12 +383,14 @@ def get_data_parallel_rank(self): def destroy_model_parallel(self): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP - _MODEL_PARALLEL_GROUP = None + global _MODEL_PARALLEL_GROUP_DEVICE_IDS global _DATA_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + _MODEL_PARALLEL_GROUP_DEVICE_IDS = None _DATA_PARALLEL_GROUP = None -def init_deepspeed(trainer, num_training_steps): +def init_deepspeed(trainer, num_training_steps, mpu): """ Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration @@ -381,14 +406,6 @@ def init_deepspeed(trainer, num_training_steps): ds_config_file = args.deepspeed model = trainer.model - # 2D Parallel - if len(args.pipeline): - n_gpus = torch.distributed.get_world_size() - mpu = MPU(n_gpus) - mpu.initialize_model_parallel() - else: - mpu = None - with io.open(ds_config_file, "r", encoding="utf-8") as f: config = json.load(f) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index ea64940cb3fe93..ae1a8cf4c2ca96 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1622,13 +1622,20 @@ def __init__(self, config): self.pipeline_is_enabled = False - def pipeline_enable(self, chunks, device_map): + def pipeline_enable(self, chunks, device_map, mpu=None): logger.info(f"enabling pipeline with chunks={chunks}") # XXX: should be a separate function import torch - n_gpus = torch.cuda.device_count() + if mpu is not None: + this_proc_device_ids = mpu.get_model_parallel_group_device_ids() + else: + this_proc_device_ids = list(range(torch.cuda.device_count())) + + logger.info(f"process { torch.distributed.get_rank() } uses MP/PP device ids: {this_proc_device_ids}") + + n_gpus = len(this_proc_device_ids) if n_gpus < 2: raise ValueError("Need at least 2 gpus to use the pipeline") @@ -1651,14 +1658,23 @@ def make_device_map(n_gpus, n_layers): n_layers = len(self.encoder.block) device_map = make_device_map(n_gpus, n_layers) - self.device_map = device_map + # 2D parallel - i.e. deepspeed + pp + if mpu is not None: + # we need to assign the correct set of IDs for this process - that we get from MPU + remapped_device_map = {} + for i, id in enumerate(device_map.keys()): + remapped_device_map[this_proc_device_ids[i]] = device_map[id] + self.device_map = remapped_device_map + else: + self.device_map = device_map + self.pipeline_is_enabled = True - logger.info(f"using pipeline partitioning: {device_map}") + logger.info(f"using pipeline partitioning: {self.device_map}") # XXX: validate chunks is a good arg - self.encoder.pipeline_params(chunks=chunks, device_map=device_map) - self.decoder.pipeline_params(chunks=chunks, device_map=device_map) + self.encoder.pipeline_params(chunks=chunks, device_map=self.device_map) + self.decoder.pipeline_params(chunks=chunks, device_map=self.device_map) # XXX for now hardcoded the RPC setup here - but it should happen in the trainer instead import os @@ -1667,8 +1683,9 @@ def make_device_map(n_gpus, n_layers): from torch.distributed import rpc # dynamically check if rpc has been initialized already - i.e in case we have deepspeed as the launcher of 2D + # XXX: this needs to be parameterized/cleaned up try: - # will succeed if rpc has started + # will succeed if rpc has started (deepspeed launcher) torch.distributed.get_world_size() except: os.environ.update({"MASTER_ADDR": "localhost"}) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 83fdf7a593bf14..778545f5cfed42 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -246,6 +246,22 @@ def __init__( else: self.is_model_parallel = False + + + # 2D Parallel + if self.args.deepspeed and len(self.args.pipeline): + from .integrations import MPU + #n_gpus = torch.distributed.get_world_size() + # XXX: hardcoded for 2 gpus for PP/MP - needs to be configurable + #n_gpus_per_mp = n_gpus/2 + # at the moment experimenting with just 4 gpus - hence 2 gpus for MP|PP, 2 for DP + n_gpus_per_mp = 2 + self.mpu = MPU(n_gpus_per_mp) + self.mpu.initialize_model_parallel() + else: + self.mpu = None + + # XXX: for now hack over naive MP to have the same behavior if len(self.args.pipeline): # using range() syntax for upper boundary (i.e. not inclusive) @@ -263,7 +279,7 @@ def __init__( device_id, layers = x.split(":") device_map[int(device_id)] = list(range(*map(int, layers.split("-")))) - model.pipeline_enable(chunks=chunks, device_map=device_map) + model.pipeline_enable(chunks=chunks, device_map=device_map, mpu=self.mpu) default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator @@ -737,7 +753,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D num_update_steps_per_epoch = max_steps if self.args.deepspeed: - model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) + model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps, mpu=self.mpu) self.model = model.module self.model_wrapped = model # will get further wrapped in DDP self.deepspeed = model # DeepSpeedEngine object From 1c197b6e12dadb019903c1ad1435e66161275335 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 31 Jan 2021 10:20:55 -0800 Subject: [PATCH 09/11] wip --- examples/seq2seq/finetune_trainer.py | 4 +- src/transformers/integrations.py | 26 +++++--- src/transformers/models/t5/modeling_t5.py | 78 ++++++++++++++++++----- src/transformers/trainer.py | 69 ++++++++++++-------- src/transformers/training_args.py | 2 +- 5 files changed, 124 insertions(+), 55 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 73123063d07d46..31204f68fa27db 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -176,10 +176,10 @@ def main(): training_args.fp16, ) # Set the verbosity to info of the Transformers logger (on main process only): + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() if is_main_process(training_args.local_rank): transformers.utils.logging.set_verbosity_info() - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() logger.info("Training/evaluation parameters %s", training_args) # Set seed diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 53252299fab47b..a4b187752fda5f 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -252,7 +252,7 @@ def rewrite_logs(d): return new_d -# adjusted from Megatron-LM/mpu/ + import torch @@ -261,12 +261,11 @@ def rewrite_logs(d): _MODEL_PARALLEL_GROUP_DEVICE_IDS = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_GROUP_DEVICE_IDS = None +# adjusted from Megatron-LM/mpu/ class MPU: - def __init__(self, n_gpus): - self.n_gpus = n_gpus - - def initialize_model_parallel(self): + def initialize_model_parallel(self, model_parallel_size_): """ Initialize model data parallel groups. @@ -297,8 +296,6 @@ def initialize_model_parallel(self): """ - model_parallel_size_ = self.n_gpus - def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) @@ -312,10 +309,13 @@ def ensure_divisibility(numerator, denominator): ensure_divisibility(world_size, model_parallel_size) rank = torch.distributed.get_rank() - #print(f"MP size: {model_parallel_size}") + print(f"MP size: {model_parallel_size}") + print(f"world_size: {world_size}") + print(f"rank: {rank}") # Build the data parallel groups. global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_DEVICE_IDS assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" for i in range(model_parallel_size): ranks = range(i, world_size, model_parallel_size) @@ -323,6 +323,7 @@ def ensure_divisibility(numerator, denominator): if i == (rank % model_parallel_size): #print(f"DP ranks: {list(ranks)}") _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks) # Build the model parallel groups. global _MODEL_PARALLEL_GROUP @@ -343,7 +344,7 @@ def model_parallel_is_initialized(self): return True def get_model_parallel_group_device_ids(self): - """Get the model parallel group the caller rank belongs to.""" + """Get the model parallel device ids of the group the caller rank belongs to.""" assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" return _MODEL_PARALLEL_GROUP_DEVICE_IDS @@ -352,6 +353,11 @@ def get_model_parallel_group(self): assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" return _MODEL_PARALLEL_GROUP + def get_data_parallel_group_device_ids(self): + """Get the data parallel device ids of the group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP_DEVICE_IDS + def get_data_parallel_group(self): """Get the data parallel group the caller rank belongs to.""" assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" @@ -385,9 +391,11 @@ def destroy_model_parallel(self): global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP_DEVICE_IDS global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_DEVICE_IDS _MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP_DEVICE_IDS = None _DATA_PARALLEL_GROUP = None + _DATA_PARALLEL_GROUP_DEVICE_IDS = None def init_deepspeed(trainer, num_training_steps, mpu): diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index ae1a8cf4c2ca96..a09e5b1e8fc658 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -253,6 +253,7 @@ def forward(self, hidden_states): # convert into float16 if necessary if self.weight.dtype == torch.float16: hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states @@ -1098,9 +1099,9 @@ def all_cross_attentions_add(x): # crazy flattening of 2 level tuples so that the batch dimension is first to be spliced upon and then restored on the other side if past_key_values is not None: - x1 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [0, 1]) + x1 = tuple(past_key_values[i][j].to(input_ids.device) for i in range(len(past_key_values)) for j in [0, 1]) # for i in x1: print(i.shape) - x2 = tuple(past_key_values[i][j].to(0) for i in range(len(past_key_values)) for j in [2, 3]) + x2 = tuple(past_key_values[i][j].to(input_ids.device) for i in range(len(past_key_values)) for j in [2, 3]) # for i in x2: print(i.shape) past_key_values_p1 = torch.cat(x1, 1) past_key_values_p2 = torch.cat(x2, 1) @@ -1110,8 +1111,8 @@ def all_cross_attentions_add(x): past_key_values_p2 = None # batch_size=2, blocks=self.n_layers, fixed=2 (2+2 keys) - present_key_values_p1 = torch.empty(batch_size, n_layers * 2).to(0) - present_key_values_p2 = torch.empty(batch_size, n_layers * 2).to(0) + present_key_values_p1 = torch.empty(batch_size, n_layers * 2).to(input_ids.device) + present_key_values_p2 = torch.empty(batch_size, n_layers * 2).to(input_ids.device) # rewrite the model after pre-trained weights were loaded layers = [ @@ -1207,7 +1208,7 @@ def all_cross_attentions_add(x): for block in present_key_value_states: new_y = () for j in (0, 1, 2, 3): - entries = tuple(block[i][j].to(0) for i in range(real_chunks)) + entries = tuple(block[i][j].to(input_ids.device) for i in range(real_chunks)) new_y += (torch.cat(entries, 0),) new_x += (new_y,) present_key_value_states = new_x @@ -1628,16 +1629,49 @@ def pipeline_enable(self, chunks, device_map, mpu=None): # XXX: should be a separate function import torch + + try: + # will succeed if rpc has started (deepspeed launcher) + dist_world_size = torch.distributed.get_world_size() + dist_rank = torch.distributed.get_rank() + except: + dist_world_size = 1 + dist_rank = 0 + + #dist_world_size=1 + if mpu is not None: + log_prefix = f"[p{dist_rank}]" + #logger.warn(f"{log_prefix} got MPU") + logger.warn(f"{log_prefix} DP group { mpu.get_data_parallel_group_device_ids() }") + this_proc_device_ids = mpu.get_model_parallel_group_device_ids() + #logger.warn(f"{log_prefix} MP group {this_proc_device_ids }") + + # XXX: automate this: + # I think we might be getting the right groups already in this_proc_device_ids + if dist_world_size == 4: + if dist_rank == 0: + this_proc_device_ids = [0, 1] + else: + this_proc_device_ids = [2, 3] + elif dist_world_size == 2: + if dist_rank == 0: + this_proc_device_ids = [0] + else: + this_proc_device_ids = [1] + else: + log_prefix = f"[p0]" this_proc_device_ids = list(range(torch.cuda.device_count())) - logger.info(f"process { torch.distributed.get_rank() } uses MP/PP device ids: {this_proc_device_ids}") + logger.warn(f"{log_prefix} MP group {this_proc_device_ids }") + #logger.warn(f"{log_prefix} uses MP/PP device ids: {this_proc_device_ids}") n_gpus = len(this_proc_device_ids) - if n_gpus < 2: - raise ValueError("Need at least 2 gpus to use the pipeline") + # XXX: restore this later + # if n_gpus < 2: + # raise ValueError("Need at least 2 gpus to use the pipeline") if device_map is not None: logger.info(f"using user-provided device_map") @@ -1661,6 +1695,13 @@ def make_device_map(n_gpus, n_layers): # 2D parallel - i.e. deepspeed + pp if mpu is not None: # we need to assign the correct set of IDs for this process - that we get from MPU + # in case of 2D the user describes the device map only for the first group of DP + # and we need to re-assign the ids for the rest of the groups, so say a user passes a device map: + # 0:0-7, 1:7-14 + # for process rank 0 it remains that, but for process rank 1 it should become: + # 2:0-7, 3:7-14 + # and so on. + # MPU gives us the correct local MP group (this_proc_device_ids) remapped_device_map = {} for i, id in enumerate(device_map.keys()): remapped_device_map[this_proc_device_ids[i]] = device_map[id] @@ -1669,7 +1710,7 @@ def make_device_map(n_gpus, n_layers): self.device_map = device_map self.pipeline_is_enabled = True - logger.info(f"using pipeline partitioning: {self.device_map}") + logger.warn(f"{log_prefix} uses PP partitioning: {self.device_map}") # XXX: validate chunks is a good arg @@ -1684,16 +1725,19 @@ def make_device_map(n_gpus, n_layers): # dynamically check if rpc has been initialized already - i.e in case we have deepspeed as the launcher of 2D # XXX: this needs to be parameterized/cleaned up - try: - # will succeed if rpc has started (deepspeed launcher) - torch.distributed.get_world_size() - except: + # try: + # # will succeed if rpc has started (deepspeed launcher) + # torch.distributed.get_world_size() + # except: + if 1: os.environ.update({"MASTER_ADDR": "localhost"}) - os.environ.update({"MASTER_PORT": "10638"}) + os.environ.update({"MASTER_PORT": "10639"}) rpc.init_rpc( - "worker", - rank=0, - world_size=1, + #"worker", + f"worker{dist_rank}", + #rank=0, + rank=dist_rank, + world_size=dist_world_size, ) num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 2 device = torch.device("cuda") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 778545f5cfed42..c03ac588210e2f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -247,39 +247,56 @@ def __init__( self.is_model_parallel = False - - # 2D Parallel - if self.args.deepspeed and len(self.args.pipeline): - from .integrations import MPU - #n_gpus = torch.distributed.get_world_size() - # XXX: hardcoded for 2 gpus for PP/MP - needs to be configurable - #n_gpus_per_mp = n_gpus/2 - # at the moment experimenting with just 4 gpus - hence 2 gpus for MP|PP, 2 for DP - n_gpus_per_mp = 2 - self.mpu = MPU(n_gpus_per_mp) - self.mpu.initialize_model_parallel() - else: - self.mpu = None - - + self.mpu = None # XXX: for now hack over naive MP to have the same behavior if len(self.args.pipeline): # using range() syntax for upper boundary (i.e. not inclusive) # --pipeline "chunks=5; device_map=0:0-10,1:10-20" self.is_model_parallel = True - chunks_str, *device_map_str = self.args.pipeline.split() - chunks = int(chunks_str.split("=")[1]) - device_map = None - if len(device_map_str): - device_map_range = device_map_str[0].split("=")[1] - device_map_range_str = device_map_range.split(",") + # arg parser + pp_args = {} + args = self.args.pipeline.split() + if len(args): + for x in args: + k,v = x.split("=") + pp_args[k] = v + + if "chunks" in pp_args: + pp_args["chunks"] = int(pp_args["chunks"]) + else: + # XXX: probably can try some smart dynamic default based on batch_size + pp_args["chunks"] = 2 + + if "device_map" in pp_args: + device_map_range_str = pp_args["device_map"].split(",") device_map = {} for x in device_map_range_str: device_id, layers = x.split(":") device_map[int(device_id)] = list(range(*map(int, layers.split("-")))) + pp_args["device_map"] = device_map + else: + pp_args["device_map"] = None - model.pipeline_enable(chunks=chunks, device_map=device_map, mpu=self.mpu) + if "n_gpus_per_mp" in pp_args: + pp_args["n_gpus_per_mp"] = int(pp_args["n_gpus_per_mp"]) + else: + # XXX: can try some smart dynamic default here based on total_n_gpus, + # if it's not 2D all gpus will be used + # if 2D half gpus should be a good default + pp_args["n_gpus_per_mp"] = 2 + + # 2D Parallel + if self.args.deepspeed: + from .integrations import MPU + self.mpu = MPU() + #n_gpus = torch.distributed.get_world_size() + # XXX: hardcoded for 2 gpus for PP/MP - needs to be configurable + #n_gpus_per_mp = n_gpus/2 + # at the moment experimenting with just 4 gpus - hence 2 gpus for MP|PP, 2 for DP + self.mpu.initialize_model_parallel(pp_args["n_gpus_per_mp"]) + + model.pipeline_enable(chunks=pp_args["chunks"], device_map=pp_args["device_map"], mpu=self.mpu) default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator @@ -983,6 +1000,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D # Clean the state at the end of training delattr(self, "_past") + if len(self.args.pipeline): + self.model.pipeline_finalize() + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: logger.info( @@ -1011,9 +1031,6 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D # add remaining tr_loss self._total_loss_scalar += tr_loss.item() - if len(self.args.pipeline): - model.pipeline_finalize() - return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): @@ -1650,7 +1667,7 @@ def prediction_loop( metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) if len(self.args.pipeline): - model.pipeline_finalize() + self.model.pipeline_finalize() return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 458ae2e013348a..42b24d6cca69fd 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -357,7 +357,7 @@ class TrainingArguments: pipeline: str = field( default="", metadata={ - "help": "Whether to enable Pipeline Parallelism and the value is pipeline params: 'chunks=5; device_map=0:1-10,1:11-20'" + "help": "Whether to enable Pipeline Parallelism and the value is pipeline params: 'chunks=5 device_map=0:1-10,1:10-20 n_gpus_per_pp=2" }, ) From 632bdd58bc1f017728cf17b2ad295892ebb9e00d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 31 Jan 2021 10:29:53 -0800 Subject: [PATCH 10/11] wip fix --- src/transformers/models/t5/modeling_t5.py | 9 +++++---- src/transformers/trainer.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 8c4860220bb785..6e43a95abd5c21 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1646,10 +1646,11 @@ def pipeline_enable(self, chunks, device_map, mpu=None): logger.warn(f"{log_prefix} DP group { mpu.get_data_parallel_group_device_ids() }") this_proc_device_ids = mpu.get_model_parallel_group_device_ids() - #logger.warn(f"{log_prefix} MP group {this_proc_device_ids }") + logger.warn(f"{log_prefix} PP group {this_proc_device_ids } [MPU]") # XXX: automate this: - # I think we might be getting the right groups already in this_proc_device_ids + # We must be getting the right groups already from get_model_parallel_group_device_ids() + # If we don't then deepspeed isn't getting the right groups - and it'd break if dist_world_size == 4: if dist_rank == 0: this_proc_device_ids = [0, 1] @@ -1665,7 +1666,7 @@ def pipeline_enable(self, chunks, device_map, mpu=None): log_prefix = f"[p0]" this_proc_device_ids = list(range(torch.cuda.device_count())) - logger.warn(f"{log_prefix} MP group {this_proc_device_ids }") + logger.warn(f"{log_prefix} PP group {this_proc_device_ids }") #logger.warn(f"{log_prefix} uses MP/PP device ids: {this_proc_device_ids}") n_gpus = len(this_proc_device_ids) @@ -1710,7 +1711,7 @@ def make_device_map(n_gpus, n_layers): self.device_map = device_map self.pipeline_is_enabled = True - logger.warn(f"{log_prefix} uses PP partitioning: {self.device_map}") + logger.warn(f"{log_prefix} PP partitions: {self.device_map}") # XXX: validate chunks is a good arg diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ae702e7de94058..25f25986e53124 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -268,9 +268,9 @@ def __init__( # arg parser pp_args = {} - args = self.args.pipeline.split() - if len(args): - for x in args: + pp_args_str = self.args.pipeline.split() + if len(pp_args_str): + for x in pp_args_str: k,v = x.split("=") pp_args[k] = v From 757c3a7ca5c2924de41fbe1b42d62504db2bc275 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 26 Feb 2021 18:59:38 -0800 Subject: [PATCH 11/11] move --- src/transformers/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b3006f34c06d09..d866ff7e71187a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -319,7 +319,9 @@ def __init__( ): self.place_model_on_device = False - + # XXX: this is probably wrong - as it won't fit on device normally + if len(self.args.pipeline): + model = model.to(args.device) self.mpu = None # XXX: for now hack over naive MP to have the same behavior @@ -392,9 +394,6 @@ def __init__( if self.is_model_parallel: self.args._n_gpu = 1 - if len(self.args.pipeline): - model = model.to(args.device) - # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model self.model = model