Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for issue 46 #186

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ model:
# will be used in NeMo version > 1.20.0
# keeping it for now
end_strings: ["<|endoftext|>", "<extra_id_1>"]
# whether the sampling params above are used when computing log probs
# (if False then log probs computations assume default temperature / top_k / top_p)
apply_to_logprobs:
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
loss: False
kl_penalty_actor: False
kl_penalty_ref: ${.kl_penalty_actor}


# length argument for autoregressive sampling
# max length means max amount of tokens to generate
Expand Down
50 changes: 39 additions & 11 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,21 @@ def fwd_output_and_loss_func(data_iterator, model):

parallel_logits = model(batch["tokens"], batch["position_ids"], batch["attention_mask"], labels=None,)

sampling_params = self._sampling_params if self._sampling_params["apply_to_logprobs"]["loss"] else None

def loss_func(parallel_logits):
mask = batch["mask"]
advantages = batch["advantages"]
prev_log_probs = batch["prev_log_probs"]
tokens = batch["tokens"]

curr_log_probs = from_parallel_logits_to_logprobs(vocab_parallel_logits=parallel_logits, target=tokens)
prompt_lengths = torch.argmax(mask, dim=1) + 1
curr_log_probs = from_parallel_logits_to_logprobs(
vocab_parallel_logits=parallel_logits,
target=tokens,
sampling_params=sampling_params,
prompt_lengths=prompt_lengths,
)

scaled_entropy = torch.tensor(0.0, dtype=parallel_logits.dtype, device=parallel_logits.device)
if self.entropy_bonus > 0:
Expand Down Expand Up @@ -208,17 +216,22 @@ def finish_training(self):
"""

# inference calls
def get_logprob_output_only_func(self, inference_only=True):
def get_logprob_output_only_func(self, inference_only=True, apply_sampling_params=False):
fwd_output_only_func = self.get_forward_output_only_func()

def log_prob_output_only_func(dataloader_iter, model):
batch = next(dataloader_iter)

output_tensor, _ = fwd_output_only_func(iter([batch,]), model)
# The last element of `batch` contains prompt lengths which are not needed here.
output_tensor, _ = fwd_output_only_func(iter([batch[:-1],]), model)
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
sampling_params = self._sampling_params if apply_sampling_params else None

def id_func(output_tensor, non_loss_data=True):
logprobs = from_parallel_logits_to_logprobs(
vocab_parallel_logits=output_tensor, target=batch[0], inference_only=inference_only
vocab_parallel_logits=output_tensor,
target=batch[0],
inference_only=inference_only,
sampling_params=sampling_params,
prompt_lengths=batch[-1],
)
return logprobs

Expand All @@ -227,18 +240,24 @@ def id_func(output_tensor, non_loss_data=True):
return log_prob_output_only_func

@torch.no_grad()
def get_inference_log_probs(self, response_tokens, forward_micro_batch_size):
def get_inference_log_probs(
self, prompt_lengths, response_tokens, forward_micro_batch_size, apply_sampling_params=False
):
set_sync_funcs(self, forward_only=True)

mbs, seq_length = response_tokens.size()
num_microbatches = divide(mbs, forward_micro_batch_size)
attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(response_tokens)

batch_iter = get_iterator_k_split([response_tokens, attention_mask, position_ids], num_microbatches)
batch_iter = get_iterator_k_split(
[response_tokens, attention_mask, position_ids, prompt_lengths], num_microbatches
)

fwd_bwd_function = get_forward_backward_func()
logprobs_list = fwd_bwd_function(
forward_step_func=self.get_logprob_output_only_func(inference_only=True),
forward_step_func=self.get_logprob_output_only_func(
inference_only=True, apply_sampling_params=apply_sampling_params
),
data_iterator=batch_iter,
model=self.model,
num_microbatches=num_microbatches,
Expand Down Expand Up @@ -298,7 +317,10 @@ def infer(self, inference_batch):

# TODO(geshen): get nemo generate to return the unaltered log probs
log_probs = self.get_inference_log_probs(
response_tokens, forward_micro_batch_size=self.forward_micro_batch_size
prompt_lengths,
response_tokens,
forward_micro_batch_size=self.forward_micro_batch_size,
apply_sampling_params=self._sampling_params["apply_to_logprobs"]["kl_penalty_actor"],
)

rollout_batch = {
Expand All @@ -319,14 +341,20 @@ def get_init_policy_logprobs(self, rollout_batches):
# With adapters disabled (meaning using the init policy), calculate init_log_probs
for rollout_batch in rollout_batches:
init_log_prob = self.get_inference_log_probs(
rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size
rollout_batch["prompt_lengths"].cuda(),
rollout_batch["response_tokens"].cuda(),
forward_micro_batch_size=self.forward_micro_batch_size,
apply_sampling_params=self._sampling_params["apply_to_logprobs"]["kl_penalty_ref"],
)
init_log_probs.append(init_log_prob)
else:
with cpu_weight_swap(self, self.init_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2):
for rollout_batch in rollout_batches:
init_log_prob = self.get_inference_log_probs(
rollout_batch["response_tokens"].cuda(), forward_micro_batch_size=self.forward_micro_batch_size
rollout_batch["prompt_lengths"].cuda(),
rollout_batch["response_tokens"].cuda(),
forward_micro_batch_size=self.forward_micro_batch_size,
apply_sampling_params=self._sampling_params["apply_to_logprobs"]["kl_penalty_ref"],
)
init_log_probs.append(init_log_prob)

Expand Down
141 changes: 135 additions & 6 deletions nemo_aligner/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from typing import Dict, Optional, Union

import torch
import torch.distributed
import torch.nn.functional as F
from megatron.core import parallel_state, tensor_parallel

from nemo.collections.nlp.modules.common.text_generation_utils import get_model_parallel_src_rank
Expand Down Expand Up @@ -191,12 +193,124 @@ def _compute_distributed_log_softmax(vocab_parallel_logits):
return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype)


def _compute_distributed_top_k_p(logits, k, p, rank, world_size):
"""Expects a size B x S x V//TP tensor, computes a distributed top_k and top_p - setting all other logits to -Inf.
return shape B x S x V//TP where only global top-k-p values (across the V dimension) are not filtered (i.e. set to -Inf)
"""
src_rank = get_model_parallel_src_rank()
get_vocab_range = tensor_parallel.utils.VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = logits.size(-1)
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)

local_topk_values, local_topk_indices = torch.topk(logits, k=k, dim=-1) # [B,S,k]
local_topk_indices += vocab_start_index
# Prepare containers for the gathered top-k values and indices from all GPUs
if rank == src_rank:
gathered_values = [torch.zeros_like(local_topk_values) for _ in range(world_size)]
gathered_indices = [torch.zeros_like(local_topk_indices) for _ in range(world_size)]
else:
gathered_values = None
gathered_indices = None

# Gather top-k values and indices from all GPUs
torch.distributed.gather(local_topk_values, gathered_values, dst=src_rank)
torch.distributed.gather(local_topk_indices, gathered_indices, dst=src_rank)

if rank == src_rank: # only rank 0 will do the computation and scatter the outcome
# Concatenate the gathered values and indices along a new dimension
all_values = torch.cat(gathered_values, dim=-1) # [B,S,world_size*k]
all_indices = torch.cat(gathered_indices, dim=-1)

# Perform a global top-k operation to find the global top-k values and indices
global_topk_values, topk_indices = torch.topk(all_values, k=k, dim=-1) # [B,S,k]
global_topk_indices = torch.gather(all_indices, -1, topk_indices)

# perform top_p
if 0.0 < p < 1.0:
# perform top_p and save in global_top_k_p_indices and spread to all ranks
sorted_logits, sorted_indices = torch.sort(global_topk_values, descending=True, dim=-1) # [B,S,k] for each
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # [B,S,k]
global_top_k_p_indices = torch.gather(global_topk_indices, -1, sorted_indices)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove = sorted_indices_to_remove.roll(shifts=1, dims=-1)
sorted_indices_to_remove[..., 0] = False
global_top_k_p_indices = torch.where(
sorted_indices_to_remove, torch.tensor(-1, dtype=torch.long), global_top_k_p_indices
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
) # the top_p are kept, the rest are set to -1 (so will be filtered in the following lines)
else:
global_top_k_p_indices = torch.empty_like(local_topk_indices) # [B,S,k]

torch.distributed.broadcast(global_top_k_p_indices, src=src_rank)

# generate a mask according to the rank
# filter indices within the current rank's segment
mask_top_k_p_indices = (global_top_k_p_indices >= vocab_start_index) & (
global_top_k_p_indices < vocab_end_index
) # [B,S,k] where only indices that are within the scope of current rank are True

# adjust indices to local index space
local_top_k_p_indices = global_top_k_p_indices
local_top_k_p_indices -= vocab_start_index
local_top_k_p_indices = torch.where(
mask_top_k_p_indices, local_top_k_p_indices, torch.tensor(-1, dtype=torch.long)
) # [B,S,k] - the global top_k_p indices are localized to the rank, or -1 if they are not in this rank's segment

valid_logits = torch.zeros_like(logits, dtype=torch.bool)
batch_indices, sequence_indices = torch.where(
mask_top_k_p_indices.any(dim=-1)
) # collect all b,s indices where there is a valid index in [b,s,:]
local_vocab_indices = local_top_k_p_indices[
batch_indices, sequence_indices
] # collect the v indices per each [b,s]. should be up to k valid indices (not valid is -1)
valid_local_indx_mask = local_vocab_indices != -1
valid_local_batch_idx = batch_indices.unsqueeze(1).expand_as(valid_local_indx_mask)[valid_local_indx_mask]
valid_local_sequence_idx = sequence_indices.unsqueeze(1).expand_as(valid_local_indx_mask)[valid_local_indx_mask]
valid_local_vocab_idx = local_vocab_indices[valid_local_indx_mask]
valid_logits[valid_local_batch_idx, valid_local_sequence_idx, valid_local_vocab_idx] = True
logits[~valid_logits] = -torch.inf
# return updated_logits
return logits


def _distributed_apply_sampling_params(logits, context_lengths, sampling_params, rank, world_size):
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
# apply the sampling params to the logits - focusing only on the generated tokens.
if sampling_params.get("use_greedy", False):
return logits
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
if sampling_params.get("repetition_penalty", 1.0) != 1.0:
raise NotImplementedError("not supporting repetition penalty when applying sampling params to logprobs")

context_length = context_lengths.min().item()
resp_logits = logits[:, context_length - 1 :]
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
# divide by temp
if sampling_params["temperature"] != 1.0:
resp_logits /= sampling_params["temperature"]
top_k = sampling_params["top_k"]
top_p = sampling_params["top_p"]
if top_k > 0:
# Note : currently assuming that top_p is applied only if top_k>0.
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
resp_logits = _compute_distributed_top_k_p(resp_logits, top_k, top_p, rank, world_size)
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
elif 0.0 < top_p < 1.0:
raise NotImplementedError(
"Currently not supporting 0 < top_p < 1 with top_k=0 when applying sampling params to log probs"
)

return logits


class DistributedLogprob(torch.autograd.Function):
"""Function to get logprobs out and differentiate through it
"""

@staticmethod
def forward(ctx, vocab_parallel_logits, target, inference_only=False, higher_stability=False):
def forward(
ctx,
vocab_parallel_logits,
target,
inference_only=False,
higher_stability=False,
sampling_params=None,
prompt_lengths=None,
):
get_vocab_range = tensor_parallel.utils.VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size(-1)
rank = parallel_state.get_tensor_model_parallel_rank()
Expand All @@ -208,6 +322,14 @@ def forward(ctx, vocab_parallel_logits, target, inference_only=False, higher_sta
masked_target = target - vocab_start_index
masked_target[target_mask] = 0

# if sampling_params should be applied, apply them to the vocab_parallel_logits
if sampling_params is not None:
if prompt_lengths is None:
raise ValueError("prompt_lengths must be provided to apply sampling params to log ptobs")
vocab_parallel_logits = _distributed_apply_sampling_params(
vocab_parallel_logits, prompt_lengths, sampling_params, rank, world_size
)

# higher stability uses a more numerically stable distributed log_softmax instead of softmax
# however, it uses more VRAM because there is an unavoidable exp() OP on the entire logits tensor
# some models (like DPO) will get -inf in the resulting logprobs unless you set higher_stability=True
Expand Down Expand Up @@ -248,7 +370,7 @@ def backward(ctx, grad_output):
grad_input.mul_(grad_output.unsqueeze(dim=-1))

# if you add an argument to the forward method, then you must add a corresponding None here
return grad_input, None, None, None
return grad_input, None, None, None, None, None


def calculate_distributed_entropy(vocab_parallel_logits, mask=None):
Expand All @@ -259,16 +381,23 @@ def calculate_distributed_entropy(vocab_parallel_logits, mask=None):
return calculate_entropy(full_log_probs, mask)


def from_parallel_logits_to_logprobs(vocab_parallel_logits, target, inference_only=False, higher_stability=False):
def from_parallel_logits_to_logprobs(
vocab_parallel_logits,
target,
inference_only=False,
higher_stability=False,
sampling_params=None,
prompt_lengths=None,
):
"""get log probs out of a B x S x V//TP tensor
NOTE: this function shifts the target, which means you must give it the unmodified targets

Returns a B x S-1 tensor
"""
target = target.roll(shifts=-1, dims=-1)
return DistributedLogprob.apply(vocab_parallel_logits, target, inference_only, higher_stability)[
:, :-1
].contiguous()
return DistributedLogprob.apply(
vocab_parallel_logits, target, inference_only, higher_stability, sampling_params, prompt_lengths
guyknvda marked this conversation as resolved.
Show resolved Hide resolved
)[:, :-1].contiguous()


def pad_tensors_to_max_global_seq_len(list_of_tensors, pad_value, group, sequence_length_to_pad_to=None):
Expand Down
Loading