diff --git a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml index 044532d34..f154946ac 100644 --- a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml +++ b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml @@ -125,6 +125,13 @@ model: # will be used in NeMo version > 1.20.0 # keeping it for now end_strings: ["<|endoftext|>", ""] + # whether the sampling params above are used when computing log probs + # (if False then log probs computations assume default temperature / top_k / top_p) + apply_to_logprobs: + loss: False + kl_penalty_actor: False + kl_penalty_ref: ${.kl_penalty_actor} + # length argument for autoregressive sampling # max length means max amount of tokens to generate diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py index 1c250f233..d1b24657a 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py @@ -114,13 +114,21 @@ def fwd_output_and_loss_func(data_iterator, model): parallel_logits = model(batch["tokens"], batch["position_ids"], batch["attention_mask"], labels=None,) + sampling_params = self._sampling_params if self._sampling_params["apply_to_logprobs"]["loss"] else None + def loss_func(parallel_logits): mask = batch["mask"] advantages = batch["advantages"] prev_log_probs = batch["prev_log_probs"] tokens = batch["tokens"] - curr_log_probs = from_parallel_logits_to_logprobs(vocab_parallel_logits=parallel_logits, target=tokens) + prompt_lengths = torch.argmax(mask, dim=1) + 1 + curr_log_probs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=parallel_logits, + target=tokens, + sampling_params=sampling_params, + prompt_lengths=prompt_lengths, + ) scaled_entropy = torch.tensor(0.0, dtype=parallel_logits.dtype, device=parallel_logits.device) if self.entropy_bonus > 0: @@ -208,17 +216,22 @@ def finish_training(self): """ # inference calls - def get_logprob_output_only_func(self, inference_only=True): + def get_logprob_output_only_func(self, inference_only=True, apply_sampling_params=False): fwd_output_only_func = self.get_forward_output_only_func() def log_prob_output_only_func(dataloader_iter, model): batch = next(dataloader_iter) - - output_tensor, _ = fwd_output_only_func(iter([batch,]), model) + # The last element of `batch` contains prompt lengths which are not needed here. + output_tensor, _ = fwd_output_only_func(iter([batch[:-1],]), model) + sampling_params = self._sampling_params if apply_sampling_params else None def id_func(output_tensor, non_loss_data=True): logprobs = from_parallel_logits_to_logprobs( - vocab_parallel_logits=output_tensor, target=batch[0], inference_only=inference_only + vocab_parallel_logits=output_tensor, + target=batch[0], + inference_only=inference_only, + sampling_params=sampling_params, + prompt_lengths=batch[-1], ) return logprobs @@ -227,18 +240,24 @@ def id_func(output_tensor, non_loss_data=True): return log_prob_output_only_func @torch.no_grad() - def get_inference_log_probs(self, response_tokens, forward_micro_batch_size): + def get_inference_log_probs( + self, prompt_lengths, response_tokens, forward_micro_batch_size, apply_sampling_params=False + ): set_sync_funcs(self, forward_only=True) mbs, seq_length = response_tokens.size() num_microbatches = divide(mbs, forward_micro_batch_size) attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(response_tokens) - batch_iter = get_iterator_k_split([response_tokens, attention_mask, position_ids], num_microbatches) + batch_iter = get_iterator_k_split( + [response_tokens, attention_mask, position_ids, prompt_lengths], num_microbatches + ) fwd_bwd_function = get_forward_backward_func() logprobs_list = fwd_bwd_function( - forward_step_func=self.get_logprob_output_only_func(inference_only=True), + forward_step_func=self.get_logprob_output_only_func( + inference_only=True, apply_sampling_params=apply_sampling_params + ), data_iterator=batch_iter, model=self.model, num_microbatches=num_microbatches, @@ -298,7 +317,10 @@ def infer(self, inference_batch): # TODO(geshen): get nemo generate to return the unaltered log probs log_probs = self.get_inference_log_probs( - response_tokens, forward_micro_batch_size=self.forward_micro_batch_size + prompt_lengths, + response_tokens, + forward_micro_batch_size=self.forward_micro_batch_size, + apply_sampling_params=self._sampling_params["apply_to_logprobs"]["kl_penalty_actor"], ) rollout_batch = { @@ -319,14 +341,20 @@ def get_init_policy_logprobs(self, rollout_batches): # With adapters disabled (meaning using the init policy), calculate init_log_probs for rollout_batch in rollout_batches: init_log_prob = self.get_inference_log_probs( - rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size + rollout_batch["prompt_lengths"].cuda(), + rollout_batch["response_tokens"].cuda(), + forward_micro_batch_size=self.forward_micro_batch_size, + apply_sampling_params=self._sampling_params["apply_to_logprobs"]["kl_penalty_ref"], ) init_log_probs.append(init_log_prob) else: with cpu_weight_swap(self, self.init_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): for rollout_batch in rollout_batches: init_log_prob = self.get_inference_log_probs( - rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size + rollout_batch["prompt_lengths"].cuda(), + rollout_batch["response_tokens"].cuda(), + forward_micro_batch_size=self.forward_micro_batch_size, + apply_sampling_params=self._sampling_params["apply_to_logprobs"]["kl_penalty_ref"], ) init_log_probs.append(init_log_prob) diff --git a/nemo_aligner/utils/distributed.py b/nemo_aligner/utils/distributed.py index 78c860108..fba877ae7 100644 --- a/nemo_aligner/utils/distributed.py +++ b/nemo_aligner/utils/distributed.py @@ -22,6 +22,8 @@ from typing import Dict, Optional, Union import torch +import torch.distributed +import torch.nn.functional as F from megatron.core import parallel_state, tensor_parallel from nemo.collections.nlp.modules.common.text_generation_utils import get_model_parallel_src_rank @@ -191,12 +193,124 @@ def _compute_distributed_log_softmax(vocab_parallel_logits): return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) +def _compute_distributed_top_k_p(logits, k, p, rank, world_size): + """Expects a size B x S x V//TP tensor, computes a distributed top_k and top_p - setting all other logits to -Inf. + return shape B x S x V//TP where only global top-k-p values (across the V dimension) are not filtered (i.e. set to -Inf) + """ + src_rank = get_model_parallel_src_rank() + get_vocab_range = tensor_parallel.utils.VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = logits.size(-1) + vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + + local_topk_values, local_topk_indices = torch.topk(logits, k=k, dim=-1) # [B,S,k] + local_topk_indices += vocab_start_index + # Prepare containers for the gathered top-k values and indices from all GPUs + if rank == src_rank: + gathered_values = [torch.zeros_like(local_topk_values) for _ in range(world_size)] + gathered_indices = [torch.zeros_like(local_topk_indices) for _ in range(world_size)] + else: + gathered_values = None + gathered_indices = None + + # Gather top-k values and indices from all GPUs + torch.distributed.gather(local_topk_values, gathered_values, dst=src_rank) + torch.distributed.gather(local_topk_indices, gathered_indices, dst=src_rank) + + if rank == src_rank: # only rank 0 will do the computation and scatter the outcome + # Concatenate the gathered values and indices along a new dimension + all_values = torch.cat(gathered_values, dim=-1) # [B,S,world_size*k] + all_indices = torch.cat(gathered_indices, dim=-1) + + # Perform a global top-k operation to find the global top-k values and indices + global_topk_values, topk_indices = torch.topk(all_values, k=k, dim=-1) # [B,S,k] + global_topk_indices = torch.gather(all_indices, -1, topk_indices) + + # perform top_p + if 0.0 < p < 1.0: + # perform top_p and save in global_top_k_p_indices and spread to all ranks + sorted_logits, sorted_indices = torch.sort(global_topk_values, descending=True, dim=-1) # [B,S,k] for each + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # [B,S,k] + global_top_k_p_indices = torch.gather(global_topk_indices, -1, sorted_indices) + sorted_indices_to_remove = cumulative_probs > p + sorted_indices_to_remove = sorted_indices_to_remove.roll(shifts=1, dims=-1) + sorted_indices_to_remove[..., 0] = False + global_top_k_p_indices = torch.where( + sorted_indices_to_remove, torch.tensor(-1, dtype=torch.long), global_top_k_p_indices + ) # the top_p are kept, the rest are set to -1 (so will be filtered in the following lines) + else: + global_top_k_p_indices = torch.empty_like(local_topk_indices) # [B,S,k] + + torch.distributed.broadcast(global_top_k_p_indices, src=src_rank) + + # generate a mask according to the rank + # filter indices within the current rank's segment + mask_top_k_p_indices = (global_top_k_p_indices >= vocab_start_index) & ( + global_top_k_p_indices < vocab_end_index + ) # [B,S,k] where only indices that are within the scope of current rank are True + + # adjust indices to local index space + local_top_k_p_indices = global_top_k_p_indices + local_top_k_p_indices -= vocab_start_index + local_top_k_p_indices = torch.where( + mask_top_k_p_indices, local_top_k_p_indices, torch.tensor(-1, dtype=torch.long) + ) # [B,S,k] - the global top_k_p indices are localized to the rank, or -1 if they are not in this rank's segment + + valid_logits = torch.zeros_like(logits, dtype=torch.bool) + batch_indices, sequence_indices = torch.where( + mask_top_k_p_indices.any(dim=-1) + ) # collect all b,s indices where there is a valid index in [b,s,:] + local_vocab_indices = local_top_k_p_indices[ + batch_indices, sequence_indices + ] # collect the v indices per each [b,s]. should be up to k valid indices (not valid is -1) + valid_local_indx_mask = local_vocab_indices != -1 + valid_local_batch_idx = batch_indices.unsqueeze(1).expand_as(valid_local_indx_mask)[valid_local_indx_mask] + valid_local_sequence_idx = sequence_indices.unsqueeze(1).expand_as(valid_local_indx_mask)[valid_local_indx_mask] + valid_local_vocab_idx = local_vocab_indices[valid_local_indx_mask] + valid_logits[valid_local_batch_idx, valid_local_sequence_idx, valid_local_vocab_idx] = True + logits[~valid_logits] = -torch.inf + # return updated_logits + return logits + + +def _distributed_apply_sampling_params(logits, context_lengths, sampling_params, rank, world_size): + # apply the sampling params to the logits - focusing only on the generated tokens. + if sampling_params.get("use_greedy", False): + return logits + if sampling_params.get("repetition_penalty", 1.0) != 1.0: + raise NotImplementedError("not supporting repetition penalty when applying sampling params to logprobs") + + context_length = context_lengths.min().item() + resp_logits = logits[:, context_length - 1 :] + # divide by temp + if sampling_params["temperature"] != 1.0: + resp_logits /= sampling_params["temperature"] + top_k = sampling_params["top_k"] + top_p = sampling_params["top_p"] + if top_k > 0: + # Note : currently assuming that top_p is applied only if top_k>0. + resp_logits = _compute_distributed_top_k_p(resp_logits, top_k, top_p, rank, world_size) + elif 0.0 < top_p < 1.0: + raise NotImplementedError( + "Currently not supporting 0 < top_p < 1 with top_k=0 when applying sampling params to log probs" + ) + + return logits + + class DistributedLogprob(torch.autograd.Function): """Function to get logprobs out and differentiate through it """ @staticmethod - def forward(ctx, vocab_parallel_logits, target, inference_only=False, higher_stability=False): + def forward( + ctx, + vocab_parallel_logits, + target, + inference_only=False, + higher_stability=False, + sampling_params=None, + prompt_lengths=None, + ): get_vocab_range = tensor_parallel.utils.VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = vocab_parallel_logits.size(-1) rank = parallel_state.get_tensor_model_parallel_rank() @@ -208,6 +322,14 @@ def forward(ctx, vocab_parallel_logits, target, inference_only=False, higher_sta masked_target = target - vocab_start_index masked_target[target_mask] = 0 + # if sampling_params should be applied, apply them to the vocab_parallel_logits + if sampling_params is not None: + if prompt_lengths is None: + raise ValueError("prompt_lengths must be provided to apply sampling params to log ptobs") + vocab_parallel_logits = _distributed_apply_sampling_params( + vocab_parallel_logits, prompt_lengths, sampling_params, rank, world_size + ) + # higher stability uses a more numerically stable distributed log_softmax instead of softmax # however, it uses more VRAM because there is an unavoidable exp() OP on the entire logits tensor # some models (like DPO) will get -inf in the resulting logprobs unless you set higher_stability=True @@ -248,7 +370,7 @@ def backward(ctx, grad_output): grad_input.mul_(grad_output.unsqueeze(dim=-1)) # if you add an argument to the forward method, then you must add a corresponding None here - return grad_input, None, None, None + return grad_input, None, None, None, None, None def calculate_distributed_entropy(vocab_parallel_logits, mask=None): @@ -259,16 +381,23 @@ def calculate_distributed_entropy(vocab_parallel_logits, mask=None): return calculate_entropy(full_log_probs, mask) -def from_parallel_logits_to_logprobs(vocab_parallel_logits, target, inference_only=False, higher_stability=False): +def from_parallel_logits_to_logprobs( + vocab_parallel_logits, + target, + inference_only=False, + higher_stability=False, + sampling_params=None, + prompt_lengths=None, +): """get log probs out of a B x S x V//TP tensor NOTE: this function shifts the target, which means you must give it the unmodified targets Returns a B x S-1 tensor """ target = target.roll(shifts=-1, dims=-1) - return DistributedLogprob.apply(vocab_parallel_logits, target, inference_only, higher_stability)[ - :, :-1 - ].contiguous() + return DistributedLogprob.apply( + vocab_parallel_logits, target, inference_only, higher_stability, sampling_params, prompt_lengths + )[:, :-1].contiguous() def pad_tensors_to_max_global_seq_len(list_of_tensors, pad_value, group, sequence_length_to_pad_to=None):