Skip to content

Commit

Permalink
removed unused codes.
Browse files Browse the repository at this point in the history
Signed-off-by: Xuesong Yang <[email protected]>
  • Loading branch information
XuesongYang committed Nov 9, 2024
1 parent 8fd7cd4 commit 6d6973d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1023,8 +1023,6 @@ def forward(
# For flat seq_pattern we need all the logits
token_logits = token_logits[:, :, :first_layer_vocabsize]
speech_layers = self.num_speech_codebooks - 1
last_layer_output = dec_output
last_layer_logits = token_logits

# speech_logits_list will be used in loss calculation (parallel output)
speech_logits_list = []
Expand Down
241 changes: 3 additions & 238 deletions nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,7 @@ def _build_sample(self, tup):
taskname = "squad"
prompt_template = self.task_templates[taskname]["prompt_template"]
prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"]
total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"]
virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"]
truncation_field = self.task_templates[taskname]['truncate_field']
answer_field = self.task_templates[taskname]["answer_field"]

input_example = prompt_template
Expand Down Expand Up @@ -563,9 +561,7 @@ def _build_sample(self, tup):
answer_text_ids = [self.tokenizer.pad_id]
else:
answer_text_ids = [self.tokenizer.bos_id]
# a trick to align with the data format in t5 pretraining
# if self.add_sentinel_to_input:
# answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.FIRST.value)

answer_text_ids += answer_ids

if self.add_eos_to_decoder_output:
Expand Down Expand Up @@ -662,6 +658,8 @@ def _build_sample(self, tup):
is_speech,
cross_attention_prior,
)
else:
return None

def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens):
total_len = self._get_len(context_tokens, question_tokens, virtual_tokens)
Expand Down Expand Up @@ -990,236 +988,3 @@ def pad_batch_and_build_loss_mask(self, batch):
}

return data_dict


class GPTSpeechLMTarredDataset(T5SpeechLMTarredDataset):
"""No support for cross attention here yet"""

def _build_sample(self, tup):
audio_filename, self.encodec, self.ref_encodec, offset_id = tup

file_id, _ = os.path.splitext(os.path.basename(audio_filename))
manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id]
manifest_entry = self.manifest_processor.collection[manifest_idx]
doc = {}
doc['context'] = manifest_entry.context
doc['context_type'] = manifest_entry.context_type
doc['context_duration'] = manifest_entry.context_duration
doc['answer'] = manifest_entry.answer
doc['answer_type'] = manifest_entry.answer_type
doc['answer_duration'] = manifest_entry.answer_duration
doc['question'] = manifest_entry.question
doc['question_type'] = manifest_entry.question_type

taskname = "squad"
prompt_template = self.task_templates[taskname]["prompt_template"]
prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"]
virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"]
answer_field = self.task_templates[taskname]["answer_field"]

input_example = prompt_template

# Format the input example according to the template
# Get context, question and answer codes in a dict.
input_dict = self._insert_data_in_template(input_example, prompt_template_fields, doc, answer_field)
context_tokens = input_dict['context']
question_tokens = input_dict['question']

# Logic to prune context
# In case of TTS task, the entire reference speech is not required, so we randomly select a portion
# of the reference audio.
# In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be
# predicted by the decoder.
start_token_index = 0
end_token_index = -1

total_context_len = context_tokens[0].size()[1]
context_3s = 3 * 75
if total_context_len > context_3s:
start_token_index = random.randint(0, total_context_len - context_3s)
# logging.debug(f"start_token_index: {start_token_index}")
end_token_index = start_token_index + min(context_3s, total_context_len)
# logging.debug(f"end_token_index: {end_token_index}")
context_tokens[0] = context_tokens[0][:, start_token_index:end_token_index]

# Get virtual tokens
virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits)

# a trick to align with the data format in t5 pretraining
# new
virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens)
if self.add_sentinel_to_input:
question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value)

# Add BOS/EOS to the input of encoder if desired, adds EOS by default
if self.ul2_prompt_token is not None:
ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token)
assert len(ul2_prompt_token_id) == 1
context_tokens = ul2_prompt_token_id + context_tokens
if self.add_bos:
context_tokens = [self.tokenizer.bos_id] + context_tokens
if self.add_eos:
question_tokens = [self.tokenizer.pad_id] + question_tokens + [self.tokenizer.pad_id]

virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens)
context_tokens, context_tokens_len = self.list_to_tensor(context_tokens)
question_tokens, question_tokens_len = self.list_to_tensor(question_tokens)

if doc["question_type"] != "SPEECH" and doc["context_type"] == "SPEECH":
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id)
if doc["context_type"] != "SPEECH" and doc["question_type"] == "SPEECH":
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id)
context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1)

# get answer ids
if answer_field in doc.keys(): # training and validation
answer_ids = self._get_tokens(doc, answer_field, doc[answer_field])
answer_text_ids = answer_ids

if self.add_eos_to_decoder_output:
answer_text_ids += [self.tokenizer.eos_id]
else:
answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value)

# Skip example if the final length doesn't fit length requirements even after truncation
input_ids = answer_text_ids
input_ids, input_ids_len = self.list_to_tensor(input_ids, True)
input_len = self._get_element_len(context_and_question_tokens) + self._get_element_len(answer_text_ids) - 1
if input_len > self.max_seq_length:
# logging.debug(f"Overflow. input_len:{input_len}. self.max_seq_length:{self.max_seq_length}. overflow_len:{self.max_seq_length - input_len}.")
overflow_len = self.max_seq_length - input_len
# truncate context if context after truncation is at least 1s
# else truncate answer as final option
if context_tokens_len - overflow_len > 75:
# logging.debug(f"Cutting context. context_tokens:{context_tokens.shape}. context_tokens_len:{context_tokens_len}.")
context_tokens = context_tokens[:, : context_tokens_len - overflow_len]
context_tokens_len = context_tokens_len - overflow_len
# logging.debug(f"Cut context. context_tokens:{context_tokens.shape}. context_tokens_len:{context_tokens_len}.")
else:
# logging.debug(f"Cutting answer. input_ids:{input_ids.shape}. input_ids_len:{input_ids_len}.")
input_ids = input_ids[:, : input_ids_len - overflow_len]
input_ids_len = input_ids_len - overflow_len
# logging.debug(f"Cut answer. input_ids:{input_ids.shape}. input_ids_len:{input_ids_len}.")

is_speech = True if doc["answer_type"] == "SPEECH" else False
if is_speech:
assert input_ids.dim() == 2
if self.seq_pattern == "delay_parallel":
num_codebooks = input_ids.shape[0]
dec_input_padded = torch.cat(
[
torch.zeros_like(input_ids[:, 0:num_codebooks]),
input_ids,
torch.zeros_like(input_ids[:, 0:num_codebooks]),
],
dim=1,
)
dec_input_new = []
for _c in range(self.num_speech_codebooks):
st = num_codebooks - _c
et_decoder_input = dec_input_padded.shape[1] - _c
dec_input_new.append(dec_input_padded[_c, st:et_decoder_input])
input_ids = torch.stack(dec_input_new, dim=0)
input_ids_len = torch.tensor(input_ids.shape[1]).long()

return (
context_tokens,
context_tokens_len,
question_tokens,
question_tokens_len,
input_ids,
input_ids_len,
)

def collate_fn(self, batch):
(
_,
context_tokens_len,
_,
question_tokens_len,
_,
input_ids_len,
) = zip(*batch)

decoder_input_len = (
torch.stack(context_tokens_len) + torch.stack(question_tokens_len) + torch.stack(input_ids_len)
)
max_decoder_input_len = max(decoder_input_len).item() if decoder_input_len is not None else 0

decoder_mask = get_mask_from_lengths(decoder_input_len - 1)
speech_mask = get_mask_from_lengths(decoder_input_len - 1)
context_question_mask = torch.ones(speech_mask.shape)
(
decoder_input_list,
decoder_labels_list,
) = (
[],
[],
)
for i, sample_tuple in enumerate(batch):
(
context_tokens,
context_tokens_len,
question_tokens,
question_tokens_len,
input_ids,
input_ids_len,
) = sample_tuple

context_tokens_input = context_tokens.clone().contiguous().detach()
for l in range(1, context_tokens_input.shape[0]):
context_tokens_input[l] += self.speech_offset + 1024 * l # TODO: fix hardcode
input_ids_shifted = input_ids.clone().contiguous().detach()
for l in range(1, input_ids_shifted.shape[0]):
input_ids_shifted[l] += self.speech_offset + 1024 * l # TODO: fix hardcode

complete_input = torch.cat([context_tokens_input, question_tokens, input_ids_shifted], dim=1)
complete_input_padded = general_padding(
complete_input,
decoder_input_len[i].item(),
max_decoder_input_len,
pad_value=self.tokenizer.pad_id,
)
complete_output = torch.cat([context_tokens, question_tokens, input_ids], dim=1)
complete_output_padded = general_padding(
complete_output,
decoder_input_len[i].item(),
max_decoder_input_len,
pad_value=self.tokenizer.pad_id,
)
decoder_labels = complete_output_padded[:, 1:].contiguous()
decoder_input = complete_input_padded[:, :-1].contiguous()

decoder_input_list.append(decoder_input)
decoder_labels_list.append(decoder_labels)

decoder_mask[i, : context_tokens_len + question_tokens_len - 1] = 0 # Mask out context and question
speech_mask[i, context_tokens_len : context_tokens_len + question_tokens_len] = (
0 # Mask out context and question
)
context_question_mask[i, : context_tokens_len + question_tokens_len] = 0

# Using causal attention mask for whole input
batch_size = len(decoder_input_list)
attention_mask = torch.tril(
torch.ones((batch_size, max_decoder_input_len - 1, max_decoder_input_len - 1))
).view(batch_size, 1, max_decoder_input_len - 1, max_decoder_input_len - 1)

# Convert attention mask from float to bool
attention_mask = attention_mask < 0.5

decoder_input = torch.stack(decoder_input_list)
decoder_input_p = decoder_input[:, 0, :] if decoder_input.dim() == 3 else decoder_input
position_ids = build_position_ids(decoder_input_p)
data_dict = {
"tokens": decoder_input,
"position_ids": position_ids,
"attention_mask": attention_mask,
"labels": torch.stack(decoder_labels_list),
"speech_mask": speech_mask, # For TTS, can just be loss_mask since answer will always be speech
"loss_mask": decoder_mask, # Mask out context and question and padding
"attention_prior": None,
"context_question_mask": context_question_mask,
}

return data_dict

0 comments on commit 6d6973d

Please sign in to comment.