Skip to content

Commit

Permalink
[Enhance]: Fix sequence parallel memory bottleneck in DPO & ORPO (#830)
Browse files Browse the repository at this point in the history
* [WIP]: Fix sequence parallel memory bottleneck in DPO

* loss mask before split

* refactor orpo
  • Loading branch information
RangiLyu authored Jul 19, 2024
1 parent b92481f commit ff226e1
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 122 deletions.
115 changes: 59 additions & 56 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,77 +62,66 @@ def _gather_masked_logits(self, logits, labels, mask):

def get_logps(
self,
all_logits, # bs, seqlen,vocab_size
all_ref_logits, # bs, seqlen,vocab_size
labels, # bs, seqlen
policy_logps, # bs, seqlen,vocab_size
ref_logps, # bs, seqlen,vocab_size
loss_mask, # bs, seqlen
):
labels = labels[:, 1:].clone()
all_logits = all_logits[:, :-1, :]
all_ref_logits = all_ref_logits[:, :-1, :]

labels[labels == -100] = 0
loss_mask = labels != 0
all_logps = self._gather_masked_logits(all_logits, labels,
loss_mask).sum(-1)
all_ref_logps = self._gather_masked_logits(all_ref_logits, labels,
loss_mask).sum(-1)
policy_logps = policy_logps[:, :-1].sum(-1)
ref_logps = ref_logps[:, :-1].sum(-1)
loss_mask = loss_mask[:, :-1]

if self.loss_type == 'ipo': # average_log_prob
all_logps = all_logps / loss_mask.sum(-1)
all_ref_logps = all_ref_logps / loss_mask.sum(-1)
policy_logps = policy_logps / loss_mask.sum(-1)
ref_logps = ref_logps / loss_mask.sum(-1)

policy_chosen_logps = all_logps[::2]
policy_rejected_logps = all_logps[1::2]
reference_chosen_logps = all_ref_logps[::2]
reference_rejected_logps = all_ref_logps[1::2]
policy_chosen_logps = policy_logps[::2]
policy_rejected_logps = policy_logps[1::2]
reference_chosen_logps = ref_logps[::2]
reference_rejected_logps = ref_logps[1::2]
return (policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps)

def get_var_len_atten_logps(self, all_logits, all_ref_logits, labels,
def get_var_len_atten_logps(self, policy_logps, ref_logps, loss_mask,
cu_seqlens, attention_mask):
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# unpack sequence
unpacked_logits = torch.split(all_logits, seqlens, dim=1)
unpacked_ref_logits = torch.split(all_ref_logits, seqlens, dim=1)
unpacked_labels = torch.split(labels, seqlens, dim=1)
unpacked_policy_logps = torch.split(policy_logps, seqlens, dim=1)
unpacked_ref_logps = torch.split(ref_logps, seqlens, dim=1)
unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1)
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
assert False in attention_mask
unpacked_logits = unpacked_logits[:-1]
unpacked_ref_logits = unpacked_ref_logits[:-1]
unpacked_labels = unpacked_labels[:-1]
assert len(unpacked_logits) % 2 == 0
unpacked_policy_logps = unpacked_policy_logps[:-1]
unpacked_ref_logps = unpacked_ref_logps[:-1]
unpacked_loss_mask = unpacked_loss_mask[:-1]
assert len(unpacked_policy_logps) % 2 == 0

def compute_logps(_logits, _labels):
_labels = _labels[:, 1:].clone()
_logits = _logits[:, :-1, :]
_labels[_labels == -100] = 0
loss_mask = _labels != 0
logps = self._gather_masked_logits(_logits, _labels, loss_mask)
logps = logps.sum(-1)
def compute_logps(_logps, _mask):
_logps = _logps[:, :-1].sum(-1)
_mask = _mask[:, :-1]
if self.loss_type == 'ipo':
logps /= loss_mask.sum(-1)
return logps
_logps /= _mask.sum(-1)
return _logps

(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps,
reference_rejected_logps) = [], [], [], []
for i in range(len(unpacked_logits) // 2):
chosen = unpacked_logits[2 * i]
rejected = unpacked_logits[2 * i + 1]
chosen_ref = unpacked_ref_logits[2 * i]
rejected_ref = unpacked_ref_logits[2 * i + 1]
chosen_label = unpacked_labels[2 * i]
rejected_label = unpacked_labels[2 * i + 1]
policy_chosen_logps.append(compute_logps(chosen, chosen_label))
for i in range(len(unpacked_policy_logps) // 2):
chosen = unpacked_policy_logps[2 * i]
rejected = unpacked_policy_logps[2 * i + 1]
chosen_ref = unpacked_ref_logps[2 * i]
rejected_ref = unpacked_ref_logps[2 * i + 1]
chosen_mask = unpacked_loss_mask[2 * i]
rejected_mask = unpacked_loss_mask[2 * i + 1]
policy_chosen_logps.append(compute_logps(chosen, chosen_mask))
policy_rejected_logps.append(
compute_logps(rejected, rejected_label))
compute_logps(rejected, rejected_mask))
reference_chosen_logps.append(
compute_logps(chosen_ref, chosen_label))
compute_logps(chosen_ref, chosen_mask))
reference_rejected_logps.append(
compute_logps(rejected_ref, rejected_label))
compute_logps(rejected_ref, rejected_mask))

return (torch.stack(policy_chosen_logps),
torch.stack(policy_rejected_logps),
Expand All @@ -142,7 +131,7 @@ def compute_logps(_logits, _labels):
@staticmethod
def _split_for_sequence_parallel(data):
# attention mask should not be split
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids')
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels')
sp_group = get_sequence_parallel_group()
for key in ARGS_NEED_TO_SPLIT:
val = data.get(key, None)
Expand All @@ -154,8 +143,14 @@ def _split_for_sequence_parallel(data):

def compute_loss(self, data, data_samples=None):
# modified from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py # noqa

labels = data.pop('labels')
# shift labels first and add a dummy label at the end, to support sequence parallel # noqa
data['labels'] = torch.cat(
(data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])),
dim=1)
tmp_label = data['labels'].clone()
tmp_label[tmp_label == 0] = -100
all_loss_mask = data[
'labels'] != -100 # loss mask of all tokens in all sp ranks # noqa

if get_sequence_parallel_world_size() > 1:
data = self._split_for_sequence_parallel(data)
Expand All @@ -168,14 +163,22 @@ def compute_loss(self, data, data_samples=None):
else:
all_ref_logits = self.ref_llm(**data).logits

labels = data['labels']
labels[labels == -100] = 0
loss_mask = labels != 0 # loss mask in a single sp rank
policy_logps = self._gather_masked_logits(all_logits, labels,
loss_mask)
ref_logps = self._gather_masked_logits(all_ref_logits, labels,
loss_mask)

if get_sequence_parallel_world_size() > 1:
all_logits = gather_forward_split_backward(
all_logits,
policy_logps = gather_forward_split_backward(
policy_logps,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')
all_ref_logits = gather_forward_split_backward(
all_ref_logits,
ref_logps = gather_forward_split_backward(
ref_logps,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')
Expand All @@ -184,15 +187,15 @@ def compute_loss(self, data, data_samples=None):
(policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps) = self.get_logps(
all_logits, all_ref_logits, labels)
policy_logps, ref_logps, all_loss_mask)
else:
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
(policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps) = self.get_var_len_atten_logps(
all_logits, all_ref_logits, labels, cu_seqlens,
policy_logps, ref_logps, all_loss_mask, cu_seqlens,
data['attention_mask'])

pi_logratios = policy_chosen_logps - policy_rejected_logps
Expand Down
135 changes: 69 additions & 66 deletions xtuner/model/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,12 @@ def _gather_masked_logits(self, logits, labels, mask):

def get_logps(
self,
all_logits, # bs, seqlen,vocab_size
average_log_prob, # bs, seqlen,vocab_size
labels, # bs, seqlen
all_logps, # bs, seqlen
average_log_prob,
loss_mask, # bs, seqlen
):
labels = labels[:, 1:].clone()
all_logits = all_logits[:, :-1, :]

labels[labels == -100] = 0
loss_mask = labels != 0
all_logps = self._gather_masked_logits(all_logits, labels,
loss_mask).sum(-1)
all_logps = all_logps[:, :-1].sum(-1)
loss_mask = loss_mask[:, :-1]

if average_log_prob: # average_log_prob
all_logps = all_logps / loss_mask.sum(-1)
Expand All @@ -53,47 +48,44 @@ def get_logps(
rejected_logps = all_logps[1::2]
return chosen_logps, rejected_logps

def get_var_len_atten_logps(self, all_logits, average_log_prob, labels,
def get_var_len_atten_logps(self, all_logps, average_log_prob, loss_mask,
cu_seqlens, attention_mask):
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# unpack sequence
unpacked_logits = torch.split(all_logits, seqlens, dim=1)
unpacked_labels = torch.split(labels, seqlens, dim=1)
unpacked_logps = torch.split(all_logps, seqlens, dim=1)
unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1)
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
assert False in attention_mask
unpacked_logits = unpacked_logits[:-1]
unpacked_labels = unpacked_labels[:-1]
assert len(unpacked_logits) % 2 == 0

def compute_logps(_logits, _labels):
_labels = _labels[:, 1:].clone()
_logits = _logits[:, :-1, :]
_labels[_labels == -100] = 0
loss_mask = _labels != 0
logps = self._gather_masked_logits(_logits, _labels, loss_mask)
logps = logps.sum(-1)
unpacked_logps = unpacked_logps[:-1]
unpacked_loss_mask = unpacked_loss_mask[:-1]
assert len(unpacked_logps) % 2 == 0

def compute_logps(_logps, _mask):
_logps = _logps[:, :-1].sum(-1)
_mask = _mask[:, :-1]
if average_log_prob:
logps /= loss_mask.sum(-1)
return logps
_logps /= _mask.sum(-1)
return _logps

chosen_logps, rejected_logps = [], []
for i in range(len(unpacked_logits) // 2):
chosen = unpacked_logits[2 * i]
rejected = unpacked_logits[2 * i + 1]
chosen_label = unpacked_labels[2 * i]
rejected_label = unpacked_labels[2 * i + 1]
chosen_logps.append(compute_logps(chosen, chosen_label))
rejected_logps.append(compute_logps(rejected, rejected_label))
for i in range(len(unpacked_logps) // 2):
chosen = unpacked_logps[2 * i]
rejected = unpacked_logps[2 * i + 1]
chosen_mask = unpacked_loss_mask[2 * i]
rejected_mask = unpacked_loss_mask[2 * i + 1]
chosen_logps.append(compute_logps(chosen, chosen_mask))
rejected_logps.append(compute_logps(rejected, rejected_mask))

return (torch.stack(chosen_logps), torch.stack(rejected_logps))

def cross_entropy_loss(self, logits, labels):
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# labels are already shifted, now we need to remove the last dummy label # noqa
labels = labels[..., :-1].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
Expand Down Expand Up @@ -126,7 +118,8 @@ def odds_ratio_loss(
@staticmethod
def _split_for_sequence_parallel(data):
# attention mask should not be split
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids')
ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels',
'chosen_rejected_tag')
sp_group = get_sequence_parallel_group()
for key in ARGS_NEED_TO_SPLIT:
val = data.get(key, None)
Expand All @@ -137,53 +130,63 @@ def _split_for_sequence_parallel(data):
return data

def compute_loss(self, data, data_samples=None):
labels_ori = data.pop('labels')
# shift labels first and add a dummy label at the end, to support sequence parallel # noqa
data['labels'] = torch.cat(
(data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])),
dim=1)
tmp_label = data['labels'].clone()
tmp_label[tmp_label == 0] = -100
# loss mask of all tokens in all sp ranks
all_loss_mask = data['labels'] != -100

if self.use_varlen_attn:
# create a chosen rejected tag for varlen_attn ce loss
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

chosen_rejected_tag = torch.ones_like(data['labels'])
unpacked_tag = list(
torch.split(chosen_rejected_tag, seqlens, dim=1))
# import pdb; pdb.set_trace()
for i in range(len(unpacked_tag) // 2):
# import pdb; pdb.set_trace()
unpacked_tag[2 * i + 1] *= 0
chosen_rejected_tag = torch.cat(unpacked_tag, dim=1)
data['chosen_rejected_tag'] = chosen_rejected_tag

if get_sequence_parallel_world_size() > 1:
data = self._split_for_sequence_parallel(data)

chosen_rejected_tag = data.pop('chosen_rejected_tag', None)
all_logits = self.llm(**data).logits

labels = data['labels'].clone()
labels[labels == -100] = 0
loss_mask = labels != 0 # loss mask in a single sp rank
all_logps = self._gather_masked_logits(all_logits, labels, loss_mask)
if get_sequence_parallel_world_size() > 1:
all_logits = gather_forward_split_backward(
all_logits,
all_logps = gather_forward_split_backward(
all_logps,
dim=1,
sp_group=get_sequence_parallel_group(),
grad_scale='up')

if not self.use_varlen_attn:
chosen_nll_loss = self.cross_entropy_loss(all_logits[::2],
labels_ori.clone()[::2])
data['labels'][::2])
chosen_logps, rejected_logps = self.get_logps(
all_logits, True, labels_ori)
all_logps, True, all_loss_mask)
else:
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

attention_mask = data['attention_mask']
if attention_mask is not None:
# It indicate that we pad the original sequence, labels,
# position_ids and cumulative_len for sequence parallel if the
# attention_mask is not None.
# We then need to remove the padded segments.
logits = torch.split(all_logits, seqlens, dim=1)[:-1]
assert len(logits) % 2 == 0
chosen_logits = logits[::2]
labels = torch.split(labels_ori.clone(), seqlens, dim=1)[:-1]
assert len(labels) % 2 == 0
chosen_labels = labels[::2]
else:
chosen_logits = torch.split(all_logits, seqlens, dim=1)[::2]
chosen_labels = torch.split(
labels_ori.clone(), seqlens, dim=1)[::2]

chosen_logits = torch.cat(chosen_logits, dim=1)
chosen_labels = torch.cat(chosen_labels, dim=1)
chosen_idxs = chosen_rejected_tag == 1
chosen_logits = all_logits[chosen_idxs]
chosen_labels = data['labels'][chosen_idxs]
chosen_nll_loss = self.cross_entropy_loss(chosen_logits,
chosen_labels)

chosen_logps, rejected_logps = self.get_var_len_atten_logps(
all_logits, True, labels_ori, cu_seqlens, attention_mask)
all_logps, True, all_loss_mask, cu_seqlens,
data['attention_mask'])
(losses, chosen_rewards, rejected_rewards, log_odds_ratio,
log_odds_chosen) = self.odds_ratio_loss(chosen_logps, rejected_logps)
losses = losses.mean()
Expand Down

0 comments on commit ff226e1

Please sign in to comment.