From 3e3f0c47bd72128557e7aa41a320ae13cc6cc375 Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 6 Nov 2024 12:29:56 -0800 Subject: [PATCH 01/18] merge Signed-off-by: Jason --- Dockerfile.t5tts_from2407 | 194 ++ T5InferenceClean.ipynb | 389 +++ .../conf/megatron_t5_speechllm_inference.yaml | 160 ++ ...megatron_t5_speechllm_inference_model.yaml | 221 ++ ...n_t5_speechllm_inference_multiencoder.yaml | 226 ++ .../conf/megatron_t5_speechllm_medium.yaml | 161 ++ .../megatron_t5_speechllm_multiencoder.yaml | 231 ++ .../conf/megatron_t5_speechlm_model.yaml | 229 ++ .../tts/speechllm/megatron_t5_speechllm.py | 59 + .../megatron_t5_speechllm_inference.py | 55 + .../asr/modules/conformer_encoder.py | 2 +- nemo/collections/asr/modules/conv_asr.py | 2 +- .../common/parts/preprocessing/collections.py | 124 + .../common/parts/preprocessing/manifest.py | 17 +- .../tokenizers/sentencepiece_tokenizer.py | 10 +- .../megatron/base_prompt_learning_dataset.py | 48 +- .../nlp/modules/common/megatron/attention.py | 171 +- .../common/megatron/megatron_decoders.py | 4 +- .../megatron/megatron_encoder_decoder.py | 38 +- .../common/megatron/megatron_encoders.py | 48 +- .../megatron/megatron_transformer_decoder.py | 32 +- .../megatron/megatron_transformer_encoder.py | 220 ++ .../nlp/modules/common/megatron/module.py | 2 +- .../megatron/token_level_encoder_decoder.py | 433 ++- .../modules/common/megatron/transformer.py | 141 +- .../nlp/modules/common/megatron/utils.py | 22 +- .../tts/data/speechllm/__init__.py | 0 .../data/speechllm/t5_speechllm_dataset.py | 1588 +++++++++++ .../speechllm/t5_speechllm_tarred_dataset.py | 1212 ++++++++ .../tts/g2p/models/zh_cn_pinyin.py | 10 +- .../tts/models/speechllm/__init__.py | 0 .../megatron_base_speechllm_prompt_model.py | 445 +++ .../speechllm/megatron_t5_speechllm_model.py | 2509 +++++++++++++++++ .../tts/modules/audio_codec_modules.py | 858 +++++- nemo/collections/tts/parts/utils/helpers.py | 69 + .../tts/parts/utils/tts_dataset_utils.py | 5 +- nemo/utils/exp_manager.py | 13 +- nemo/utils/timers.py | 10 +- requirements/requirements_tts.txt | 2 + scripts/speechllm_multitask_dataprep.py | 783 +++++ 40 files changed, 10653 insertions(+), 90 deletions(-) create mode 100644 Dockerfile.t5tts_from2407 create mode 100644 T5InferenceClean.ipynb create mode 100644 examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml create mode 100644 examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml create mode 100644 examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml create mode 100644 examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml create mode 100644 examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml create mode 100644 examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml create mode 100644 examples/tts/speechllm/megatron_t5_speechllm.py create mode 100644 examples/tts/speechllm/megatron_t5_speechllm_inference.py create mode 100644 nemo/collections/tts/data/speechllm/__init__.py create mode 100644 nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py create mode 100644 nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py create mode 100644 nemo/collections/tts/models/speechllm/__init__.py create mode 100644 nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py create mode 100644 nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py create mode 100644 scripts/speechllm_multitask_dataprep.py diff --git a/Dockerfile.t5tts_from2407 b/Dockerfile.t5tts_from2407 new file mode 100644 index 000000000000..1f6a69364400 --- /dev/null +++ b/Dockerfile.t5tts_from2407 @@ -0,0 +1,194 @@ +# syntax=docker/dockerfile:experimental + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASE_IMAGE=nvcr.io/nvidia/nemo:24.07 + +# build an image that includes only the nemo dependencies, ensures that dependencies +# are included first for optimal caching, and useful for building a development +# image (by specifying build target as `nemo-deps`) +FROM ${BASE_IMAGE} as nemo-deps + +# dependency flags; should be declared after FROM +# torchaudio: not required by default +ARG REQUIRE_TORCHAUDIO=false +# k2: not required by default +ARG REQUIRE_K2=false +# ais cli: not required by default, install only if required +ARG REQUIRE_AIS_CLI=false + +# Ensure apt-get won't prompt for selecting options +ENV DEBIAN_FRONTEND=noninteractive +RUN cd + +# libavdevice-dev required for latest torchaudio +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y \ + libsndfile1 sox \ + libfreetype6 \ + swig \ + ffmpeg \ + libavdevice-dev && \ + rm -rf /var/lib/apt/lists/* + +# libtool, ... , libgts-dev are required for graphviz +# graphviz is required for k2 and pynini visualization +RUN apt-get update && \ + apt-get install -y \ + libtool \ + libltdl-dev \ + automake \ + autoconf \ + bison \ + flex \ + tcl \ + ghostscript \ + libgd-dev \ + fontconfig \ + libcairo2-dev \ + libpango1.0-dev \ + libgts-dev && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace/ + +ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea +ARG MCORE_TAG=3f90b989c477ba9be5d6011866641eda9d91f588 +ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c +# Install megatron core, this can be removed once 0.3 pip package is released +# We leave it here in case we need to work off of a specific commit in main +RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ + cd Megatron-LM && \ + git checkout ${MCORE_TAG} && \ + pip install . + +# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 +RUN git clone https://github.com/NVIDIA/apex.git && \ + cd apex && \ + git checkout ${APEX_TAG} && \ + pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir \ + --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ + +# Transformer Engine 1.2.0 +RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ + cd TransformerEngine && \ + git fetch origin ${TE_TAG} && \ + git checkout FETCH_HEAD && \ + git submodule init && git submodule update && \ + NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . + +WORKDIR /tmp/ + +# uninstall stuff from base container +RUN pip3 uninstall -y sacrebleu torchtext + +# build torchaudio +WORKDIR /tmp/torchaudio_build +COPY scripts/installers /tmp/torchaudio_build/scripts/installers/ +RUN INSTALL_MSG=$(/bin/bash /tmp/torchaudio_build/scripts/installers/install_torchaudio_latest.sh); INSTALL_CODE=$?; \ + echo ${INSTALL_MSG}; \ + if [ ${INSTALL_CODE} -ne 0 ]; then \ + echo "torchaudio installation failed"; \ + if [ "${REQUIRE_TORCHAUDIO}" = true ]; then \ + exit ${INSTALL_CODE}; \ + else echo "Skipping failed torchaudio installation"; fi \ + else echo "torchaudio installed successfully"; fi + +COPY scripts /tmp/nemo/scripts/ +# install correct graphviz version (k2 and pynini visualization tool), skip if installation fails +RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_graphviz.sh --docker); INSTALL_CODE=$?; \ + echo ${INSTALL_MSG}; \ + if [ ${INSTALL_CODE} -ne 0 ]; then \ + echo "graphviz installation failed"; \ + if [ "${REQUIRE_K2}" = true ]; then \ + exit ${INSTALL_CODE}; \ + else echo "Skipping failed graphviz installation"; fi \ + else echo "graphviz installed successfully"; fi + +# # install k2, skip if installation fails +# COPY scripts /tmp/nemo/scripts/ +# RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_k2.sh); INSTALL_CODE=$?; \ +# echo ${INSTALL_MSG}; \ +# if [ ${INSTALL_CODE} -ne 0 ]; then \ +# echo "k2 installation failed"; \ +# if [ "${REQUIRE_K2}" = true ]; then \ +# exit ${INSTALL_CODE}; \ +# else echo "Skipping failed k2 installation"; fi \ +# else echo "k2 installed successfully"; fi + +# install nemo dependencies +WORKDIR /tmp/nemo +ENV LHOTSE_REQUIRE_TORCHAUDIO=0 +COPY requirements . +# exclude requirements_vllm.txt, since `vllm==0.5.x` breaks the container due to hardcoded requirements `torch==2.3.0` +RUN for f in $(ls requirements*.txt | grep -v 'requirements_vllm.txt'); do \ + pip3 install --disable-pip-version-check --no-cache-dir -r $f; done + +# install flash attention +RUN pip install flash-attn +# install numba for latest containers +RUN pip install numba>=0.57.1 +# Extra t5 libraries +RUN pip install ipdb seaborn gradio + +# copy nemo source into a scratch image +FROM scratch as nemo-src +COPY . . + +# start building the final container +FROM nemo-deps as nemo +ARG NEMO_VERSION=2.0.0 + +# Check that NEMO_VERSION is set. Build will fail without this. Expose NEMO and base container +# version information as runtime environment variable for introspection purposes +RUN /usr/bin/test -n "$NEMO_VERSION" && \ + /bin/echo "export NEMO_VERSION=${NEMO_VERSION}" >> /root/.bashrc && \ + /bin/echo "export BASE_IMAGE=${BASE_IMAGE}" >> /root/.bashrc + +# Install NeMo +RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]" + +# Check install +# NB: adjusting LD_LIBRARY_PATH (only here, should not be persistent!) is a temporary hack +# to avoid failure if CUDA is unavailable (`docker build` does not expose GPUs) +# The error is raised in NeMo Core, and the main reason is reinstalled Transformer-Engine; +RUN export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CUDA_HOME}/compat/lib.real && \ + python -c "import nemo.collections.asr as nemo_asr" && \ + python -c "import nemo.collections.nlp as nemo_nlp" && \ + python -c "import nemo.collections.tts as nemo_tts" && \ + python -c "import nemo_text_processing.text_normalization as text_normalization" + + +# copy scripts/examples/tests into container for end user +WORKDIR /workspace/nemo +# COPY scripts /workspace/nemo/scripts +# COPY examples /workspace/nemo/examples +# COPY tests /workspace/nemo/tests +# COPY tutorials /workspace/nemo/tutorials +# COPY README.rst LICENSE /workspace/nemo/ + +RUN printf "#!/bin/bash\njupyter lab --no-browser --allow-root --ip=0.0.0.0" >> start-jupyter.sh && \ + chmod +x start-jupyter.sh + +# If required, install AIS CLI and Python AIS SDK +RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_ais_cli_latest.sh && pip install aistore); INSTALL_CODE=$?; \ + echo ${INSTALL_MSG}; \ + if [ ${INSTALL_CODE} -ne 0 ]; then \ + echo "AIS CLI installation failed"; \ + if [ "${REQUIRE_AIS_CLI}" = true ]; then \ + exit ${INSTALL_CODE}; \ + else echo "Skipping AIS CLI installation"; fi \ + else echo "AIS CLI installed successfully"; fi diff --git a/T5InferenceClean.ipynb b/T5InferenceClean.ipynb new file mode 100644 index 000000000000..af43d6c8409e --- /dev/null +++ b/T5InferenceClean.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "7554757b", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "import torch\n", + "import json\n", + "from omegaconf.omegaconf import OmegaConf, open_dict\n", + "import shutil\n", + "\n", + "from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel\n", + "from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder\n", + "from nemo.collections.asr.parts.preprocessing.segment import AudioSegment\n", + "from nemo.core.config import hydra_runner\n", + "from nemo.utils import logging\n", + "from nemo.utils.exp_manager import exp_manager\n", + "from IPython.display import Audio, display\n", + "import torchaudio\n", + "\n", + "# CHANGE THIS TO A LOCAL DIRECTORY\n", + "EXP_DIR = \"/datap/misc/NotebookInference\"\n", + "\n", + "if not os.path.exists(EXP_DIR):\n", + " os.makedirs(EXP_DIR)" + ] + }, + { + "cell_type": "markdown", + "id": "5fdfa55a", + "metadata": {}, + "source": [ + "## Save a dummy manifest to setup Model Test Step" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a0c40ac", + "metadata": {}, + "outputs": [], + "source": [ + "def write_records(fp, records):\n", + " with open(fp, \"w\") as f:\n", + " for record in records:\n", + " f.write(json.dumps(record) + \"\\n\")\n", + "\n", + "dummy_codes = torch.ones(8, 300).cpu().type(torch.int16)\n", + "dummy_codes_fp = os.path.join(EXP_DIR, \"dummy_codes.pt\")\n", + "torch.save(dummy_codes, dummy_codes_fp)\n", + "\n", + "\n", + "dummy_record = {\n", + " \"question\" : \"Phoneme TTS Sample Text\",\n", + " \"answer\" : dummy_codes_fp,\n", + " \"context\" : dummy_codes_fp,\n", + " \"context_type\" : \"REFSPEAKERCODEC\",\n", + " \"question_type\" : \"TEXT\",\n", + " \"answer_type\" : \"AUDIOCODEC\",\n", + " \"context_duration\" : 5.0,\n", + " \"answer_duration\" : 5.0,\n", + " \"taskname\" : \"squad\"\n", + "}\n", + "\n", + "dummy_val_file = os.path.join(EXP_DIR, \"dummy_val.json\")\n", + "\n", + "write_records(dummy_val_file, [dummy_record])" + ] + }, + { + "cell_type": "markdown", + "id": "c9cd90c5", + "metadata": {}, + "source": [ + "## Load and setup the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee1df6bf", + "metadata": {}, + "outputs": [], + "source": [ + "# CHANGE THESE PATHS TO RELEVANT MOUNTED PATHS IN DOCKER\n", + "config_path = \"/home/pneekhara/2023/NeMo/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml\"\n", + "# checkpoint_path = \"/datap/misc/temp_checkpoints_new/desta_less_sophia_highLR_step159600.ckpt\"\n", + "checkpoint_path = \"/datap/misc/checkpoints/desta_less_sophia_213850.ckpt\"\n", + "codecmodel_path = \"/datap/misc/checkpoints/SpeechCodec_2402.nemo\"\n", + "vocab_file = \"/datap/misc/checkpoints/9a77f10c2793465e8e8a3fa5fcbef8b0_vocab.txt\"\n", + "\n", + "cfg = OmegaConf.load(config_path)\n", + "\n", + "if \"gradient_as_bucket_view\" not in cfg.model:\n", + " with open_dict(cfg):\n", + " cfg.model.gradient_as_bucket_view=False\n", + "\n", + "trainer = MegatronTrainerBuilder(cfg).create_trainer()\n", + "exp_manager(trainer, cfg.exp_manager)\n", + "\n", + "with open_dict(cfg):\n", + " cfg.exp_manager.exp_dir=EXP_DIR\n", + " cfg.checkpoint_path = checkpoint_path\n", + " cfg.model.data.sup_data_path=\"/datap/misc/speechllm_codecdatasets/\"\n", + " cfg.model.global_batch_size=1\n", + " cfg.model.micro_batch_size=1\n", + " cfg.model.data.speech_offset=30128\n", + " cfg.model.lm_vocab_size=30000\n", + " cfg.model.data.add_special_tokens_to_only_first_codebook=True\n", + " cfg.model.data.train_task=\"all\"\n", + " cfg.model.freeze_model=False\n", + " cfg.model.data.max_seq_length=2048\n", + " cfg.model.max_inference_timesteps=2000\n", + " cfg.model.data.context_duration_min=20.0\n", + " cfg.model.data.context_duration_max=20.0\n", + " cfg.model.top_k=80\n", + " cfg.model.temperature=0.85\n", + " cfg.model.data.speech_offset=30128\n", + " cfg.model.lm_vocab_size=30000\n", + " cfg.model.codecmodel_path=codecmodel_path\n", + " cfg.trainer.devices=1\n", + " cfg.trainer.precision=\"bf16\"\n", + " cfg.model.precision = cfg.trainer.precision\n", + " cfg.model.override_tokenizer_vocab_file=vocab_file\n", + " cfg.model.english_only_model=True\n", + " cfg.model.asr_model_name=\"stt_en_conformer_transducer_large\"\n", + " cfg.model.frozen_model.decoder.layer_type=[1,1,1,2,2,2,2,2,2,2,1,1]\n", + " cfg.model.alignment_decoder_layerids=[0,1,2,3,4]\n", + " cfg.model.enc_output_to_layers=[[8,9],[3,4,5,6,7]]\n", + " cfg.model.data.test_ds=[dummy_val_file]\n", + " cfg.model.data.num_workers = 0\n", + "\n", + "\n", + "checkpoint_path = cfg.get('checkpoint_path', None)\n", + "assert checkpoint_path is not None, \"checkpoint path needs to be valid\"\n", + "\n", + "model = MegatronT5SpeechLMModel.load_from_checkpoint(\n", + " checkpoint_path=checkpoint_path, trainer=trainer, cfg=cfg.model\n", + " )\n", + "model.eval()\n", + "model = model.cuda()\n", + "\n", + "codec_model = model.additional_models['codec']\n", + "trainer.test(model)\n" + ] + }, + { + "cell_type": "markdown", + "id": "e5461918", + "metadata": {}, + "source": [ + "## Helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "227b07f2", + "metadata": {}, + "outputs": [], + "source": [ + "out_dir = os.path.join( model.trainer.logger.save_dir, model.trainer.logger.name, model.trainer.logger.version, \"Sample_Audios\")\n", + "out_path = os.path.join(out_dir, 'predicted_wav_0.wav')\n", + "\n", + "\n", + "def encode(wav_path):\n", + " # Convert an audio file to nemo codec codes\n", + " features = AudioSegment.segment_from_file(\n", + " wav_path, target_sr=codec_model.sample_rate, n_segments=-1, trim=False,\n", + " )\n", + " audio_samples = features.samples\n", + " audio = torch.tensor(audio_samples).cuda()\n", + " audio_length = torch.tensor(audio.size(0)).long().cuda()\n", + " print(f\"audio {audio.size()} audio_length {audio_length}\")\n", + " print(f\"audio {audio.device} audio_length {audio_length.device} codec_model {codec_model.device}\")\n", + "\n", + " original_codec_codes, _ = codec_model.encode(audio=audio.unsqueeze(0), audio_len=audio_length.unsqueeze(0))\n", + " original_codec_codes = original_codec_codes[0]\n", + " print(f\"original_codec_codes {original_codec_codes.size()} audio {audio.size()} audio_length {audio_length}\")\n", + " duration = original_codec_codes.size()[1] / 86\n", + " \n", + " target_codec_filepath = wav_path[:-4] + \"_codes.pt\"\n", + " torch.save(original_codec_codes.cpu().type(torch.int16), target_codec_filepath)\n", + " return original_codec_codes, target_codec_filepath, duration\n", + " \n", + " \n", + " \n", + "def play_codec(codec_path):\n", + " # Convert nemo codecs to audio and play it\n", + " codec = torch.load(codec_path)\n", + " codec = codec.to('cuda')\n", + " codec = codec.unsqueeze(0)\n", + " codec_lens = torch.Tensor([codec.shape[2]]).long().cuda()\n", + " codec_decoded_audios, _ = codec_model.decode(tokens=codec.long(), tokens_len=codec_lens)\n", + " codec_decoded_audio = codec_decoded_audios[0]\n", + " temp_wav_path = os.path.join(EXP_DIR, \"temp.wav\")\n", + " torchaudio.save(temp_wav_path, codec_decoded_audio[None].cpu(), 22050)\n", + " display(Audio(temp_wav_path))\n", + "\n", + "def generate_new_audio(\n", + " text,\n", + " context,\n", + " context_duration=4.0,\n", + " context_type=\"REFSPEAKERCODEC\",\n", + " temperature=0.85,\n", + " top_k=80,\n", + " text_task=\"Phoneme TTS \"\n", + " ):\n", + " # Prepare data in speechllm format\n", + " model.cfg.temperature = temperature\n", + " model.cfg.top_k = top_k\n", + " dummy_answer = dummy_codes_fp\n", + " json_in = {}\n", + " json_in[\"question\"] = text_task + text\n", + " json_in[\"question_type\"] = \"TEXT\"\n", + " json_in[\"answer\"] = dummy_answer \n", + " json_in[\"context\"] = context \n", + " json_in[\"answer_type\"] = \"AUDIOCODEC\"\n", + " json_in[\"context_type\"] = context_type\n", + " json_in[\"context_duration\"] = context_duration\n", + " json_in[\"answer_duration\"] = 2.0\n", + " json_in[\"taskname\"] = \"squad\"\n", + " json_in[\"lang\"] = \"en\"\n", + " json_in = [json_in]\n", + " \n", + " # Prepare dataloader\n", + " model._test_ds.examples = []\n", + " model._test_ds.examples = model._test_ds.load_data(json_in)\n", + " \n", + " sampler = torch.utils.data.distributed.DistributedSampler(\n", + " model._test_ds, num_replicas=1, rank=0, shuffle=False, seed=1\n", + " )\n", + "\n", + " model._test_dl = torch.utils.data.DataLoader(\n", + " model._test_ds,\n", + " collate_fn=model._test_ds.collate_fn,\n", + " sampler=sampler,\n", + " batch_size=1,\n", + " drop_last=False,\n", + " num_workers=1,\n", + " pin_memory=False,\n", + " persistent_workers=True\n", + " )\n", + " \n", + " # Run inference\n", + " model.cfg.data.test_ds = None\n", + " trainer.test(model, model._test_dl)\n", + " print(\"Out path:\", out_path)\n", + " print(\"Inference done\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7f449a2", + "metadata": {}, + "outputs": [], + "source": [ + "text_contexts = [\n", + " \"TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Lindy_CMU_FEARFUL |\",\n", + " \"TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Lindy_CMU_HAPPY |\",\n", + " \"TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Rodney_CMU_HAPPY |\",\n", + " \"TEXT CONTEXT: | Language:en Dataset:PromptTTS Gender:female SpeakingRate:2. Slow emotion:neutral Pitch:4. High SNR:5. Clean REVERB:5. Very close-sounding |\"\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "a3d5a467", + "metadata": {}, + "source": [ + "## Generate audio from a text context" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf9660d9", + "metadata": {}, + "outputs": [], + "source": [ + "text = \"As I closed my laptop for the night, my reflection in the screen continued to smile back at me.\"\n", + "text_task = \"Phoneme TTS \" # Can be \"Text to speech this \" (for sentence-piece tokenizer) or \"Phoneme TTS \" (for phoneme tokenizer)\n", + "context = text_contexts[1] # Sample Text Context\n", + "context_type = \"TEXT\" # Can be REFSPEAKERCODEC (for audio context), TEXT (for text context)\n", + "generate_new_audio(\n", + " text, \n", + " context, \n", + " context_type=context_type, \n", + " context_duration=5.0, # Does not matter, should just be > 3 so that dataset does not filter it out.\n", + " top_k=80, # Can play around with this to check roubstness\n", + " temperature=0.8, # Can play around with this. temperature < 0.85 can be more robust\n", + " text_task=text_task\n", + ")\n", + "display(Audio(out_path))" + ] + }, + { + "cell_type": "markdown", + "id": "f1a964c3", + "metadata": {}, + "source": [ + "## Listen to some ground-truth context audios" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f7c966e", + "metadata": {}, + "outputs": [], + "source": [ + "context_paths = [\n", + " \"/datap/misc/speechllm_codecdatasets/codecs/RivattsAllLanguagesUpdated_train_nemo_codec_bw_6.0/target_codes_en_Lindy_44khz_CMU_HAPPY_LINDY_CMU_HAPPY_000570.pt\",\n", + "]\n", + "\n", + "for cidx, context_path in enumerate(context_paths):\n", + " print(cidx, context_path)\n", + " play_codec(context_path)" + ] + }, + { + "cell_type": "markdown", + "id": "45f8bc28", + "metadata": {}, + "source": [ + "## Generate audio from an audio context" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7dd9deb2", + "metadata": {}, + "outputs": [], + "source": [ + "text = \"As I closed my laptop for the night, my reflection in the screen continued to smile back at me.\"\n", + "text_task = \"Text to speech this \" # Can be \"Text to speech this \" (for sentence-piece tokenizer) or \"Phoneme TTS \" (for phoneme tokenizer)\n", + "context = context_paths[0] # Sample Text Context\n", + "context_type = \"REFSPEAKERCODEC\" # Can be REFSPEAKERCODEC (for audio context), TEXT (for text context)\n", + "generate_new_audio(\n", + " text, \n", + " context, \n", + " context_type=context_type, \n", + " context_duration=5.0, # Does not matter, should just be > 3 so that dataset does not filter it out.\n", + " temperature=0.8, # Can play around with this. temperature < 0.85 can be more robust\n", + " text_task=text_task\n", + ")\n", + "display(Audio(out_path))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f40d2450", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml new file mode 100644 index 000000000000..8b37077bfdd5 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml @@ -0,0 +1,160 @@ +name: megatron_t5_speechllm_tts_inference +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 10000 + max_steps: -1 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 3 + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 2 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 16 + micro_batch_size: 16 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 1000 # Maximum number of timesteps to run inference for + + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + language_model_path: ??? # Path to the pretrained T5 language model .nemo file, always required + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + use_flash_attention: false + lm_vocab_size: 30000 + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 1536 + sample_rate: 24000 + add_eos: true + add_bos: false + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30000 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 5e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml new file mode 100644 index 000000000000..bd65eb956bbc --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml @@ -0,0 +1,221 @@ +name: megatron_t5_speechllm_tts_inference +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 10000 + max_steps: -1 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 3 + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 2 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 16 + micro_batch_size: 16 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 1000 # Maximum number of timesteps to run inference for + + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + + frozen_model: + # micro_batch_size: null + # global_batch_size: null + # megatron_amp_O2: true + # seq_length: 512 + # max_position_embeddings: 512 + # precision: bf16 + # Above is overridden in code + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + # num_sentinel_tokens: 100 + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 1536 + sample_rate: 24000 + add_eos: true + add_bos: false + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30000 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 5e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml new file mode 100644 index 000000000000..761268ca6fa1 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml @@ -0,0 +1,226 @@ +name: megatron_t5_speechllm +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 2000 # Maximum number of timesteps to run inference for + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: encoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] + + frozen_model: + # micro_batch_size: null + # global_batch_size: null + # megatron_amp_O2: true + # seq_length: 512 + # max_position_embeddings: 512 + # precision: bf16 + # Above is overridden in code + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + # num_sentinel_tokens: 100 + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: multi_transformer + n_transformers: 2 + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 6 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: false + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + encoder_type: ${model.frozen_model.encoder.arch} + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml new file mode 100644 index 000000000000..bd31f0712fdf --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml @@ -0,0 +1,161 @@ +name: megatron_t5_speechllm_medium + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 1000000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + language_model_path: ??? # Path to the pretrained T5 language model .nemo file, always required + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + use_flash_attention: false + lm_vocab_size: 30000 + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml new file mode 100644 index 000000000000..c121c8f9a510 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml @@ -0,0 +1,231 @@ +name: megatron_t5_speechllm + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: encoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] + + frozen_model: + # micro_batch_size: null + # global_batch_size: null + # megatron_amp_O2: true + # seq_length: 512 + # max_position_embeddings: 512 + # precision: bf16 + # Above is overridden in code + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + # num_sentinel_tokens: 100 + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: multi_transformer + n_transformers: 2 + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 6 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + encoder_type: ${model.frozen_model.encoder.arch} + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 1000 + constant_steps: 0 + min_lr: 1e-5 + monitor: val_loss + reduce_on_plateau: false \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml new file mode 100644 index 000000000000..5210254a2e7d --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml @@ -0,0 +1,229 @@ +name: megatron_t5_speechllm + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + + frozen_model: + # micro_batch_size: null + # global_batch_size: null + # megatron_amp_O2: true + # seq_length: 512 + # max_position_embeddings: 512 + # precision: bf16 + # Above is overridden in code + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + # num_sentinel_tokens: 100 + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + use_ipa: false + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 1000 + constant_steps: 0 + min_lr: 1e-5 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/tts/speechllm/megatron_t5_speechllm.py b/examples/tts/speechllm/megatron_t5_speechllm.py new file mode 100644 index 000000000000..3f0f7a2e76b1 --- /dev/null +++ b/examples/tts/speechllm/megatron_t5_speechllm.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_t5_speechllm_medium.yaml") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + # MegatronTrainerBuilder compat checks + if "gradient_as_bucket_view" not in cfg.model: + with open_dict(cfg): + cfg.model.gradient_as_bucket_view = False + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + # load existing or init new soft prompt T5 model + if cfg.model.get("restore_path", None) is not None: + logging.info(f"cfg.model.restore_path {cfg.model.restore_path}") + model = MegatronT5SpeechLMModel.restore_from( + cfg.model.restore_path, cfg.model, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector() + ) + else: + logging.info(f"cfg.model.restore_path is None") + model = MegatronT5SpeechLMModel(cfg.model, trainer=trainer) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/tts/speechllm/megatron_t5_speechllm_inference.py b/examples/tts/speechllm/megatron_t5_speechllm_inference.py new file mode 100644 index 000000000000..27e5deb1f81a --- /dev/null +++ b/examples/tts/speechllm/megatron_t5_speechllm_inference.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_t5_speechllm_inference.yaml") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + # MegatronTrainerBuilder compat checks + if "gradient_as_bucket_view" not in cfg.model: + with open_dict(cfg): + cfg.model.gradient_as_bucket_view = False + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + # load existing or init new soft prompt T5 model + checkpoint_path = cfg.get('checkpoint_path', None) + assert checkpoint_path is not None, "Please specify checkpoint_path in the config file" + model = MegatronT5SpeechLMModel.load_from_checkpoint( + checkpoint_path=checkpoint_path, trainer=trainer, cfg=cfg.model + ) + model.eval() + model = model.cuda() + trainer.test(model) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 27d0cde33f8c..d081f5034eed 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -672,7 +672,7 @@ def forward_internal( def update_max_seq_length(self, seq_length: int, device): # Find global max audio length across all nodes - if torch.distributed.is_initialized(): + if torch.distributed.is_initialized() and (not getattr(self, 'disable_torch_distributed', False)): global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) # Update across all ranks in the distributed system diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 3cb9ec13109b..b4b9a6735013 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -200,7 +200,7 @@ def forward(self, audio_signal, length): def update_max_sequence_length(self, seq_length: int, device): # Find global max audio length across all nodes - if torch.distributed.is_initialized(): + if torch.distributed.is_initialized() and (not getattr(self, 'disable_torch_distributed', False)): global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) # Update across all ranks in the distributed system diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index b16ac50e4d56..d4e96875be6b 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -308,6 +308,130 @@ def __init__( super().__init__(data) +class InstructionTuningAudioText(_Collection): + """`AudioText` collector from asr structured json files.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='InstructionTuningText', + field_names='id context context_type context_duration question question_type answer answer_type answer_duration speaker', + ) + + def __init__( + self, + manifests_files: Union[str, List[str]], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + """Parse lists of audio files, durations and transcripts texts. + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `AudioText` constructor. + **kwargs: Kwargs to pass to `AudioText` constructor. + """ + + output_type = self.OUTPUT_TYPE + self.use_phoneme_tokenizer = use_phoneme_tokenizer + data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 + if index_by_file_id: + self.mapping = {} + + for item in manifest.item_iter(manifests_files): + + id = item['id'] + context = item['context'] + context_duration = item['context_duration'] + context_type = item['context_type'] + question = item['question'] + question_type = item['question_type'] + speaker = item['speaker'] + answer = item['answer'] + answer_duration = item['answer_duration'] + answer_type = item['answer_type'] + task = item['task'] + + task = 'tts' if task is None else task + duration = answer_duration if task == 'tts' else context_duration + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + # Check segment length + approx_context_len = min(self._get_len(context_type, context, context_duration) * 0.3, 400) + approx_question_len = self._get_len(question_type, question, None) + approx_answer_len = self._get_len(answer_type, answer, answer_duration) + + if ( + decoder_only_model and approx_context_len + approx_question_len + approx_answer_len >= max_seq_length + ) or (approx_context_len + approx_question_len >= max_seq_length or approx_answer_len >= max_seq_length): + duration_filtered += duration + num_filtered += 1 + continue + + total_duration += duration + data.append( + output_type( + id, + context, + context_type, + context_duration, + question, + question_type, + answer, + answer_type, + answer_duration, + speaker, + ) + ) + + if index_by_file_id: + file_id, _ = os.path.splitext(os.path.basename(context)) + if ".context" in file_id: + file_id = file_id[:-8] + if file_id not in self.mapping: + self.mapping[file_id] = [] + self.mapping[file_id].append(len(data) - 1) + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + super().__init__(data) + + def _get_len(self, field_type, data, duration_data): + if field_type == "SPEECH": + return duration_data * 76 + elif field_type == "TEXT": + if self.use_phoneme_tokenizer: + # Approx len is number of characters + return len(data) + else: + return len(data.split(' ')) + 3 + elif field_type == "TOKENS": + return len(data) + 3 + + class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" diff --git a/nemo/collections/common/parts/preprocessing/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py index 1d49bd7c7019..e2ad08bd04c2 100644 --- a/nemo/collections/common/parts/preprocessing/manifest.py +++ b/nemo/collections/common/parts/preprocessing/manifest.py @@ -110,6 +110,8 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: item['audio_file'] = item.pop('audio_filename') elif 'audio_filepath' in item: item['audio_file'] = item.pop('audio_filepath') + elif 'context' in item: + item['audio_file'] = item['context'] # Video File if 'video_filename' in item: @@ -132,7 +134,9 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: item['video_file'] = get_full_path(audio_file=item['video_file'], manifest_file=manifest_file) # Duration. - if 'duration' not in item: + if 'context_duration' in item and 'duration' not in item: + item['duration'] = item['context_duration'] + elif 'duration' not in item: raise ValueError( f"Manifest file {manifest_file} has invalid json line structure: {line} without proper duration key." ) @@ -184,6 +188,15 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: orig_sr=item.get('orig_sample_rate', None), token_labels=item.get('token_labels', None), lang=item.get('lang', None), + context=item.get('context', None), + context_type=item.get('context_type', None), + context_duration=item.get('context_duration', None), + answer=item.get('answer', None), + answer_type=item.get('answer_type', None), + answer_duration=item.get('answer_duration', None), + question=item.get('question', None), + question_type=item.get('question_type', None), + task=item.get('task', None), ) return item @@ -247,7 +260,7 @@ def get_full_path( if ( (len(audio_file) < audio_file_len_limit) and not os.path.isabs(audio_file) - and not os.path.isfile(audio_file) + # and not os.path.isfile(audio_file) # Commented out because it slows down dataloading ): # If audio_file is not available and the path is not absolute, the full path is assumed # to be relative to the manifest file parent directory or data directory. diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index a8ea949019c1..56a4b04dfe0f 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -25,7 +25,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging -__all__ = ['SentencePieceTokenizer', 'create_spt_model'] +__all__ = ['SentencePieceTokenizer', 'SentencePieceSpeechLLMTTSTokenizer', 'create_spt_model'] class SentencePieceTokenizer(TokenizerSpec, ChatTemplateMixin): @@ -315,6 +315,14 @@ def vocab(self): return main_vocab + special_tokens +class SentencePieceSpeechLLMTTSTokenizer(SentencePieceTokenizer): + def add_phone_tokens_to_special_tokens(self): + for i, word in enumerate(self.vocab): + if word.startswith("p{"): + self.special_token_to_id[word] = i + self.id_to_special_token[i] = word + + def create_spt_model( data_file: str, vocab_size: int, diff --git a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py index bbd14f47a651..cb43408478e4 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +import omegaconf from nemo.collections.nlp.modules.common import VirtualPromptSource from nemo.core import Dataset @@ -71,7 +72,52 @@ def __init__( elif isinstance(datasets[0], str): for path in datasets: dataset = open(path, 'r', encoding='utf-8') - self.load_data(dataset) + dataset_examples = self.load_data(dataset) + self.examples.extend(dataset_examples) + elif (isinstance(datasets[0], omegaconf.ListConfig) or isinstance(datasets[0], list)): + # Dataset is a list of tuples with the first element being the probability of sampling from the dataset + # This code repeates the smaller datasets to approximately match the target probabilities + total_examples = 0 + dataset_lengths = [] + target_probs = [] + datasets_examples_list = [] + for prob_and_path in datasets: + prob = prob_and_path[0] + path = prob_and_path[1] + dataset = open(path, 'r', encoding='utf-8') + dataset_examples = self.load_data(dataset) + datasets_examples_list.append(dataset_examples) + dataset_lengths.append(len(dataset_examples)) + total_examples += len(dataset_examples) + target_probs.append(prob) + + # Normalize the target probs + target_probs = [prob / sum(target_probs) for prob in target_probs] + current_probs = [dataset_lengths[i] / total_examples for i in range(len(dataset_lengths))] + + # Increase number of examples needed without reducing the larger datasets with low target probs + new_total_examples = total_examples + for dataset_idx in range(len(datasets)): + if target_probs[dataset_idx] < current_probs[dataset_idx]: + target_total_examples = int(dataset_lengths[dataset_idx] / target_probs[dataset_idx]) + new_total_examples = max(new_total_examples, target_total_examples) + + final_total_examples = 0 + final_dataset_lengths = [] + for dataset_idx in range(len(datasets)): + num_samples_required = int(new_total_examples * target_probs[dataset_idx]) + num_repeat = max(int(round(num_samples_required // dataset_lengths[dataset_idx])), 1) # At least 1 repeat + logging.info("dataset idx {}, num_repeat {}".format(dataset_idx, num_repeat)) + dataset_examples_repeated = datasets_examples_list[dataset_idx] * num_repeat + final_dataset_lengths.append(len(dataset_examples_repeated)) + final_total_examples += len(dataset_examples_repeated) + self.examples.extend(dataset_examples_repeated) + + final_probs = [final_dataset_lengths[i] / final_total_examples for i in range(len(final_dataset_lengths))] + logging.info("Target probs: {}".format(target_probs)) + logging.info("Final probs: {}".format(final_probs)) + logging.info("Initial total examples: {}".format(total_examples)) + logging.info("Final total examples: {}".format(final_total_examples)) else: raise ValueError("Datasets must be a list of dicts or a list of filepath strings") diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index c1b4e3023e42..da1005ebbdb8 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -39,6 +39,7 @@ attention_mask_func, ) from nemo.core import adapter_mixins +from nemo.utils import logging try: from apex.transformer.enums import AttnMaskType, AttnType @@ -380,6 +381,7 @@ def forward( rotary_pos_emb=None, # rotary positional embedding relative_position_bias=None, checkpoint_core_attention=False, + return_scores=False, ): # hidden_states: [sq, b, h] @@ -387,6 +389,7 @@ def forward( # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: + logging.debug(f"Initializing KV Cache.") assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype, hidden_states.device @@ -398,7 +401,9 @@ def forward( # Some consistency check. if inference_max_sequence_len: - assert self.inference_current_sequence_len < self.inference_key_memory.size(0) + # Added equals to as inference key_memory size refers to cross-attention key size + # which is already equal to the current "sequence length" + assert self.inference_current_sequence_len <= self.inference_key_memory.size(0) assert inference_max_sequence_len == self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left @@ -412,6 +417,7 @@ def forward( # ===================== if self.attention_type == AttnType.self_attn: + logging.debug(f"Start Self-Attention!") # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) if self.is_adapter_available(): @@ -433,8 +439,20 @@ def forward( (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( mixed_x_layer, 3, contiguous_split_chunks=True ) - else: + else: # Else in cross_attention # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + if ( + inference_max_sequence_len is None + ) or self.inference_current_sequence_len < inference_max_sequence_len: + # If we are in traning and inference_max_sequence_len is None + # Or we haven't cached the key and value part of cross attention in the decoder on step 0, + # Do the caching + mixed_kv_layer, _ = self.key_value(encoder_output) + if self.is_adapter_available(): + lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER) + if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']: + lora_mixed_kv_layer = lora_kv_adapter(encoder_output) + mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer mixed_kv_layer, _ = self.key_value(encoder_output) if self.is_adapter_available(): lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER) @@ -442,19 +460,25 @@ def forward( lora_mixed_kv_layer = lora_kv_adapter(encoder_output) mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - if self.megatron_legacy: - mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + if self.megatron_legacy: + mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( - mixed_kv_layer, 2, contiguous_split_chunks=True - ) + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( + mixed_kv_layer, 2, contiguous_split_chunks=True + ) + else: + # else if we are in inference and have already cached key, value, can just read cache + key_layer = self.inference_key_memory[: self.inference_current_sequence_len, ...] + value_layer = self.inference_value_memory[: self.inference_current_sequence_len, ...] + if attention_mask is not None: + attention_mask = attention_mask[..., -1, :].unsqueeze(-2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) @@ -490,7 +514,10 @@ def forward( if rotary_pos_emb is not None: rotary_pos_emb = rotary_pos_emb if isinstance(rotary_pos_emb, tuple) else ((rotary_pos_emb,) * 2) - if inference_max_sequence_len: + # If we are in cross attention (inference_current_sequence_len == inference_max_sequence_len == inference_key_memory.size(0)) + # We only need to cache this once + if inference_max_sequence_len and self.inference_current_sequence_len < inference_max_sequence_len: + logging.debug(f"inference_current_sequence_len={self.inference_current_sequence_len} | key_layer.shape={key_layer.shape} | inference_key_memory={self.inference_key_memory.size()} | inference_value_memory={self.inference_value_memory.size()}") # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) @@ -501,7 +528,7 @@ def forward( key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask - if attention_mask is not None: + if attention_mask is not None and self.attention_type == AttnType.self_attn: attention_mask = attention_mask[..., start:end, :end] # adjust the key rotary positional embedding if rotary_pos_emb is not None: @@ -569,7 +596,10 @@ def forward( relative_position_bias=relative_position_bias, headscale_tensor=self.head_scale_tensor if self.headscale else None, inference_mode=inference_max_sequence_len is not None and query_layer.shape[0] == 1, + return_scores=return_scores, ) + if return_scores: + context_layer, attention_probs = context_layer # ================= # Output. [sq, b, h] @@ -585,6 +615,9 @@ def forward( if get_key_value: output = [output, present] + if return_scores: + output = [output, attention_probs] + return output, bias @@ -857,6 +890,7 @@ def forward( relative_position_bias=None, headscale_tensor=None, inference_mode=None, + return_scores=None, ): b, np, sq, sk, hn = ( query_layer.size(1), @@ -865,6 +899,7 @@ def forward( key_layer.size(0), query_layer.size(3), ) + logging.debug(f"query_layer.shape={query_layer.size()}\tkey_layer.shape={key_layer.size()}") # ================================================== # Update attention mask for inference. [b, np, sq, sk] @@ -914,9 +949,36 @@ def forward( # relative_position_bias [b, np, sq, sk] # context_layer [b, np, sq, hn] # ================================================== - context_layer = self.attn_fn( - query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode - ) + if not return_scores: + logging.debug( + f"not returning scores: attn_type={self.attention_type} | attn_fn={self.attn_fn} | return_scores={return_scores}" + ) + context_layer = self.attn_fn( + query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode, + ) + else: + # SpeechLLM TTS modifications + if return_scores or relative_position_bias is not None: + logging.debug( + f"torch a: return_scores: {return_scores}, relative_position_bias is not None: {relative_position_bias is not None}" + ) + context_layer = self.torch_attention_with_prior( + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, + return_scores=return_scores, + ) + context_layer, attention_probs = context_layer + else: + logging.debug( + f"attn_fn: {self.attn_fn}, return_scores: {return_scores}, relative_position_bias is not None: {relative_position_bias is not None}" + ) + context_layer = self.attn_fn( + query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode, + ) if headscale_tensor is not None: context_layer = context_layer * headscale_tensor @@ -928,7 +990,10 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) - return context_layer + if return_scores: + return context_layer, attention_probs + else: + return context_layer def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode): sq, b, np, hn = query_layer.shape @@ -966,6 +1031,9 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a attention_scores += attention_bias attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + logging.debug(f"attention_type={self.attention_type}") + logging.debug(f"attention_scores.shape={attention_scores.shape}") + logging.debug(f"attention_mask.shape={attention_mask.shape}") # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -986,6 +1054,69 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a return context_layer + def torch_attention_with_prior( + self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode, return_scores=False + ): + sq, b, np, hn = query_layer.shape + sk = key_layer.shape[0] + + if self.multi_query_attention: + query_layer = rearrange(query_layer, 'sq b np hn -> b (np sq) hn') + key_layer = rearrange(key_layer, 'sk b 1 hn -> b hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + else: + query_layer = rearrange(query_layer, 'sq b np hn -> (b np) sq hn') + key_layer = rearrange(key_layer, 'sk b np hn -> (b np) hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + + matmul_input_buffer = torch.empty( + query_layer.shape[0], + query_layer.shape[1], + key_layer.shape[2], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, + key_layer, + beta=0.0, + alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(b, np, sq, sk) + + if attention_bias is not None: + # attention_bias is not None only for cross attention layers right now in T5 + attention_scores = torch.log_softmax(attention_scores, dim=-1) + attention_bias + + _attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with tensor_parallel.random.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(_attention_probs) + else: + attention_probs = self.attention_dropout(_attention_probs) + + # change view [b * np, sq, sk] + attention_probs = rearrange(attention_probs, 'b np sq sk -> (b np) sq sk') + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = rearrange(context_layer, '(b np) sq hn -> b np sq hn', np=np) + + if return_scores: + # return context_layer, _attention_probs + return context_layer, attention_scores + else: + return context_layer + def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode): query_layer = rearrange(query_layer, 'sq b np hn -> b sq np hn') key_layer = rearrange(key_layer, 'sk b np hn -> b sk np hn') diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py index 712ce10b81b5..e70b26e5cb08 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py @@ -14,6 +14,7 @@ """Transformer based language model.""" from ast import Mod +from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_transformer_decoder import MegatronTransformerDecoderModule from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( MegatronRetrievalTransformerDecoderModule, @@ -87,7 +88,7 @@ def get_decoder_model( transformer_block_type="pre_ln", hidden_steps=-1, parent_model_type=ModelType.encoder_or_decoder, - layer_type=None, + layer_type=LayerType.decoder, chunk_size=64, layer_number_offset=0, # this is use only for attention norm_factor scaling megatron_legacy=False, @@ -158,6 +159,7 @@ def get_decoder_model( moe_dropout=moe_dropout, position_embedding_type=position_embedding_type, use_flash_attention=use_flash_attention, + layer_type=layer_type, ) elif arch == "retro": decoder = MegatronRetrievalTransformerDecoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index c4192dacb45a..cc7072be0c40 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -85,6 +85,8 @@ def __init__( encoder_attn_mask_type = AttnMaskType.padding elif hasattr(encoder.model, 'self_attn_mask_type'): encoder_attn_mask_type = encoder.model.self_attn_mask_type + elif isinstance(encoder.model, torch.nn.ModuleList) and hasattr(encoder.model[0], 'self_attn_mask_type'): + encoder_attn_mask_type = encoder.model[0].self_attn_mask_type else: raise AttributeError( "Could not find an attribute for encoder self_attn_mask_type, make sure it is set when instatiating the encoder or pass it to the constructor of this class." @@ -157,6 +159,11 @@ def decode( dec_get_key_value=False, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None ): if self.decoder is None: raise ValueError(f"Cannot call .decode(...) when self.decoder is None.") @@ -170,6 +177,11 @@ def decode( enc_attn_mask=enc_attn_mask, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers ) return dec_output @@ -191,6 +203,11 @@ def forward( dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, batch_data=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None ): # encoder if enc_output is None: @@ -207,7 +224,10 @@ def forward( assert self.encoder_hidden_state is not None enc_output = self.encoder_hidden_state else: - enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask) + if isinstance(enc_output_attn_mask, list): + enc_attn_mask = [mask.to(enc_attn_mask[midx]) for midx, mask in enumerate(enc_output_attn_mask)] + else: + enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask) if self.decoder is None or output_enc_hidden_only: return enc_output @@ -225,6 +245,11 @@ def forward( dec_get_key_value=dec_get_key_value, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers ) # if self.hiddens_module is not None enc_output is a dict, else it is a torch.tensor @@ -246,7 +271,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= def load_state_dict(self, state_dict, strict=True): """Customized load.""" - self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) - self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) - if self.hiddens_module is not None: - self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) + try: + self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) + self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) + if self.hiddens_module is not None: + self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) + except KeyError as e: + super().load_state_dict(state_dict, strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py index 601eb320e8fc..e0e14e024629 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py @@ -14,7 +14,7 @@ """Transformer based language model.""" from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule -from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import MegatronTransformerEncoderModule +from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import MegatronTransformerEncoderModule, MultiMegatronTransformerEncoderModule from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( MegatronRetrievalTransformerEncoderModule, ) @@ -108,6 +108,7 @@ def get_encoder_model( version=1, # model version position_embedding_type='learned_absolute', use_flash_attention=False, + n_transformers=1, ): """Build language model and return along with the key to save.""" @@ -167,6 +168,51 @@ def get_encoder_model( position_embedding_type=position_embedding_type, use_flash_attention=use_flash_attention, ) + elif arch == "multi_transformer": + encoder = MultiMegatronTransformerEncoderModule( + config=config, + n_transformers=n_transformers, + init_method=init_method, + output_layer_init_method=scaled_init_method, + hidden_size=hidden_size, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + encoder_attn_mask_type=encoder_attn_mask_type, + pre_process=pre_process, + post_process=post_process, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + ffn_dropout=ffn_dropout, + precision=precision, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + activations_checkpoint_granularity=activations_checkpoint_granularity, + layernorm_epsilon=layernorm_epsilon, + bias_activation_fusion=bias_activation_fusion, + bias_dropout_add_fusion=bias_dropout_add_fusion, + masked_softmax_fusion=masked_softmax_fusion, + persist_layer_norm=persist_layer_norm, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + normalization=normalization, + transformer_block_type=transformer_block_type, + headscale=headscale, + parent_model_type=parent_model_type, + megatron_legacy=megatron_legacy, + normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, + ) + elif arch == "retro": encoder = MegatronRetrievalTransformerEncoderModule( config=config, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py index 4a05a08820e7..3fdff2a7068c 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py @@ -97,6 +97,7 @@ def __init__( moe_dropout=0.0, position_embedding_type='learned_absolute', use_flash_attention=False, + layer_type=LayerType.decoder, ): super(MegatronTransformerDecoderModule, self).__init__(config=config) @@ -121,7 +122,7 @@ def __init__( # Transformer. self.model = ParallelTransformer( config=config, - layer_type=LayerType.decoder, + layer_type=layer_type, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, num_layers=self.num_layers, @@ -178,14 +179,30 @@ def forward( get_key_value=False, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): # convert to Megatron mask dec_attn_mask_3d = build_attention_mask_3d( source_mask=dec_attn_mask, target_mask=dec_attn_mask, attn_mask_type=self.model_attn_mask_type, ) - enc_dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=enc_attn_mask, attn_mask_type=AttnMaskType.padding, - ) + + if isinstance(enc_output, list): + assert len(enc_output) == len(enc_attn_mask) + enc_dec_attn_mask_3d = [] + for i in range(len(enc_output)): + enc_dec_attn_mask_3d.append( + attn_mask_postprocess(build_attention_mask_3d( + source_mask=dec_attn_mask, target_mask=enc_attn_mask[i], attn_mask_type=AttnMaskType.padding, + )) + ) + else: + enc_dec_attn_mask_3d = attn_mask_postprocess(build_attention_mask_3d( + source_mask=dec_attn_mask, target_mask=enc_attn_mask, attn_mask_type=AttnMaskType.padding, + )) # transformer decoder dec_output = self.model( @@ -194,9 +211,14 @@ def forward( layer_past=layer_past, get_key_value=get_key_value, encoder_output=enc_output, - enc_dec_attn_mask=attn_mask_postprocess(enc_dec_attn_mask_3d), + enc_dec_attn_mask=enc_dec_attn_mask_3d, self_attention_relative_position_bias=dec_self_attention_relative_position_bias, cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers, ) return dec_output diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index 7a41e1300066..67c4d071c279 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -23,6 +23,7 @@ build_attention_mask_3d, ) from nemo.core.classes.exportable import Exportable +import torch try: from apex.transformer.enums import AttnMaskType, ModelType @@ -173,6 +174,7 @@ def forward( layer_past=None, get_key_value=False, enc_self_attention_relative_position_bias=None, + set_inference_key_value_memory=False, ): # convert to Megatron mask if self.use_flash_attention: @@ -192,6 +194,7 @@ def forward( get_key_value=get_key_value, self_attention_relative_position_bias=enc_self_attention_relative_position_bias, cross_attention_relative_position_bias=None, + set_inference_key_value_memory=set_inference_key_value_memory, ) return enc_output @@ -231,3 +234,220 @@ def load_state_dict(self, state_dict, strict=True): state_dict_ = state_dict_self_attention self.model.load_state_dict(state_dict_, strict=strict) + + +class MultiMegatronTransformerEncoderModule(MegatronModule, Exportable, MegatronEncoderModule): + """Transformer encoder model.""" + + def __init__( + self, + config: ModelParallelConfig, + n_transformers, + init_method, + output_layer_init_method, + hidden_size, + ffn_hidden_size, + num_layers, + num_attention_heads, + apply_query_key_layer_scaling=True, + kv_channels=None, + pre_process=True, + post_process=True, + encoder_attn_mask_type=AttnMaskType.padding, + hidden_dropout=0.1, + attention_dropout=0.1, + ffn_dropout=0.0, + precision=16, + fp32_residual_connection=False, + activations_checkpoint_method=None, + activations_checkpoint_num_layers=1, + activations_checkpoint_granularity=None, + layernorm_epsilon=1e-5, + bias_activation_fusion=True, + bias_dropout_add_fusion=True, + masked_softmax_fusion=True, + persist_layer_norm=False, + openai_gelu=False, + onnx_safe=False, + activation='gelu', + bias=True, + normalization='layernorm', + transformer_block_type='pre_ln', + headscale=False, + parent_model_type=ModelType.encoder_or_decoder, + megatron_legacy=False, + normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, + position_embedding_type='learned_absolute', + use_flash_attention=False, + ): + super(MultiMegatronTransformerEncoderModule, self).__init__(config=config) + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = hidden_size + self.num_layers = num_layers + self.init_method = init_method + self.model_attn_mask_type = encoder_attn_mask_type + self.hidden_dropout = hidden_dropout + self.output_layer_init_method = output_layer_init_method + self.parent_model_type = parent_model_type + self.normalization = normalization + self.transformer_block_type = transformer_block_type + self.use_flash_attention = use_flash_attention + + if kv_channels is None: + + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + # Transformer List + self.model = [] + for i in range(n_transformers): + transformer = ParallelTransformer( + config=config, + layer_type=LayerType.encoder, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + num_layers=self.num_layers, + hidden_size=self.hidden_size, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + self_attn_mask_type=self.model_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + precision=precision, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + activations_checkpoint_granularity=activations_checkpoint_granularity, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + ffn_dropout=ffn_dropout, + bias_activation_fusion=bias_activation_fusion, + bias_dropout_add_fusion=bias_dropout_add_fusion, + masked_softmax_fusion=masked_softmax_fusion, + persist_layer_norm=persist_layer_norm, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + normalization=normalization, + transformer_block_type=transformer_block_type, + headscale=headscale, + model_type=parent_model_type, + megatron_legacy=megatron_legacy, + normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, + ) + self.model.append(transformer) + + self.model = torch.nn.ModuleList(self.model) + + self._model_key = 'model' + + def set_input_tensor(self, input_tensor): + """ See megatron.model.transformer.set_input_tensor()""" + for mi in range(len(self.model)): + self.model[mi].set_input_tensor(input_tensor) + + # def set_input_tensor(self, input_tensor): + # """ See megatron.model.transformer.set_input_tensor()""" + # import ipdb; ipdb.set_trace() + # assert isinstance(input_tensor, list) + # assert len(input_tensor) == len(self.model) + # for _input_tensor in input_tensor: + # self.model.set_input_tensor(_input_tensor) + + def forward( + self, + enc_input, + enc_attn_mask, + layer_past=None, + get_key_value=False, + enc_self_attention_relative_position_bias=None, + set_inference_key_value_memory=False, + ): + + assert isinstance(enc_input, list) + assert len(enc_input) == len(self.model) + assert isinstance(enc_attn_mask, list) + assert len(enc_attn_mask) == len(self.model) + assert isinstance(enc_self_attention_relative_position_bias, list) + # convert to Megatron mask + enc_outputs = [] + for encoder_number in range(len(self.model)): + enc_input_ = enc_input[encoder_number] + enc_attn_mask_ = enc_attn_mask[encoder_number] + enc_self_attention_relative_position_bias_ = enc_self_attention_relative_position_bias[encoder_number] + + if self.use_flash_attention: + enc_attn_mask_3d = enc_attn_mask_ < 0.5 + else: + enc_attn_mask_3d = attn_mask_postprocess( + build_attention_mask_3d( + source_mask=enc_attn_mask_, target_mask=enc_attn_mask_, attn_mask_type=self.model_attn_mask_type, + ) + ) + + # transformer encoder + enc_output = self.model[encoder_number]( + enc_input_, + enc_attn_mask_3d, + layer_past=layer_past, + get_key_value=get_key_value, + self_attention_relative_position_bias=enc_self_attention_relative_position_bias_, + cross_attention_relative_position_bias=None, + set_inference_key_value_memory=set_inference_key_value_memory, + ) + + enc_outputs.append(enc_output) + + return enc_outputs + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + + state_dict_[self._model_key] = self.model.state_dict_for_save_checkpoint(destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Encoder. + if self._model_key in state_dict: + state_dict_ = state_dict[self._model_key] + # for backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # for backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.model.load_state_dict(state_dict_, strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index ccd485427c3c..c311f63e15de 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -113,7 +113,7 @@ def decoder_cross_attention_relative_position_embeddings_weight(self): def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): if not self.share_token_embeddings: - raise Exception('initialize_word_embeddings() was called but ' 'share_token_embeddings is false') + raise Exception('initialize_word_embeddings() was called but share_token_embeddings is false') # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index b7b377940eb4..b4875a2ffa41 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -42,6 +42,7 @@ ) from nemo.collections.nlp.modules.common.megatron.vocab_parallel_cross_entropy import vocab_parallel_cross_entropy from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import logging try: from apex.transformer.enums import AttnMaskType, ModelType @@ -252,6 +253,7 @@ def __init__( moe_dropout=encoder_cfg.get('moe_dropout', 0.0), position_embedding_type=encoder_cfg.get('position_embedding_type', 'learned_absolute'), use_flash_attention=encoder_cfg.get('use_flash_attention', False), + n_transformers=encoder_cfg.get('n_transformers', 1), ) if add_decoder: @@ -388,6 +390,7 @@ def __init__( moe_dropout=decoder_cfg.get('moe_dropout', 0.0), position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute'), use_flash_attention=decoder_cfg.get('use_flash_attention', False), + layer_type=decoder_cfg.get('layer_type', LayerType.decoder), ) hiddens_module = get_hiddens_module(hiddens_cfg, model_parallel_cfg=config) @@ -410,6 +413,7 @@ def __init__( if add_decoder and post_process: if share_decoder_tokens_head_embeddings: + # parallel_output is True if TP > 1 (3b model) self.tokens_head = MegatronTokenLevelHead( self.word_embeddings_weight().size(0), parallel_output, bias=tokens_head_bias ) @@ -708,8 +712,433 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= def load_state_dict(self, state_dict, strict=True): """Customized load.""" - - self.encoder_embedding.encoder_embeddingload_state_dict(state_dict[self._encoder_embedding_key], strict=strict) + self.encoder_embedding.load_state_dict(state_dict[self._encoder_embedding_key], strict=strict) self.decoder_embedding.load_state_dict(state_dict[self._decoder_embedding_key], strict=strict) self.enc_dec_model.load_state_dict(state_dict[self._enc_dec_model_key], strict=strict) self.tokens_head.load_state_dict(state_dict[self._tokens_head_key], strict=strict) + + +class MegatronTokenLevelEncoderDecoderSpeechLLMModule(MegatronTokenLevelEncoderDecoderModule): + def __init__(self, *args, **kwargs): + super(MegatronTokenLevelEncoderDecoderSpeechLLMModule, self).__init__(*args, **kwargs) + # Overridden in MegatronT5SpeechLMModel constructor + self.seq_pattern = "parallel" + self.speech_head_type = "token_level" + self.attn_prior_scaledown_start_step = 10000 + self.attn_prior_end_step = 11000 + self.use_alignment_loss = False + self.return_all_crossattention_probs = False + self.logging_step = False + self.num_cross_attention_heads = 12 # 12 for 220m T5, 16 for 11b T5 + self.enc_output_to_layers = None + + def get_decoder_embeddings(self, dec_input_ids, dec_position_ids, token_type_ids): + if dec_input_ids.dim() <= 2: + dec_input = self.decoder_embedding(dec_input_ids, dec_position_ids, token_type_ids=token_type_ids) + else: + dec_input = None + for i in range(dec_input_ids.size()[1]): + if i == 0: + # For the first channel (text + first layer of speech), use the decoder embedding layer + dec_input = self.decoder_embedding( + dec_input_ids[:, i, :], dec_position_ids, token_type_ids=token_type_ids + ) + else: + # For the rest of the channels (speech), use the speech embedding layer. No need for position, since already added in first layer. + current = self.speech_tokens_embeddings[i - 1](dec_input_ids[:, i, :]).permute(1, 0, 2) + # @pneekhara - Commenting the below because we always want to include all channels for speech. + # @pneekhara - include_channel_flag can become 0 when doing autoregressive inference and the first timestep is zeros + # For text inputs, only include 1st channel embeddings. Zero-out others. + # include_channel_flag = (torch.sum(dec_input_ids[:, i, :], dim=1) > 0).float() # [B] + # current = current * include_channel_flag.unsqueeze(0).unsqueeze(2) + dec_input = dec_input + current + + return dec_input + + def forward( + self, + enc_input_ids=None, + enc_attn_mask=None, + dec_input_ids=None, + dec_attn_mask=None, + token_type_ids=None, + labels=None, + batch_data=None, # additional data to be passed to hiddens module + enc_output=None, # Result of running the entire encoder + enc_output_attn_mask=None, + enc_input=None, # Result of running encoder embedding only + output_enc_hidden_only=False, + speech_mask=None, + cross_attention_prior=None, + text_limits=None, + global_step=None, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + ): + """ + Return value is per token / per dimension (i.e., non collapsed loss value) + """ + ( + encoder_self_attention_relative_position_bias, + decoder_self_attention_relative_position_bias, + decoder_cross_attention_relative_position_bias, + ) = (None, None, None) + + if enc_input is not None and enc_output is not None: + raise ValueError( + """Both enc_input and enc_output are not None. + You should only be passing one of them. + enc_input is the result of the encoder embedding layer + enc_output is the result of running the entire transformer encoder.""" + ) + + # In order of precedence, we use enc_output, enc_input, and then enc_input_ids to determine the encoder sequence length. + if enc_output is not None: + # If enc_output is provided in `batch_for_pipeline`, we need to transpose it from [B x S x H] -> [S x B x H]. + if isinstance(enc_output, list): + encoder_self_attention_relative_position_bias = [None for _ in enc_output] + enc_output = [x.transpose(0, 1) for x in enc_output] + enc_seq_length = [x.size(0) for x in enc_output] + else: + enc_output = enc_output.transpose(0, 1) + enc_seq_length = enc_output.size(0) + elif enc_input is not None: + # If enc_input is provided, we need to transpose it from [B x S x H] -> [S x B x H]. + if isinstance(enc_input, list): + encoder_self_attention_relative_position_bias = [None for _ in enc_input] + enc_input = [x.transpose(0, 1) for x in enc_input] + enc_seq_length = [x.size(0) for x in enc_input] + else: + enc_input = enc_input.transpose(0, 1) + enc_seq_length = enc_input.size(0) + # Only need to run encoder embedding and position ids if enc_input or enc_output is not provided. + elif enc_input_ids is not None: + assert False, "This should not be reached for speech models" + enc_seq_length = enc_input_ids.size(1) + if self.pre_process and self.add_encoder: + # We don't need position ids for RPE, because the embedding layer does not have position embeddings. + if self.encoder_relative_position_embedding is None: + enc_input_ids_p = enc_input_ids[:, 0, :] if enc_input_ids.dim() == 3 else enc_input_ids + enc_position_ids = build_position_ids(enc_input_ids_p) + else: + enc_position_ids = None + enc_input = self.encoder_embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids) + if self.is_adapter_available(): + _sq, _bs, _hs = enc_input.size() + ptuning_adapter = self.get_adapter_module(AdapterName.PTUNING_ADAPTER) + v = ptuning_adapter.virtual_tokens + if ( + ptuning_adapter and _sq >= v + ): # The sequence should be longer the v to insert virtual embeddings. + virtual_embeddings = ptuning_adapter(_bs) + enc_input = enc_input[ + v:, :, : + ] # the first v tokens are pads so that they can be swapped out with virtual embeddings. + enc_input = torch.concat([virtual_embeddings, enc_input], dim=0) + else: + enc_input = None + else: + assert False, "This should not be reached for speech models" + # This should only happen with PP > 1 for enc-dec prompt learning models + enc_seq_length = enc_attn_mask.size(1) + + if self.add_encoder and self.encoder_relative_position_embedding is not None: + assert False, "Not implemented for speech models yet." + encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( + query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, + ) + + if output_enc_hidden_only: + assert False, "Not implemented for speech models yet." + # When pipeline parallel > 1 we need to make sure encoder exist (will be missing in decoder) + # Speecht5 should not go here for inference + if enc_output is None and self.enc_dec_model.encoder is not None: + enc_output = self.enc_dec_model.encode( + enc_input=enc_input, + enc_attn_mask=enc_attn_mask, + enc_layer_past=None, + enc_get_key_value=False, + enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + batch_data=batch_data, + ) + else: + enc_output = self.enc_dec_model.encoder_hidden_state + + return enc_output + else: + if enc_output_attn_mask is None: + enc_output_attn_mask = enc_attn_mask + + if self.pre_process and self.add_decoder: + # We don't need position ids for RPE, because the embedding layer does not have position embeddings. + if self.decoder_relative_position_embedding is None: + dec_input_ids_p = dec_input_ids[:, 0, :] if dec_input_ids.dim() == 3 else dec_input_ids + dec_position_ids = build_position_ids(dec_input_ids_p) + else: + dec_position_ids = None + dec_input = self.get_decoder_embeddings(dec_input_ids, dec_position_ids, token_type_ids) + if not set_inference_key_value_memory and (decoder_max_sequence_len or encoder_max_sequence_len): + # In inference + # On step 0 when set_inference_key_value_memory is True, we need all inputs in case + # we are using decoder context + # Else on step >= 1, only need last input + logging.debug("Clipping dec_input and only keep the last input.") + dec_input = dec_input[-1, :, :].unsqueeze(0) # shape (b, embed_dim) + else: + # Note: This is when the decoder itself is split across PP ranks. + dec_input = None + + if self.add_decoder and self.decoder_relative_position_embedding is not None: + assert False, "This should not be reached." + decoder_self_attention_relative_position_bias = self.decoder_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) + ) + if not self.decoder_cfg.relative_position_bias_self_attention_only: + decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, + ) + else: + decoder_cross_attention_relative_position_bias = None + + return_all_crossattention_probs = self.return_all_crossattention_probs + single_encoder = False + if not isinstance(cross_attention_prior, list): + single_encoder = True + cross_attention_prior = [cross_attention_prior] + + + decoder_cross_attention_relative_position_bias = [] + for _cross_attention_prior in cross_attention_prior: + _decoder_cross_attention_relative_position_bias = None + if _cross_attention_prior is not None: + # cross_attention_prior shape [B, dec_len, enc_len] + # Repeat it to make it [B, 12, dec_len, enc_len] + attn_prior_end_step = self.attn_prior_end_step + attn_prior_scaledown_start_step = self.attn_prior_scaledown_start_step + num_attention_heads = self.num_cross_attention_heads + assert attn_prior_scaledown_start_step <= attn_prior_end_step + logging.debug( + f"attn_prior_scaledown_start_step: {attn_prior_scaledown_start_step}, attn_prior_scaledown_start_step: {attn_prior_end_step}" + ) + if global_step >= attn_prior_end_step: + _decoder_cross_attention_relative_position_bias = None + elif global_step > attn_prior_scaledown_start_step and global_step < attn_prior_end_step: + total_annealing_steps = attn_prior_end_step - attn_prior_scaledown_start_step + curr_annealing_step = global_step - attn_prior_scaledown_start_step + curr_cross_attention_prior = _cross_attention_prior + ( + (1.0 - _cross_attention_prior) * curr_annealing_step / total_annealing_steps + ) + _decoder_cross_attention_relative_position_bias = curr_cross_attention_prior.unsqueeze(1).repeat( + 1, num_attention_heads, 1, 1 + ) + _decoder_cross_attention_relative_position_bias = torch.log(_decoder_cross_attention_relative_position_bias + 1e-8) + else: + _decoder_cross_attention_relative_position_bias = _cross_attention_prior.unsqueeze(1).repeat( + 1, num_attention_heads, 1, 1 + ) + _decoder_cross_attention_relative_position_bias = torch.log(_decoder_cross_attention_relative_position_bias + 1e-8) + decoder_cross_attention_relative_position_bias.append(_decoder_cross_attention_relative_position_bias) + + return_all_crossattention_probs = return_all_crossattention_probs or self.logging_step + + if single_encoder: + decoder_cross_attention_relative_position_bias = decoder_cross_attention_relative_position_bias[0] + + output = self.enc_dec_model( + enc_input=enc_input, + enc_attn_mask=enc_attn_mask, + dec_input=dec_input, + dec_attn_mask=dec_attn_mask, + enc_layer_past=None, + enc_get_key_value=False, + enc_output=enc_output, + enc_output_attn_mask=enc_output_attn_mask, + dec_layer_past=None, + dec_get_key_value=False, + enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + dec_self_attention_relative_position_bias=decoder_self_attention_relative_position_bias, + dec_cross_attention_relative_position_bias=decoder_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + batch_data=batch_data, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=self.enc_output_to_layers + ) + + alignment_loss = None + if self.post_process and self.add_decoder: + dec_output, enc_output = output # [s, b, h] + if return_all_crossattention_probs: + dec_output, attention_scores = dec_output + attention_probs = [torch.softmax(attention_score, dim=-1) for lidx, attention_score in enumerate(attention_scores) if lidx in self.alignment_decoder_layerids] + + if text_limits is not None and self.use_alignment_loss and hasattr(self, "forward_sum_loss"): + attention_scores_filtered = [ + attention_scores[lidx] for lidx in self.alignment_decoder_layerids + ] + attention_scores_combined = torch.cat(attention_scores_filtered, dim=1) + text_start_idx = text_limits[0, 0].item() + assert torch.all( + text_limits[:, 0] == text_start_idx + ) # all texts should start at the same index + end_offset = self.alignment_text_end_offset + # align_every_n_head: eg if set to 2, will skip every other head + # if set to 12, will select 1 head from every layer + align_every_n_head = self.align_every_n_head + dec_start_idx = self.decoder_context_len + 1 # +1 to remove bos + attention_scores_sliced = attention_scores_combined[ + :,::align_every_n_head,dec_start_idx:,text_start_idx:-(2 + end_offset) + ] # -2 to remove eos and pad + attention_logprobs = ( + attention_scores_sliced # not taking log_softmax, since we will do that in loss function + ) + attention_logprobs = torch.mean(attention_logprobs, dim=1, keepdim=True) + dec_len = torch.sum(dec_attn_mask, dim=1) - dec_start_idx + enc_len = text_limits[:, 1] - text_limits[:, 0] - end_offset + alignment_loss = self.forward_sum_loss( + attn_logprob=attention_logprobs, in_lens=enc_len, out_lens=dec_len + ) + else: + attention_probs = None + # project decoder output to vocabulary-size dimensions + if self.share_decoder_tokens_head_embeddings: + first_layer_vocabsize = ( + self.speech_offset + self.speech_codebook_size + ) # variables set in __init__ of speechlm model + token_logits = self.tokens_head(dec_output, self.word_embeddings_weight()) # s, b, vocab + if self.seq_pattern in ["parallel", "delay_parallel"]: + # For flat seq_pattern we need all the logits + token_logits = token_logits[:, :, :first_layer_vocabsize] + speech_layers = self.num_speech_codebooks - 1 + last_layer_output = dec_output + last_layer_logits = token_logits + + # speech_logits_list will be used in loss calculation (parallel output) + speech_logits_list = [] + if self.seq_pattern in ["parallel", "delay_parallel"] and torch.count_nonzero(speech_mask) > 0: + for i in range(speech_layers): + last_layer_logits = self.speech_tokens_heads[i](dec_output)[0] # T, B, 1024 + speech_logits_list.append(last_layer_logits) # T, B, 1024 + else: + token_logits = self.tokens_head(dec_output)[0] # T, B, WordEmbSize + + if labels is not None: + if labels.dim() == 2: + # [b, s] -> [s, b] + labels = labels.transpose(0, 1).contiguous() + elif labels.dim() == 3: + # [b, c, s] -> [c, s, b] + labels = labels.permute(1, 2, 0).contiguous() + + # Set label smoothing to 0 if in eval mode. + label_smoothing = self.label_smoothing if self.training else 0.0 + + # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i + if self.fp16_cross_entropy: + assert token_logits.dtype == torch.half + if labels.dim() == 3: + raise NotImplementedError("fp16_cross_entropy is not support for labels of dimension 3") + tokens_loss = vocab_parallel_cross_entropy(token_logits, labels, label_smoothing) + else: + if labels.dim() == 2: + tokens_loss = vocab_parallel_cross_entropy(token_logits.float(), labels, label_smoothing) + elif labels.dim() == 3: + if token_logits.size()[0] != labels[0, :, :].size()[0]: + raise Exception("TODO: add a permute") + tokens_loss = vocab_parallel_cross_entropy( + token_logits.float(), labels[0, :, :], label_smoothing + ) + logging.debug(f"token_loss: {tokens_loss}") + logging.debug(f"token_loss: {torch.all(torch.isfinite(tokens_loss))}") + if ( + self.seq_pattern in ["parallel", "delay_parallel"] + and torch.count_nonzero(speech_mask) > 0 + ): + for i in range(speech_layers): + if speech_logits_list[i].size()[0] != labels[i + 1, :, :].size()[0]: + raise Exception("TODO: add a permute") + curr_codebook_loss = ( + vocab_parallel_cross_entropy( + speech_logits_list[i].float(), labels[i + 1, :, :], label_smoothing + ) + * speech_mask.T + ) + tokens_loss += curr_codebook_loss + logging.debug(f"token_loss_{i}: {tokens_loss}") + logging.debug(f"token_loss_{i}: {torch.all(torch.isfinite(tokens_loss))}") + + # [s, b] -> [b, s] + tokens_loss = tokens_loss.transpose(0, 1).contiguous() + + # check if hiddens is used + if self.hiddens_cfg is not None: + raise NotImplementedError("Not currently implemented for speechllm") + else: + return tokens_loss, [token_logits, speech_logits_list, attention_probs, alignment_loss] + else: + # else return token logits (and hiddens if needed) + # [s, b, h] -> [b, s, h] + # If labels is None then we are in inference mode and we return the gathered logits + if self.parallel_output: + # Gather logits from tensor parallel if in parallel_output mode + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region( + token_logits + ) # T, B, 30208 + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) # T, B, 1024 + + token_logits = token_logits.transpose(0, 1).contiguous() # (B, T, 30208) + speech_logits = torch.stack(speech_logits_list, dim=-1) # T, B, 1024, 7 + speech_logits = speech_logits.transpose(0, 1).contiguous() # (B, T, 1024, 7) + + _si = self.speech_offset + _ei = _si + self.speech_codebook_size + first_layer_speech_logits = token_logits[:, :, _si:_ei].unsqueeze(-1) # (b, s, 1023, 1) + + all_speech_logits = torch.cat( + [first_layer_speech_logits, speech_logits], dim=-1 + ) # (b, s, 1024, 8) + + if self.hiddens_cfg is not None: + raise NotImplementedError("Not currently implemented for speechllm") + else: + # all_speech_logits: tensor, (b, s, 1024, 8), all layers of speech. + # token_logits: tensor, (b, s, vocab_size), text token logits. + # speech_logits: tensor, (b, s, 1024, 7), 1-7 layers of speech. + # attention_probs: tensor or None, (b, s, ) + # enc_output: tensor, (virtual_token_len+context_token_len+question_token_len+extra_id_0+[SEP], b, ) + return all_speech_logits, [token_logits, speech_logits, attention_probs, enc_output] + + elif self.add_decoder and not self.add_encoder: + decoder_output, _ = output + return decoder_output + else: + encoder_output = output + return encoder_output + + def state_dict(self): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._encoder_embedding_key] = self.encoder_embedding.state_dict() + state_dict_[self._decoder_embedding_key] = self.decoder_embedding.state_dict() + state_dict_[self._enc_dec_model_key] = self.enc_dec_model.state_dict() + state_dict_[self._tokens_head_key] = self.tokens_head.state_dict() + if hasattr(self, "speech_tokens_heads"): + state_dict_["speech_tokens_heads"] = self.speech_tokens_heads.state_dict() + if hasattr(self, "speech_tokens_embeddings"): + state_dict_["speech_tokens_embeddings"] = self.speech_tokens_embeddings.state_dict() + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + super().load_state_dict(state_dict, strict=strict) + if hasattr(self, "speech_tokens_heads"): + self.speech_tokens_heads.load_state_dict(state_dict["speech_tokens_heads"], strict=strict) + if hasattr(self, "speech_tokens_embeddings"): + self.speech_tokens_embeddings.load_state_dict(state_dict["speech_tokens_embeddings"], strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index ab10b0d0e8b3..3203b03deca9 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn from einops import rearrange +from omegaconf.listconfig import ListConfig from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( @@ -479,6 +480,10 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + return_crossattention_scores=False, + return_selfattention_scores=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): # Self attention. if rotary_pos_emb is not None: @@ -507,12 +512,16 @@ def forward( layer_past=layer_past, get_key_value=get_key_value, set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, + inference_max_sequence_len=inference_max_sequence_len or decoder_max_sequence_len, rotary_pos_emb=self_attention_pos_emb, relative_position_bias=self_attention_relative_position_bias, checkpoint_core_attention=checkpoint_core_attention, + return_scores=return_selfattention_scores, ) + if return_selfattention_scores: + attention_output, attention_probs = attention_output + if get_key_value: attention_output, presents = attention_output @@ -574,12 +583,10 @@ def forward( enc_dec_attn_mask, encoder_output=encoder_output, rotary_pos_emb=cross_attention_pos_emb, - set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, checkpoint_core_attention=checkpoint_core_attention, ) else: - + # Return Scores is being passed only for inter_attention and not self attention attention_output, attention_bias = self.inter_attention( normalization_output, enc_dec_attn_mask, @@ -587,7 +594,12 @@ def forward( rotary_pos_emb=cross_attention_pos_emb, relative_position_bias=cross_attention_relative_position_bias, checkpoint_core_attention=checkpoint_core_attention, + return_scores=return_crossattention_scores, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=encoder_max_sequence_len, ) + if return_crossattention_scores: + attention_output, attention_probs = attention_output # If normformer, apply norm on the output of the self attention. if self.transformer_block_type == 'normformer': @@ -632,6 +644,9 @@ def forward( if get_key_value: output = [output, presents] + if return_crossattention_scores or return_selfattention_scores: + output = [output, attention_probs] + return output @@ -735,6 +750,10 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + return_crossattention_scores=False, + return_selfattention_scores=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): if self.dtype == torch.float32: return super().forward( @@ -750,6 +769,10 @@ def forward( self_attention_relative_position_bias, cross_attention_relative_position_bias, checkpoint_core_attention, + return_crossattention_scores=return_crossattention_scores, + return_selfattention_scores=return_selfattention_scores, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, ) with torch.autocast(device_type="cuda", dtype=self.dtype): return super().forward( @@ -765,6 +788,10 @@ def forward( self_attention_relative_position_bias, cross_attention_relative_position_bias, checkpoint_core_attention, + return_crossattention_scores=return_crossattention_scores, + return_selfattention_scores=return_selfattention_scores, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, ) @@ -1072,10 +1099,12 @@ def __init__( # Transformer layers. def build_layer(layer_number): - if isinstance(layer_type, list): + if isinstance(layer_type, (list, ListConfig)): lt = layer_type[layer_number - 1] else: lt = layer_type + if isinstance(lt, int): + lt = LayerType(lt) if self.transformer_engine: transformer_layer_args = { @@ -1493,7 +1522,16 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_activations_all_layers=None, + return_all_crossattention_probs=False, + return_all_selfattention_probs=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): + if return_all_crossattention_probs and return_all_selfattention_probs: + raise NotImplementedError( + "We can only return 1 of cross attention probs or self attention probs. Not both yet." + ) # Checks. if inference_max_sequence_len: assert self.activations_checkpoint_method is None, 'inference does not work with activation checkpointing' @@ -1580,6 +1618,7 @@ def forward( if self.inference_params != None: self.inference_params.sequence_len_offset = self.inference_current_sequence_len + attention_probs_list = [] if self.return_select_layer < 0: assert ( parallel_state.get_pipeline_model_parallel_world_size() == 1 @@ -1588,10 +1627,29 @@ def forward( logging.warning("Returning embeddings states only!") return hidden_states + layer_to_encoder_num_mapping = {} + if enc_output_to_layers is not None: + assert len(enc_output_to_layers) == len(encoder_output) + for encoder_idx in range(len(encoder_output)): + for layer_idx in enc_output_to_layers[encoder_idx]: + layer_to_encoder_num_mapping[layer_idx] = encoder_idx + for index in range(self.num_layers): layer = self._get_layer(index) past = None + _encoder_output = encoder_output + _enc_dec_attn_mask = enc_dec_attn_mask + _cross_attention_relative_position_bias = cross_attention_relative_position_bias + _encoder_max_sequence_len = encoder_max_sequence_len + if index in layer_to_encoder_num_mapping: + _encoder_output = encoder_output[layer_to_encoder_num_mapping[index]] + _enc_dec_attn_mask = enc_dec_attn_mask[layer_to_encoder_num_mapping[index]] + _cross_attention_relative_position_bias = cross_attention_relative_position_bias[layer_to_encoder_num_mapping[index]] + if encoder_max_sequence_len is not None: + _encoder_max_sequence_len = encoder_max_sequence_len[layer_to_encoder_num_mapping[index]] + + if layer_past is not None: past = layer_past[index] @@ -1625,27 +1683,65 @@ def forward( hidden_states = layer( hidden_states, attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, inference_params=self.inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, ) else: - hidden_states = layer( - hidden_states, - attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - layer_past=past, - get_key_value=get_key_value, - set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, - rotary_pos_emb=rotary_pos_emb, - self_attention_relative_position_bias=self_attention_relative_position_bias, - cross_attention_relative_position_bias=cross_attention_relative_position_bias, - checkpoint_core_attention=checkpoint_core_attention, - ) + if layer.layer_type == LayerType.decoder and return_all_crossattention_probs: + hidden_states, attention_probs = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + return_crossattention_scores=return_all_crossattention_probs, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=_encoder_max_sequence_len, + ) + attention_probs_list.append(attention_probs) + elif layer.layer_type == LayerType.encoder and return_all_selfattention_probs: + hidden_states, attention_probs = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + get_key_value=get_key_value, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + return_selfattention_scores=return_all_selfattention_probs, + ) + attention_probs_list.append(attention_probs) + else: + hidden_states = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + get_key_value=get_key_value, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=_encoder_max_sequence_len, + ) if self.return_select_layer < 0: assert ( @@ -1679,4 +1775,7 @@ def forward( if get_key_value: output = [output, presents] + if return_all_crossattention_probs or return_all_selfattention_probs: + output = [output, attention_probs_list] + return output diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 601cb7a4d7e8..1540170a8dc1 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -474,9 +474,25 @@ def get_iterator_k_split( else: # Split a list of torch tensors assert batch[0].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" - split_batch = [ - torch.tensor_split(item, num_microbatches, dim=0) if torch.is_tensor(item) else item for item in batch - ] + split_batch = [] + for item in batch: + if torch.is_tensor(item): + split_batch.append(torch.tensor_split(item, num_microbatches, dim=0)) + elif isinstance(item, list): + if isinstance(item[0], torch.Tensor): + split_tensors = [torch.tensor_split(elem, num_microbatches, dim=0) for elem in item] + split_tuple = [] + for mbi in range(num_microbatches): + split_tuple.append([split_tensors[i][mbi] for i in range(len(split_tensors))]) + split_tuple = tuple(split_tuple) + split_batch.append(split_tuple) + else: + split_batch.append(split_list(item, num_microbatches)) + elif item is None: + split_batch.append(item) + else: + raise ValueError(f"Unsupported item type: {type(item)}") + microbatches = [ [elem[i] if elem is not None else elem for elem in split_batch] for i in range(num_microbatches) ] diff --git a/nemo/collections/tts/data/speechllm/__init__.py b/nemo/collections/tts/data/speechllm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py new file mode 100644 index 000000000000..6ebf64d7c17c --- /dev/null +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -0,0 +1,1588 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import json +import random +from dataclasses import dataclass +from pathlib import Path +from typing import ClassVar, List, Optional, Union + +import numpy as np +import torch +from hydra.utils import instantiate +from omegaconf import OmegaConf +from tqdm.auto import tqdm + +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import get_ipa_punctuation_list +from nemo.collections.common.tokenizers.text_to_speech.tokenizer_utils import any_locale_text_preprocessing +from nemo.collections.nlp.data.language_modeling.megatron.base_prompt_learning_dataset import BasePromptLearningDataset +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import T5Sentinel +from nemo.collections.nlp.modules.common import VirtualPromptSource +from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + BetaBinomialInterpolator, + beta_binomial_prior_distribution, + general_padding, + get_base_dir, +) +from nemo.utils import logging + +__all__ = ['T5SpeechLMDataset'] + + +def get_full_list_puncts(): + punct_set = set() + for locale_id in ["en-US", "de-DE", "fr-FR"]: + punct_list = get_ipa_punctuation_list(locale=locale_id) + punct_set.update(punct_list) + return sorted(punct_set) + + +@dataclass +class G2PConfig: + _target_: str = "nemo.collections.tts.g2p.models.en_us_arpabet.EnglishG2p" + phoneme_dict: str = "scripts/tts_dataset_files/cmudict-0.7b_nv22.10" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + + +@dataclass +class EnglishIpaG2pConfig: + _target_: str = "nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p" + phoneme_dict: str = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + locale: str = "en-US" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + grapheme_case: str = "upper" + use_stresses: bool = True + use_chars: bool = True + ignore_ambiguous_words: bool = False + + +@dataclass +class TextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer" + punct: bool = True + stresses: bool = True + chars: bool = True + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: G2PConfig = G2PConfig() + + +@dataclass +class EnglishIpaTextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer" + locale: str = "en-US" + punct: bool = True + # Define non_default_punct_list as a ClassVar to explicitly mark it as a class variable + non_default_punct_list: ClassVar[List[str]] = get_full_list_puncts() + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: EnglishIpaG2pConfig = EnglishIpaG2pConfig() + + +@dataclass +class TextTokenizerConfig: + text_tokenizer: TextTokenizer = TextTokenizer() + + +@dataclass +class EnglishIpaTextTokenizerConfig: + text_tokenizer: EnglishIpaTextTokenizer = EnglishIpaTextTokenizer() + + +def _get_default_text_tokenizer_conf(phoneme_probability: float = 0.5, use_ipa: bool = False): + if use_ipa: + g2p = EnglishIpaG2pConfig(phoneme_probability=phoneme_probability) + _text_tokenizer = EnglishIpaTextTokenizer(g2p=g2p) + text_tokenizer: EnglishIpaTextTokenizerConfig = EnglishIpaTextTokenizerConfig(text_tokenizer=_text_tokenizer) + else: + g2p = G2PConfig(phoneme_probability=phoneme_probability) + _text_tokenizer = TextTokenizer(g2p=g2p) + text_tokenizer: TextTokenizerConfig = TextTokenizerConfig(text_tokenizer=_text_tokenizer) + return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) + + +def pad_text_to_speech_dims(text_tensor, pad_id, pad_size=7): + token_len = text_tensor.shape[0] + empty_padding = torch.ones((pad_size, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id + return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0) + + +class Lang(enum.Enum): + en = 1 + es = 2 + fr = 3 + zh = 4 + de = 4 + + +class T5SpeechLMDataset(BasePromptLearningDataset): + """ + The dataset class for prompt-tuning or p-tuning pretrained T5 SpeechLM models. + """ + + def __init__( + self, + datasets, + tokenizer, + virtual_prompt_source: VirtualPromptSource, + task_templates: dict, + pseudo_tokens, + pad_token_id: str, + max_seq_length: int, + sample_rate: int, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + for_train: bool = True, + decoder_starts_with_pad: bool = False, + add_eos_to_decoder_output: bool = True, + add_sentinel_to_input: bool = True, + ul2_prompt_token: str = None, + segment_max_duration: Optional[int] = None, + trim: bool = False, + trim_ref: Optional[float] = None, + trim_top_db: Optional[int] = None, + trim_frame_length: Optional[int] = None, + trim_hop_length: Optional[int] = None, + pad_multiple: int = 1, + pitch_augment: bool = False, + sup_data_path: Optional[Union[Path, str]] = None, + speech_offset: Optional[int] = None, + train_task: Optional[str] = None, + seq_pattern: Optional[str] = "parallel", + use_attention_prior: Optional[bool] = False, + attention_prior_scaling_factor: Optional[float] = 1.0, + spec_aug=False, + spec_aug_time_width=0.2, + spec_aug_time_masks=2, + cross_attention_epsilon: Optional[float] = 0.0, + lm_vocab_size: Optional[int] = None, + num_speech_codebooks: Optional[int] = 8, + codebook_fps: Optional[int] = 86, + add_special_tokens_to_only_first_codebook: Optional[bool] = False, + context_pattern: Optional[str] = "parallel", + context_duration_min: Optional[float] = 3.0, + context_duration_max: Optional[float] = 5.0, + skip_datasets: Optional[List[str]] = [], # substrings of dataset names to skip + english_only_model: Optional[bool] = False, + context_conditioning: Optional[str] = "decoder", # encoder or decoder + use_beta_binomial_interpolator: Optional[str] = False, # encoder or decoder + context_slice_method: Optional[str] = "random", # random or fixed + phoneme_probability: Optional[float] = 0.5, + encoder_type: Optional[str] = "single_transformer", + use_ipa: bool = False, + **kwargs, + ): + """ + Only speech parameters are explained here. + segment_max_duration: Optional[int] = None, - Speech max segment duration + trim: bool = False, - speech parameter + trim_ref: Optional[float] = None, - speech parameter + trim_top_db: Optional[int] = None, - speech parameter + trim_frame_length: Optional[int] = None, - speech parameter + trim_hop_length: Optional[int] = None, - speech parameter + pad_multiple: int = 1, - speech parameter + pitch_augment: bool = False, - speech parameter + sup_data_path: Optional[Union[Path, str]] = None, - Supplementary folder path where codecs are stored. + speech_offset: Optional[int] = None, - if speech tokens then add this offset to the token indices to distinguish between text and speech tokens. + lm_vocab_size: Optional[int] = None, - vocab size of the original language model (phoneme tokens start from this index) + english_only_model: Optional[bool] = False, specify if monolingual or multi-lingual modeling. + use_ipa: bool = False, specify if using IPA tokens or default ARPABET tokens. Either choice still mixes chars. + **kwargs, + """ + # These two variables need to be set before calling super().__init__() because the parent class calls `load_data()` which requires these attributes. + self._rng = random.Random() + self.spec_aug = spec_aug if for_train else False + self.time_width = spec_aug_time_width + self.time_masks = spec_aug_time_masks + self.decoder_starts_with_pad = decoder_starts_with_pad + self.add_eos_to_decoder_output = add_eos_to_decoder_output + self.add_sentinel_to_input = add_sentinel_to_input + self.ul2_prompt_token = ul2_prompt_token + # Speech related variables + self.base_data_dir = None + self.segment_max_duration = segment_max_duration + self.sample_rate = sample_rate + self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) + self.pad_multiple = pad_multiple + self.pitch_augment = pitch_augment + self.trim = trim + self.trim_ref = trim_ref if trim_ref is not None else np.max + self.trim_top_db = trim_top_db if trim_top_db is not None else 60 + self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048 + self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512 + self.speech_offset = speech_offset if speech_offset is not None else 3 + self.seq_pattern = seq_pattern + self.use_attention_prior = use_attention_prior + self.attention_prior_scaling_factor = attention_prior_scaling_factor + self.cross_attention_epsilon = cross_attention_epsilon # value of prior for context tokens (b/w 0 and 1) + assert self.cross_attention_epsilon >= 0.0 and self.cross_attention_epsilon <= 1.0 + self.lm_vocab_size = tokenizer.vocab_size if lm_vocab_size is None else lm_vocab_size + self.num_speech_codebooks = num_speech_codebooks + self.codebook_fps = codebook_fps + self.add_special_tokens_to_only_first_codebook = add_special_tokens_to_only_first_codebook + # context_pattern and duration arguments are supported only if context_type is REFSPEAKERCODEC in the manifest + self.context_pattern = context_pattern + self.context_duration_min = context_duration_min + self.context_duration_max = context_duration_max + self.english_only_model = english_only_model + self.phoneme_tokenizer = None + if english_only_model: + self.phoneme_tokenizer = instantiate(_get_default_text_tokenizer_conf(phoneme_probability=phoneme_probability, use_ipa=use_ipa)).text_tokenizer + else: + self.g2p = {"fr": lambda x: x} + if kwargs.get("g2p", None): + if "english" in kwargs["g2p"]: + english_g2p = instantiate(kwargs["g2p"]["english"]) + self.g2p["en"] = lambda x: english_g2p(x) + if "spanish" in kwargs["g2p"]: + spanish_g2p = instantiate(kwargs["g2p"]["spanish"]) + self.g2p["es"] = lambda x: spanish_g2p(x) + if "mandarin" in kwargs["g2p"]: + mandarin_g2p = instantiate(kwargs["g2p"]["mandarin"]) + self.g2p["zh"] = lambda x: mandarin_g2p(x) + if "german" in kwargs["g2p"]: + german_g2p = instantiate(kwargs["g2p"]["german"]) + self.g2p["de"] = lambda x: german_g2p(x) + + self.context_conditioning = context_conditioning + if self.context_conditioning == "decoder": + assert self.context_duration_min == self.context_duration_max, "For decoder conditioning, context_duration_min and context_duration_max should be same" + self.decoder_context_len = int(self.context_duration_min * self.codebook_fps) #TODO: Just take from model var? + + # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type\ + self.sup_data_path = None + if sup_data_path is not None: + Path(sup_data_path).mkdir(parents=True, exist_ok=True) + self.sup_data_path = sup_data_path + + self.codec_folder = kwargs.pop('codec_folder', None) + self.train_task = train_task + if self.codec_folder is None and self.sup_data_path is not None: + self.codec_folder = Path(self.sup_data_path) / "codec" + elif isinstance(self.codec_folder, str): + self.codec_folder = Path(self.codec_folder) + + self.codec_folder.mkdir(exist_ok=True, parents=True) + + self.context_length = kwargs.pop('context_length', None) # only used in gpt dataset atm + # self.attention_prior_strength = attention_prior_strength + self.transformer_type = kwargs.pop('transformer_type', 'T5') + self.skip_datasets = skip_datasets + + self.beta_binomial_interpolator = BetaBinomialInterpolator(scaling_factor=self.attention_prior_scaling_factor) if use_beta_binomial_interpolator else None + self.context_slice_method = context_slice_method + self.encoder_type = encoder_type + super().__init__( + datasets=datasets, + tokenizer=tokenizer, + virtual_prompt_source=virtual_prompt_source, + task_templates=task_templates, + pseudo_tokens=pseudo_tokens, + pad_token_id=pad_token_id, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + for_train=for_train, + ) + + def load_data(self, dataset): + """ + Loads a dataset by filling in the task templates specified in the config file + with the information from each training/inference example. Converts all input + text into token ids. Also replaces the <|VIRTUAL_PROMPT_#|> placeholders in + the task templates with the actual virtual prompt token ids. + + params: + dataset: A list of json objects or a dictionary objects each + containing the information needed for a training example + """ + copy_dataset = list(dataset) + audio_filelist = [] + # This loop is needed to calculate self.base_data_dir. + for json_line in copy_dataset: + if type(json_line) == dict: + doc = json_line + else: + doc = json.loads(json_line) + taskname = doc["taskname"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + + for p in prompt_template_fields: + if f"{p}_type" in doc and doc[f"{p}_type"] == "SPEECH": + audio_filelist.append(doc[p]) + self.base_data_dir = get_base_dir(audio_filelist) + + skipped = 0 + tts = 0 + asr = 0 + i = 0 + logging.info(f"copy_dataset len === {len(copy_dataset)}") + examples = [] + for json_line in tqdm(copy_dataset): + i += 1 + + # Read example dict or load the information for a single example from .json file + if type(json_line) == dict: + doc = json_line + else: + doc = json.loads(json_line) + + if self.context_conditioning == "decoder": + # Modify doc to make combine context and anwer + assert ";" not in doc['context'], "Multiple contexts not supported in decoder conditioning" + doc['answer'] = "{};{}".format(doc['context'], doc['answer']) + doc['answer_duration'] = self.context_duration_min + doc['answer_duration'] + doc['answer_type'] = "CONTEXTANSWER" + doc['context_type'] = "DUMMYCONTEXT" + doc['context'] = "DUMMYCONTEXT" + + question_in_manifest = doc['question'] + + if "Text to speech this" in question_in_manifest or "Phoneme TTS" in question_in_manifest: + tts += 1 + if self.train_task not in ['tts', 'all']: + continue + elif "Next token prediction" in question_in_manifest: + if self.train_task != 'tts': + asr += 1 + else: + tts += 1 + continue + else: + if self.train_task == 'tts': + continue + asr += 1 + + if doc["context_type"] == "SPEECH": + assert "context_duration" in doc, f"context_duration key not in document {doc}" + approx_context_len = 3 * (self.codebook_fps + 1) # +1 just to be safe + if self.context_length is not None and doc["context_duration"] < self.context_length: + logging.debug( + f"skipped as context_length of {doc['context_duration']} is less than {self.context_length}" + ) + skipped += 1 + continue + elif "Remove Noise" in question_in_manifest: + approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) + elif "Extract Speaker Audio" in question_in_manifest: + approx_context_len = ( + doc["answer_duration"] * (self.codebook_fps + 1) + 400 + ) # 400 is the max ref speaker audio + elif ("Text to speech this" in question_in_manifest) or ('Phoneme TTS' in question_in_manifest): + # approx_context_len = 400 + approx_context_len = 5 * (self.codebook_fps + 1) # better than 400. TODO: pneekhara: Need to change things for multi-encoder vs single encoder based filtering. + elif "Edit Speech" in question_in_manifest: + approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) + else: + raise NotImplementedError(f"Unknown context type {doc['context_type']}") + + approx_question_len = len(doc["question"].split(' ')) + 3 + if 'Phoneme TTS' in question_in_manifest: + # approx len is equal to num of characters + approx_question_len = len(question_in_manifest) + + if doc["answer_type"] in ["SPEECH", "AUDIOCODEC", "CONTEXTANSWER"]: + assert "answer_duration" in doc, f"answer_duration key not in document {doc}" + approx_answer_len = doc["answer_duration"] * (self.codebook_fps + 1) + 3 # +3 for EOS, BOS padding + if self.seq_pattern == "delay_parallel": + # In delay parallel, there is padding so add 8 frames + approx_answer_len = approx_answer_len + self.num_speech_codebooks + else: + approx_answer_len = len(doc["answer"].split(' ')) + 3 + + skip_record = False + for skip_dataset in self.skip_datasets: + if skip_dataset in doc['answer']: + skip_record = True + + if not skip_record: + if (self.transformer_type == "GPT") and ( + self.min_seq_length + < approx_context_len + approx_question_len + approx_answer_len + < self.max_seq_length + ): + examples.append(doc) + elif (self.transformer_type == "T5") and ( + self.min_seq_length < approx_context_len + approx_question_len < self.max_seq_length + and self.min_seq_length < approx_answer_len < self.max_seq_length + ): + examples.append(doc) + else: + logging.debug(f"skipped for {approx_context_len + approx_question_len} {approx_answer_len} len") + skipped += 1 + else: + print("Skipping", doc['answer']) + logging.debug(f"skipped for {doc['answer']} as it is in skip_datasets") + skipped += 1 + + # logging.info(f"After Process len(self.examples) {len(self.examples)} TTS = {tts} ASR = {asr}") + logging.info(f'Skipped {skipped} sentences, sequence length too short or too long even after truncation') + + return examples + + def __getitem__(self, idx): + doc = self.examples[idx] + taskname = doc["taskname"] + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + truncation_field = self.task_templates[taskname]['truncate_field'] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + self._input_sanity_checks( + total_virtual_tokens=total_virtual_tokens, + virtual_token_splits=virtual_token_splits, + prompt_template=prompt_template, + prompt_template_fields=doc.keys(), # Skip this check as we don't need it for TTS + truncation_field=truncation_field, + answer_field=answer_field, + doc=doc, + ) + question_in_manifest = doc['question'] + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + # TODO @xueyang: declare the instructions when initializing the dataset so that they can be re-used. Temporally + # hardcode them here. + question_text = doc["question"].strip() + instructions = ["Phoneme TTS", "Text to speech this"] + for prefix in instructions: + if doc["question"].startswith(prefix): + question_text = doc["question"][len(prefix):].strip() + break + + input_dict = self._insert_data_in_template(prompt_template_fields, doc, answer_field) + lang = Lang[doc.get("lang", "en")] + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + if ("Text to speech this" in question_in_manifest) and (doc["context_type"] == "SPEECH"): + total_context_len = context_tokens[0].size()[1] + reduced_len = min( + 400, + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)), + ) + start_token_index = random.randint( + 0, total_context_len - reduced_len + ) # start index can be greater than 440 + context_tokens[0] = context_tokens[0][ + :, start_token_index : min(start_token_index + 440, start_token_index + reduced_len) + ] + elif "Next token prediction" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + end_token_index = int(total_context_len * random.uniform(0.01, 0.2)) + context_tokens[0] = context_tokens[0][:, :end_token_index] + + # Get virtual tokens + # `virtual_tokens` is "". + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + # print("virtual_tokens", virtual_tokens) + + # a trick to align with the data format in t5 pretraining + # new + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + question_tokens = question_tokens + [self.tokenizer.eos_id] + + # Try to truncate input text to fit into the max sequence length + if self._get_len(context_tokens, question_tokens, virtual_tokens) > self.max_seq_length: + context_tokens, question_tokens, virtual_tokens = self._truncate_input_speech( + context_tokens, question_tokens, virtual_tokens + ) + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] == "TEXT" and doc["context_type"] != "TEXT": + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + + # context_tokens: tensor, (num_speech_codebooks, audio_context_len) + # question_tokens: tensor, (num_speech_codebooks, instruction token len + question token len + 1 ( + 1 ([SEP])), only first row includes token ids while all other rows are all zeros (pad) + if self.encoder_type == "multi_transformer": + context_and_question_tokens = [context_tokens, question_tokens] + else: + context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + if end_token_index > -1: + answer_ids[0] = answer_ids[0][:, end_token_index:] + + if self.decoder_starts_with_pad: + answer_text_ids = [self.tokenizer.pad_id] + else: + answer_text_ids = [self.tokenizer.bos_id] + # a trick to align with the data format in t5 pretraining + # if self.add_sentinel_to_input: + # answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + answer_text_ids += answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + taskname_id = self.tokenizer.text_to_ids(taskname) + elif ( + self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT + ): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged. + taskname_id = -1 + else: + raise ValueError("Invalid virtual prompt source specified") + + dec_input = None + dec_labels = None + + # if single-encoder and context_condition is decoder, answer_text_ids = [CLS_id, context audio code tensors, zero-pad, answer audio code tensor, SEP_id] + # if multi-encoder, answer_text_ids = [CLS_id, answer audio codec tensor, SEP_id], so dec_input will not include audio context anymore. + if answer_field in doc.keys(): # training and validation + dec_input = answer_text_ids[:-1] + dec_labels = answer_text_ids[1:] + + # if single-encoder and context_condition is decoder: + # dec_input: shape=(self.num_speech_codebooks, 1([CLS]) + len(context audio frames) + 1([PAD]) + len(answer audio frames)) + # dec_labels: shape=(self.num_speech_codebooks, len(context audio frames) + 1([PAD]) + len(answer audio frames) + 1([SEP])) + # if multi-encoder: + # dec_input: (num_speech_codebooks, 1([CLS]) + len(answer audio frames)) + # dec_labels: (num_speech_codebooks, len(answer audio frames) + 1([SEP])) + dec_input, dec_input_len = self.list_to_tensor(dec_input, True) + dec_labels, dec_labels_len = self.list_to_tensor(dec_labels, True) + is_speech = True if doc["answer_type"] != "TEXT" else False + if is_speech: + assert dec_input.dim() == 2 and dec_labels.dim() == 2 + if self.seq_pattern == "delay_parallel": + num_codebooks = dec_input.shape[0] + dec_input_padded = torch.cat( + [ + torch.zeros_like(dec_input[:, 0:num_codebooks]), + dec_input, + torch.zeros_like(dec_input[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_labels_padded = torch.cat( + [ + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + dec_labels, + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + dec_labels_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dec_input_padded.shape[1] - _c + et_decoder_labels = dec_labels_padded.shape[1] - _c + dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) + dec_labels_new.append(dec_labels_padded[_c, st:et_decoder_labels]) + dec_input = torch.stack(dec_input_new, dim=0) + dec_labels = torch.stack(dec_labels_new, dim=0) + dec_input_len = torch.tensor(dec_input.shape[1]).long() + dec_labels_len = torch.tensor(dec_labels.shape[1]).long() + + if self.encoder_type == "multi_transformer": + enc_len = question_tokens_len + virtual_tokens_len + else: + enc_len = context_tokens_len + question_tokens_len + virtual_tokens_len + # TODO: Remove hardcoding + start_of_question_offset = 4 # For both "Text to Speech this" and "Phoneme TTS" + end_of_question_offset = 2 + cross_attention_prior = torch.zeros(dec_labels_len, enc_len) + self.cross_attention_epsilon + if self.use_attention_prior: + prior_dec_len = dec_labels_len.item() + prior_dec_start_idx = 0 + if self.context_conditioning == "decoder": + prior_dec_len = dec_labels_len.item() - (self.decoder_context_len + 1) + prior_dec_start_idx = self.decoder_context_len + 1 + text_len = question_tokens_len.item() - start_of_question_offset - end_of_question_offset + audio_len = prior_dec_len + if self.beta_binomial_interpolator is not None: + cross_attention_question_prior = torch.from_numpy( + self.beta_binomial_interpolator(audio_len, text_len) + ) + else: + cross_attention_question_prior = torch.from_numpy( + beta_binomial_prior_distribution( + text_len, + audio_len, + scaling_factor=self.attention_prior_scaling_factor, + ) + ) + if self.encoder_type == "multi_transformer": + cross_attention_prior[ + prior_dec_start_idx:, virtual_tokens_len + start_of_question_offset : -end_of_question_offset + ] = cross_attention_question_prior + else: + cross_attention_prior[ + prior_dec_start_idx:, virtual_tokens_len + context_tokens_len + start_of_question_offset : -end_of_question_offset + ] = cross_attention_question_prior + + if self.encoder_type == "multi_transformer": + context_and_question_len = [context_tokens_len, question_tokens_len] + else: + context_and_question_len = context_tokens_len + question_tokens_len + return ( + taskname_id, # List, only one item. token id for "squad" + virtual_tokens, # Tensor, shape=(3,). token id for ['', '', ''] + virtual_tokens_len, # tensor, 3 + context_tokens_len, # tensor, 1 + # tensor if single encoder and context_condition is encoder, shape=(self.num_speech_codebooks, 1(context) + question len + 1() + 1([SEP])). only first row includes token ids while all other rows are all zeros (pad). + # list if multi-encoder and context_condition is encoder. + context_and_question_tokens, + # tensor scalar if single encoder and context_condition is decoder, 1 + (question len + 1 + 1). + # list if multi-encoder and context_condition is encoder. + context_and_question_len, + dec_input, # tensor, shape=(self.num_speech_codebooks, 1 CLS + context audio frame len + 1 pad + answer audio frame len), first column is [CLS_id, 0*7]^T + dec_input_len, # scalar tensor, 1 CLS + context audio frame len + 1 pad + answer audio frame len. 1 corresponds to CLS id + dec_labels, # tensor, shape=(self.num_speech_codebooks, context audio frame len + 1 pad + answer frame len + 1 SEP). + dec_labels_len, # tensor, context audio frame len + 1 PAD + answer frame len + 1 SEP. 1 corresponds to SEP id. + is_speech, # True + cross_attention_prior, # tensor, shape=(dec_labels_len, context_tokens_len + question_tokens_len + virtual_tokens_len). + lang.value, # int, + question_text, # str, answer transcript without question type (Phoneme TTS or Text to speech this). + ) + + def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens): + total_len = self._get_len(context_tokens, question_tokens, virtual_tokens) + context_len = self._get_element_len(context_tokens) + truncation_length = total_len - self.max_seq_length + 1 + context_tokens[0] = context_tokens[0][:, min(truncation_length, context_len) :] + return context_tokens, question_tokens, virtual_tokens + + def list_to_tensor(self, element, fill=False): + """ + Convert list to tensor. The list might contain integers, 2D-tensors (speech tokens) and combination of two. + If all of them are ints, simply convert to tensor + If combination of 2D-tensor and ints. Convert int to the dimension of the tensor. + example: [2, 4, 5] -> torch.tensor([2, 4, 5]) + example: [2, torch.tensor([[4, 5, 6], [6, 7, 8]])] -> torch.tensor( [[-1, 4, 5, 6], [2, 6, 7, 8]] ) + """ + ret, ln = None, None + if element is None: + return ret, ln + + max_len = max([1 if isinstance(item, int) else len(item) for item in element]) + if max_len == 1: + ret = torch.as_tensor(element).long() + ln = torch.tensor(ret.size()[0]).long() + else: + ret = [] + for e in element: + if isinstance(e, int): + tmp = torch.full((self.num_speech_codebooks, 1), e if fill else -1) + tmp[self.num_speech_codebooks - 1] = e + if self.add_special_tokens_to_only_first_codebook: + # Fill zeros in all other codebooks (to avoid out of range when getting embeddings) + tmp[1:] = 0 + else: + tmp = e + ret.append(tmp) + ret = torch.cat(ret, dim=1) + ln = torch.tensor(ret.size()[1]).long() + return ret, ln + + def _get_text_tokens(self, text): + input_ids = self.tokenizer.text_to_ids(text) + return input_ids + + def _get_phoneme_tokens(self, text, lang="en"): + if self.english_only_model: + input_ids = self.phoneme_tokenizer.encode(text) + input_ids_adjusted = [_id + self.lm_vocab_size for _id in input_ids] + return input_ids_adjusted + else: + text = any_locale_text_preprocessing(text) + input_ids = self.g2p[lang](text) + input_ids_adjusted = [] + for i in input_ids: + input_ids_adjusted.append(f"p{{{i}}}") + input_ids_adjusted = self.tokenizer.text_to_ids("".join(input_ids_adjusted)) + return input_ids_adjusted + + def _pad_wav_to_multiple(self, wav): + if self.pad_multiple > 1: + if wav.shape[0] % self.pad_multiple != 0: + wav = torch.cat( + [wav, torch.zeros(self.pad_multiple - wav.shape[0] % self.pad_multiple, dtype=torch.float)] + ) + return wav + + def _get_element_len(self, element): + length = 0 + if isinstance(element, list): + for e in element: + if isinstance(e, int): + length += 1 + else: + if e.dim() > 1: + length += e.size()[1] + else: + length += e.size()[0] + else: + if element.dim() > 1: + length += element.size()[1] + else: + length += element.size()[0] + return length + + def _get_len(self, context_tokens, question_tokens, virtual_tokens): + length = 0 + length += self._get_element_len(context_tokens) + length += self._get_element_len(question_tokens) + length += self._get_element_len(virtual_tokens) + return length + + def _load_audio(self, audio_filepath, dur=-1): + if self.segment_max_duration is not None and dur > 0 and dur > self.segment_max_duration: + # this case has been added for segmenting audio for speaker verification task of SSLDisentangler + n_segments = int(self.segment_max_duration * self.sample_rate) + features = AudioSegment.segment_from_file( + audio_filepath, target_sr=self.sample_rate, n_segments=n_segments, trim=self.trim + ) + + features = torch.tensor(features.samples) + if self.pad_multiple > 1: + features = self._pad_wav_to_multiple(features) + audio, audio_length = features, torch.tensor(features.shape[0]).long() + else: + features = self.featurizer.process( + audio_filepath, + trim=self.trim, + trim_ref=self.trim_ref, + trim_top_db=self.trim_top_db, + trim_frame_length=self.trim_frame_length, + trim_hop_length=self.trim_hop_length, + ) + + if self.pad_multiple > 1: + features = self._pad_wav_to_multiple(features) + + audio, audio_length = features, torch.tensor(features.shape[0]).long() + + return audio, audio_length + + def convert_audio(self, audio, sample_rate, target_sample_rate, target_channels): + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + assert audio.shape[1] in [1, 2], "Audio must be mono or stereo." + # assert sample_rate == target_sample_rate, "sample rate of FastPitch and Encodec model has to be same" + if target_channels == 2: + *shape, _, length = audio.shape + audio = audio.expand(*shape, target_channels, length) + return audio + + def get_codec(self, audio): + wav1 = self.convert_audio(audio, self.sample_rate, self.encodec_model.sample_rate, self.encodec_model.channels) + encoded_frames = self.encodec_model.encode(wav1) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) + return codes.squeeze(0) + + def get_quantizer_codebook(self, reference_codec, reference_codec_length): + out = torch.zeros((1, 128, reference_codec_length.item())) + for i in range(reference_codec.size()[0]): + out += self.encodec_model.quantizer.vq.layers[i].decode(reference_codec[i, :].unsqueeze(0)) + return out.squeeze(0) + + def _get_speech_tokens(self, audio_filepath, dur=-1): + # Let's keep audio name and all internal directories in rel_audio_path_as_text_id to avoid any collisions + rel_audio_path = Path(audio_filepath).relative_to(self.base_data_dir).with_suffix("") + rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") + + # Load audio features + audio, audio_length = self._load_audio(audio_filepath, dur) + + # Convert to codes + codec_codes, codec_codes_length = None, None # Codes + codec_path = self.codec_folder / f"{rel_audio_path_as_text_id}.pt" + + if codec_path.exists(): + try: + codec_codes = torch.load(codec_path).long() + except Exception as e: + print(f"[ERROR IN LOADING {codec_path}] e") + codec_codes = self.get_codec(audio).long() + torch.save(codec_codes, codec_path) + else: + codec_codes = self.get_codec(audio).long() + torch.save(codec_codes, codec_path) + + codec_codes_length = torch.tensor(codec_codes.shape[1]).long() + + # Convert codes to codes corresponding to megatron embedding layer + codec_codes[0] = (codec_codes[0] + self.speech_offset).long() + + return codec_codes + + def _get_tokens(self, doc, field, field_data): + if self.context_slice_method == "random": + # During training, we want a random slice of the context + rng = random.Random() # Custom random generator (since random uses fixed seeds) + elif self.context_slice_method == "fixed": + # During inference, we want a fixed slice of the context + rng = random + else: + raise ValueError(f"Invalid context_slice_method {self.context_slice_method}") + if f"{field}_type" not in doc.keys(): + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'TEXT': + _text = field_data.strip(" ") + if _text.startswith("Phoneme TTS"): + lang = doc.get("lang", "en") + instruction_tokens = self._get_text_tokens("Phoneme TTS") + field_tokens = self._get_phoneme_tokens(_text[len("Phoneme TTS"):].strip(), lang=lang) + field_tokens = instruction_tokens + field_tokens + elif _text.startswith("Edit Speech"): + # Always use phoneme tokenizer for edit speech + instruction_tokens = self._get_text_tokens("Edit Speech") + field_tokens = self._get_phoneme_tokens(_text[len("Edit Speech"):].strip()) + field_tokens = instruction_tokens + field_tokens + elif _text.startswith("TEXT CONTEXT:"): + # Speaker id conditioning + field_tokens = self._get_text_tokens(_text) + # pad field tokens to fixed length + # assert self.context_duration_min == self.context_duration_max, "TEXT CONTEXT only supports fixed context duration" + # To keep context length the same for audio or tex context + # _fixed_context_len = int(self.context_duration_min * self.codebook_fps) + field_tokens = field_tokens + [self.tokenizer.eos_id] + else: + # if starts with Text to speech this + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'SPEECH': + dur = -1 + if f"{field}_duration" in doc: + dur = doc[f"{field}_duration"] + field_tokens = self._get_speech_tokens(field_data, dur) # list of ids + if not isinstance(field_tokens, list): + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'AUDIOCODEC': + reference_codec_paths = field_data.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + if self.codec_folder is not None: + reference_codec_path = self.codec_folder / reference_codec_path + field_tokens = torch.load(reference_codec_path).long() + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + # print("AUDIOCODEC", field_tokens.shape) + elif doc[f"{field}_type"] == 'REFSPEAKERCODEC': + reference_codec_paths = field_data.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + if self.codec_folder is not None: + reference_codec_path = self.codec_folder / reference_codec_path + field_tokens = torch.load(reference_codec_path).long() + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + _min_len = int(self.context_duration_min * self.codebook_fps) + _max_len = int(self.context_duration_max * self.codebook_fps) + reference_codec_len = rng.randint(_min_len, _max_len) + reference_codec_len = min(reference_codec_len, field_tokens.shape[1]) + si = rng.randint(0, field_tokens.shape[1] - reference_codec_len) + field_tokens = field_tokens[:, si : si + reference_codec_len] + if self.context_pattern == "delay_parallel": + field_tokens = torch.cat( + [ + torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(), + field_tokens, + torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(), + ], + dim=1, + ) + new_field_tokens = [] + for _c in range(self.num_speech_codebooks): + st = self.num_speech_codebooks - _c + et = field_tokens.shape[1] - _c + new_field_tokens.append(field_tokens[_c, st:et]) + field_tokens = torch.stack(new_field_tokens, dim=0) + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'DUMMYCONTEXT': + field_tokens = torch.zeros(self.num_speech_codebooks, 1).long() + return [field_tokens] + elif doc[f"{field}_type"] == 'CONTEXTANSWER': + # Both Context and Answer are in the field + context_info, answer_codec_path = field_data.split(";") + if self.codec_folder is not None: + context_codec_path = self.codec_folder / context_info + answer_codec_path = self.codec_folder / answer_codec_path + if context_info.startswith("TEXT CONTEXT:"): + context_tokens = self._get_text_tokens(context_info.strip(" ")) + # pad field tokens to fixed length + assert self.context_duration_min == self.context_duration_max, "TEXT CONTEXT only supports fixed context duration" + _fixed_context_len = int(self.context_duration_min * self.codebook_fps) + context_tokens = context_tokens + [self.tokenizer.pad_id] * (_fixed_context_len - len(context_tokens)) + + answer_tokens = torch.load(answer_codec_path).long() + answer_tokens[0] = (answer_tokens[0] + self.speech_offset).long() + field_tokens = context_tokens + [self.tokenizer.pad_id] + [answer_tokens] + else: + context_tokens = torch.load(context_codec_path).long() + context_tokens[0] = (context_tokens[0] + self.speech_offset).long() + assert self.context_duration_min == self.context_duration_max, "CONTEXTANSWER only supports fixed context duration" + reference_codec_len = int(self.context_duration_min * self.codebook_fps) + if context_tokens.shape[1] < reference_codec_len: + # Repeat the context to match the reference_codec_len + context_tokens = torch.cat([context_tokens] * (reference_codec_len // context_tokens.shape[1] + 1), dim=1) + assert context_tokens.shape[1] >= reference_codec_len, "CONTEXTANSWER context duration is less than min duration {} {} {}".format(context_tokens.shape[1], reference_codec_len, context_codec_path) + si = rng.randint(0, context_tokens.shape[1] - reference_codec_len) + context_tokens = context_tokens[:, si:si+reference_codec_len] + + answer_tokens = torch.load(answer_codec_path).long() + answer_tokens[0] = (answer_tokens[0] + self.speech_offset).long() + pad_tokens = torch.zeros(self.num_speech_codebooks, 1).long() + # padding between context and answer + field_tokens = torch.cat([context_tokens, pad_tokens, answer_tokens], dim=1) + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'SEPARATIONCODECS': + mixed_codec_path, reference_codec_paths = field_data.split(",") + reference_codec_paths = reference_codec_paths.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + mixed_codec = torch.load(mixed_codec_path).long() + reference_codec = torch.load(reference_codec_path).long() + reference_codec_len = rng.randint(240, 400) + reference_codec = reference_codec[:, :reference_codec_len] + # MIXED AUDIO AND REF AUDIO ARE SEPARATED BY 8 TIMESTEPS OF 1023 TOKENS IN ALL CODEBOOKS + mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long() + field_tokens = torch.cat([mixed_codec, mask_tokens, reference_codec], dim=1) + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'EDITINGCODECS': + reference_audio_path = field_data + reference_codec = torch.load(reference_audio_path).long() + assert reference_codec.shape[1] > 80 # ensure reference audio is atleast 1 second + mask_len = rng.randint(40, 320) # ~0.5 second to 4 seconds + mask_len = min(mask_len, reference_codec.shape[1] - 80) + mask_start = rng.randint(0, reference_codec.shape[1] - mask_len) + mask_end = mask_start + mask_len + mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long() + seg1 = reference_codec[:, :mask_start] + seg2 = reference_codec[:, mask_end:] + field_tokens = torch.cat([seg1, mask_tokens, seg2], dim=1) + # MISSING AUDIO IS REPLACED WITH 8 TIMESTEPS OF 1023 TOKENS IN ALL CODEBOOKS + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + else: + raise Exception(f"{field}_type not recognized") + return field_tokens + + def _insert_data_in_template(self, prompt_template_fields, doc, answer_field): + """ Format the input example according to the template """ + out_dict = {} + for field in prompt_template_fields: + # discard the last one, {label} / {answer} + # Or if some fields from the template aren't present, e.g. {answer} during inference + # just remove that field from the template, leaving the space blank + if field == answer_field or field not in doc.keys(): + continue + # out_dict[field] = "" + + elif field in doc.keys(): + field_data = doc[field] + if f"{field}_type" not in doc.keys(): + doc[f"{field}_type"] = "TEXT" + raise Exception(f"{field}_type does not exist in doc") + else: + out_dict[field] = self._get_tokens(doc, field, field_data) + return out_dict + + def get_position_ids(self, virtual_token, context_and_qquestion): + enc_input = [] + enc_input.append(virtual_token) + if context_and_qquestion.dim() > 2: + enc_input.append(context_and_qquestion[:, 0, :]) + else: + enc_input.append(context_and_qquestion) + + enc_input = torch.cat(enc_input, dim=1) + + enc_input_p = enc_input[:, 0, :] if enc_input.dim() == 3 else enc_input + return build_position_ids(enc_input_p).contiguous() + + def collate_fn(self, batch): + """ Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch """ + + data_dict = self.pad_batch_and_build_loss_mask(batch) + + if self.encoder_type == "multi_transformer": + position_ids = [self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][0]), self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][1])] + else: + position_ids = self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens']) + + return ( + data_dict['virtual_tokens'], + data_dict['context_and_question_tokens'], + data_dict['enc_mask'], + data_dict['dec_input'], + data_dict['dec_input_mask'], + data_dict['dec_labels'], + data_dict['dec_labels_mask'], + position_ids, + data_dict['taskname_id'], + data_dict['speech_mask'], + data_dict['context_and_question_tokens_lens'], + data_dict['cross_attention_prior'], + data_dict['text_limits'], + data_dict['lang'], + data_dict['question_texts'], + ) + + def pad_batch_and_build_loss_mask(self, batch): + """ Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask """ + ( + taskname_ids, + _, + virtual_tokens_len, + _, + _, + context_and_question_tokens_len, + _, + dec_input_len, + _, + dec_labels_len, + _, + _, + _, + question_texts, + ) = zip(*batch) + + taskname_ids = self.pad_taskname_ids(taskname_ids) + + max_virtual_tokens_len = max(virtual_tokens_len).item() if virtual_tokens_len is not None else 0 + if isinstance(virtual_tokens_len, tuple): + virtual_tokens_len = torch.stack(virtual_tokens_len) + virtual_mask = get_mask_from_lengths(virtual_tokens_len) + + if self.encoder_type == "multi_transformer": + max_context_len = max(_c[0] for _c in context_and_question_tokens_len) if context_and_question_tokens_len is not None else 0 + max_question_len = max(_c[1] for _c in context_and_question_tokens_len) if context_and_question_tokens_len is not None else 0 + max_context_and_question_tokens_len = [max_context_len, max_question_len] + context_len = torch.stack([_c[0] for _c in context_and_question_tokens_len]) + question_len = torch.stack([_c[1] for _c in context_and_question_tokens_len]) + context_mask = get_mask_from_lengths(context_len) + question_mask = get_mask_from_lengths(question_len) + context_and_question_tokens_len = [context_len, question_len] + context_and_question_mask = [context_mask, question_mask] + enc_mask = [torch.cat([virtual_mask, context_and_question_mask[0]], dim=1), torch.cat([virtual_mask, context_and_question_mask[1]], dim=1)] + # import ipdb; ipdb.set_trace() + else: + max_context_and_question_tokens_len = ( + max(context_and_question_tokens_len).item() if context_and_question_tokens_len is not None else 0 + ) + if isinstance(context_and_question_tokens_len, tuple): + context_and_question_tokens_len = torch.stack(context_and_question_tokens_len) + context_and_question_mask = get_mask_from_lengths(context_and_question_tokens_len) + enc_mask = torch.cat([virtual_mask, context_and_question_mask], dim=1) + + max_dec_input_len = max(dec_input_len).item() if dec_input_len is not None else 0 + max_dec_labels_len = max(dec_labels_len).item() if dec_labels_len is not None else 0 + + ( + virtual_tokens_list, + context_question_tokens_list, + dec_input_list, + dec_input_mask_list, + dec_labels_list, + dec_labels_mask_list, + speech_mask_list, + cross_attention_prior_list, + text_limits, + lang_list, + ) = ( + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], + ) + + for i, sample_tuple in enumerate(batch): + ( + _, + virtual_token, + virtual_token_len, + context_token_len, + context_and_question_token, + context_and_question_token_len, + dec_input, + dec_input_len, + dec_label, + dec_label_len, + is_speech, + cross_attention_prior, + lang, + _, + ) = sample_tuple + + virtual_tokens_list.append( + general_padding( + virtual_token, virtual_token_len.item(), max_virtual_tokens_len, pad_value=self.tokenizer.pad_id + ) + ) + + if self.encoder_type == "multi_transformer": + context_tokens_padded = general_padding( + context_and_question_token[0], + context_and_question_token_len[0].item(), + max_context_and_question_tokens_len[0], + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims( + context_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + question_tokens_padded = general_padding( + context_and_question_token[1], + context_and_question_token_len[1].item(), + max_context_and_question_tokens_len[1], + pad_value=self.tokenizer.pad_id, + ) + if len(question_tokens_padded.shape) < 2: + question_tokens_padded = pad_text_to_speech_dims( + question_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + context_question_tokens_list.append([context_tokens_padded, question_tokens_padded]) + else: + # This means context and questions are concatenated together + context_tokens_padded = general_padding( + context_and_question_token, + context_and_question_token_len.item(), + max_context_and_question_tokens_len, + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims( + context_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + context_question_tokens_list.append(context_tokens_padded) + + if max_dec_input_len > 0: + dec_input_padded = general_padding( + dec_input, dec_input_len.item(), max_dec_input_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_input_padded.shape) < 2: + dec_input_padded = pad_text_to_speech_dims( + dec_input_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + dec_input_list.append(dec_input_padded) + dec_mask = ( + torch.as_tensor(([1] * dec_input_len) + ([0] * (max_dec_input_len - dec_input_len))) + .long() + .contiguous() + ) + dec_input_mask_list.append(dec_mask) + speech_mask = dec_mask if is_speech else torch.zeros(dec_mask.shape) + speech_mask_list.append(speech_mask) + + if max_dec_labels_len > 0: + loss_mask = ( + torch.as_tensor(([1] * dec_label_len) + ([0] * (max_dec_labels_len - dec_label_len))) + .long() + .contiguous() + ) + dec_label_padded = general_padding( + dec_label, dec_label_len.item(), max_dec_labels_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_label_padded.shape) < 2: + dec_label_padded = pad_text_to_speech_dims( + dec_label_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + dec_labels_list.append(dec_label_padded) + dec_labels_mask_list.append(loss_mask) + + _p0 = max_dec_labels_len - dec_label_len + if self.encoder_type == "multi_transformer": + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len[1] + - context_and_question_token_len[1] + - virtual_token_len + ) + else: + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len + - context_and_question_token_len + - virtual_token_len + ) + + cross_attention_prior_padded = torch.nn.functional.pad( + cross_attention_prior, pad=(0, _p1, 0, _p0), mode="constant", value=1, + ) + cross_attention_prior_list.append(cross_attention_prior_padded) + + if self.encoder_type == "multi_transformer": + _start_of_text_id = virtual_token_len + 4 + _end_of_text_id = _start_of_text_id + ( + context_and_question_token_len[1] - 2 - 4 + ) # -2 for some end tokens + else: + _start_of_text_id = virtual_token_len + context_token_len + 4 + _end_of_text_id = _start_of_text_id + ( + context_and_question_token_len - context_token_len - 2 - 4 + ) # -2 for some end tokens + text_limits.append(torch.tensor([_start_of_text_id.item(), _end_of_text_id.item()])) + lang_list.append(torch.tensor(lang)) + + dec_labels_mask = torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None + if dec_labels_mask is not None and self.context_conditioning == 'decoder': + # Mask out context tokens from loss computation. +1 for bos/pad in the beginning + dec_labels_mask[:,:self.decoder_context_len + 1] = 0 + + if self.encoder_type == "multi_transformer": + context_batch = torch.stack([c[0] for c in context_question_tokens_list]) + question_batch = torch.stack([c[1] for c in context_question_tokens_list]) + context_and_question_tokens = [context_batch, question_batch] + else: + context_and_question_tokens = torch.stack(context_question_tokens_list) + + data_dict = { + "taskname_id": taskname_ids, + "virtual_tokens": torch.stack(virtual_tokens_list), + "context_and_question_tokens": context_and_question_tokens, + "enc_mask": enc_mask, + "dec_input": torch.stack(dec_input_list) if len(dec_input_list) > 0 else None, + "dec_input_mask": torch.stack(dec_input_mask_list) if len(dec_input_mask_list) > 0 else None, + "dec_labels": torch.stack(dec_labels_list) if len(dec_labels_list) > 0 else None, + "dec_labels_mask": dec_labels_mask, + "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, + "context_and_question_tokens_lens": context_and_question_tokens_len, + "cross_attention_prior": torch.stack(cross_attention_prior_list) + if len(cross_attention_prior_list) > 0 + else None, + "text_limits": torch.stack(text_limits) if len(text_limits) > 0 else None, # tensor, valid range of answer transcripts without virtual/instruction/end tokens. + "lang": torch.stack(lang_list), + "question_texts": question_texts, + } + + return data_dict + + +class GPTSpeechLMDataset(T5SpeechLMDataset): + def __init__(self, *args, **kwargs): + kwargs["transformer_type"] = "GPT" + super().__init__(*args, **kwargs) + + def __getitem__(self, idx): + doc = self.examples[idx] + taskname = doc["taskname"] + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + truncation_field = self.task_templates[taskname]['truncate_field'] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + self._input_sanity_checks( + total_virtual_tokens=total_virtual_tokens, + virtual_token_splits=virtual_token_splits, + prompt_template=prompt_template, + prompt_template_fields=prompt_template_fields, + truncation_field=truncation_field, + answer_field=answer_field, + doc=doc, + ) + question_in_manifest = doc['question'] + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + input_dict = self._insert_data_in_template(prompt_template_fields, doc, answer_field) + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + if ("Text to speech this" in question_in_manifest or "Phoneme TTS" in question_in_manifest) and ( + doc["context_type"] == "SPEECH" + ): + total_context_len = context_tokens[0].size()[1] + + # Redo of this logic 11/29 + # logging.debug(f"total_context_len: {total_context_len}") + context_3s = 3 * self.codebook_fps + if total_context_len > context_3s: + start_token_index = random.randint(0, total_context_len - context_3s) + # logging.debug(f"start_token_index: {start_token_index}") + end_token_index = start_token_index + min(context_3s, total_context_len) + # logging.debug(f"end_token_index: {end_token_index}") + context_tokens[0] = context_tokens[0][:, start_token_index:end_token_index] + # logging.debug(f"context_tokens: {context_tokens[0].shape}") + + # Get virtual tokens + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + + # a trick to align with the data format in t5 pretraining + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + # question_tokens = question_tokens + [self.tokenizer.eos_id] + question_tokens = [self.tokenizer.pad_id] + question_tokens + [self.tokenizer.pad_id] + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] == "TEXT" and doc["context_type"] != "TEXT": + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + answer_text_ids = answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + taskname_id = self.tokenizer.text_to_ids(taskname) + elif ( + self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT + ): + taskname_id = -1 + else: + raise ValueError("Invalid virtual prompt source specified") + + input_ids = answer_text_ids + + input_ids, input_ids_len = self.list_to_tensor(input_ids, True) + is_speech = True if doc["answer_type"] != "TEXT" else False + if is_speech: + assert input_ids.dim() == 2 + if self.seq_pattern == "delay_parallel": + + num_codebooks = input_ids.shape[0] + dinput_ids_padded = torch.cat( + [ + torch.zeros_like(input_ids[:, 0:num_codebooks]), + input_ids, + torch.zeros_like(input_ids[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dinput_ids_padded.shape[1] - _c - 1 + dec_input_new.append(dinput_ids_padded[_c, st:et_decoder_input]) + input_ids = torch.stack(dec_input_new, dim=0) + input_ids_len = torch.tensor(input_ids.shape[1]).long() + + # logging.debug( + # f"Return from getitem: \ncontext_tokens:{context_tokens.shape}\ncontext_tokens_len:{context_tokens_len}\n" + # f"question_tokens:{question_tokens.shape}\nquestion_tokens_len:{question_tokens_len}\ninput_ids:{input_ids.shape}\ninput_ids_len{input_ids_len}" + # ) + return ( + context_tokens, + context_tokens_len, + question_tokens, + question_tokens_len, + input_ids, + input_ids_len, + ) + + def collate_fn(self, batch): + (_, context_tokens_len, _, question_tokens_len, _, input_ids_len,) = zip(*batch) + + decoder_input_len = ( + torch.stack(context_tokens_len) + torch.stack(question_tokens_len) + torch.stack(input_ids_len) + ) + max_decoder_input_len = max(decoder_input_len).item() if decoder_input_len is not None else 0 + max_decoder_input_len_1 = max_decoder_input_len - 1 + + decoder_mask = get_mask_from_lengths(decoder_input_len - 1) + speech_mask = get_mask_from_lengths(decoder_input_len - 1) + context_question_mask = torch.ones(speech_mask.shape) + (decoder_input_list, decoder_labels_list,) = ( + [], + [], + ) + cross_attention_prior = torch.zeros(len(batch), max_decoder_input_len_1, max_decoder_input_len_1) + start_of_question_offset = 5 # For "Text to Speech this" - Only used in attention prior computation + end_of_question_offset = 3 # "" - Only used in attention prior computation + for i, sample_tuple in enumerate(batch): + ( + context_tokens, + context_tokens_len, + question_tokens, + question_tokens_len, + input_ids, + input_ids_len, + ) = sample_tuple + + context_tokens_input = context_tokens.clone().contiguous().detach() + for l in range(1, context_tokens_input.shape[0]): + context_tokens_input[l] += self.speech_offset + 1024 * l # TODO: fix hardcode + input_ids_shifted = input_ids.clone().contiguous().detach() + for l in range(1, input_ids_shifted.shape[0]): + input_ids_shifted[l] += self.speech_offset + 1024 * l # TODO: fix hardcode + + complete_input = torch.cat([context_tokens_input, question_tokens, input_ids_shifted], dim=1) + complete_input_padded = general_padding( + complete_input, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + ) + complete_output = torch.cat([context_tokens, question_tokens, input_ids], dim=1) + complete_output_padded = general_padding( + complete_output, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + ) + decoder_labels = complete_output_padded[:, 1:].contiguous() + decoder_input = complete_input_padded[:, :-1].contiguous() + + decoder_input_list.append(decoder_input) + decoder_labels_list.append(decoder_labels) + + decoder_mask[i, : context_tokens_len + question_tokens_len - 1] = 0 # Mask out context and question + # TODO: jasoli, the speech_mask looks wrong. I shouldn't be masking out the context + speech_mask[ + i, context_tokens_len : context_tokens_len + question_tokens_len + ] = 0 # Mask out context and question + context_question_mask[i, : context_tokens_len + question_tokens_len] = 0 + + if self.spec_aug: + # Derive time width, sometimes based percentage of input length. + time_max_width = max(1, int(input_ids_len.item() * self.time_width)) + time_start_upper_bound = max(1, input_ids_len.item() - time_max_width) + time_start = context_tokens_len.item() + question_tokens_len.item() + time_start_upper_bound += time_start + + # Set time masking + for _ in range(self.time_masks): + start = self._rng.randint(time_start, time_start_upper_bound) + width = self._rng.randint(0, time_max_width) + speech_mask[i, start : start + width] = 0 + + if self.use_attention_prior: + cross_attention_question_prior = torch.from_numpy( + beta_binomial_prior_distribution( + question_tokens_len.item() - start_of_question_offset - end_of_question_offset, + input_ids_len.item() - 1, + scaling_factor=self.attention_prior_scaling_factor, + ) + ) + cross_attention_prior[ + i, + context_tokens_len + + question_tokens_len : context_tokens_len + + question_tokens_len + + input_ids_len + - 1, + context_tokens_len + + start_of_question_offset : context_tokens_len + + question_tokens_len + - end_of_question_offset, + ] = cross_attention_question_prior + # Using causal attention mask for whole input + batch_size = len(decoder_input_list) + attention_mask = torch.tril(torch.ones((batch_size, max_decoder_input_len_1, max_decoder_input_len_1))).view( + batch_size, 1, max_decoder_input_len_1, max_decoder_input_len_1 + ) + + # Convert attention mask from float to bool + attention_mask = attention_mask < 0.5 # Currently not used, not sure if correct either + + decoder_input = torch.stack(decoder_input_list) + decoder_input_p = decoder_input[:, 0, :] if decoder_input.dim() == 3 else decoder_input + position_ids = build_position_ids(decoder_input_p) + data_dict = { + "tokens": decoder_input, + "position_ids": position_ids, + "attention_mask": attention_mask, + "labels": torch.stack(decoder_labels_list), + "speech_mask": speech_mask, # For TTS, can just be loss_mask since answer will always be speech + "loss_mask": decoder_mask, # Mask out context and question and padding + "attention_prior": cross_attention_prior, + "context_question_mask": context_question_mask, + } + + return data_dict diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py new file mode 100644 index 000000000000..7755b9d9bdbf --- /dev/null +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py @@ -0,0 +1,1212 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import random +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +import webdataset as wd +from omegaconf import OmegaConf + +from nemo.collections.asr.data.audio_to_text import ( + _speech_collate_fn, + cache_datastore_manifests, + expand_sharded_filepaths, + shard_manifests_if_needed, +) +from nemo.collections.common.parts.preprocessing import collections +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import T5Sentinel +from nemo.collections.nlp.modules.common import VirtualPromptSource +from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + beta_binomial_prior_distribution, + general_padding, +) +from nemo.core.classes import IterableDataset +from nemo.utils import logging + +__all__ = ['T5SpeechLMTarredDataset', 'GPTSpeechLMTarredDataset'] + + +@dataclass +class G2PConfig: + _target_: str = "nemo.collections.tts.g2p.models.en_us_arpabet.EnglishG2p" + phoneme_dict: str = "scripts/tts_dataset_files/cmudict-0.7b_nv22.10" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + + +@dataclass +class TextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer" + punct: bool = True + stresses: bool = True + chars: bool = True + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: G2PConfig = G2PConfig() + + +@dataclass +class TextTokenizerConfig: + text_tokenizer: TextTokenizer = TextTokenizer() + + +def _get_default_text_tokenizer_conf(): + text_tokenizer: TextTokenizerConfig = TextTokenizerConfig() + return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) + + +def pad_text_to_speech_dims(text_tensor, pad_id): + token_len = text_tensor.shape[0] + empty_padding = torch.ones((7, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id + return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0) + + +# tokenizer_config = _get_default_text_tokenizer_conf() +# phoneme_tokenizer = instantiate(tokenizer_config).text_tokenizer + + +class InstructionTuningManifestProcessor: + """ + Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + parser: Str for a language specific preprocessor or a callable. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + bos_id: Id of beginning of sequence symbol to append if not None. + eos_id: Id of end of sequence symbol to append if not None. + pad_id: Id of pad symbol. Defaults to 0. + """ + + def __init__( + self, + manifest_filepath: str, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + max_utts: int = 0, + index_by_file_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + + # ASRAudioText( + self.collection = collections.InstructionTuningAudioText( + manifests_files=manifest_filepath, + min_duration=min_duration, + max_duration=max_duration, + max_seq_length=max_seq_length, + max_number=max_utts, + index_by_file_id=index_by_file_id, + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + +class _TarredInstructionTuningDataset(IterableDataset): + """ + A similar Dataset to the AudioToCharDataset/AudioToBPEDataset, but which loads tarred audio files. + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + sample_rate: int, + shuffle_n: int = 0, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + self.shard_manifests = shard_manifests + + # Shard manifests if necessary and possible and then expand the paths + manifest_filepath = shard_manifests_if_needed( + shard_manifests=shard_manifests, + shard_strategy=shard_strategy, + manifest_filepaths=manifest_filepath, + world_size=world_size, + global_rank=global_rank, + ) + + # If necessary, cache manifests from object store + cache_datastore_manifests(manifest_filepaths=manifest_filepath) + + self.manifest_processor = InstructionTuningManifestProcessor( + manifest_filepath=manifest_filepath, + max_duration=max_duration, + min_duration=min_duration, + max_seq_length=max_seq_length, + max_utts=0, + index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + self.len = self._compute_len() + self.return_sample_id = return_sample_id + + audio_tar_filepaths = expand_sharded_filepaths( + sharded_filepaths=audio_tar_filepaths, + shard_strategy=shard_strategy, + world_size=world_size, + global_rank=global_rank, + ) + + if shuffle_n > 0: + # Only shuffle training data tar files + logging.info("Shuffling Tar files") + custom_rng = random.Random() + custom_rng.shuffle(audio_tar_filepaths) + logging.info("Done shuffling Tar files") + logging.info(audio_tar_filepaths[:10]) + + self.sample_rate = sample_rate + + # Put together WebDataset + self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) + + if shuffle_n > 0: + self._dataset = self._dataset.shuffle(shuffle_n) + else: + logging.info("WebDataset will not shuffle files within the tar files.") + + self._dataset = ( + self._dataset.rename(key='__key__', answer='pt', context='context.pt') + .to_tuple('key', 'answer', 'context') + .pipe(self._filter) + .pipe(self._loop_offsets) + .map(f=self._build_sample) + ) + + def _filter(self, iterator): + """This function is used to remove samples that have been filtered out by ASRAudioText already. + Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample + that was filtered out (e.g. for duration). + Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard, + which may make your code hang as one process will finish before the other. + """ + + class TarredAudioFilter: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + + def __iter__(self): + return self + + def __next__(self): + while True: + audio_filename, answer_bytes, context_bytes = next(self.iterator) + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + if file_id in self.collection.mapping: + return audio_filename, answer_bytes, context_bytes + + return TarredAudioFilter(self.manifest_processor.collection) + + def _loop_offsets(self, iterator): + """This function is used to iterate through utterances with different offsets for each file. + """ + + class TarredAudioLoopOffsets: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + self.current_fn = None + self.current_bytes = None + self.current_context_bytes = None + self.offset_id = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.current_fn is None: + self.current_fn, self.current_bytes, self.current_context_bytes = next(self.iterator) + self.offset_id = 0 + else: + offset_list = self.collection.mapping[self.current_fn] + if len(offset_list) == self.offset_id + 1: + self.current_fn, self.current_bytes, self.current_context_bytes = next(self.iterator) + self.offset_id = 0 + else: + self.offset_id += 1 + + return self.current_fn, self.current_bytes, self.current_context_bytes, self.offset_id + + return TarredAudioLoopOffsets(self.manifest_processor.collection) + + def _collate_fn(self, batch): + return _speech_collate_fn(batch) + + def _build_sample(self, tup): + """Builds the training sample by combining the data from the WebDataset with the manifest info. + """ + audio_filename, encodec, ref_encodec, offset_id = tup + return audio_filename, encodec, ref_encodec, offset_id + + def get_manifest_sample(self, sample_id): + return self.manifest_processor.collection[sample_id] + + def __iter__(self): + return self._dataset.__iter__() + + def _compute_len(self): + if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized(): + my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda() + torch.distributed.all_reduce(my_len) + my_len = my_len.int() + logging.info(f'Sharded manifests: Total length: {my_len}') + else: + my_len = len(self.manifest_processor.collection) + + return my_len + + def __len__(self): + return self.len + + +class T5SpeechLMTarredDataset(_TarredInstructionTuningDataset): + """ + The dataset class for prompt-tuning or p-tuning pretrained T5 SpeechLM models. + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + tokenizer, + virtual_prompt_source: VirtualPromptSource, + task_templates: dict, + pseudo_tokens, + pad_token_id: str, + max_seq_length: int, + sample_rate: int, + shuffle_n: int = 0, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + for_train: bool = True, + decoder_starts_with_pad: bool = False, + add_eos_to_decoder_output: bool = True, + add_sentinel_to_input: bool = True, + ul2_prompt_token: str = None, + segment_max_duration: Optional[int] = None, + trim: bool = False, + trim_ref: Optional[float] = None, + trim_top_db: Optional[int] = None, + trim_frame_length: Optional[int] = None, + trim_hop_length: Optional[int] = None, + pad_multiple: int = 1, + pitch_augment: bool = False, + speech_offset: Optional[int] = None, + train_task: Optional[str] = None, + seq_pattern: Optional[str] = "parallel", + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: Optional[bool] = False, + lm_vocab_size: Optional[int] = None, + use_attention_prior: Optional[bool] = False, + attention_prior_scaling_factor: Optional[float] = 1.0, + cross_attention_epsilon: Optional[float] = 0.0, + num_speech_codebooks: Optional[int] = 8, + **kwargs, + ): + """ + Only speech parameters are explained here. + segment_max_duration: Optional[int] = None, - Speech max segment duration + trim: bool = False, - speech parameter + trim_ref: Optional[float] = None, - speech parameter + trim_top_db: Optional[int] = None, - speech parameter + trim_frame_length: Optional[int] = None, - speech parameter + trim_hop_length: Optional[int] = None, - speech parameter + pad_multiple: int = 1, - speech parameter + pitch_augment: bool = False, - speech parameter + speech_offset: Optional[int] = None, - if speech tokens then add this offset to the token indices to distinguish between text and speech tokens. + **kwargs, + """ + # These two variables need to be set before calling super().__init__() because the parent class calls `load_data()` which requires these attributes. + self.decoder_starts_with_pad = decoder_starts_with_pad + self.add_eos_to_decoder_output = add_eos_to_decoder_output + self.add_sentinel_to_input = add_sentinel_to_input + self.ul2_prompt_token = ul2_prompt_token + # Speech related variables + # self.encodec_model = EncodecModel.encodec_model_24khz() + # self.encodec_model.set_target_bandwidth(6.0) + self.base_data_dir = None + self.segment_max_duration = segment_max_duration + self.sample_rate = sample_rate + # self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) + self.pad_multiple = pad_multiple + self.pitch_augment = pitch_augment + self.trim = trim + self.trim_ref = trim_ref if trim_ref is not None else np.max + self.trim_top_db = trim_top_db if trim_top_db is not None else 60 + self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048 + self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512 + self.speech_offset = speech_offset if speech_offset is not None else 3 + self.seq_pattern = seq_pattern + self.min_duration = kwargs.get('min_duration', 0.1) + self.max_duration = kwargs.get('max_duration', 20) + self.use_attention_prior = use_attention_prior + self.attention_prior_scaling_factor = attention_prior_scaling_factor + self.cross_attention_epsilon = cross_attention_epsilon # value of prior for context tokens (b/w 0 and 1) + assert self.cross_attention_epsilon >= 0.0 and self.cross_attention_epsilon <= 1.0 + + self.train_task = train_task + + # Initialized super part + self.tokenizer = tokenizer + self.virtual_prompt_source = virtual_prompt_source + self.task_templates = task_templates + self.pseudo_tokens = pseudo_tokens + self.pseudo_token_ids = set(self.tokenizer.tokens_to_ids(self.pseudo_tokens)) + self.pad_token_id = pad_token_id + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.add_bos = add_bos + self.add_eos = add_eos + self.for_train = for_train + self.use_phoneme_tokenizer = use_phoneme_tokenizer + self.examples = [] + self.lm_vocab_size = tokenizer.vocab_size if lm_vocab_size is None else lm_vocab_size + self.num_speech_codebooks = num_speech_codebooks + + assert self.min_seq_length <= max_seq_length, "Min sequence length should be less than or equal to max" + assert self.max_seq_length > 0, "Max sequence length should be greater than 0" + + self.context_length = kwargs.pop('context_length', None) # only used in gpt dataset atm + + logging.info("Loading and tokenizing dataset ... ") + + super().__init__( + audio_tar_filepaths=audio_tar_filepaths, + manifest_filepath=manifest_filepath, + sample_rate=sample_rate, + shuffle_n=shuffle_n, + min_duration=self.min_duration, + max_duration=self.max_duration, + max_seq_length=max_seq_length, + shard_strategy=shard_strategy, + shard_manifests=shard_manifests, + global_rank=global_rank, + world_size=world_size, + return_sample_id=return_sample_id, + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + self.encodec, self.ref_encodec = None, None + + def _insert_virtual_token_placeholders(self, input_example, virtual_token_splits): + """ Insert the correct number of pseudo tokens at the <|VIRTUAL_PROMPT_n|> markers """ + total_inserted_tokens = 0 + + for idx in range(len(virtual_token_splits)): + split_start = total_inserted_tokens + split_end = total_inserted_tokens + virtual_token_splits[idx] + pseudo_tokens_for_split = "".join(self.pseudo_tokens[split_start:split_end]) + input_example = input_example.replace(f'<|VIRTUAL_PROMPT_{idx}|>', pseudo_tokens_for_split) + total_inserted_tokens = split_end + + return input_example + + def pad_taskname_ids(self, taskname_ids): + # Pad taskname_ids to be the same length for the prompt encoder + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + max_taskname_length = max(len(ids) for ids in taskname_ids) + taskname_ids = [ids + [self.pad_token_id] * (max_taskname_length - len(ids)) for ids in taskname_ids] + taskname_ids = torch.tensor(taskname_ids) + + # Task ids are just used for a look up embeddings for prompt-table + elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT: + taskname_ids = torch.tensor(taskname_ids) + + return taskname_ids + + def _build_sample(self, tup): + audio_filename, self.encodec, self.ref_encodec, offset_id = tup + + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id] + manifest_entry = self.manifest_processor.collection[manifest_idx] + doc = {} + doc['context'] = manifest_entry.context + doc['context_type'] = manifest_entry.context_type + doc['context_duration'] = manifest_entry.context_duration + doc['answer'] = manifest_entry.answer + doc['answer_type'] = manifest_entry.answer_type + doc['answer_duration'] = manifest_entry.answer_duration + doc['question'] = manifest_entry.question + doc['question_type'] = manifest_entry.question_type + + taskname = "squad" + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + truncation_field = self.task_templates[taskname]['truncate_field'] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + question_in_manifest = manifest_entry.question + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + input_dict = self._insert_data_in_template(input_example, prompt_template_fields, doc, answer_field) + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + if "Text to speech this" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + reduced_len = min( + 400, + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)), + ) + start_token_index = random.randint( + 0, total_context_len - reduced_len + ) # start index can be greater than 440 + context_tokens[0] = context_tokens[0][ + :, start_token_index : min(start_token_index + 440, start_token_index + reduced_len) + ] + elif "Next token prediction" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + end_token_index = int(total_context_len * random.uniform(0.01, 0.2)) + context_tokens[0] = context_tokens[0][:, :end_token_index] + + # Get virtual tokens + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + + # a trick to align with the data format in t5 pretraining + # new + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + question_tokens = question_tokens + [self.tokenizer.eos_id] + + # Try to truncate input text to fit into the max sequence length + if self._get_len(context_tokens, question_tokens, virtual_tokens) > self.max_seq_length: + context_tokens, question_tokens, virtual_tokens = self._truncate_input_speech( + context_tokens, question_tokens, virtual_tokens + ) + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] != "SPEECH" and doc["context_type"] == "SPEECH": + question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id) + if doc["context_type"] != "SPEECH" and doc["question_type"] == "SPEECH": + context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id) + context_tokens = context_tokens.to(question_tokens.device) + context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + if end_token_index > -1: + answer_ids[0] = answer_ids[0][:, end_token_index:] + + if self.decoder_starts_with_pad: + answer_text_ids = [self.tokenizer.pad_id] + else: + answer_text_ids = [self.tokenizer.bos_id] + # a trick to align with the data format in t5 pretraining + # if self.add_sentinel_to_input: + # answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + answer_text_ids += answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + # Skip example if the final length doesn't fit length requirements even after truncation + if ( + self.min_seq_length + <= self._get_element_len(context_and_question_tokens) + self._get_element_len(virtual_tokens) + <= self.max_seq_length + and self.min_seq_length <= self._get_element_len(answer_text_ids) <= self.max_seq_length + ): + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + taskname_id = self.tokenizer.text_to_ids(taskname) + elif ( + self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT + ): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged. + taskname_id = -1 + else: + raise ValueError("Invalid virtual prompt source specified") + + dec_input = None + dec_labels = None + + if answer_field in doc.keys(): # training and validation + dec_input = answer_text_ids[:-1] + dec_labels = answer_text_ids[1:] + + dec_input, dec_input_len = self.list_to_tensor(dec_input, True) + dec_labels, dec_labels_len = self.list_to_tensor(dec_labels, True) + is_speech = True if doc["answer_type"] == "SPEECH" else False + if is_speech: + assert dec_input.dim() == 2 and dec_labels.dim() == 2 + if self.seq_pattern == "delay_parallel": + num_codebooks = dec_input.shape[0] + dec_input_padded = torch.cat( + [ + torch.zeros_like(dec_input[:, 0:num_codebooks]), + dec_input, + torch.zeros_like(dec_input[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_labels_padded = torch.cat( + [ + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + dec_labels, + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + dec_labels_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dec_input_padded.shape[1] - _c + et_decoder_labels = dec_labels_padded.shape[1] - _c + dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) + dec_labels_new.append(dec_labels_padded[_c, st:et_decoder_labels]) + dec_input = torch.stack(dec_input_new, dim=0) + dec_labels = torch.stack(dec_labels_new, dim=0) + dec_input_len = torch.tensor(dec_input.shape[1]).long() + dec_labels_len = torch.tensor(dec_labels.shape[1]).long() + + enc_len = context_tokens_len + question_tokens_len + virtual_tokens_len + # TODO: Remove hardcoding + num_question_offset = 4 # For "Text to Speech this" + + cross_attention_prior = torch.zeros(dec_labels_len, enc_len) + self.cross_attention_epsilon + if self.use_attention_prior: + cross_attention_question_prior = torch.from_numpy( + beta_binomial_prior_distribution( + question_tokens_len.item() - num_question_offset, + dec_labels_len.item(), + scaling_factor=self.attention_prior_scaling_factor, + ) + ) + cross_attention_prior[ + :, virtual_tokens_len + context_tokens_len + num_question_offset : + ] = cross_attention_question_prior + + return ( + taskname_id, + virtual_tokens, + virtual_tokens_len, + context_and_question_tokens, + context_tokens_len + question_tokens_len, + dec_input, + dec_input_len, + dec_labels, + dec_labels_len, + is_speech, + cross_attention_prior, + ) + + def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens): + total_len = self._get_len(context_tokens, question_tokens, virtual_tokens) + context_len = self._get_element_len(context_tokens) + truncation_length = total_len - self.max_seq_length + 1 + context_tokens[0] = context_tokens[0][:, min(truncation_length, context_len) :] + return context_tokens, question_tokens, virtual_tokens + + def list_to_tensor(self, element, fill=False): + """ + Convert list to tensor. The list might contain integers, 2D-tensors (speech tokens) and combination of two. + If all of them are ints, simply convert to tensor + If combination of 2D-tensor and ints. Convert int to the dimension of the tensor. + example: [2, 4, 5] -> torch.tensor([2, 4, 5]) + example: [2, torch.tensor([[4, 5, 6], [6, 7, 8]])] -> torch.tensor( [[-1, 4, 5, 6], [2, 6, 7, 8]] ) + """ + ret, ln = None, None + if element is None: + return ret, ln + + max_len = max([1 if isinstance(item, int) else len(item) for item in element]) + if max_len == 1: + ret = torch.as_tensor(element).long() + ln = torch.tensor(ret.size()[0]).long() + else: + ret = [] + for e in element: + if isinstance(e, int): + tmp = torch.full((8, 1), e if fill else -1) + tmp[7] = e + else: + tmp = e + ret.append(tmp) + ret = torch.cat(ret, dim=1) + ln = torch.tensor(ret.size()[1]).long() + return ret, ln + + def _get_text_tokens(self, text): + input_ids = self.tokenizer.text_to_ids(text) + return input_ids + + def _get_phoneme_tokens(self, text): + input_ids = phoneme_tokenizer.encode(text) + input_ids_adjusted = [_id + self.lm_vocab_size for _id in input_ids] + return input_ids_adjusted + + def _pad_wav_to_multiple(self, wav): + if self.pad_multiple > 1: + if wav.shape[0] % self.pad_multiple != 0: + wav = torch.cat( + [wav, torch.zeros(self.pad_multiple - wav.shape[0] % self.pad_multiple, dtype=torch.float)] + ) + return wav + + def _get_element_len(self, element): + length = 0 + if isinstance(element, list): + for e in element: + if isinstance(e, int): + length += 1 + else: + if e.dim() > 1: + length += e.size()[1] + else: + length += e.size()[0] + else: + if element.dim() > 1: + length += element.size()[1] + else: + length += element.size()[0] + return length + + def _get_len(self, context_tokens, question_tokens, virtual_tokens): + length = 0 + length += self._get_element_len(context_tokens) + length += self._get_element_len(question_tokens) + length += self._get_element_len(virtual_tokens) + return length + + def _get_speech_tokens(self, field): + + # Convert to codes + codec_codes, codec_codes_length = None, None # Codes + + if self.train_task == 'tts': + if field == 'context': + self.ref_encodec = torch.load(io.BytesIO(self.ref_encodec), map_location="cpu").long() + codec_codes = self.ref_encodec + elif field == 'answer': + self.encodec = torch.load(io.BytesIO(self.encodec), map_location="cpu").long() + codec_codes = self.encodec + elif self.train_task == 'asr': + if field == 'context': + self.ref_encodec = torch.load(io.BytesIO(self.ref_encodec), map_location="cpu").long() + codec_codes = self.ref_encodec + + # codec_codes_length = torch.tensor(codec_codes.shape[1]).long() + + # Convert codes to codes corresponding to megatron embedding layer + codec_codes[0] = (codec_codes[0] + self.speech_offset).long() + + return codec_codes + + def _get_tokens(self, doc, field, field_data): + if f"{field}_type" not in doc.keys(): + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'TEXT': + _text = field_data.strip(" ") + if self.use_phoneme_tokenizer: + instruction_tokens = self._get_text_tokens("Phoneme TTS") + field_tokens = self._get_phoneme_tokens(_text.replace("Text to speech this ", "")) + field_tokens = instruction_tokens + field_tokens + else: + field_tokens = self._get_text_tokens(_text) # list of ids + elif doc[f"{field}_type"] == 'SPEECH': + dur = -1 + if f"{field}_duration" in doc: + dur = doc[f"{field}_duration"] + field_tokens = self._get_speech_tokens(field) # list of ids + if not isinstance(field_tokens, list): + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'TOKENS': + # Do nothing; already tokenized + field_tokens = field_data + else: + raise Exception(f"{field}_type not recognized") + return field_tokens + + def _insert_data_in_template(self, input_example, prompt_template_fields, doc, answer_field): + """ Format the input example according to the template """ + out_dict = {} + for field in prompt_template_fields: + # discard the last one, {label} / {answer} + # Or if some fields from the template aren't present, e.g. {answer} during inference + # just remove that field from the template, leaving the space blank + if field == answer_field or field not in doc.keys(): + continue + # out_dict[field] = "" + + elif field in doc.keys(): + field_data = doc[field] + if f"{field}_type" not in doc.keys(): + doc[f"{field}_type"] = "TEXT" + raise Exception(f"{field}_type does not exist in doc") + else: + out_dict[field] = self._get_tokens(doc, field, field_data) + return out_dict + + def get_position_ids(self, virtual_token, context_and_qquestion): + enc_input = [] + enc_input.append(virtual_token) + if context_and_qquestion.dim() > 2: + enc_input.append(context_and_qquestion[:, 0, :]) + else: + enc_input.append(context_and_qquestion) + + enc_input = torch.cat(enc_input, dim=1) + + enc_input_p = enc_input[:, 0, :] if enc_input.dim() == 3 else enc_input + return build_position_ids(enc_input_p).contiguous() + + def collate_fn(self, batch): + """ Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch """ + + data_dict = self.pad_batch_and_build_loss_mask(batch) + + position_ids = self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens']) + + return ( + data_dict['virtual_tokens'], + data_dict['context_and_question_tokens'], + data_dict['enc_mask'], + data_dict['dec_input'], + data_dict['dec_input_mask'], + data_dict['dec_labels'], + data_dict['dec_labels_mask'], + position_ids, + data_dict['taskname_id'], + data_dict['speech_mask'], + data_dict['context_and_question_tokens_lens'], + data_dict['cross_attention_prior'], + ) + + def pad_batch_and_build_loss_mask(self, batch): + """ Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask """ + ( + taskname_ids, + _, + virtual_tokens_len, + _, + context_and_question_tokens_len, + _, + dec_input_len, + _, + dec_labels_len, + _, + _, + ) = zip(*batch) + + taskname_ids = self.pad_taskname_ids(taskname_ids) + + max_virtual_tokens_len = max(virtual_tokens_len).item() if virtual_tokens_len is not None else 0 + if isinstance(virtual_tokens_len, tuple): + virtual_tokens_len = torch.stack(virtual_tokens_len) + virtual_mask = get_mask_from_lengths(virtual_tokens_len) + + max_context_and_question_tokens_len = ( + max(context_and_question_tokens_len).item() if context_and_question_tokens_len is not None else 0 + ) + if isinstance(context_and_question_tokens_len, tuple): + context_and_question_tokens_len = torch.stack(context_and_question_tokens_len) + context_and_question_mask = get_mask_from_lengths(context_and_question_tokens_len) + + max_dec_input_len = max(dec_input_len).item() if dec_input_len is not None else 0 + max_dec_labels_len = max(dec_labels_len).item() if dec_labels_len is not None else 0 + enc_mask = torch.cat([virtual_mask, context_and_question_mask], dim=1) + + ( + virtual_tokens_list, + context_question_tokens_list, + dec_input_list, + dec_input_mask_list, + dec_labels_list, + dec_labels_mask_list, + speech_mask_list, + cross_attention_prior_list, + ) = ( + [], + [], + [], + [], + [], + [], + [], + [], + ) + + for i, sample_tuple in enumerate(batch): + ( + _, + virtual_token, + virtual_token_len, + context_and_question_token, + context_and_question_token_len, + dec_input, + dec_input_len, + dec_label, + dec_label_len, + is_speech, + cross_attention_prior, + ) = sample_tuple + + virtual_tokens_list.append( + general_padding( + virtual_token, virtual_token_len.item(), max_virtual_tokens_len, pad_value=self.tokenizer.pad_id + ) + ) + + context_tokens_padded = general_padding( + context_and_question_token, + context_and_question_token_len.item(), + max_context_and_question_tokens_len, + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims(context_tokens_padded, self.tokenizer.pad_id) + context_question_tokens_list.append(context_tokens_padded) + + if max_dec_input_len > 0: + dec_input_padded = general_padding( + dec_input, dec_input_len.item(), max_dec_input_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_input_padded.shape) < 2: + dec_input_padded = pad_text_to_speech_dims(dec_input_padded, self.tokenizer.pad_id) + dec_input_list.append(dec_input_padded) + dec_mask = ( + torch.as_tensor(([1] * dec_input_len) + ([0] * (max_dec_input_len - dec_input_len))) + .long() + .contiguous() + ) + dec_input_mask_list.append(dec_mask) + speech_mask = dec_mask if is_speech else torch.zeros(dec_mask.shape) + speech_mask_list.append(speech_mask) + + if max_dec_labels_len > 0: + loss_mask = ( + torch.as_tensor(([1] * dec_label_len) + ([0] * (max_dec_labels_len - dec_label_len))) + .long() + .contiguous() + ) + dec_label_padded = general_padding( + dec_label, dec_label_len.item(), max_dec_labels_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_label_padded.shape) < 2: + dec_label_padded = pad_text_to_speech_dims(dec_label_padded, self.tokenizer.pad_id) + dec_labels_list.append(dec_label_padded) + dec_labels_mask_list.append(loss_mask) + + _p0 = max_dec_labels_len - dec_label_len + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len + - context_and_question_token_len + - virtual_token_len + ) + + cross_attention_prior_padded = torch.nn.functional.pad( + cross_attention_prior, pad=(0, _p1, 0, _p0), mode="constant", value=1, + ) + cross_attention_prior_list.append(cross_attention_prior_padded) + + data_dict = { + "taskname_id": taskname_ids, + "virtual_tokens": torch.stack(virtual_tokens_list), + "context_and_question_tokens": torch.stack(context_question_tokens_list), + "enc_mask": enc_mask, + "dec_input": torch.stack(dec_input_list) if len(dec_input_list) > 0 else None, + "dec_input_mask": torch.stack(dec_input_mask_list) if len(dec_input_mask_list) > 0 else None, + "dec_labels": torch.stack(dec_labels_list) if len(dec_labels_list) > 0 else None, + "dec_labels_mask": torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None, + "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, + "context_and_question_tokens_lens": context_and_question_tokens_len, + "cross_attention_prior": torch.stack(cross_attention_prior_list) + if len(cross_attention_prior_list) > 0 + else None, + } + + return data_dict + + +class GPTSpeechLMTarredDataset(T5SpeechLMTarredDataset): + """ No support for cross attention here yet""" + + def _build_sample(self, tup): + audio_filename, self.encodec, self.ref_encodec, offset_id = tup + + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id] + manifest_entry = self.manifest_processor.collection[manifest_idx] + doc = {} + doc['context'] = manifest_entry.context + doc['context_type'] = manifest_entry.context_type + doc['context_duration'] = manifest_entry.context_duration + doc['answer'] = manifest_entry.answer + doc['answer_type'] = manifest_entry.answer_type + doc['answer_duration'] = manifest_entry.answer_duration + doc['question'] = manifest_entry.question + doc['question_type'] = manifest_entry.question_type + + taskname = "squad" + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + input_dict = self._insert_data_in_template(input_example, prompt_template_fields, doc, answer_field) + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + + total_context_len = context_tokens[0].size()[1] + context_3s = 3 * 75 + if total_context_len > context_3s: + start_token_index = random.randint(0, total_context_len - context_3s) + # logging.debug(f"start_token_index: {start_token_index}") + end_token_index = start_token_index + min(context_3s, total_context_len) + # logging.debug(f"end_token_index: {end_token_index}") + context_tokens[0] = context_tokens[0][:, start_token_index:end_token_index] + + # Get virtual tokens + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + + # a trick to align with the data format in t5 pretraining + # new + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + question_tokens = [self.tokenizer.pad_id] + question_tokens + [self.tokenizer.pad_id] + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] != "SPEECH" and doc["context_type"] == "SPEECH": + question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id) + if doc["context_type"] != "SPEECH" and doc["question_type"] == "SPEECH": + context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id) + context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + answer_text_ids = answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + # Skip example if the final length doesn't fit length requirements even after truncation + input_ids = answer_text_ids + input_ids, input_ids_len = self.list_to_tensor(input_ids, True) + input_len = self._get_element_len(context_and_question_tokens) + self._get_element_len(answer_text_ids) - 1 + if input_len > self.max_seq_length: + # logging.debug(f"Overflow. input_len:{input_len}. self.max_seq_length:{self.max_seq_length}. overflow_len:{self.max_seq_length - input_len}.") + overflow_len = self.max_seq_length - input_len + # truncate context if context after truncation is at least 1s + # else truncate answer as final option + if context_tokens_len - overflow_len > 75: + # logging.debug(f"Cutting context. context_tokens:{context_tokens.shape}. context_tokens_len:{context_tokens_len}.") + context_tokens = context_tokens[:, : context_tokens_len - overflow_len] + context_tokens_len = context_tokens_len - overflow_len + # logging.debug(f"Cut context. context_tokens:{context_tokens.shape}. context_tokens_len:{context_tokens_len}.") + else: + # logging.debug(f"Cutting answer. input_ids:{input_ids.shape}. input_ids_len:{input_ids_len}.") + input_ids = input_ids[:, : input_ids_len - overflow_len] + input_ids_len = input_ids_len - overflow_len + # logging.debug(f"Cut answer. input_ids:{input_ids.shape}. input_ids_len:{input_ids_len}.") + + is_speech = True if doc["answer_type"] == "SPEECH" else False + if is_speech: + assert input_ids.dim() == 2 + if self.seq_pattern == "delay_parallel": + num_codebooks = input_ids.shape[0] + dec_input_padded = torch.cat( + [ + torch.zeros_like(input_ids[:, 0:num_codebooks]), + input_ids, + torch.zeros_like(input_ids[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dec_input_padded.shape[1] - _c + dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) + input_ids = torch.stack(dec_input_new, dim=0) + input_ids_len = torch.tensor(input_ids.shape[1]).long() + + return ( + context_tokens, + context_tokens_len, + question_tokens, + question_tokens_len, + input_ids, + input_ids_len, + ) + + def collate_fn(self, batch): + (_, context_tokens_len, _, question_tokens_len, _, input_ids_len,) = zip(*batch) + + decoder_input_len = ( + torch.stack(context_tokens_len) + torch.stack(question_tokens_len) + torch.stack(input_ids_len) + ) + max_decoder_input_len = max(decoder_input_len).item() if decoder_input_len is not None else 0 + + decoder_mask = get_mask_from_lengths(decoder_input_len - 1) + speech_mask = get_mask_from_lengths(decoder_input_len - 1) + context_question_mask = torch.ones(speech_mask.shape) + (decoder_input_list, decoder_labels_list,) = ( + [], + [], + ) + for i, sample_tuple in enumerate(batch): + ( + context_tokens, + context_tokens_len, + question_tokens, + question_tokens_len, + input_ids, + input_ids_len, + ) = sample_tuple + + context_tokens_input = context_tokens.clone().contiguous().detach() + for l in range(1, context_tokens_input.shape[0]): + context_tokens_input[l] += self.speech_offset + 1024 * l # TODO: fix hardcode + input_ids_shifted = input_ids.clone().contiguous().detach() + for l in range(1, input_ids_shifted.shape[0]): + input_ids_shifted[l] += self.speech_offset + 1024 * l # TODO: fix hardcode + + complete_input = torch.cat([context_tokens_input, question_tokens, input_ids_shifted], dim=1) + complete_input_padded = general_padding( + complete_input, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + ) + complete_output = torch.cat([context_tokens, question_tokens, input_ids], dim=1) + complete_output_padded = general_padding( + complete_output, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + ) + decoder_labels = complete_output_padded[:, 1:].contiguous() + decoder_input = complete_input_padded[:, :-1].contiguous() + + decoder_input_list.append(decoder_input) + decoder_labels_list.append(decoder_labels) + + decoder_mask[i, : context_tokens_len + question_tokens_len - 1] = 0 # Mask out context and question + speech_mask[ + i, context_tokens_len : context_tokens_len + question_tokens_len + ] = 0 # Mask out context and question + context_question_mask[i, : context_tokens_len + question_tokens_len] = 0 + + # Using causal attention mask for whole input + batch_size = len(decoder_input_list) + attention_mask = torch.tril( + torch.ones((batch_size, max_decoder_input_len - 1, max_decoder_input_len - 1)) + ).view(batch_size, 1, max_decoder_input_len - 1, max_decoder_input_len - 1) + + # Convert attention mask from float to bool + attention_mask = attention_mask < 0.5 + + decoder_input = torch.stack(decoder_input_list) + decoder_input_p = decoder_input[:, 0, :] if decoder_input.dim() == 3 else decoder_input + position_ids = build_position_ids(decoder_input_p) + data_dict = { + "tokens": decoder_input, + "position_ids": position_ids, + "attention_mask": attention_mask, + "labels": torch.stack(decoder_labels_list), + "speech_mask": speech_mask, # For TTS, can just be loss_mask since answer will always be speech + "loss_mask": decoder_mask, # Mask out context and question and padding + "attention_prior": None, + "context_question_mask": context_question_mask, + } + + return data_dict diff --git a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py index 985897d8df3f..82810f8736be 100644 --- a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py +++ b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py @@ -93,7 +93,7 @@ def __init__( self.ascii_letter_dict = { x: ascii_letter_prefix + x for x in get_grapheme_character_set(locale="en-US", case=ascii_letter_case) } - self.ascii_letter_list = sorted(self.ascii_letter_dict) + self.ascii_letter_list = sorted(self.ascii_letter_dict.values()) self.ascii_letter_case = ascii_letter_case if apply_to_oov_word is None: @@ -181,6 +181,7 @@ def __call__(self, text: str) -> List[str]: `['wo3', 'jin1', 'tian1', 'qu4', 'le5', 'A', 'p', 'p', 'l', 'e', ' ', 'S', 't', 'o', 'r', 'e', ',', ' ', 'mai3', 'le5', 'yi2', 'ge4', 'i', 'P', 'h', 'o', 'n', 'e', '。']` """ + err = False text = set_grapheme_case(text, case=self.ascii_letter_case) pinyin_seq = [] @@ -201,7 +202,10 @@ def __call__(self, text: str) -> List[str]: tone_hyp = pinyin[-1] if tone_hyp in self.tone_dict: syllable = pinyin[:-1] - assert syllable in self.phoneme_dict, f"Syllable <{syllable}> does not exist in the dictionary." + if syllable not in self.phoneme_dict: + err = True + logging.error(f"Syllable <{syllable}> does not exist in the dictionary.") + continue phoneme_seq += self.phoneme_dict[syllable] phoneme_seq.append(self.tone_dict[tone_hyp]) # All pinyin would end up with a number in 1-5, which represents tones of the pinyin. @@ -211,4 +215,6 @@ def __call__(self, text: str) -> List[str]: phoneme_seq.append(self.ascii_letter_dict[tone_hyp]) else: phoneme_seq.append(pinyin) + if err: + logging.error(f"|{text}| contained unknown syllables") return phoneme_seq diff --git a/nemo/collections/tts/models/speechllm/__init__.py b/nemo/collections/tts/models/speechllm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py new file mode 100644 index 000000000000..eb917f0d7af3 --- /dev/null +++ b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py @@ -0,0 +1,445 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer +from torch import Tensor + +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.collections.nlp.metrics.prompt_learning_metrics import AccuracyScore, BLEUScore, ROUGEScores +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common import ( + PromptEncoder, + PromptEncoderType, + VirtualPromptPlaceholderToken, + VirtualPromptSource, + VirtualPromptStyle, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import TextGeneration +from nemo.collections.nlp.parts import utils_funcs +from nemo.utils import AppState + +try: + from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +__all__ = ['MegatronBaseSpeechLM'] + + +class MegatronBaseSpeechLM(MegatronBaseModel, TextGeneration): + """ + Model class for prompt-tuning or p-tuning a pretrained Megatron model. + + Prompt Tuning initalizes virtual prompt embeddings directly from a copy of + certain token embeddings from the the pretrained model's vocabulary + and directly tunes these embedding weights. The token embeddings used in + initalization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. + + P-tuning initializes an LSTM encoder model that generates virtual prompt + embeddings for every task. Each task shares the same encoder. After ptuning + is compelete, the learned virtual prompts can be saved to the prompt table + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous + tasks. This gives p-tuning the same task flexiblity as prompt-tuning. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer) + self.init_model(cfg, trainer) + self.config = self.model_parallel_config + + def init_model(self, cfg: DictConfig, trainer: Trainer): + self.cfg = cfg + + self.load_frozen_model(cfg, trainer) + self.prompt_encoder = None + self.tokenizer = self.frozen_model.tokenizer + + if hasattr(self.frozen_model.cfg, "encoder") and hasattr(self.frozen_model.cfg, "decoder"): + self.hidden_size = ( + self.frozen_model.cfg.encoder.hidden_size + ) # Encoder and decoder need to have the same hidden size and we check for this in the frozen enc-dec model. + else: + self.hidden_size = self.frozen_model.cfg.hidden_size + + self.existing_tasks = list(self.cfg.get('existing_tasks', [])) + self.new_tasks = list(self.cfg.get('new_tasks', [])) + self.virtual_prompt_style = VirtualPromptStyle(cfg.virtual_prompt_style) + + # Load templates for assigning virtual prompt token positions + self.load_task_templates(self.cfg.task_templates) + + if self.first_stage_of_pipeline() and self.virtual_prompt_style in [ + VirtualPromptStyle.P_TUNING, + ]: + # TODO: Handle this when moving GPT prompt learning to the base class. + self.word_embeddings = self.frozen_model.enc_dec_model.encoder_embedding.word_embeddings + + # P-Tuning uses an LSTM Encoder to produce virtual token embeddings + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + self.virtual_prompt_source = VirtualPromptSource.PROMPT_ENCODER + elif self.virtual_prompt_style == VirtualPromptStyle.NO_PROMPT: + self.virtual_prompt_source = VirtualPromptSource.NO_PROMPT + else: + raise ValueError(f"\nvirtual prompt style '{cfg.virtual_prompt_style}'") + + self._reduced_loss_buffer = [] + self._inference_config = None + + # Prepare pseudo token ids for virtual/virtual prompt tokens + self.pseudo_tokens = get_pseudo_tokens(self.max_virtual_tokens) + if isinstance(self.tokenizer, SentencePieceTokenizer): + self.tokenizer.add_special_tokens(self.pseudo_tokens) + else: + self.tokenizer.add_special_tokens({'additional_special_tokens': self.pseudo_tokens}) + self.pseudo_token_ids = self.tokenizer.tokens_to_ids(self.pseudo_tokens) + self.pseudo_token_ids_start = self.pseudo_token_ids[0] if self.pseudo_token_ids else None + self.pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.unk_id + self.decoder_seq_length = cfg.get('decoder_seq_length', 40) + + self.autocast_dtype = utils_funcs.torch_dtype_from_precision(self.cfg.precision) # Mixed precision datatype + # make sure the default pytorch lightning gradient clipping in the basemodel + self.grad_clip_pl_default = True + self.lowest_val_loss = None + self.prompt_encoder = None + + self.enable_autocast = not self.megatron_amp_O2 and self.autocast_dtype in [torch.float16, torch.bfloat16] + + # define validation metric + if self.cfg.get('report_validation_metric', False): + validation_metric = self.cfg.get('validation_metric', 'accuracy') + if validation_metric == 'accuracy': + self.validation_metric = AccuracyScore() + elif validation_metric == 'bleu': + self.validation_metric = BLEUScore() + elif validation_metric == 'rouge': + self.validation_metric = ROUGEScores() + + def load_task_templates(self, task_templates): + """ + Takes in the task template portion of the config and turns + it into a table where each task's prompt template and + the number of virtual tokens to insert in a given part of + the prompt template are specified. + """ + self.task_templates = {} + self.task_id_num_to_name = {} + self.max_virtual_tokens = 0 + + task_id_num = 0 + for task in task_templates: + self.task_templates[task.taskname] = { + "prompt_template": task.prompt_template, + "prompt_template_fields": re.findall("\{(.*?)\}", task.prompt_template), + "answer_only_loss": task.get("answer_only_loss", False), + "answer_field": task.get("answer_field", None), + "truncate_field": task.truncate_field, + "total_virtual_tokens": task.total_virtual_tokens, + "virtual_token_splits": task.virtual_token_splits, + "task_id_num": task_id_num, + } + + self.max_virtual_tokens = max(self.max_virtual_tokens, task.total_virtual_tokens) + self.task_id_num_to_name[task_id_num] = task.taskname + task_id_num += 1 + + # Check that all new tasks have the same total num virtual tokens + # Num virtual tokens for new tasks don't need to match num used for previously tuned tasks + if self.new_tasks: + new_task_name = self.new_tasks[0] + self.total_new_task_virtual_tokens = self.task_templates[new_task_name]["total_virtual_tokens"] + + assert all( + self.task_templates[taskname]["total_virtual_tokens"] == self.total_new_task_virtual_tokens + for taskname in self.new_tasks + ), "Total virtual tokens for each task tuned simultaneously must match. If you want to use a different number of virtual tokens for different tasks, tune them separately." + + def init_prompt_encoder(self): + """ + Init the prompt encoder needed for p-tuning on a new task + """ + # Total virtual tokens should be the same across all new tasks, so just need one + new_task = self.new_tasks[0] + total_virtual_tokens = self.task_templates[new_task]["total_virtual_tokens"] + + encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "tpmlp").lower()) + self.prompt_encoder = PromptEncoder( + config=self.model_parallel_config, + encoder_type=encoder_type, + total_virtual_tokens=total_virtual_tokens, + token_dim=self.hidden_size, + hidden_size=self.cfg.p_tuning.get("encoder_hidden", self.hidden_size // 2), + lstm_dropout=self.cfg.p_tuning.get("dropout", 0.0), + num_layers=self.cfg.p_tuning.get("num_layers", 2), + init_std=self.cfg.p_tuning.get("init_std", 0.023), + taskname=new_task, + ) + + def freeze_existing_word_embeddings(self): + """Freeze params of existing virtual prompts that should not be tuned further + """ + # Make sure word embeddings are frozen + for params in self.word_embeddings.parameters(): + params.requires_grad = False + + def state_dict(self): + """ + Custom state dict that only contains prompt table and prompt encoder parameters. + No frozen model parameters are stored in the state dict. Prompt encoder parameters + are only in state dict for intermediate checkpoints saved during training. Final + nemo checkpoints at the end of training will contain prompt table parameters only. + """ + state_dict_ = {} + state_dict_["frozen_model_enc_dec_model"] = self.frozen_model.enc_dec_model.state_dict() + state_dict_["word_embeddings"] = self.word_embeddings.state_dict() + if self.prompt_encoder is not None: + state_dict_["prompt_encoder"] = self.prompt_encoder.state_dict() + + return state_dict_ + + def load_state_dict(self, state_dict, strict: bool = True): + """ + Custom load state dict method that only loads prompt table and prompt encoder + parameters. Matching load method for this class' custom state dict method. + """ + self.init_prompt_encoder() + self.frozen_model.enc_dec_model.load_state_dict(state_dict["frozen_model_enc_dec_model"], strict) + self.word_embeddings.load_state_dict(state_dict["word_embeddings"], strict) + if 'prompt_encoder' in state_dict: + self.prompt_encoder.load_state_dict(state_dict["prompt_encoder"], strict) + + # Not sure why when we resume training the prompt encoder is on cpu + # Because it's not created on init - Should really be moved to init + self.prompt_encoder.to("cuda") + + def embed_input(self, input_ids: Tensor, taskname_ids: Tensor, use_cached_reps: bool): + """ + Replaces the virtual tokens in the input_ids with embeddings + calculated from either the 'prompt_table' or 'prompt_encoder'. + The virtual token placeholders have token_ids listed in + `self.pseudo_token_ids`. + + params: + input_ids: the input token ids + taskname_ids: the NLP task tag token ids + returns: + the token embedding for the LM model. + """ + # Replace virtual token ids with padding for forward pass through vocab embeddings + discrete_token_ids = input_ids.clone() + discrete_token_ids[(input_ids >= self.pseudo_token_ids_start)] = self.pad_token_id + discrete_token_embeds = self.word_embeddings(discrete_token_ids).clone() + + # Find the indicies where virtual tokens should be inserted + virtual_token_locations = input_ids >= self.pseudo_token_ids_start + + # If there are no virtual tokens, just return discrete token embeds + if not virtual_token_locations.any(): + return discrete_token_embeds + + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + # taskname_embeddings = self.word_embeddings(taskname_ids) + batch_size, _ = taskname_ids.size() + virtual_token_embeds = self.prompt_encoder(batch_size=batch_size, use_cached_reps=use_cached_reps) + else: + raise ValueError("invalid VirtualPromptSource.") + + # Create index template specifying where virtual token embeddings should be placed + batch_size, _, embedding_size = discrete_token_embeds.shape + virtual_token_index = virtual_token_locations.nonzero().reshape((batch_size, -1, 2))[:, :, 1][:, :, None] + virtual_token_index = virtual_token_index.expand( + batch_size, self.total_new_task_virtual_tokens, embedding_size + ) + + # Make sure discrete_token_embeds and virtual_token_embeds share the same dtype + discrete_token_embeds = discrete_token_embeds.type(virtual_token_embeds.dtype) + + # Insert virtual token embeddings where they belong amoung the discrete token embeddings + discrete_token_embeds.scatter_(1, virtual_token_index, virtual_token_embeds) + input_embeds = discrete_token_embeds + + return input_embeds + + def on_train_end(self): + # Save p-tuned prompts to prompt table for inference or future task training + self.save_to(save_path=self.cfg.nemo_path) + + def setup(self, stage=None): + if stage == 'predict' and self.first_stage_of_pipeline(): + return + + self.setup_test_data() + if stage == 'test': + return + + if self.first_stage_of_pipeline(): + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + if self.prompt_encoder is None: + self.init_prompt_encoder() + + self.setup_training_data() + self.setup_validation_data() + + def setup_training_data(self, training_data_config=None): + if self.cfg.data.get('train_ds', None): + self._train_ds, self._train_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.train_ds, + batch_size=self.cfg.global_batch_size, + for_train=True, + drop_last=True, + shuffle=True, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('train_manifest', None): + self._train_ds, self._train_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.train_manifest, + audio_path=self.cfg.data.train_audio_path, + batch_size=self.cfg.global_batch_size, + for_train=True, + drop_last=True, + shuffle=self.cfg.data.shuffle, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def setup_validation_data(self, validation_data_config=None): + if self.cfg.data.get('validation_ds', None): + self._validation_ds, self._validation_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.validation_ds, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=True, + drop_last=self.cfg.get("validation_drop_last", True), + shuffle=False, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('validation_manifest', None): + self._validation_ds, self._validation_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.validation_manifest, + audio_path=self.cfg.data.validation_audio_path, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=True, + drop_last=self.cfg.get("validation_drop_last", True), + shuffle=0, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def setup_test_data(self, test_data_config=None): + if self.cfg.data.get('test_ds', None): + self._test_ds, self._test_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.test_ds, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=False, + drop_last=False, + shuffle=False, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('test_manifest', None): + self._test_ds, self._test_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.test_manifest, + audio_path=self.cfg.data.test_audio_path, + batch_size=self.cfg.global_batch_size, + for_train=False, + drop_last=False, + shuffle=0, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gbs): + # This should happen only on the last batch of the dataset. + if global_batch_size_per_gpu != gbs // parallel_state.get_data_parallel_world_size(): + # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def _reconfigure_batch_sizes(self, gbs: int, mbs: int): + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=gbs, + micro_batch_size=mbs, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def set_inference_config(self, inference_config): + self._inference_config = inference_config + + def get_inference_config(self): + return self._inference_config + + def set_input_tensor(self, input_tensor): + pass + + def first_stage_of_pipeline(self): + pass + + @classmethod + def list_available_models(cls): + pass + + def load_frozen_model(self, cfg, trainer): + pass + + +def get_pseudo_tokens(num_virtual_tokens): + """ + Takes in an integer and returns a list of strings where each string + is a numbered virtual token placeholder. If + num_virtual_tokens = 3, then this function returns: + + ["", "", ""] + + Args: + num_virtual_tokens: (int) Number of virtual token strings you want to make + + returns a list of string. + + """ + pseudo_tokens = [ + VirtualPromptPlaceholderToken.BASE.value + str(i) + VirtualPromptPlaceholderToken.END.value + for i in range(num_virtual_tokens) + ] + + return pseudo_tokens diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py new file mode 100644 index 000000000000..5e47db21795f --- /dev/null +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -0,0 +1,2509 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import random +import json +import os +import string +from typing import Any, List +from functools import partial + +import editdistance +import numpy as np +import soundfile as sf +import torch +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from omegaconf.omegaconf import open_dict +from pytorch_lightning.trainer.trainer import Trainer + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceSpeechLLMTTSTokenizer +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model +from nemo.collections.nlp.models.language_modeling.megatron_t5_sft_model import MegatronT5SFTModel +from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import ( + MegatronTokenLevelEncoderDecoderSpeechLLMModule, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + init_method_normal, +) +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.collections.tts.data.speechllm.t5_speechllm_dataset import Lang, T5SpeechLMDataset +from nemo.collections.tts.data.speechllm.t5_speechllm_tarred_dataset import T5SpeechLMTarredDataset +from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.speechllm.megatron_base_speechllm_prompt_model import MegatronBaseSpeechLM +from nemo.collections.tts.parts.utils.helpers import plot_alignment_to_numpy_for_speechllm, plot_codec_to_numpy +from nemo.utils import AppState, logging +import imageio + +try: + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.enums import ModelType + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +import time +import torchaudio +from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE +from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector +import librosa + +__all__ = ['MegatronT5SpeechLMModel'] + + +class MegatronT5OverrideModel(MegatronT5Model): + def _build_tokenizer(self): + if self._cfg.tokenizer.library == "sentencepiece": + if hasattr(self._cfg.tokenizer, "sentencepiece_legacy"): + legacy = self._cfg.tokenizer.sentencepiece_legacy + else: + legacy = True if self._cfg.tokenizer.library == 'sentencepiece' else False + self.tokenizer = SentencePieceSpeechLLMTTSTokenizer( + model_path=self.register_artifact("tokenizer.model", self._cfg.tokenizer.get('model', None)), legacy=legacy + ) + + if self._cfg.tokenizer.get('additional_special_tokens', None) is not None: + tokens_list = OmegaConf.to_object(self._cfg.tokenizer.additional_special_tokens) + self.tokenizer.add_special_tokens(tokens_list) + else: + super()._build_tokenizer() + + def model_provider_func(self, pre_process, post_process, add_encoder, add_decoder): + if not hasattr(self.cfg, 'encoder') or not hasattr(self.cfg, 'decoder'): + logging.warning( + 'Could not find encoder or decoder in config. This is probably because of restoring an old checkpoint. Copying shared model configs to encoder and decoder configs.' + ) + # After the call below, self.cfg.encoder and self.cfg.decoder will be populated with the cfg.model configs from old checkpoints. + self._populate_encoder_decoder_configs_for_backward_compatibility(self.cfg) + + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.encoder.arch == 'perceiver': + raise ValueError(f"Perceivers with pipeline parallel > 1 is not supported yet.") + + if not hasattr(self.cfg, 'embedding_init_method_std'): + embedding_init_method_std = self.cfg.encoder.init_method_std + else: + embedding_init_method_std = self.cfg.embedding_init_method_std + + if not hasattr(self.cfg, 'embedding_dropout'): + embedding_dropout = self.cfg.encoder.hidden_dropout + else: + embedding_dropout = self.cfg.embedding_dropout + + model = MegatronTokenLevelEncoderDecoderSpeechLLMModule( + config=self.model_parallel_config, + encoder_cfg=self.cfg.encoder, + decoder_cfg=self.cfg.decoder, + vocab_size=self.padded_vocab_size, + max_position_embeddings=self.cfg.max_position_embeddings, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + fp16_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False), + precision=self.cfg.get('precision', 16), + embedding_init_method_std=embedding_init_method_std, + embedding_dropout=embedding_dropout, + label_smoothing=self.cfg.get('label_smoothing', 0.0), + add_encoder=add_encoder, + add_decoder=add_decoder, + share_token_embeddings=self.cfg.get('share_token_embeddings', True), + share_decoder_tokens_head_embeddings=self.cfg.get('share_decoder_tokens_head_embeddings', True), + tokens_head_bias=self.cfg.get('tokens_head_bias', True), + hiddens_cfg=self.cfg.get('hiddens', None), + ) + return model + + +class MegatronT5SpeechLMModel(MegatronBaseSpeechLM): + """ + Model class for prompt-tuning or p-tuning a pretrained Megatron T5 model. + + Prompt Tuning initializes virtual prompt embeddings directly from a copy of + certain token embeddings from the pretrained T5 model's vocabulary + and directly tunes these embedding weights. The token embeddings used in + initialization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. Virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. + + P-tuning initializes an LSTM encoder model that generates virtual prompt + embeddings for every task. Each task shares the same encoder. After p-tuning + is complete, the learned virtual prompts can be saved to the prompt table + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous + tasks. This gives p-tuning the same task flexibility as prompt-tuning. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + # torch.autograd.set_detect_anomaly(True) + super().__init__(cfg, trainer) + self.model_type = ModelType.encoder_and_decoder + speech_codebook_size = cfg.data.get('speech_codebook_size', 1024) + num_speech_codebooks = cfg.data.get('num_speech_codebooks', 8) + speech_offset = cfg.data.get('speech_offset', 30000) + codecmodel_type = cfg.get('codecmodel_type', 'nemo_codec') + attn_prior_scaledown_start_step = cfg.get('attn_prior_scaledown_start_step', 10000) + attn_prior_end_step = cfg.get('attn_prior_end_step', 11000) + num_cross_attention_heads = cfg.get('num_cross_attention_heads', 12) + self.lm_vocab_size = cfg.get('lm_vocab_size', 30000) + self.context_pattern = cfg.data.get('context_pattern', 'parallel') + self.context_conditioning = cfg.get('context_conditioning', "decoder") + self.context_duration_min = cfg.data.get('context_duration_min', 2.9) + self.context_duration_max = cfg.data.get('context_duration_max', 2.9) + self.codebook_fps = cfg.data.get('codebook_fps', 86) + self.decoder_context_len = 0 + if self.context_conditioning == "decoder": + assert self.context_duration_min == self.context_duration_max, "Decoder context duration must be fixed" + self.decoder_context_len = int(self.codebook_fps * self.context_duration_min) + + self.speech_offset = speech_offset + self.speech_codebook_size = speech_codebook_size + self.num_speech_codebooks = num_speech_codebooks + self.codecmodel_type = codecmodel_type + self.enc_output_to_layers = cfg.get('enc_output_to_layers', None) + if self.enc_output_to_layers is not None: + # Convert from listconfig to list + self.enc_output_to_layers = [ [l for l in encoder_layer] for encoder_layer in self.enc_output_to_layers ] + + self.frozen_model.enc_dec_model.speech_offset = speech_offset + self.frozen_model.enc_dec_model.speech_codebook_size = speech_codebook_size + self.frozen_model.enc_dec_model.num_speech_codebooks = num_speech_codebooks + self.frozen_model.enc_dec_model.seq_pattern = cfg.get('seq_pattern', 'parallel') + self.frozen_model.enc_dec_model.attn_prior_scaledown_start_step = attn_prior_scaledown_start_step + self.frozen_model.enc_dec_model.attn_prior_end_step = attn_prior_end_step + self.frozen_model.enc_dec_model.alignment_decoder_layerids = cfg.get('alignment_decoder_layerids', list(range(0, 12))) + self.frozen_model.enc_dec_model.return_all_crossattention_probs = cfg.get('return_all_crossattention_probs', False) + self.frozen_model.enc_dec_model.num_cross_attention_heads = num_cross_attention_heads + self.frozen_model.enc_dec_model.context_conditioning = self.context_conditioning + self.frozen_model.enc_dec_model.decoder_context_len = self.decoder_context_len + self.frozen_model.enc_dec_model.enc_output_to_layers = self.enc_output_to_layers + + self.alignment_loss_start_step = 0 + self.alignment_loss_end_step = float('inf') + self.use_alignment_loss = cfg.get('use_alignment_loss', False) + if self.use_alignment_loss: + alignment_loss_scale = cfg.get('alignment_loss_scale', 1.0) + self.frozen_model.enc_dec_model.use_alignment_loss = True + self.frozen_model.enc_dec_model.forward_sum_loss = ForwardSumLoss(loss_scale=alignment_loss_scale) + self.frozen_model.enc_dec_model.alignment_text_end_offset = cfg.get('alignment_text_end_offset', 0) + self.frozen_model.enc_dec_model.align_every_n_head = cfg.get('align_every_n_head', 1) + self.alignment_loss_start_step = cfg.get('alignment_loss_start_step', 0) + self.alignment_loss_end_step = cfg.get('alignment_loss_end_step', float('inf')) + + # Need to explicitly set this since it is already initialized + self.frozen_model.enc_dec_model.tokens_head.parallel_output = self.frozen_model.enc_dec_model.parallel_output + + list_of_speech_heads = [] + list_of_speech_tokens_embeddings = [] + for _ in range(self.num_speech_codebooks - 1): + # init is NOT used since we overwrite the weight below anyways + _speech_head_embedding = tensor_parallel.VocabParallelEmbedding( + speech_codebook_size, + embedding_dim=self.word_embeddings.embedding_dim, + init_method=lambda x: x.data.fill_(0), + config=self.model_parallel_config, + ) + _speech_head_embedding.weight.data.fill_(0) + _speech_head_embedding.shared = True + list_of_speech_tokens_embeddings.append(_speech_head_embedding) + # Linear layer that maps from hidden size to speech codebook size + hidden_size = self.frozen_model.enc_dec_model.decoder_cfg.hidden_size + init_method_std = self.frozen_model.enc_dec_model.decoder_cfg.init_method_std + # Changing to ColumnParallelLinear instead of Linear to support 3b Tensor Parallelism + _speech_head = tensor_parallel.ColumnParallelLinear( + input_size=hidden_size, + output_size=speech_codebook_size, + bias=True, + gather_output=not self.frozen_model.enc_dec_model.parallel_output, + init_method=init_method_normal(init_method_std), + config=self.model_parallel_config, + # use_cpu_initialization=False, + # params_dtype=self.frozen_model.enc_dec_model.dtype, + ) + list_of_speech_heads.append(_speech_head) + + self.frozen_model.enc_dec_model.speech_tokens_heads = torch.nn.ModuleList(list_of_speech_heads) + self.frozen_model.enc_dec_model.speech_tokens_embeddings = torch.nn.ModuleList( + list_of_speech_tokens_embeddings + ) + + self.sample_rate = 24000 + if codecmodel_type == 'nemo_codec': + codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path')) + codec_model.to('cuda') + codec_model.eval() + self.sample_rate = 22050 + else: + raise NotImplementedError() + + self.additional_models = {'codec': codec_model} + self.train_check_interval = self.cfg.get('train_check_interval', 500) + self.plot_alignments_sliced = self.cfg.get('plot_alignments_sliced', True) + app_state = AppState() + self.is_rank_zero = app_state.global_rank == 0 + self.predict_step_outputs = [] + self.phoneme_tokenizer = None + + # classifier-free guidance (CFG) option during training. The probability (0.0 <= ε <= 1.0) is used to trigger the action that the + # text or audio tokens in a batch are replaced by [UNK], such that mimicking the text- or audio-free scenario. + # If a random number is greater than ε, then keep text or audio tokens as-is, otherwise, the text or audio tokens are + # replaced by [UNK]. Default to 0.0, meaning CFG is disabled. + self.train_text_cfg_prob = cfg.get('train_text_cfg_prob', 0.0) + self.train_audio_cfg_prob = cfg.get('train_audio_cfg_prob', 0.0) + self._rng = random.Random() + + # control the strength of the classifier guidance during inference, Logits_cfg = w*Logits_cond + (1-w)*Logits_uncond, + # equivalent to Logits_cfg = Logits_cond + alpha*(Logits_cond - Logits_uncond) where alpha=w-1. + # Default w to 1.O, indicating no interpolation is applied. + self.inference_cfg_interpolation_scale = cfg.get('inference_cfg_interpolation_scale', 1.0) + self.inference_apply_text_cfg = cfg.get('inference_apply_text_cfg', False) + self.inference_apply_audio_cfg = cfg.get('inference_apply_audio_cfg', False) + if self.inference_cfg_interpolation_scale == 1.0: + self.inference_apply_text_cfg = False + self.inference_apply_audio_cfg = False + + # whether to apply cfg filter to address faster speech rate. + self.inference_apply_cfg_filter = cfg.get("inference_apply_cfg_filter", False) + + # this scale is suggested to be smaller than `self.question_guidance_scale` and it is used to balance the weights + # between the conditioned logits after applying cfg filter and the original unconditioned logits. Default to 1.0, + # indicating only conditioned logits are used. + if not self.inference_apply_cfg_filter: + self.inference_cfg_filter_interpolation_scale = None + else: + self.inference_cfg_filter_interpolation_scale = cfg.get('inference_cfg_filter_interpolation_scale', 1.0) + + # whether to estimate MOS in predict_step. + self.estimate_mos = cfg.get('estimate_mos', True) + if self.estimate_mos: + # requires to specify a non-matching high-quality and clean reference audio file. It is used to estimate MOS. + self.non_matching_ref_audio_filepath = cfg.get('non_matching_ref_audio_filepath', None) + if self.non_matching_ref_audio_filepath is None: + raise ValueError(f"Please provide a high-quality reference audio to estimate the MOS. Alternatively, " + f"set `model.estimate_mos=False` to disable MOS estimation.") + if not os.path.exists(self.non_matching_ref_audio_filepath): + raise FileNotFoundError(f"Please provide a valid file path for a high-quality reference audio to estimate" + f" the MOS. Alternatively, set `model.estimate_mos=False` to disable MOS estimation.") + + def decode_wav_from_codec_model(self, codes): + codec_model = self.additional_models['codec'] + if self.codecmodel_type == 'nemo_codec': + codec_len = torch.Tensor([codes.shape[1]]).long().cuda() + if codec_len < 10: + # return a one-second silence + return torch.zeros(24000).cuda() + wav, _ = codec_model.decode(tokens=codes.unsqueeze(0), tokens_len=codec_len) + wav = wav[0] + else: + raise NotImplementedError() + return wav + + def first_stage_of_pipeline(self): + if self.frozen_model.enc_dec_model.pre_process and parallel_state.get_pipeline_model_parallel_rank() == 0: + return True + return False + + def forward( + self, + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_mask, + position_ids, + taskname_ids, + labels=None, + speech_mask=None, + inference=False, + inference_step=0, + cross_attention_prior=None, + text_limits=None, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + ): + """ + Special forward method for p-tuning/prompt-tuning pretrained + T5 style models. + """ + multi_encoder = False + if isinstance(context_and_question_tokens, list): + multi_encoder = True + assert isinstance(enc_mask, list) + assert isinstance(position_ids, list) + if cross_attention_prior is None: + cross_attention_prior = [None for _ in range(len(context_and_question_tokens))] + assert isinstance(cross_attention_prior, list) + assert len(context_and_question_tokens) == len(enc_mask) == len(position_ids) == len(cross_attention_prior) + else: + multi_encoder = False + context_and_question_tokens = [context_and_question_tokens] + enc_mask = [enc_mask] + position_ids = [position_ids] + cross_attention_prior = [cross_attention_prior] + + + enc_output = None + logging.debug(f"self.first_stage_of_pipeline()={self.first_stage_of_pipeline()}\tinference_step={inference_step}") + if self.first_stage_of_pipeline() and inference_step == 0: + # Get embeddings for text tokens and insert virtual token embeddings + encoder_input_list = [] + for ei in range(len(context_and_question_tokens)): + input_embeds = self.get_embeddings_and_combine( + [virtual_tokens, context_and_question_tokens[ei]], taskname_ids, inference + ) + # TODO: This check needs to be revisited with PP support. + if hasattr(self.frozen_model.enc_dec_model.encoder_embedding, 'position_embeddings'): + position_embeddings = self.frozen_model.enc_dec_model.encoder_embedding.position_embeddings( + position_ids[ei] + ) + encoder_input = input_embeds + position_embeddings + else: + encoder_input = input_embeds + encoder_input_list.append(encoder_input) + else: + encoder_input_list = None + encoder_input = None + if inference_step != 0: + enc_output = context_and_question_tokens if multi_encoder else context_and_question_tokens[0] + + # If the decoder input starts with instead of , which is the case for huggingface T5 models, we don't want to mask the first token. + # For NeMo-Megatron, the sequence starts with , which is never masked so we can always set index 0 to be unmasked. + dec_mask[:, 0] = 1 + + if not self.cfg.data.get('use_attention_prior', False): + cross_attention_prior = [None for _ in range(len(cross_attention_prior))] + + _encoder_input = encoder_input_list + if not multi_encoder: + context_and_question_tokens = context_and_question_tokens[0] + enc_mask = enc_mask[0] + position_ids = position_ids[0] + cross_attention_prior = cross_attention_prior[0] + _encoder_input = encoder_input_list[0] if encoder_input_list is not None else None + + # Call forward on T5 model with preprocessed embeddings + if inference and inference_step == 0: + set_inference_key_value_memory = True + else: + set_inference_key_value_memory = False + + if self.autocast_dtype == torch.float32: + output, out_logits = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=_encoder_input, + enc_output=enc_output, + speech_mask=speech_mask, + cross_attention_prior=cross_attention_prior, + text_limits=text_limits, + global_step=self.global_step, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + else: + with torch.autocast(device_type="cuda", dtype=self.autocast_dtype): + output, out_logits = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=_encoder_input, + enc_output=enc_output, + speech_mask=speech_mask, + cross_attention_prior=cross_attention_prior, + text_limits=text_limits, + global_step=self.global_step, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + + return output, encoder_input, out_logits + + def load_frozen_model(self, cfg, trainer): + self.megatron_amp_O2 = cfg.get('megatron_amp_o2', False) + + # TODO: Fix this once apex patches FusedScaledMaskedSoftmax. + # This is a workaround for the fact that `masked_softmax_fusion` has issues with certain input sizes that may be present while finetuning. + cfg_language_model_path = cfg.get('language_model_path', None) + cfg_frozen_model = cfg.get('frozen_model', None) + if not (bool(cfg_language_model_path) ^ bool(cfg_frozen_model)): + raise ValueError( + "T5-TTS requires either 'language_model_path' or 'frozen_model' in its config, but not both." + ) + + if cfg_language_model_path: + t5_cfg = MegatronT5Model.restore_from(cfg_language_model_path, trainer=trainer, return_config=True) + else: + t5_cfg = cfg_frozen_model + + OmegaConf.set_struct(t5_cfg, True) + with open_dict(t5_cfg): + if hasattr(t5_cfg, 'encoder') and hasattr(t5_cfg, 'decoder'): + t5_cfg.encoder.masked_softmax_fusion = False + t5_cfg.decoder.masked_softmax_fusion = False + else: + t5_cfg.masked_softmax_fusion = False + t5_cfg.megatron_amp_O2 = self.megatron_amp_O2 + # hack to make the _GLOBAL_NUM_MICROBATCHES_CALCULATOR initialize + t5_cfg.micro_batch_size = cfg.get('micro_batch_size', 4) + t5_cfg.global_batch_size = cfg.get('global_batch_size', 4) + t5_cfg.precision = trainer.precision + t5_cfg.tokenizer.num_sentinel_tokens = cfg.get('num_sentinel_tokens', 39184 - 29056) + t5_cfg.seq_length = cfg.data.max_seq_length + if cfg.get('max_position_embeddings', None) is None: + t5_cfg.max_position_embeddings = cfg.data.max_seq_length + else: + t5_cfg.max_position_embeddings = cfg.get('max_position_embeddings') + t5_cfg.use_flash_attention = cfg.get('use_flash_attention', False) + if cfg.get('override_token_model', None): + t5_cfg.tokenizer.model = cfg['override_token_model'] + if cfg.get('override_tokenizer_vocab_file', None): + t5_cfg.tokenizer.vocab_file = cfg['override_tokenizer_vocab_file'] + + if cfg.get('train_from_scratch', False): + print("Training from scratch!") + # Defaults for 220m model + # To override any of these, add +model.override_= to the config file. + # Eg. +model.override_hidden_size=1024 + overide_keys = [ + 'hidden_size', # 768 + 'num_layers', # 12 + 'num_attention_heads', # 12 + 'hidden_dropout', # 0.1 + 'attention_dropout', # 0.1 + 'kv_channels', # 64 + 'ffn_hidden_size', # 2048 + ] + # Defaults for 220m model + for k in overide_keys: + if cfg.get(f'override_{k}') is not None: + t5_cfg[k] = cfg.get(f'override_{k}') + + self.frozen_model = MegatronT5OverrideModel(t5_cfg, trainer=trainer) + num_params = sum(p.numel() for p in self.frozen_model.parameters() if p.requires_grad) + print(f"Number of parameters: {num_params}") + else: + print(f"Loading from pretrained checkpoint: {cfg_language_model_path}") + if cfg_language_model_path is None: + raise ValueError( + "T5-TTS SFT on pretrained model checkpoint requires `langauge_model_path` in its config." + ) + + self.frozen_model = MegatronT5OverrideModel.restore_from( + cfg_language_model_path, + trainer=trainer, + override_config_path=t5_cfg, + save_restore_connector=NLPSaveRestoreConnector(), + ) + + if not cfg.get('english_only_model', False): + self.frozen_model.tokenizer.add_phone_tokens_to_special_tokens() + + logging.info(f"self.frozen_model {self.frozen_model}") + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + """ + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + # Get seq length of batch + batch = next(dataloader_iter) + _, seq_length = batch[0].shape + if batch[4].dim() > 2: + _, _, dec_seq_length = batch[4].shape + else: + _, dec_seq_length = batch[4].shape + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only), + data_iterator=data_iter, + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=get_micro_batch_size(), + decoder_seq_length=dec_seq_length, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # we're not on the last pipeline stage so no losses + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean + + def convert_tokens_to_range(self, tokens, apply_offset_correction=True, pattern=None): + # convert tokens to range [0, 1024] + output_tokens = tokens.clone() + if apply_offset_correction: + output_tokens[0] = output_tokens[0] - self.speech_offset + output_tokens = torch.clamp(output_tokens, min=0, max=self.speech_codebook_size - 1) + if pattern is None: + pattern = self.cfg.get('seq_pattern', 'delay_parallel') + if pattern == "delay_parallel": + output_tokens_new = [] + for _c in range(output_tokens.shape[0]): + si = _c + ei = _c + output_tokens.shape[1] - self.num_speech_codebooks + output_tokens_new.append(output_tokens[_c, si:ei]) + output_tokens_new = torch.stack(output_tokens_new) + output_tokens = output_tokens_new + + return output_tokens + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + _batch = [] + for x in batch: + if isinstance(x, torch.Tensor): + x = x.cuda(non_blocking=True) + elif isinstance(x, list): + if isinstance(x[0], torch.Tensor): + x = [y.cuda(non_blocking=True) for y in x] + _batch.append(x) + batch = _batch + # batch = [x.cuda(non_blocking=True) if isinstance(x, torch.Tensor) else x for x in batch] + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, + _, # TODO: text limit and lang not in tarred dataset + _, + ) = batch + + if self.trainer.global_step % self.train_check_interval == 0 and not validation_step and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = True + + _cross_attention_prior = cross_attention_prior + if isinstance(context_and_question_tokens, list): + # None for context and prior for question + _cross_attention_prior = [None, cross_attention_prior] + + output_tensor, encoder_input, out_logits = model( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=labels, + speech_mask=speech_mask, + cross_attention_prior=_cross_attention_prior, + text_limits=text_limits, + inference=False, + ) + output_tensor = output_tensor.contiguous() + + alignment_loss = out_logits[3] + if alignment_loss is not None: + self.logger.experiment.add_scalar('train_alignment_loss', alignment_loss, self.global_step) + + if self.trainer.global_step % self.train_check_interval == 0 and not validation_step and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = False + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if torch.count_nonzero(speech_mask) == 0: + text_labels = labels[:, 0, :] # [B, 8, T] -> [B, T] + token_logits = out_logits[0] * 1 # [T, B, V] + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + token_logits = token_logits.argmax(dim=2) # [T, B] + token_logits = token_logits.t() # [B, T] + score = 0 + for i in range(text_labels.size()[0]): + r = text_labels[i].long() + nzm = r != 0 + r = r.tolist() + h = token_logits[i].long() * nzm + h = h.tolist() + score += editdistance.eval(r, h) + score /= text_labels.size()[0] + logging.info(f"wer score : {score}") + self.logger.experiment.add_scalar('WER', score, self.global_step) + else: + audio_len = self.decoder_context_len + (labels[0][0][self.decoder_context_len:] != 0).sum().item() + labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) + label_wav = self.decode_wav_from_codec_model(labels_to_1024) + dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024) + self.logger.experiment.add_audio( + "train_label_wav", label_wav, self.global_step, self.sample_rate + ) + self.logger.experiment.add_audio( + "train_dec_input_wav", dec_input_wav, self.global_step, self.sample_rate + ) + if isinstance(context_and_question_tokens, list): + context_tokens = context_and_question_tokens[0] + question_tokens = context_and_question_tokens[1] + input_token_list_all = [ + question_tokens[0, 0, i].item() + for i in range(question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) + for ti, t in enumerate(input_token_list_all) + if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][0].item() + _context_tokens = context_tokens[0, :, :context_end_step] + else: + input_token_list_all = [ + context_and_question_tokens[0, 0, i].item() + for i in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) + for ti, t in enumerate(input_token_list_all) + if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + _context_tokens = context_and_question_tokens[0, :, :context_end_step] + + if context_end_step > 1: + is_speech_context = _context_tokens[1,:].sum().item() > 0 + if is_speech_context: + _context_tokens = self.convert_tokens_to_range( + _context_tokens, pattern=self.context_pattern + ) + _context_wav = self.decode_wav_from_codec_model(_context_tokens) + self.logger.experiment.add_audio( + "train_context_wav", _context_wav, self.global_step, self.sample_rate + ) + else: + _context_token_list = [ v.item() for v in _context_tokens[0, :] ] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("train_context_text", _context_text, self.global_step) + + question_si = text_limits[0, 0].item() - virtual_tokens.shape[1] + question_ei = text_limits[0, 1].item() - virtual_tokens.shape[1] + text_si = text_limits[0, 0].item() + text_ei = text_limits[0, 1].item() + input_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Train Input Text", input_text, self.global_step) + + input_phoneme_tokens = [ + v - self.lm_vocab_size + for v in input_token_list_all[question_si:question_ei] + if v >= self.lm_vocab_size + ] + + if len(input_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(input_phoneme_tokens) + self.logger.experiment.add_text( + "Train Input Phoneme Text", phoneme_text, self.global_step + ) + + token_logits = out_logits[0] + speech_logits_list = out_logits[1] + + attention_probs_list = out_logits[2] # list of (BS, 12, out_length, in_length) + if attention_probs_list is not None: + attention_sliced_list = [] + for lidx in range(len(attention_probs_list)): + attention_probs = attention_probs_list[lidx] + for _i in range(attention_probs.shape[1]): + name = f"Attention Probs Layer {lidx} Head {_i}" + attention_to_plot = attention_probs[0, _i, :audio_len, :text_ei] + if self.plot_alignments_sliced: + attention_to_plot = attention_probs[0, _i, 0:audio_len, text_si:text_ei] + # 4 to offset "Text to Speech this" + name += " Sliced" + alignment_image = plot_alignment_to_numpy_for_speechllm( + attention_to_plot.cpu().float().numpy().T, + phoneme_ver=0 if self.plot_alignments_sliced else 1, + phoneme_seq=None if self.plot_alignments_sliced else [text_si], + ) + self.logger.experiment.add_image( + name, alignment_image, self.global_step, dataformats="HWC", + ) + attention_sliced_list.append( + attention_probs[0, _i, self.decoder_context_len:audio_len, text_si:text_ei] + ) + attention_sliced = torch.stack(attention_sliced_list) + attention_sliced = torch.mean(attention_sliced, 0) + text = None + if len(input_text) > 0: + text = self.frozen_model.tokenizer.ids_to_tokens( + [ + v + for v in input_token_list_all[question_si:question_ei] + if v < self.lm_vocab_size + ] + ) + if len(input_phoneme_tokens) > 0: + text = phoneme_text.split("|") + alignment_image_sliced = plot_alignment_to_numpy_for_speechllm( + attention_sliced.cpu().float().numpy().T, + phoneme_seq=text, + phoneme_ver=2, + vmin=0.0, + phone_offset=0, + h_offset=False, + ) + self.logger.experiment.add_image( + f"Attention Probs Average Sliced", + alignment_image_sliced, + self.global_step, + dataformats="HWC", + ) + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + token_logits_example = token_logits[:, 0, :] * 1 + speech_logits_example = speech_logits[:, 0, :, :] * 1 + first_layer_tokens = token_logits_example.argmax(dim=1) - self.speech_offset + other_layer_tokens = [] + for _i in range(speech_logits_example.shape[2]): + other_layer_tokens.append(speech_logits_example[:, :, _i].argmax(dim=1)) + + all_layer_tokens = torch.stack([first_layer_tokens] + other_layer_tokens) # (8, t) + all_layer_tokens = self.convert_tokens_to_range( + all_layer_tokens, apply_offset_correction=False + ) + # all_layer_tokens = torch.clip(all_layer_tokens, 0, 1023) + predicted_wav = self.decode_wav_from_codec_model(all_layer_tokens) + self.logger.experiment.add_audio( + "train_tf_pred_wav", predicted_wav, self.global_step, self.sample_rate + ) + + def loss_func(loss_args): + output_tensor, out_logits, curr_step = loss_args + alignment_loss = out_logits[3] + loss = self.frozen_model.loss_func(loss_mask, output_tensor) + if ( + (alignment_loss is not None) + and (curr_step > self.alignment_loss_start_step) + and (curr_step < self.alignment_loss_end_step) + ): + logging.debug(f"Adding alignment loss. cur:{curr_step} start:{self.alignment_loss_start_step}") + loss = loss + alignment_loss + reduced_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'avg': reduced_loss} + + return [output_tensor, out_logits, self.global_step], loss_func + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + """ Used in inference / predict """ + + def fwd_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + _batch = [] + for x in batch: + if isinstance(x, torch.Tensor): + x = x.cuda(non_blocking=True) + elif isinstance(x, list): + if isinstance(x[0], torch.Tensor): + x = [y.cuda(non_blocking=True) for y in x] + _batch.append(x) + batch = _batch + # batch = [x.cuda(non_blocking=True) if isinstance(x, torch.Tensor) else x for x in batch] + ( + decoder_max_sequence_len, + encoder_max_sequence_len, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + speech_mask, + ) = batch + + + output_logits, _, token_and_speech_logits = model( + context_and_question_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=None, + speech_mask=speech_mask, + inference=True, + inference_step=1, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + output_tensor = [output_logits, token_and_speech_logits] + + def id_func(output_tensor): + return 0, {'output_logits': output_tensor[0], 'token_and_speech_logits': output_tensor[1]} + + return output_tensor, id_func + + return fwd_output_only_func + + def backward(self, *args, **kwargs): + """ LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. + """ + return + + def optimizer_zero_grad(self, *args, **kwargs): + """ LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. + """ + return + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + When using pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.frozen_model.enc_dec_model.set_input_tensor(input_tensor) + + def on_train_epoch_start(self) -> None: + gbs = self.cfg.global_batch_size + mbs = self.cfg.micro_batch_size + self._reconfigure_batch_sizes(gbs, mbs) + return super().on_train_epoch_start() + + def on_validation_epoch_start(self) -> None: + gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) + mbs = self.cfg.get('validation_micro_batch_size', self.cfg.micro_batch_size) + self._reconfigure_batch_sizes(gbs, mbs) + return super().on_validation_epoch_start() + + def training_step(self, dataloader_iter, batch_idx): + self._optimizer.zero_grad() + batch = next(dataloader_iter) + + # apply text classifier-free guidance by replacing input question tokens with [UNK]. + if self.train_text_cfg_prob > 0.0: + if self._rng.random() < self.train_text_cfg_prob: + logging.info(f"Text Classifier-Free Guidance is triggered for the {batch_idx}-th batch.") + + # temporally disable computing CTC alignment loss. + if self.use_alignment_loss: + self.frozen_model.enc_dec_model.use_alignment_loss = False + + # make cross-attention prior to None to remove the prior. + batch[11] = None + + # replace question token IDs with [UNK]'s id. No speech offset for Phoneme's [UNK]. Same op as train. + # instruction token IDs are bpe token IDs directly obtained from self.tokenizer without any offset. + # question token IDs are phoneme and grapheme token IDs and are offset by self.lm_vocab_size + # if under "Phoneme TTS" instruction, so existing no overlaps between instruction and question token IDs. + # question token IDs are bpe token IDs without any offset + # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. + context_and_question_tokens = batch[1] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + text_limits = batch[12] + virtual_tokens = batch[0] + question_limits = text_limits - virtual_tokens.size(1) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. + context_tokens, question_tokens = context_and_question_tokens + question_tokens_unconditioned = question_tokens.clone() + time_range = torch.arange(question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device).unsqueeze(0) + question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + batch[1] = [context_tokens, question_tokens_unconditioned] + else: + context_and_question_tokens_unconditioned = context_and_question_tokens.clone() # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + time_range = torch.arange(context_and_question_tokens_unconditioned.size(2), device=context_and_question_tokens_unconditioned.device).unsqueeze(0) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + batch[1] = context_and_question_tokens_unconditioned + + del question_limits, question_start, question_end, time_range, question_mask + else: + # recover to original alignment loss config. + self.frozen_model.enc_dec_model.use_alignment_loss = self.use_alignment_loss + + # apply audio context classifier-free guidance by replacing audio codec with [UNK] + if self.train_audio_cfg_prob > 0.0: + if self._rng.random() < self.train_audio_cfg_prob: + logging.info(f"Audio Classifier-Free Guidance is triggered for the {batch_idx}-th batch.") + + context_and_question_tokens = batch[1] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. + context_tokens, question_tokens = context_and_question_tokens + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + batch[1] = [context_tokens_unconditioned, question_tokens] + else: + # dec_input + dec_input = batch[3] + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1:self.decoder_context_len + 1] = self.tokenizer.unk_id # TODO @xueyang: switch to other token id if this one is conflict with text unk. + batch[3] = dec_input_unconditioned + + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=False) + self.allreduce_gradients() + + ## logging + # we can only log on one rank if it is rank zero so we broadcast from last rank + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + torch.distributed.broadcast(loss_mean, get_last_rank()) + + if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) + + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) + return loss_mean + + def get_predictions(self, input_ids, enc_mask, encoder_input, labels): + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=input_ids, + enc_mask=enc_mask, + num_tokens_to_generate=self.decoder_seq_length, + encoder_input=encoder_input, + bos_id=self.tokenizer.pad_id + if self.cfg.data.get('decoder_starts_with_pad', False) + else self.tokenizer.bos_id, + ) + # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. + preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) + labels_text = MegatronT5SFTModel.ids_to_text(labels, self.tokenizer) + input_text = MegatronT5SFTModel.ids_to_text(input_ids, self.tokenizer) + return { + 'predicted_token_ids': preds_text, + 'labels': labels_text, + 'enc_inputs': input_text, + } + + def get_embeddings(self, tokens, taskname_ids, inference=False): + out = None + if tokens.dim() > 2: + for i in range(tokens.size()[1]): # for 8 channels + if i == 0: + # Embed first layer using word embeddings + out = self.embed_input(tokens[:, i, :], taskname_ids, inference) # (B, T, D) + else: + # Embed other layers using speech embeddings + cur = self.frozen_model.enc_dec_model.speech_tokens_embeddings[i - 1](tokens[:, i, :]) + # do not add embeddings of zero tokens of other channels (except the first channel) + non_zero_flag = tokens[:, i, :] != 0 # (B, T) + cur = cur * non_zero_flag.unsqueeze(2) + out = out + cur + else: + out = self.embed_input(tokens, taskname_ids, inference) + return out + + def get_embeddings_and_combine(self, token_list, taskname_ids, inference): + embedding_list = [] + for tokens in token_list: + embedding_list.append(self.get_embeddings(tokens, taskname_ids, inference)) + return torch.cat(embedding_list, dim=1) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, + _, + _, + ) = batch + # loss_mask (b, t) + # does not use dataloader_iter due to device placement issues arising from PTL + + mode = self.training + self.eval() + gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) + self._reconfigure_and_process_inference_batch(virtual_tokens.size(0), gbs) + + loss_mean = self.fwd_bwd_step( + itertools.chain([batch]), batch_idx, forward_only=True + ) # comment this out and add custom forward function to calculate WER + # # logging.info (f'loss_mean {loss_mean}') + + if batch_idx == 0 and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = True + self.predict_step_outputs = [] + # log_scalars=False avoids logging scalar TTS metrics in the predict_step + # Images, audio and texts will still be logged + self.predict_step(batch=batch, batch_idx=batch_idx, log_scalars=False, global_step=self.global_step) + for inf_key in self.predict_step_outputs[0]: + if self.predict_step_outputs[0][inf_key] is not None: + self.logger.experiment.add_scalar( + f'Val_{inf_key}', self.predict_step_outputs[0][inf_key], self.global_step + ) + + labels_original = labels.clone() # (b, 8, t) + + _cross_attention_prior = cross_attention_prior + if isinstance(context_and_question_tokens, list): + _cross_attention_prior = [None, cross_attention_prior] + + output_loss, _, output_logits = self.forward( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=labels, + speech_mask=speech_mask, + cross_attention_prior=_cross_attention_prior, + text_limits=text_limits, + inference=False, + ) + + if batch_idx == 0 and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = False + with torch.cuda.amp.autocast(enabled=False): + if torch.count_nonzero(speech_mask) == 0: + text_labels = labels[:, 0, :] # [B, 8, T] -> [B, T] + token_logits = output_logits[0] * 1 # [T, B, V] + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + token_logits = token_logits.argmax(dim=2) # [T, B] + token_logits = token_logits.t() # [B, T] + score = 0 + for i in range(text_labels.size()[0]): + r = text_labels[i].long() + nzm = r != 0 + r = r.tolist() + h = token_logits[i].long() * nzm + h = h.tolist() + score += editdistance.eval(r, h) + score /= text_labels.size()[0] + logging.info(f"wer score : {score}") + self.logger.experiment.add_scalar('WER', score, self.global_step) + else: + audio_len = self.decoder_context_len + (labels[0][0][self.decoder_context_len:] != 0).sum().item() + labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) + label_wav = self.decode_wav_from_codec_model(labels_to_1024) + dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024) + self.logger.experiment.add_audio("val_label_wav", label_wav, self.global_step, self.sample_rate) + self.logger.experiment.add_audio( + "val_dec_input_wav", dec_input_wav, self.global_step, self.sample_rate + ) + + if isinstance(context_and_question_tokens, list): + context_tokens = context_and_question_tokens[0] + question_tokens = context_and_question_tokens[1] + input_token_list_all = [ + question_tokens[0, 0, i].item() for i in range(question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list_all) if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][0].item() + _context_tokens = context_tokens[0, :, :context_end_step] + + else: + input_token_list_all = [ + context_and_question_tokens[0, 0, i].item() + for i in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list_all) if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + _context_tokens = context_and_question_tokens[0, :, :context_end_step] + if context_end_step > 1: + is_speech_context = _context_tokens[1,:].sum().item() > 0 + if is_speech_context: + _context_tokens = self.convert_tokens_to_range(_context_tokens, pattern=self.context_pattern) + _context_wav = self.decode_wav_from_codec_model(_context_tokens) + self.logger.experiment.add_audio( + "val_context_wav", _context_wav, self.global_step, self.sample_rate + ) + else: + _context_token_list = [v.item() for v in _context_tokens[0, :]] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("val_context_text", _context_text, self.global_step) + + question_si = text_limits[0, 0].item() - virtual_tokens.shape[1] + question_ei = text_limits[0, 1].item() - virtual_tokens.shape[1] + + text_si = text_limits[0, 0].item() + text_ei = text_limits[0, 1].item() + + input_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Val Input Text", input_text, self.global_step) + + input_phoneme_tokens = [ + v - self.lm_vocab_size + for v in input_token_list_all[question_si:question_ei] + if v >= self.lm_vocab_size + ] + if len(input_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(input_phoneme_tokens) + self.logger.experiment.add_text("Val Input Phoneme Text", phoneme_text, self.global_step) + + token_logits = output_logits[0] + speech_logits_list = output_logits[1] + + # if self.trainer.global_step % 500 == 0: + attention_probs_list = output_logits[2] # list of (BS, 12, out_length, in_length) + if attention_probs_list is not None: + attention_sliced_list = [] + for lidx in range(len(attention_probs_list)): + attention_probs = attention_probs_list[lidx] + for _i in range(attention_probs.shape[1]): + attention_sliced_list.append(attention_probs[0, _i, self.decoder_context_len:audio_len, text_si:text_ei]) + attention_sliced = torch.stack(attention_sliced_list) + attention_sliced = torch.mean(attention_sliced, 0) + text = None + if len(input_text) > 0: + text = self.frozen_model.tokenizer.ids_to_tokens( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + if len(input_phoneme_tokens) > 0: + text = phoneme_text.split("|") + alignment_image_sliced = plot_alignment_to_numpy_for_speechllm( + attention_sliced.cpu().float().numpy().T, + phoneme_seq=text, + phoneme_ver=2, + vmin=0.0, + phone_offset=0, + h_offset=False, + ) + self.logger.experiment.add_image( + f"Val Attention Probs Average Sliced", + alignment_image_sliced, + self.global_step, + dataformats="HWC", + ) + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + token_logits_example = token_logits[:, 0, :] * 1 + speech_logits_example = speech_logits[:, 0, :, :] * 1 + first_layer_tokens = token_logits_example.argmax(dim=1) - self.speech_offset + other_layer_tokens = [] + for _i in range(speech_logits_example.shape[2]): + other_layer_tokens.append(speech_logits_example[:, :, _i].argmax(dim=1)) + + all_layer_tokens = torch.stack([first_layer_tokens] + other_layer_tokens) # (8, t) + all_layer_tokens = self.convert_tokens_to_range(all_layer_tokens, apply_offset_correction=False) + all_layer_tokens = torch.clip(all_layer_tokens, 0, self.speech_codebook_size - 1) + predicted_wav = self.decode_wav_from_codec_model(all_layer_tokens) + self.logger.experiment.add_audio( + "val_tf_pred_wav", predicted_wav, self.global_step, self.sample_rate + ) + + first_layer_logits = output_logits[0] + speech_logits_list = output_logits[1] + + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + first_layer_logits = tensor_parallel.gather_from_tensor_model_parallel_region(first_layer_logits) + if torch.count_nonzero(speech_mask) > 0: + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + first_layer_preds = first_layer_logits.argmax(dim=2) # (t,bs) + first_layer_preds = first_layer_preds.transpose(0, 1) # (bs,t) + labels_first_layer = labels_original[:, 0, :] # (bs,t) + correct_predictions = first_layer_preds == labels_first_layer # (bs,t) + correct_predictions = correct_predictions * loss_mask # (bs,t) + total_correct_predictions = torch.sum(correct_predictions) + total_predictions = torch.sum(loss_mask) + first_layer_accuracy = total_correct_predictions / total_predictions + first_layer_loss = torch.nn.functional.cross_entropy( + first_layer_logits.permute(1, 2, 0), labels_first_layer, reduction='none' + ) # (bs,t) + first_layer_loss = torch.sum(first_layer_loss * loss_mask) / total_predictions + + metrics = { + 'loss': loss_mean, + 'first_layer_accuracy': first_layer_accuracy, + 'first_layer_loss': first_layer_loss, + } + loss_total = first_layer_loss + for i in range(self.num_speech_codebooks - 1): + if torch.count_nonzero(speech_mask) > 0: + speech_logits_i = speech_logits[:, :, :, i] + speech_preds_i = speech_logits_i.argmax(dim=2) # (t,bs) + speech_preds_i = speech_preds_i.transpose(0, 1) # (bs,t) + labels_i = labels_original[:, i + 1, :] # (bs,t) + correct_predictions_i = speech_preds_i == labels_i # (bs,t) + correct_predictions_i = correct_predictions_i * loss_mask * speech_mask # (bs,t) + total_correct_predictions_i = torch.sum(correct_predictions_i) + total_predictions_i = torch.sum(loss_mask * speech_mask) + speech_accuracy_i = total_correct_predictions_i / total_predictions_i + loss_i = torch.nn.functional.cross_entropy( + speech_logits_i.permute(1, 2, 0), labels_i, reduction='none' + ) # (bs,t) + loss_i = torch.sum(loss_i * loss_mask * speech_mask) / total_predictions_i + else: + speech_accuracy_i = torch.tensor(0.0) + loss_i = torch.tensor(0.0) + metrics[f'speech_accuracy_{i+1}'] = speech_accuracy_i + metrics[f'speech_loss_{i+1}'] = loss_i + loss_total += loss_i + + metrics['loss_total_check'] = loss_total + self.validation_step_outputs.append(metrics) + self.train(mode=mode) + self.frozen_model.train() + return metrics['loss'] + + def on_validation_epoch_end(self): + outputs = self.validation_step_outputs + if self.cfg.get('pipeline_model_parallel_size', 1) > 1: + if parallel_state.is_pipeline_last_stage(): + # only the last pipeline parallel stages return loss + averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() + averaged_loss_total_check = torch.stack([item['loss_total_check'] for item in outputs]).mean() + averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() + + self.log( + 'val_loss_total_check', averaged_loss_total_check, prog_bar=False, rank_zero_only=True, batch_size=1 + ) + self.log( + 'val_first_layer_accuracy', + averaged_first_layer_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + logging.info(f'Validation first_layer_accuracy: {averaged_first_layer_accuracy}') + logging.info(f'Validation loss_total_check: {averaged_loss_total_check}') + + for i in range(1, self.num_speech_codebooks): + averaged_speech_accuracy = torch.stack([item[f'speech_accuracy_{i}'] for item in outputs]).mean() + averaged_speech_loss = torch.stack([item[f'speech_loss_{i}'] for item in outputs]).mean() + self.log( + f'val_speech_accuracy_{i}', + averaged_speech_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.log( + f'val_speech_loss_{i}', averaged_speech_loss, prog_bar=True, rank_zero_only=True, batch_size=1 + ) + logging.info(f'Validation speech_accuracy_{i}: {averaged_speech_accuracy}') + logging.info(f'Validation speech_loss_{i}: {averaged_speech_loss}') + else: + averaged_loss = torch.tensor(0.0).cuda() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(averaged_loss, get_last_rank()) + + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + logging.info(f'Validation loss: {averaged_loss}') + + else: + if len(outputs) > 0: + averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() + averaged_loss_total_check = torch.stack([item['loss_total_check'] for item in outputs]).mean() + logging.info(f'Validation loss: {averaged_loss}') + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'val_loss_total_check', averaged_loss_total_check, prog_bar=False, rank_zero_only=True, batch_size=1 + ) + + averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() + logging.info(f'Validation first_layer_accuracy: {averaged_first_layer_accuracy}') + self.log( + 'val_first_layer_accuracy', + averaged_first_layer_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + + for i in range(1, self.num_speech_codebooks): + averaged_speech_accuracy = torch.stack([item[f'speech_accuracy_{i}'] for item in outputs]).mean() + averaged_speech_loss = torch.stack([item[f'speech_loss_{i}'] for item in outputs]).mean() + logging.info(f'Validation speech_accuracy_{i}: {averaged_speech_accuracy}') + logging.info(f'Validation speech_loss_{i}: {averaged_speech_loss}') + self.log( + f'val_speech_accuracy_{i}', + averaged_speech_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.log( + f'val_speech_loss_{i}', averaged_speech_loss, prog_bar=True, rank_zero_only=True, batch_size=1 + ) + + if self.cfg.get("report_validation_metric", False): + gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] + + all_preds = list(itertools.chain(*[item['predicted_token_ids'] for item in outputs])) + all_labels = list(itertools.chain(*[item['labels'] for item in outputs])) + all_inputs = list(itertools.chain(*[item['enc_inputs'] for item in outputs])) + + assert len(all_preds) == len(all_labels) + assert len(all_preds) == len(all_inputs) + + # Gather inputs, preds, labels from all workers + torch.distributed.all_gather_object( + gather_results, + [(input, pred, label) for (input, pred, label) in zip(all_inputs, all_preds, all_labels)], + group=parallel_state.get_data_parallel_group(), + ) + + # Deduplicate sentences that may have been distributed across multiple data parallel ranks. + if parallel_state.get_data_parallel_rank() == 0: + + gather_results_dedup = list(set(itertools.chain(*gather_results))) + + val_metric_dict = self.validation_metric.get_score( + [i[2] for i in gather_results_dedup], [i[1] for i in gather_results_dedup], + ) + + for metric, val in val_metric_dict.items(): + logging.info(f'Validation {metric}: {val}') + val_metric = list(val_metric_dict.items())[0][1] + metric_name = list(val_metric_dict.items())[0][0] + else: + val_metric = torch.tensor(0.0).cuda() + metric_name = '' + + self.log(f'val_{metric_name}', val_metric, prog_bar=True, rank_zero_only=True, batch_size=1) + + gbs = self.cfg.global_batch_size + mbs = self.cfg.micro_batch_size + self._reconfigure_batch_sizes(gbs, mbs) + self.validation_step_outputs.clear() + + def test_step(self, batch, batch_idx): + result = self.predict_step(batch, batch_idx) + return result + + def on_test_epoch_end(self): + """ + This might still be broken for lightning 2.0. to fix: see + https://github.com/NVIDIA/NeMo/blob/9bdf4d12276ee8f95a340cf2f7f340e9b5b74a7e/docs/source/starthere/migration-guide.rst + """ + outputs = self.predict_step_outputs + average_metrics = {} + for output in outputs: + for key in output: + if key not in average_metrics: + average_metrics[key] = [] + if isinstance(output[key], torch.Tensor): + average_metrics[key].append(output[key].item()) + elif output[key] is None: + continue + else: + average_metrics[key].append(output[key]) + + for key in average_metrics: + average_metrics[key] = np.mean(average_metrics[key]).item() + logging.info(f'Test {key}: {average_metrics[key]}') + self.log(f'test_{key}', average_metrics[key], prog_bar=True, rank_zero_only=True, batch_size=1) + self.logger.experiment.add_scalar(f'Inf Cumulative {key}', average_metrics[key], 0) + + # save average metrics into json file + with open(os.path.join(self.logger.log_dir, 'output_metrics.json'), 'w') as f: + json.dump(average_metrics, f) + + def build_virtual_prompt_dataset( + self, dataset_paths, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory + ): + dataset = T5SpeechLMDataset( + datasets=dataset_paths, + tokenizer=self.tokenizer, + sample_rate=self.cfg.data.get('sample_rate', 24000), + virtual_prompt_source=self.virtual_prompt_source, + task_templates=self.task_templates, + pseudo_tokens=self.pseudo_tokens, + pad_token_id=self.pad_token_id, + max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings), + min_seq_length=self.cfg.data.get('min_seq_length', 1), + add_bos=self.cfg.data.get('add_bos', False), + add_eos=self.cfg.data.get('add_eos', True), + decoder_starts_with_pad=self.cfg.data.get('decoder_starts_with_pad', False), + add_eos_to_decoder_output=self.cfg.data.get('add_eos_to_decoder_output', True), + add_sentinel_to_input=self.cfg.data.get('add_sentinel_to_input', True), + ul2_prompt_token=self.cfg.data.get('ul2_prompt_token', None), + for_train=for_train, + segment_max_duration=self.cfg.data.get('segment_max_duration', None), + trim=self.cfg.data.get('trim', None), + trim_ref=self.cfg.data.get('trim_ref', None), + trim_top_db=self.cfg.data.get('trim_top_db', None), + trim_frame_length=self.cfg.data.get('trim_frame_length', None), + trim_hop_length=self.cfg.data.get('trim_hop_length', None), + pad_multiple=self.cfg.data.get('pad_multiple', 1), + pitch_augment=self.cfg.data.get('pitch_augment', None), + sup_data_path=self.cfg.data.get('sup_data_path', None), + codec_folder=self.cfg.data.get('codec_folder', None), + speech_offset=self.cfg.data.get('speech_offset', None), + train_task=self.cfg.data.get('train_task', "tts"), + seq_pattern=self.cfg.get('seq_pattern', 'delay_parallel'), + use_attention_prior=self.cfg.data.get('use_attention_prior', False), + attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0), + cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0), + lm_vocab_size=self.lm_vocab_size, + num_speech_codebooks=self.num_speech_codebooks, + codebook_fps=self.cfg.data.get('codebook_fps', 86), + add_special_tokens_to_only_first_codebook=self.cfg.data.get( + 'add_special_tokens_to_only_first_codebook', False + ), + context_pattern=self.cfg.data.get('context_pattern', 'parallel'), + context_duration_min=self.cfg.data.get('context_duration_min', 3.0), + context_duration_max=self.cfg.data.get('context_duration_max', 5.0), + g2p=self.cfg.data.get('g2p', None), + skip_datasets=self.cfg.data.get('skip_datasets', []), + english_only_model=self.cfg.get('english_only_model', False), + use_ipa=self.cfg.data.get('use_ipa', False), + context_conditioning=self.cfg.get('context_conditioning', "decoder"), + use_beta_binomial_interpolator=self.cfg.get('use_beta_binomial_interpolator', False), + context_slice_method=self.cfg.data.get('context_slice_method', 'random'), + phoneme_probability=self.cfg.data.get('phoneme_probability', 0.5), + encoder_type=self.cfg.data.get('encoder_type', 'single_transformer'), + ) + + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=self.cfg.seed + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + sampler=sampler, + batch_size=batch_size // world_size, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=True + if num_workers > 0 + else False, # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + ) + logging.info(f'build success: {len(dataloader)} {dataset_paths}') + if self.phoneme_tokenizer is None: + self.phoneme_tokenizer = dataset.phoneme_tokenizer + return dataset, dataloader + + def build_virtual_prompt_tarred_dataset( + self, dataset_paths, audio_path, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory + ): + dataset = T5SpeechLMTarredDataset( + audio_tar_filepaths=audio_path, + manifest_filepath=dataset_paths, + tokenizer=self.tokenizer, + sample_rate=self.cfg.data.get('sample_rate', 24000), + virtual_prompt_source=self.virtual_prompt_source, + task_templates=self.task_templates, + pseudo_tokens=self.pseudo_tokens, + pad_token_id=self.pad_token_id, + max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings), + min_seq_length=self.cfg.data.get('min_seq_length', 1), + shuffle_n=shuffle, + add_bos=self.cfg.data.get('add_bos', False), + add_eos=self.cfg.data.get('add_eos', True), + decoder_starts_with_pad=self.cfg.data.get('decoder_starts_with_pad', False), + add_eos_to_decoder_output=self.cfg.data.get('add_eos_to_decoder_output', True), + add_sentinel_to_input=self.cfg.data.get('add_sentinel_to_input', True), + ul2_prompt_token=self.cfg.data.get('ul2_prompt_token', None), + for_train=for_train, + segment_max_duration=self.cfg.data.get('segment_max_duration', None), + trim=self.cfg.data.get('trim', None), + trim_ref=self.cfg.data.get('trim_ref', None), + trim_top_db=self.cfg.data.get('trim_top_db', None), + trim_frame_length=self.cfg.data.get('trim_frame_length', None), + trim_hop_length=self.cfg.data.get('trim_hop_length', None), + pad_multiple=self.cfg.data.get('pad_multiple', 1), + pitch_augment=self.cfg.data.get('pitch_augment', None), + speech_offset=self.cfg.data.get('speech_offset', None), + train_task=self.cfg.data.get('train_task', "tts"), + seq_pattern=self.cfg.get('seq_pattern', 'delay_parallel'), + use_attention_prior=self.cfg.data.get('use_attention_prior', False), + attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0), + cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0), + lm_vocab_size=self.lm_vocab_size, + num_speech_codebooks=self.num_speech_codebooks, + ) + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + batch_size=batch_size // world_size, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=True + if num_workers > 0 + else False, # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + ) + logging.info(f'build success: {len(dataloader)} {dataset_paths}') + + return dataset, dataloader + + def process_text(self, input_text): + """ + Normalizes text for CER/WER calculation. + Taken from hallucination_eval.py + """ + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + return single_space_text + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_scalars=True, global_step=None) -> Any: + + with torch.no_grad(): + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input_raw, + dec_input_mask_raw, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, # [start of question token, question token len) in [0, enc_mask.size(1)) + lang, + question_texts, + ) = batch + + batch_size = virtual_tokens.size(0) + dec_input = dec_input_raw * 1 # (B, 8, T) # TODO @xueyang: apply clone() method bypasses this unnecessary computation. + dec_input_mask = dec_input_mask_raw * 1 # (B, T) + dec_input_mask[:, :] = 1 # Does not really matter + output_token_list = [] + + end_indices = {} + # pad dec_input (B, 8, T) to 1000 timesteps + max_inference_timesteps = self.cfg.get('max_inference_timesteps', 2000) + # TODO @xueyang: potential bug when max_inference_timesteps < dec_input.shape[2], then dec_input is clipped. + dec_input = torch.nn.functional.pad(dec_input, (0, max_inference_timesteps - dec_input.shape[2]), value=0) + dec_input[:, :, self.decoder_context_len + 1:].zero_() + # TODO @xueyang: why not just declare torch.ones(dec_input_raw.size(0), max_inference_timesteps)? + dec_input_mask = torch.nn.functional.pad( + dec_input_mask, (0, max_inference_timesteps - dec_input_mask.shape[1]), value=1 + ) + + if self.inference_apply_text_cfg and self.inference_apply_audio_cfg: + question_limits = text_limits - virtual_tokens.size(1) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type = "multi_transformers". + context_tokens, question_tokens = context_and_question_tokens + + # text + question_tokens_unconditioned = question_tokens.clone() + time_range = torch.arange(question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device).unsqueeze(0) + question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # audio + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [torch.cat((context_tokens, context_tokens_unconditioned), dim=0), torch.cat((question_tokens, question_tokens_unconditioned), dim=0)] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert self.context_conditioning == "decoder", f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + + # text + context_and_question_tokens_unconditioned = context_and_question_tokens.clone() + time_range = torch.arange(context_and_question_tokens_unconditioned.size(2), device=context_and_question_tokens_unconditioned.device).unsqueeze(0) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # audio + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1:self.decoder_context_len + 1] = self.tokenizer.unk_id # TODO @xueyang: switch to other token id if this one is conflict with text unk. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat((context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + + # clean up useless variables. + del question_limits, question_start, question_end, time_range, question_mask + elif self.inference_apply_text_cfg: + # replace question token IDs with [UNK]'s id. No speech offset for Phoneme's [UNK]. Same op as train. + # instruction token IDs are bpe token IDs directly obtained from self.tokenizer without any offset. + # question token IDs are phoneme and grapheme token IDs and are offset by self.lm_vocab_size + # if under "Phoneme TTS" instruction, so exising no overlaps between instruction and question token IDs. + # question token IDs are bpe token IDs without any offset + # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. + question_limits = text_limits - virtual_tokens.size(1) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type = "multi_transformers". + context_tokens, question_tokens = context_and_question_tokens + question_tokens_unconditioned = question_tokens.clone() + + time_range = torch.arange(question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device).unsqueeze(0) + question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [torch.cat((context_tokens, context_tokens), dim=0), torch.cat((question_tokens, question_tokens_unconditioned), dim=0)] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert self.context_conditioning == "decoder", f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + context_and_question_tokens_unconditioned = context_and_question_tokens.clone() + time_range = torch.arange(context_and_question_tokens_unconditioned.size(2), device=context_and_question_tokens_unconditioned.device).unsqueeze(0) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat((context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + + # clean up useless variables. + del question_limits, question_start, question_end, time_range, question_mask + elif self.inference_apply_audio_cfg: + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance(context_and_question_tokens, list): # indicate that self.encoder_type = "multi_transformers" + context_tokens, question_tokens = context_and_question_tokens + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [torch.cat((context_tokens, context_tokens_unconditioned), dim=0), torch.cat((question_tokens, question_tokens), dim=0)] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert self.context_conditioning == "decoder", f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1:self.decoder_context_len + 1] = self.tokenizer.unk_id # TODO @xueyang: switch to other token id if this one is conflict with text unk. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat((context_and_question_tokens, context_and_question_tokens), dim=0) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + else: + logging.debug( + f"Neither text or audio cfg logits are applied:" + f" self.inference_apply_text_cfg={self.inference_apply_text_cfg}," + f" self.inference_apply_audio_cfg={self.inference_apply_audio_cfg}" + ) + + end_inference_loop_at = None + fwd_bwd_function = get_forward_backward_func() + encoder_output = None + attention_probs_all = [] + start_time = time.time() + for t in range(self.decoder_context_len + 1, dec_input.shape[2] - 1): + # Start at 0 if encoder context, else context_len + if t % 100 == 0: + logging.info("Timestep {}".format(t)) + if t == end_inference_loop_at: + print("All ends detected") + break + + if isinstance(enc_mask, list): + encoder_max_sequence_len = [e.size(1) for e in enc_mask] + else: + encoder_max_sequence_len = enc_mask.size(1) + + # if context_condition is decoder, then t starts at [PAD] token represented as [0] * 8. + # if context_condition is encoder, then t starts at [CLS]. + if t == self.decoder_context_len + 1: + # Run first step manually + output_logits, _, token_and_speech_logits = self.forward( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input[:, :, : t + 1], # tensors representing [CLS] + context audio tokens + [PAD] if context_condition is decoder, otherwise, tensors representing [CLS]. + dec_input_mask[:, : t + 1], # doesn't matter because of all ones. + position_ids, + taskname_ids, + labels=None, + speech_mask=speech_mask, + inference=True, + inference_step=0, + decoder_max_sequence_len=max_inference_timesteps, + encoder_max_sequence_len=encoder_max_sequence_len + ) + encoder_output = token_and_speech_logits[-1] + + if isinstance(encoder_output, list): + encoder_output = [e.transpose(0, 1) for e in encoder_output] + else: + encoder_output = encoder_output.transpose(0, 1) + + else: + # Prepare batch + batch = [ + max_inference_timesteps, + encoder_max_sequence_len, + encoder_output, + enc_mask, + dec_input[:, :, : t + 1], + dec_input_mask[:, : t + 1], + position_ids, + taskname_ids, + speech_mask, + ] + + output_tensor = fwd_bwd_function( + forward_step_func=self.get_forward_output_only_func(), + data_iterator=iter([batch,]), + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=True, + seq_length=t, + micro_batch_size=dec_input.shape[0], + ) + output_logits = output_tensor[0]['output_logits'] # (B, T, V, 8) or (2B, T, V, 8) + token_and_speech_logits = output_tensor[0]['token_and_speech_logits'] + + # when return_all_crossattention is False, attention_probs is None. + if self.frozen_model.enc_dec_model.return_all_crossattention_probs: + attention_probs = token_and_speech_logits[2] + attention_probs_mean = torch.stack(attention_probs).mean(dim=0) # B, 12, 1, enc_timesteps + attention_probs_all.append(attention_probs_mean) + + if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: + # interpolate conditioned and unconditioned logits + token_logits = self.inference_cfg_interpolation_scale * token_and_speech_logits[0][:batch_size] + (1 - self.inference_cfg_interpolation_scale) * token_and_speech_logits[0][batch_size:] + output_speech_logits = self.inference_cfg_interpolation_scale * output_logits[:batch_size] + (1 - self.inference_cfg_interpolation_scale) * output_logits[batch_size:] + else: + token_logits = token_and_speech_logits[0] # (B, T, V) + output_speech_logits = output_logits + + token_logits_currtimestep = token_logits[:, -1, :] # (B, V) + token_preds = token_logits_currtimestep.argmax(dim=1) # (B,) + + if torch.count_nonzero(speech_mask) > 0: + output_logits_currtimestep = ( + output_speech_logits[:, -1, :, :].permute(0, 2, 1).contiguous().view(-1, self.speech_codebook_size) + ) # (B*8, V) + output_logits_currtimestep_conditioned = ( + output_logits[:batch_size][:, -1, :, :].permute(0, 2, 1).contiguous().view(-1, self.speech_codebook_size) + ) + output_logits_currtimestep_unconditioned = ( + output_logits[batch_size:][:, -1, :, :].permute(0, 2, 1).contiguous().view(-1, self.speech_codebook_size) + ) + else: + output_logits_currtimestep = token_logits_currtimestep # (B, V) + output_logits_currtimestep_conditioned = token_logits_currtimestep + output_logits_currtimestep_unconditioned = token_logits_currtimestep + + top_k = self.cfg.get('top_k', 80) + + # (B*8, 80) or (B, 80) + output_logits_currtimestep_topk = torch.topk(output_logits_currtimestep, top_k, dim=1)[0] + + # find indices which are not top k + indices_to_remove = output_logits_currtimestep < output_logits_currtimestep_topk[:, -1].unsqueeze(1) + # (B*8, 1024) or (B, 1024) + + if self.inference_apply_cfg_filter: + output_logits_currtimestep_rescored = output_logits_currtimestep_conditioned.clone() + else: + output_logits_currtimestep_rescored = output_logits_currtimestep.clone() + + output_logits_currtimestep_rescored[indices_to_remove] = -float('Inf') + + # logits interpolation between conditioned and unconditioned logits. + if (self.inference_apply_text_cfg or self.inference_apply_audio_cfg) and self.inference_apply_cfg_filter: + output_logits_currtimestep_rescored = self.inference_cfg_filter_interpolation_scale * output_logits_currtimestep_rescored + (1 - self.inference_cfg_filter_interpolation_scale) * output_logits_currtimestep_unconditioned + + temperature = self.cfg.get('temperature', 0.85) # Set temp 0.01 for greedy decoding + output_logits_currtimestep_rescored = output_logits_currtimestep_rescored / temperature + output_logits_currtimestep_rescored = torch.nn.functional.softmax( + output_logits_currtimestep_rescored, dim=1 + ) + + output_tokens_curr_timestep = torch.multinomial( + output_logits_currtimestep_rescored, num_samples=1 + ) # (B*8, 1) + + if torch.count_nonzero(speech_mask) > 0: + # Convert back to (B, 8) + output_tokens_curr_timestep = output_tokens_curr_timestep.view( + batch_size, self.num_speech_codebooks + ) + + for _b in range(token_preds.shape[0]): + if t > self.decoder_context_len + 10 and token_preds[_b] == self.tokenizer.eos_id: + if _b not in end_indices: + logging.info("End detected for item {}".format(_b) + " at timestep {}".format(t)) + end_indices[_b] = t + if len(end_indices) == token_preds.shape[0]: + end_inference_loop_at = t + self.num_speech_codebooks + + output_token_list.append(output_tokens_curr_timestep) + + # duplicate to 2b dim as input for the next iteration if enabling cfg. + if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: + output_tokens_curr_timestep = torch.cat((output_tokens_curr_timestep, output_tokens_curr_timestep), dim=0) + + if torch.count_nonzero(speech_mask) > 0: + dec_input_next_timestep = output_tokens_curr_timestep * 1 # (B,8) + dec_input_next_timestep[:, 0] = ( + dec_input_next_timestep[:, 0] + self.speech_offset + ) # add offset to first codebook + dec_input[:, :, t + 1] = dec_input_next_timestep * 1 + else: + dec_input[:, 0, t + 1] = output_tokens_curr_timestep.squeeze(1) + # # TF + # if t+1 < 10: + # dec_input[:, :, t + 1] = dec_input_raw[:, :, t+1] + + # end of for loop + output_tokens_combined = torch.stack(output_token_list) # (T, B, 8) if speech else (T, B) + if torch.count_nonzero(speech_mask) > 0: + output_tokens_combined = output_tokens_combined.permute(1, 2, 0) # (B, 8, T) + else: + output_tokens_combined = output_tokens_combined.squeeze(2) + output_tokens_combined = output_tokens_combined.permute(1, 0) # (B, T) + + # consider only autoregressive time, disconsider loading eval models for RTF time + total_process_time = time.time() - start_time + + # Layerwise token error rate + ter_dict = {} + for i in range(self.num_speech_codebooks): + ter_dict[i] = {'hypothesis': [], 'gt': []} + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if 'nemo_sv_model' not in self.additional_models: + nemo_sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + nemo_sv_model = nemo_sv_model.to(device) + nemo_sv_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + nemo_sv_model.eval() + self.additional_models['nemo_sv_model'] = nemo_sv_model + logging.info(f"Loaded SV Model: {nemo_sv_model}") + else: + nemo_sv_model = self.additional_models['nemo_sv_model'] + + if 'asr_model' not in self.additional_models: + asr_model = self.cfg.get("asr_model_name", "stt_multilingual_fastconformer_hybrid_large_pc_blend_eu") + + if "hybrid" in asr_model: + model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel + else: + model = nemo_asr.models.EncDecRNNTBPEModel + asr_model = model.from_pretrained(model_name=asr_model) + asr_model = asr_model.to(device) + asr_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + asr_model.eval() + self.additional_models['asr_model'] = asr_model + logging.info(f"Loaded ASR Model: {asr_model}") + else: + asr_model = self.additional_models['asr_model'] + + asr_model_zh = None + if Lang.zh.value in lang: + if 'asr_model_zh' not in self.additional_models: + asr_model_zh = nemo_asr.models.EncDecRNNTModel.from_pretrained( + model_name="stt_zh_conformer_transducer_large" + ) + asr_model_zh = asr_model_zh.to(device) + asr_model_zh.eval() + self.additional_models['asr_model_zh'] = asr_model_zh + else: + asr_model_zh = self.additional_models['asr_model_zh'] + + if 'wavlm_sv_model' not in self.additional_models: + wavlm_sv_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv') + wavlm_sv_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv') + wavlm_sv_model = wavlm_sv_model.to(device) + wavlm_sv_model = wavlm_sv_model.eval() + self.additional_models['wavlm_sv_model'] = wavlm_sv_model + self.additional_models['wavlm_sv_extractor'] = wavlm_sv_extractor + logging.info(f"Loaded SV Model: {wavlm_sv_model}") + else: + wavlm_sv_model = self.additional_models['wavlm_sv_model'] + wavlm_sv_extractor = self.additional_models['wavlm_sv_extractor'] + + # load MOS estimator model only if True. + if self.estimate_mos: + # load mos estimator. + if 'squim_mos_model' not in self.additional_models: + squim_mos_model_full = SQUIM_SUBJECTIVE.get_model().to(device) + self.additional_models['squim_mos_model'] = squim_mos_model_full + else: + squim_mos_model_full = self.additional_models['squim_mos_model'] + + # load non-matching reference clean audio. + ref_16khz_wav, _ = librosa.load(self.non_matching_ref_audio_filepath, sr=16000) + + # prepare MOS estimator by taking a single audio example as an input. + squim_mos_model = partial( + squim_mos_model_full, + reference=torch.from_numpy(ref_16khz_wav).to(device).unsqueeze(0) + ) + + _exp_dir_path = self.logger.log_dir + _exp_dir_path = _exp_dir_path + '/Sample_Audios' + if not os.path.exists(_exp_dir_path): + os.mkdir(_exp_dir_path) + + squim_mos_list_pred = [] + squim_mos_list_context = [] + squim_mos_list_gt = [] + similarity_list = [] + similarity_list_wavlm = [] + pred_context_similarity_list = [] + pred_context_similarity_list_wavlm = [] + gt_context_similarity_list = [] + gt_context_similarity_list_wavlm = [] + question_type = [] + + # predicting audio + batch_size = output_tokens_combined.shape[0] + test_dataloader_batch_size = batch_size + # self.test_dataloader() is not defined during validation + if isinstance(self.test_dataloader(), torch.utils.data.DataLoader): + test_dataloader_batch_size = self.test_dataloader().batch_size + + # logging attention maps. + # empty attention_probs_all indicates self.frozen_model.enc_dec_model.return_all_crossattention_probs is False. + if len(attention_probs_all) != 0: + attention_probs_all = torch.cat(attention_probs_all, dim=2) # B, 12, dec_timesteps, enc_timesteps + attention_probs_all = attention_probs_all.mean(dim=1) # B, dec_timesteps, enc_timesteps + + for i in range(batch_size): + text_end_step = text_limits[i, 1].item() + text_start_step = text_limits[i, 0].item() + end_index = end_indices.get(i, output_tokens_combined.shape[2]) + if len(attention_probs_all) != 0: + attention_probs_example = attention_probs_all[i][:end_index - (1 + self.decoder_context_len), + text_start_step:text_end_step] # T, enc_timesteps + attention_map = attention_probs_example.float().cpu().numpy().T + alignment_image = plot_alignment_to_numpy_for_speechllm( + attention_map, + phoneme_ver=1, + phoneme_seq=None, + ) + # ctc_loss = self.frozen_model.enc_dec_model.forward_sum_loss( + # attn_logprob=attention_probs_example[None,None,:,:], + # in_lens=torch.tensor([attention_probs_example.shape[1]]).to(device), + # out_lens=torch.tensor([attention_probs_example.shape[0]]).to(device) + # ) + + if global_step is not None: + # During validation, step is simply global_step + i + step = global_step + i + else: + # During inference, step is the index of the sample + step = batch_idx * test_dataloader_batch_size + i + + # print("Ctc Loss: ", step, ctc_loss.item()) + self.logger.experiment.add_image( + "Inf Attention Map", alignment_image, step, dataformats="HWC", + ) + # Save attention image to file + alignment_fp = os.path.join(_exp_dir_path, f'attention_map_{step}.png') + imageio.imwrite(alignment_fp, alignment_image) + + wer_score = 0 + audio_to_pred = [] + audio_to_pred_zh = [] + total_audio_seconds = 0 + for i in range(batch_size): + if global_step is not None: + # During validation, step is simply global_step + i + step = global_step + i + else: + # During inference, step is the index of the sample + step = batch_idx * test_dataloader_batch_size + i + + audio_len = self.decoder_context_len + (labels[i][0][self.decoder_context_len:] != 0).sum().item() + + if torch.count_nonzero(speech_mask) > 0: + dec_input_to_1024 = self.convert_tokens_to_range(dec_input_raw[i, :, 0:audio_len]) + dec_input_to_1024_answer = dec_input_to_1024[:,self.decoder_context_len+1:] + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024_answer) + self.logger.experiment.add_audio("Inf Dec Input Wav", dec_input_wav, step, self.sample_rate) + + predicted_tokens = output_tokens_combined[i] # Should not contain context even if decoder context + if i in end_indices: + logging.info(f"Clipping until end index for audio {i}") + if self.cfg.get('seq_pattern', 'parallel') == 'delay_parallel': + predicted_tokens = predicted_tokens[:, 0 : end_indices[i] - (1 + self.decoder_context_len) + self.num_speech_codebooks] # trim to audio length + else: + predicted_tokens = predicted_tokens[:, 0 : end_indices[i] - (1 + self.decoder_context_len)] # trim to audio length + + pred_img = predicted_tokens.data.cpu().float().numpy() + dec_inp_img = dec_input_to_1024.data.cpu().float().numpy() + start_time = time.time() + predicted_tokens = self.convert_tokens_to_range(predicted_tokens, apply_offset_correction=False) + predicted_wav = self.decode_wav_from_codec_model(predicted_tokens) + # accumulate audio length in seconds and process time in seconds to the RTF + total_process_time = total_process_time + (time.time() - start_time) + total_audio_seconds = total_audio_seconds + predicted_wav.size(-1) / self.sample_rate + + self.logger.experiment.add_audio("Inf Pred Wav", predicted_wav, step, self.sample_rate) + self.logger.experiment.add_image( + "Inf Pred Tokens", plot_codec_to_numpy(pred_img), step, dataformats="HWC", + ) + self.logger.experiment.add_image( + "Inf Dec Input Tokens", plot_codec_to_numpy(dec_inp_img), step, dataformats="HWC", + ) + + # save predicted_wav and gt_wav to a wav files in dir_path + if global_step is not None: + # During training, overwrite the wav file from the previous validation + wav_num = i + else: + wav_num = step + + audio_fp_pred = os.path.join(_exp_dir_path, f'predicted_wav_{wav_num}.wav') + sf.write(audio_fp_pred, predicted_wav.cpu().numpy(), self.sample_rate) + audio_fp_gt = os.path.join(_exp_dir_path, f'dec_input_wav_{wav_num}.wav') + sf.write(audio_fp_gt, dec_input_wav.cpu().numpy(), self.sample_rate) + + # speaker verification evaluation using nemo model + spk_embedding_pred = nemo_sv_model.get_embedding(audio_fp_pred) + spk_embedding_pred = spk_embedding_pred.cpu().detach().numpy().flatten() + spk_embedding_gt = nemo_sv_model.get_embedding(audio_fp_gt) + spk_embedding_gt = spk_embedding_gt.cpu().detach().numpy().flatten() + similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) + ) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Cossim Individual Sample', similarity, step) + similarity_list.append(similarity) + + # speaker verification evaluation using wavlm model + gt_16khz_wav, _ = librosa.load(audio_fp_gt, sr=16000) + pred_16khz_wav, _ = librosa.load(audio_fp_pred, sr=16000) + inputs_wavlm = wavlm_sv_extractor([pred_16khz_wav, gt_16khz_wav], padding=True, return_tensors="pt", sampling_rate=16000) + for key in inputs_wavlm.keys(): + inputs_wavlm[key] = inputs_wavlm[key].to(device) + + with torch.no_grad(): + wavlm_embeddings = wavlm_sv_model(**inputs_wavlm).embeddings + wavlm_embeddings = torch.nn.functional.normalize(wavlm_embeddings, dim=-1).cpu() + + spk_embedding_pred_wavlm = wavlm_embeddings[0].cpu().detach().numpy().flatten() + spk_embedding_gt_wavlm = wavlm_embeddings[1].cpu().detach().numpy().flatten() + similarity_wavlm = np.dot(spk_embedding_pred_wavlm, spk_embedding_gt_wavlm) / ( + np.linalg.norm(spk_embedding_pred_wavlm) * np.linalg.norm(spk_embedding_gt_wavlm) + ) + similarity_list_wavlm.append(similarity_wavlm) + + if lang[i] == Lang.zh.value: + audio_to_pred_zh.append({"step": i, "audio": audio_fp_pred}) + audio_to_pred_zh.append({"step": i, "audio": audio_fp_gt}) + else: + audio_to_pred.append({"step": i, "audio": audio_fp_pred}) + audio_to_pred.append({"step": i, "audio": audio_fp_gt}) + + if isinstance(context_and_question_tokens, list): + context_tokens, question_tokens = context_and_question_tokens + input_token_list = [ + question_tokens[i, 0, j].item() + for j in range(context_and_question_tokens_lens[1][i].item()) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][i] + context_tokens = context_tokens[i][:, :context_end_step] + else: + input_token_list = [ + context_and_question_tokens[i, 0, j].item() + for j in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + context_tokens = context_and_question_tokens[i][:, :context_end_step] + + spk_embedding_context = spk_embedding_gt + spk_embedding_context_wavlm = spk_embedding_gt_wavlm + if self.decoder_context_len > 0: + context_tokens = dec_input_to_1024[:, :self.decoder_context_len+1] + context_wav = self.decode_wav_from_codec_model(context_tokens) + elif context_end_step > 1: + is_speech_context = context_tokens[1,:].sum().item() > 0 + if is_speech_context: + context_tokens = self.convert_tokens_to_range(context_tokens, pattern=self.context_pattern) + context_wav = self.decode_wav_from_codec_model(context_tokens) + else: + context_wav = None + _context_token_list = [ v.item() for v in context_tokens[0, :] ] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Context Text", _context_text, self.global_step) + + else: + context_wav = None + # raise NotImplementedError("During prediction, there was no context found.") + if context_wav is not None: + self.logger.experiment.add_audio("Context Wav", context_wav, step, self.sample_rate) + context_wav_fp = os.path.join(_exp_dir_path, f'context_wav_{wav_num}.wav') + sf.write(context_wav_fp, context_wav.cpu().numpy(), self.sample_rate) + # titanet + spk_embedding_context = nemo_sv_model.get_embedding(context_wav_fp) + spk_embedding_context = spk_embedding_context.cpu().detach().numpy().flatten() + # wavlm + context_wavlm_wav, _ = librosa.load(context_wav_fp, sr=16000) + inputs_wavlm = wavlm_sv_extractor([context_wavlm_wav], padding=True, return_tensors="pt", sampling_rate=16000) + for key in inputs_wavlm.keys(): + inputs_wavlm[key] = inputs_wavlm[key].to(device) + + with torch.no_grad(): + wavlm_embeddings = wavlm_sv_model(**inputs_wavlm).embeddings + wavlm_embeddings = torch.nn.functional.normalize(wavlm_embeddings, dim=-1).cpu() + + spk_embedding_context_wavlm = wavlm_embeddings[0].cpu().detach().numpy().flatten() + + pred_similarity_context = np.dot(spk_embedding_context, spk_embedding_pred) / ( + np.linalg.norm(spk_embedding_context) * np.linalg.norm(spk_embedding_pred) + ) + gt_similarity_context = np.dot(spk_embedding_context, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_context) * np.linalg.norm(spk_embedding_gt) + ) + + pred_similarity_context_wavlm = np.dot(spk_embedding_context_wavlm, spk_embedding_pred_wavlm) / ( + np.linalg.norm(spk_embedding_context_wavlm) * np.linalg.norm(spk_embedding_pred_wavlm) + ) + gt_similarity_context_wavlm = np.dot(spk_embedding_context_wavlm, spk_embedding_gt_wavlm) / ( + np.linalg.norm(spk_embedding_context_wavlm) * np.linalg.norm(spk_embedding_gt_wavlm) + ) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Cossim Context Pred', pred_similarity_context, step) + self.logger.experiment.add_scalar(f'Inf SV Cossim Context GT', gt_similarity_context, step) + pred_context_similarity_list.append(pred_similarity_context) + gt_context_similarity_list.append(gt_similarity_context) + pred_context_similarity_list_wavlm.append(pred_similarity_context_wavlm) + gt_context_similarity_list_wavlm.append(gt_similarity_context_wavlm) + + task_question = self.frozen_model.tokenizer.ids_to_text( + [v[1] for v in input_token_list if v[1] < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Inf Task Question", task_question, step) + if "Phoneme TTS" in task_question: + question_type.append("Phoneme TTS") + elif "Text to speech this" in task_question: + question_type.append("Text to speech this") + else: + question_type.append("Other") + + task_question_phoneme_tokens = [ + v[1] - self.lm_vocab_size for v in input_token_list if v[1] >= self.lm_vocab_size + ] + if len(task_question_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(task_question_phoneme_tokens) + self.logger.experiment.add_text("Inf Task Question Phoneme Text", phoneme_text, step) + + # store predicted_tokens for each layer to compute token error rate + for layer_idx in range(self.num_speech_codebooks): + ter_dict[layer_idx]['hypothesis'].append(predicted_tokens[layer_idx].cpu().numpy().tolist()) + ter_dict[layer_idx]['gt'].append(dec_input_to_1024_answer[layer_idx].cpu().numpy().tolist()) + + # estimate MOS scores. + if self.estimate_mos: + squim_mos_score_pred = squim_mos_model(torch.from_numpy(pred_16khz_wav).to(device).unsqueeze(0)).item() + squim_mos_score_gt = squim_mos_model(torch.from_numpy(gt_16khz_wav).to(device).unsqueeze(0)).item() + if context_wav is not None: + squim_mos_score_context = squim_mos_model(context_wav.to(device).unsqueeze(0)).item() + squim_mos_list_context.append(squim_mos_score_context) + squim_mos_list_pred.append(squim_mos_score_pred) + squim_mos_list_gt.append(squim_mos_score_gt) + else: + r = labels[i, 0].long() + nzm = r != 0 + r = r.tolist()[:-1] + nzm = nzm[:-1] + h = output_tokens_combined[i].long() * nzm + h = h.tolist() + cur_wer_score = editdistance.eval(r, h) + if log_scalars: + self.logger.experiment.add_scalar('WER', cur_wer_score, step) + logging.info(f"current wer score : {cur_wer_score}") + wer_score += cur_wer_score + if wer_score > 0: + wer_score /= batch_size + if log_scalars: + self.logger.experiment.add_scalar('AVG WER', wer_score, step) + logging.info(f"average wer score : {wer_score}") + + # compute token error rate for each layer + if log_scalars: + for layer_idx in range(self.num_speech_codebooks): + wer = word_error_rate(ter_dict[layer_idx]['hypothesis'], ter_dict[layer_idx]['gt'], use_cer=True) + self.logger.experiment.add_scalar(f'Inf TER Layer {layer_idx}', wer, 0) + + greedy_transcripts = [] + if len(audio_to_pred) > 0: + greedy_transcripts.extend(asr_model.transcribe([i["audio"] for i in audio_to_pred])[0]) + if len(audio_to_pred_zh) > 0: + greedy_transcripts.extend(asr_model_zh.transcribe([i["audio"] for i in audio_to_pred_zh])[0]) + + all_audio_to_pred = audio_to_pred + audio_to_pred_zh + # Note WER over the batch is not equal to WER(sample) / batch_size, but approx. here + + # These are between ASR outputs of GT audio and predicted audio + wer_batch = [] + cer_batch = [] + cer_phoneme = [] + wer_phoneme = [] + cer_tts = [] + wer_tts = [] + + # These are between ASR output of Pred audio and GT text + wer_batch_gt = [] + cer_batch_gt = [] + cer_phoneme_gt = [] + wer_phoneme_gt = [] + cer_tts_gt = [] + wer_tts_gt = [] + + for i in range(0, len(greedy_transcripts) - 1, 2): + assert all_audio_to_pred[i]["step"] == all_audio_to_pred[i + 1]["step"] + # step = batch_idx * self.test_dataloader().batch_size + all_audio_to_pred[i]["step"] + step = batch_idx * test_dataloader_batch_size + all_audio_to_pred[i]["step"] + question_text = question_texts[i//2] + + # No need to process text since both are ASR outputs + cer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=True) + wer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=False) + + # Processing text since one is ASR output and the other is the GT text + cer_gt = word_error_rate([self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=True) + wer_gt = word_error_rate([self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=False) + + self.logger.experiment.add_text("Inf Predicted Text", greedy_transcripts[i], step) + self.logger.experiment.add_text("Inf GT Text", greedy_transcripts[i + 1], step) + self.logger.experiment.add_text("Inf Question Text", question_text, step) + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER Transcript', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER Transcript', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT Transcript', cer_gt, step) + cer_batch.append(cer_sample) + wer_batch.append(wer_sample) + cer_batch_gt.append(cer_gt) + wer_batch_gt.append(wer_gt) + if question_type[all_audio_to_pred[i]["step"]] == "Phoneme TTS": + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER Phoneme Task', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER Phoneme Task', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT Phoneme Task', cer_gt, step) + cer_phoneme.append(cer_sample) + wer_phoneme.append(wer_sample) + cer_phoneme_gt.append(cer_gt) + wer_phoneme_gt.append(wer_gt) + elif question_type[all_audio_to_pred[i]["step"]] == "Text to speech this": + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER TTS Task', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER TTS Task', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT TTS Task', cer_gt, step) + cer_tts.append(cer_sample) + wer_tts.append(wer_sample) + cer_tts_gt.append(cer_gt) + wer_tts_gt.append(wer_gt) + + # compute average similarity + similarity_avg = np.mean(similarity_list) + pred_context_similarity_avg = np.mean(pred_context_similarity_list) + gt_context_similarity_avg = np.mean(gt_context_similarity_list) + similarity_avg_wavlm = np.mean(similarity_list_wavlm) + pred_context_similarity_avg_wavlm = np.mean(pred_context_similarity_list_wavlm) + gt_context_similarity_avg_wavlm = np.mean(gt_context_similarity_list_wavlm) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Avg Cossim', similarity_avg, batch_idx) + self.predict_step_outputs.append( + { + 'titanet_avg_cossim': similarity_avg, + 'titanet_avg_cossim_context_pred': pred_context_similarity_avg, + 'titanet_avg_cossim_context_gt': gt_context_similarity_avg, + 'wavlm_avg_cossim': similarity_avg_wavlm, + 'wavlm_avg_cossim_context_pred': pred_context_similarity_avg_wavlm, + 'wavlm_avg_cossim_context_gt': gt_context_similarity_avg_wavlm, + 'squim_mos_pred': np.mean(squim_mos_list_pred) if len(squim_mos_list_pred) > 0 else None, + 'squim_mos_context': np.mean(squim_mos_list_context) if len(squim_mos_list_context) > 0 else None, + 'squim_mos_gt': np.mean(squim_mos_list_gt) if len(squim_mos_list_gt) > 0 else None, + 'cer_transcript': np.mean(cer_batch), + 'wer_transcript': np.mean(wer_batch), + 'cer_phoneme': np.mean(cer_phoneme) if len(cer_phoneme) > 0 else None, + 'wer_phoneme': np.mean(wer_phoneme) if len(wer_phoneme) > 0 else None, + 'cer_tts': np.mean(cer_tts) if len(cer_tts) > 0 else None, + 'wer_tts': np.mean(wer_tts) if len(wer_tts) > 0 else None, + 'cer_transcript_gt': np.mean(cer_batch_gt), + 'wer_transcript_gt': np.mean(wer_batch_gt), + 'cer_phoneme_gt': np.mean(cer_phoneme_gt) if len(cer_phoneme_gt) > 0 else None, + 'wer_phoneme_gt': np.mean(wer_phoneme_gt) if len(wer_phoneme_gt) > 0 else None, + 'cer_tts_gt': np.mean(cer_tts_gt) if len(cer_tts_gt) > 0 else None, + 'wer_tts_gt': np.mean(wer_tts_gt) if len(wer_tts_gt) > 0 else None, + "RTF": total_process_time / total_audio_seconds, + } + ) + + #TODO @xueyang: PTL 2.0+ patch. Signature of method `on_predict_epoch_end` does not match signature of the base method in PTL class 'ModelHooks'. + # Remove the `outputs` param and choose `self.predict_step_output` instead. + def on_predict_epoch_end(self, outputs: List[Any]) -> None: + + gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] + all_preds = list(itertools.chain(*[item['preds_text'] for item in outputs[0]])) + all_labels = list(itertools.chain(*[item['labels_text'] for item in outputs[0]])) + all_inputs = list(itertools.chain(*[item['input_text'] for item in outputs[0]])) + + assert len(all_preds) == len(all_labels) + assert len(all_preds) == len(all_inputs) + + # Gather inputs, predictions, and ground truths from all workers + torch.distributed.all_gather_object( + gather_results, + [(input, pred, label) for (input, pred, label) in zip(all_inputs, all_preds, all_labels)], + group=parallel_state.get_data_parallel_group(), + ) + + # Deduplicate sentences that may have been distributed across multiple data parallel ranks. + if parallel_state.get_data_parallel_rank() == 0: + gather_results_dedup = list(set(itertools.chain(*gather_results))) + + input_prediction_pair = [] + correct = 0 + for (input, pred, label) in gather_results_dedup: + input_prediction_pair.append((input, pred)) + if label: + if pred == label: + correct += 1 + + acc = correct / len(gather_results_dedup) if all_labels[0] else None + logging.info(f'Prediction results: {acc}') + logging.info(f'Test finish') diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 057d9f49546d..f25d5da9cbb9 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torchaudio from einops import rearrange from transformers import AutoModel @@ -203,6 +204,7 @@ def __init__( stride: int = 1, dilation: int = 1, padding: Optional[int] = None, + activation: Optional[str] = None ): super().__init__() if not padding: @@ -217,6 +219,10 @@ def __init__( padding_mode="reflect", ) self.conv = nn.utils.weight_norm(conv) + if activation is not None: + self.activation = CodecActivation(activation=activation, channels=out_channels) + else: + self.activation = None @property def input_types(self): @@ -237,6 +243,8 @@ def remove_weight_norm(self): @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) + if self.activation is not None: + out = self.activation(out) out = mask_sequence_tensor(out, input_len) return out @@ -433,6 +441,254 @@ def forward(self, audio_real, audio_gen): return scores_real, scores_gen, fmaps_real, fmaps_gen +class SSLModel(NeuralModule): + def __init__(self, slm_model_name): + super().__init__() + self.ssl_model = AutoModel.from_pretrained(slm_model_name) + + def forward(self, *args, **kwargs): + return self.ssl_model(*args, **kwargs) + + +class SLMDiscriminator(NeuralModule): + """SLM Discriminator as in StyleTTS2 paper. + Adapted from https://github.com/yl4579/StyleTTS2/blob/5cedc71c333f8d8b8551ca59378bdcc7af4c9529/losses.py#L193""" + + def __init__(self, + slm_model_name="microsoft/wavlm-base-plus", + slm_sr=16000, + input_sr=22050, + slm_hidden=768, + slm_layers=13, + initial_channel=64, + use_spectral_norm=False, + lrelu_slope=0.1): + super().__init__() + + self.lrelu_slope = lrelu_slope + + # define slm model + self.slm_model = SSLModel(slm_model_name) + self.slm_model.ssl_model.feature_extractor._requires_grad = False + + # Freeze slm model + self.slm_model.freeze() + + self.resample = torchaudio.transforms.Resample(input_sr, slm_sr) + + norm_f = nn.utils.weight_norm if use_spectral_norm == False else nn.utils.spectral_norm + self.pre = norm_f(nn.Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)) + + self.convs = nn.ModuleList([ + norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)), + norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)), + norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)), + ]) + + self.conv_post = norm_f(nn.Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) + + def _forward(self, x): + x = self.slm_model(input_values=self.resample(x), output_hidden_states=True).hidden_states + x = torch.stack(x, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2) + + x = self.pre(x) + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x.unsqueeze(-1)) + + x = self.conv_post(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def forward(self, audio_real, audio_gen): + + y_d_r, fmap_r = self._forward(audio_real) + y_d_g, fmap_g = self._forward(audio_gen) + + return [y_d_r.unsqueeze(1)], [y_d_g.unsqueeze(1)], [fmap_r], [fmap_g] + + +class DiscriminatorSTFT(NeuralModule): + """ + Discriminator network from EnCodec for Complex STFT input, but without dilations. + + Args: + filters: number of filters to use in Conv2d layers + lrelu_slope: Slope to use for activations. Leaky relu with slope of 0.1 or 0.2 is recommended for the + stability of the feature matching loss + """ + + def __init__(self, filters: int = 32, lrelu_slope: float = 0.1): + super().__init__() + + self.activation = nn.LeakyReLU(lrelu_slope) + self.conv_layers = nn.ModuleList( + [ + Conv2dNorm(2, filters, kernel_size=(3, 9)), + Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), + Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), + Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), + Conv2dNorm(filters, filters, kernel_size=(3, 3)), + ] + ) + self.conv_post = Conv2dNorm(filters, 1, kernel_size=(3, 3)) + + @property + def input_types(self): + return { + "spec": NeuralType(('B', 'C', 'T_spec', 'D'), VoidType()), + } + + @property + def output_types(self): + return { + "scores": NeuralType(('B', 'C', 'T_spec'), VoidType()), + "fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())], + } + + @typecheck() + def forward(self, spec): + fmap = [] + + # [batch, 2, T_spec, fft] + out = spec + for conv in self.conv_layers: + # [batch, filters, T_spec, fft // strides] + out = conv(inputs=out) + out = self.activation(out) + fmap.append(out) + # [batch, 1, T_spec, fft // 8] + scores = self.conv_post(inputs=out) + fmap.append(scores) + scores = rearrange(scores, "B 1 T C -> B C T") + + return scores, fmap + + +class MultiBandDiscriminatorSTFT(NeuralModule): + """ + Multi-band STFT discriminator proposed in DAC (https://arxiv.org/abs/2306.06546). + + Computes the complex STFT for a given resolution and splits it into sub-bands, + which are given to separate discriminator networks. + + Args: + resolution: STFT resolution, provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length) + stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). + The floats are in the range [0, 1] representing the fraction of all stft bands. + For example for n_fft=1024, the stft output has 513 dimensions. + For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. + """ + + def __init__(self, resolution: Tuple[int], stft_bands: Iterable[Tuple[int]]): + super().__init__() + + self.n_fft, self.hop_length, self.win_length = resolution + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.discriminators = nn.ModuleList([DiscriminatorSTFT() for _ in stft_bands]) + n_stft = self.n_fft // 2 + 1 + self.stft_bands = [(int(band[0] * n_stft), int(band[1] * n_stft)) for band in stft_bands] + + def compute_stft(self, audio): + # [B, fft, T_spec] + fft = torch.stft( + audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + normalized=True, + center=True, + return_complex=True, + ) + fft = rearrange(fft, "B fft T -> B T fft") + # [batch, 2, T_spec, fft] + out = torch.stack([fft.real, fft.imag], dim=1) + return out + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores_list": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "fmaps_list": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + } + + @typecheck() + def forward(self, audio): + scores_list = [] + fmap_list = [] + spec = self.compute_stft(audio) + for band, disc in zip(self.stft_bands, self.discriminators): + spec_band = spec[:, :, :, band[0] : band[1]] + score, fmap = disc(spec=spec_band) + scores_list.append(score) + fmap_list.append(fmap) + + return scores_list, fmap_list + + +class MultiResolutionDiscriminatorSTFT(NeuralModule): + """ + Multi-resolution discriminator which creates a multi-band discriminator for each input resolution. + + Args: + resolutions: List of STFT resolutions, each resolution provided as a tuple of 3 integers ordered + (num_fft, hop_length, window_length) + stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). + The floats are in the range [0, 1] representing the fraction of all stft bands. + For example for n_fft=1024, the stft output has 513 dimensions. + For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. + """ + + def __init__(self, resolutions: Iterable[Tuple[int]], stft_bands: Iterable[Tuple[int]]): + super().__init__() + self.discriminators = nn.ModuleList( + [MultiBandDiscriminatorSTFT(resolution=resolution, stft_bands=stft_bands) for resolution in resolutions] + ) + + @property + def input_types(self): + return { + "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), + } + + @property + def output_types(self): + return { + "scores_real": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "scores_gen": [NeuralType(('B', 'C', 'T_spec'), VoidType())], + "fmaps_real": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + "fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], + } + + @typecheck() + def forward(self, audio_real, audio_gen): + scores_real = [] + scores_gen = [] + fmaps_real = [] + fmaps_gen = [] + + for disc in self.discriminators: + score_real_i, fmap_real_i = disc(audio=audio_real) + scores_real = scores_real + score_real_i + fmaps_real = fmaps_real + fmap_real_i + + score_gen_i, fmap_gen_i = disc(audio=audio_gen) + scores_gen = scores_gen + score_gen_i + fmaps_gen = fmaps_gen + fmap_gen_i + + return scores_real, scores_gen, fmaps_real, fmaps_gen + class DiscriminatorSTFT(NeuralModule): """ @@ -1064,29 +1320,152 @@ def forward(self, inputs, input_len): return out -class HiFiGANResBlock(NeuralModule): +class ResidualBlockV2(NeuralModule): """ - Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. Args: channels: Input dimension. - kernel_size: Kernel size of the residual blocks. - dilations: List of dilations. One residual block will be created for each dilation in the list. - activation: Activation for the residual blocks. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + dilation: Dilation of the residual convolutions. + dropout_rate: Dropout to apply to residuals. + activation: Activation to apply in between residual convolutions. """ - def __init__(self, channels: int, kernel_size: int, dilations: Iterable[int], activation: str): - super().__init__() + def __init__( + self, + channels: int, + filters: int, + kernel_size: int = 3, + activation: str = "lrelu", + ): + super(ResidualBlockV2, self).__init__() - self.res_blocks = nn.ModuleList( - [ - ResidualBlock( - channels=channels, - filters=channels, - kernel_size=kernel_size, - dilation=dilation, - activation=activation, - ) + self.input_conv = Conv1dNorm( + in_channels=channels, out_channels=filters, kernel_size=kernel_size, activation=activation + ) + self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) + self.output_activation = CodecActivation(activation=activation, channels=channels) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} + + @property + def output_types(self): + return {"out": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} + + @typecheck() + def forward(self, inputs, input_len): + res = self.input_conv(inputs=inputs, input_len=input_len) + res = self.skip_conv(inputs=res, input_len=input_len) + out = inputs + res + out = self.output_activation(out) + return out + + +class ResidualBlockV3(NeuralModule): + """ + The residual block structure defined by the HiFi-GAN V1 and V2 configurations. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + dilation: Dilation of the residual convolutions. + dropout_rate: Dropout to apply to residuals. + activation: Activation to apply in between residual convolutions. + """ + + def __init__( + self, + channels: int, + filters: int, + down_sample_rate: int, + kernel_size: int = 3, + activation: str = "lrelu", + ): + super(ResidualBlockV3, self).__init__() + + if down_sample_rate > 1: + self.down_sample_rate = down_sample_rate + self.down_sample_conv = Conv1dNorm( + in_channels=channels, out_channels=filters, kernel_size=kernel_size, stride=self.down_sample_rate + ) + self.down_sample_activation = CodecActivation(activation=activation, channels=filters) + channels = filters + else: + self.down_sample_rate = None + self.down_sample_conv = None + self.down_sample_activation = None + + self.input_conv = Conv1dNorm( + in_channels=channels, out_channels=filters, kernel_size=kernel_size + ) + self.skip_activation = CodecActivation(activation=activation, channels=filters) + self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) + self.output_activation = CodecActivation(activation=activation, channels=channels) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()) + } + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "out_len": NeuralType(tuple('B'), LengthsType()) + } + + @typecheck() + def forward(self, inputs, input_len): + if self.down_sample_rate is not None: + inputs = self.down_sample_conv(inputs=inputs, input_len=input_len) + inputs = self.down_sample_activation(inputs) + input_len = input_len // self.down_sample_rate + + skip_input = self.input_conv(inputs=inputs, input_len=input_len) + skip_input = self.skip_activation(skip_input) + res = self.skip_conv(inputs=skip_input, input_len=input_len) + out = inputs + res + out = self.output_activation(out) + return out, input_len + + +class HiFiGANResBlock(NeuralModule): + """ + Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. + + Args: + channels: Input dimension. + kernel_size: Kernel size of the residual blocks. + dilations: List of dilations. One residual block will be created for each dilation in the list. + activation: Activation for the residual blocks. + """ + + def __init__(self, channels: int, kernel_size: int, dilations: Iterable[int], activation: str): + super().__init__() + + self.res_blocks = nn.ModuleList( + [ + ResidualBlock( + channels=channels, + filters=channels, + kernel_size=kernel_size, + dilation=dilation, + activation=activation, + ) for dilation in dilations ] ) @@ -1442,6 +1821,51 @@ def forward(self, audio, audio_len): return spec, spec_len +class STFTProcessor(NeuralModule): + def __init__(self, n_fft, win_length, hop_length, log_guard=1.0): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.log_guard = log_guard + self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), + "spec_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + spec_len = audio_len // self.hop_length + audio_padded = torch.nn.functional.pad(audio, (self.stft_pad_amount, self.stft_pad_amount), "reflect") + # [B, n_fft, T_spec] + fft = torch.stft( + audio_padded, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + return_complex=True, + center=False + ) + fft_mag = torch.abs(fft) + fft_mag_log = torch.log(fft_mag + self.log_guard) + fft_mag_log = mask_sequence_tensor(fft_mag_log, spec_len) + return fft_mag_log, spec_len + + class ResNetEncoder(NeuralModule): """ Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing @@ -1514,6 +1938,168 @@ def forward(self, inputs, input_len): return encoded +class ResNetEncoderV2(NeuralModule): + def __init__( + self, + in_channels, + out_channels, + num_layers, + hidden_channels, + filters, + kernel_size=3, + activation="lrelu" + ): + super(ResNetEncoderV2, self).__init__() + + self.pre_conv = Conv1dNorm( + in_channels=in_channels, + out_channels=hidden_channels, + kernel_size=kernel_size + ) + self.pre_act = CodecActivation(activation, channels=hidden_channels) + self.res_blocks = nn.ModuleList([ + ResidualBlockV2( + channels=hidden_channels, + filters=filters, + kernel_size=kernel_size, + activation=activation + ) + for _ in range(num_layers) + ]) + self.post_conv = Conv1dNorm( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=kernel_size + ) + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + self.post_conv.remove_weight_norm() + for res_layer in self.res_layers: + res_layer.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, inputs, input_len): + encoded = self.pre_conv(inputs=inputs, input_len=input_len) + encoded = self.pre_act(encoded) + for res_block in self.res_blocks: + encoded = res_block(inputs=encoded, input_len=input_len) + encoded = self.post_conv(inputs=encoded, input_len=input_len) + return encoded, input_len + + +class ResNetEncoderV3(NeuralModule): + def __init__( + self, + in_channels, + out_channels, + filter_list, + stride_list, + kernel_size=3, + activation="lrelu" + ): + super(ResNetEncoderV3, self).__init__() + + input_dim = filter_list[0] + self.pre_conv = Conv1dNorm( + in_channels=in_channels, + out_channels=input_dim, + kernel_size=kernel_size + ) + self.pre_act = CodecActivation(activation, channels=input_dim) + self.res_blocks = nn.ModuleList([]) + for (filters, stride) in zip(filter_list, stride_list): + res_block = ResidualBlockV3( + channels=input_dim, + filters=filters, + down_sample_rate=stride, + kernel_size=kernel_size, + activation=activation + ) + self.res_blocks.append(res_block) + input_dim = filters + + self.post_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=out_channels, + kernel_size=kernel_size + ) + + def remove_weight_norm(self): + self.pre_conv.remove_weight_norm() + self.post_conv.remove_weight_norm() + for res_layer in self.res_layers: + res_layer.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, inputs, input_len): + encoded = self.pre_conv(inputs=inputs, input_len=input_len) + encoded = self.pre_act(encoded) + encoded_len = input_len + for res_block in self.res_blocks: + encoded, encoded_len = res_block(inputs=encoded, input_len=encoded_len) + encoded = self.post_conv(inputs=encoded, input_len=encoded_len) + return encoded, encoded_len + + +class SpectrogramEncoder(NeuralModule): + def __init__(self, spec_processor, encoder): + super(SpectrogramEncoder, self).__init__() + self.spec_processor = spec_processor + self.encoder = encoder + + def remove_weight_norm(self): + self.encoder.remove_weight_norm() + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + spec, spec_len = self.spec_processor(audio=audio, audio_len=audio_len) + encoded, encoded_len = self.encoder(inputs=spec, input_len=spec_len) + return encoded, encoded_len + + class FullBandMelEncoder(NeuralModule): """ Encoder which encodes the entire mel spectrogram with a single encoder network. @@ -1617,3 +2203,243 @@ def forward(self, audio, audio_len): # [B, C, T] encoded = torch.cat(outputs, dim=1) return encoded, spec_len + + +class DownSampleResidualBlock(NeuralModule): + """ + The residual block structure defined by the HiFi-GAN V1 and V2 configurations. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + dilation: Dilation of the residual convolutions. + dropout_rate: Dropout to apply to residuals. + activation: Activation to apply in between residual convolutions. + """ + + def __init__( + self, + channels: int, + filters: int, + kernel_size: int, + down_sample_rate: int, + down_sample_kernel_size: int, + activation: str = "lrelu", + ): + super(DownSampleResidualBlock, self).__init__() + + self.down_sample_rate = down_sample_rate + self.down_sample_conv = Conv1dNorm( + in_channels=channels, + out_channels=filters, + kernel_size=down_sample_kernel_size, + stride=self.down_sample_rate, + activation=activation + ) + self.res_block = ResidualBlockV2( + channels=filters, filters=filters, kernel_size=kernel_size, activation=activation + ) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()) + } + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "out_len": NeuralType(tuple('B'), LengthsType()) + } + + @typecheck() + def forward(self, inputs, input_len): + output_len = input_len // self.down_sample_rate + out = self.down_sample_conv(inputs=inputs, input_len=output_len) + out = self.res_block(inputs=out, input_len=output_len) + return out, output_len + + +class STFTResidualBlock(NeuralModule): + """ + The residual block structure defined by the HiFi-GAN V1 and V2 configurations. + + Args: + channels: Input dimension. + filters: Number of channels in the residual convolutions. + kernel_size: Kernel size of the residual convolutions. + dilation: Dilation of the residual convolutions. + dropout_rate: Dropout to apply to residuals. + activation: Activation to apply in between residual convolutions. + """ + + def __init__( + self, + resolution, + input_dim, + filters, + kernel_size, + down_sample_rate, + down_sample_kernel_size, + activation, + ): + super(STFTResidualBlock, self).__init__() + + self.down_sample_rate = down_sample_rate + self.down_sample_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=filters, + kernel_size=down_sample_kernel_size, + stride=self.down_sample_rate, + activation=activation + ) + + n_fft, hop_length, win_length = resolution + stft_dim = n_fft // 2 + 1 + self.spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) + self.spec_conv = Conv1dNorm(in_channels=stft_dim, out_channels=filters, kernel_size=kernel_size) + self.spec_act = CodecActivation(activation=activation, channels=filters) + + self.res_block = ResidualBlockV2( + channels=filters, filters=filters, kernel_size=kernel_size, activation=activation + ) + + def remove_weight_norm(self): + self.input_conv.remove_weight_norm() + self.skip_conv.remove_weight_norm() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'C', 'T'), VoidType()), + "input_len": NeuralType(tuple('B'), LengthsType()), + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), + "out_len": NeuralType(tuple('B'), LengthsType()) + } + + @typecheck() + def forward(self, inputs, input_len, audio, audio_len): + out_len = input_len // self.down_sample_rate + out = self.down_sample_conv(inputs=inputs, input_len=out_len) + + spec, _ = self.spec_processor(audio=audio, audio_len=audio_len) + spec_res = self.spec_conv(inputs=spec, input_len=out_len) + out = out + spec_res + out = self.spec_act(out) + + out = self.res_block(inputs=out, input_len=out_len) + return out, out_len + + +class MultiResolutionSTFTEncoder(NeuralModule): + def __init__( + self, + resolutions, + filter_list, + down_sample_filter_list, + out_dim, + kernel_size=3, + down_sample_kernel_size=5, + activation="lrelu" + ): + super(MultiResolutionSTFTEncoder, self).__init__() + assert len(resolutions) == len(filter_list) + + n_fft, hop_length, win_length = resolutions[0] + input_filters = filter_list[0] + input_dim = n_fft // 2 + 1 + self.pre_spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) + self.pre_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=input_filters, + kernel_size=kernel_size, + activation=activation + ) + self.pre_res_block = ResidualBlockV2( + channels=input_filters, + filters=input_filters, + kernel_size=kernel_size, + activation=activation + ) + input_dim = input_filters + self.stft_res_blocks = nn.ModuleList([]) + for resolution, filters in zip(resolutions[1:], filter_list[1:]): + stft_res_block = STFTResidualBlock( + resolution=resolution, + input_dim=input_dim, + down_sample_rate=2, + filters=filters, + kernel_size=kernel_size, + down_sample_kernel_size=down_sample_kernel_size, + activation=activation, + ) + self.stft_res_blocks.append(stft_res_block) + input_dim = filters + + self.down_sample_res_blocks = nn.ModuleList([]) + for filters in down_sample_filter_list: + down_sample_res_block = DownSampleResidualBlock( + channels=input_dim, + filters=input_dim, + down_sample_rate=2, + kernel_size=kernel_size, + down_sample_kernel_size=down_sample_kernel_size, + activation=activation + ) + self.down_sample_res_blocks.append(down_sample_res_block) + input_dim = filters + + self.post_conv = Conv1dNorm( + in_channels=input_dim, + out_channels=out_dim, + kernel_size=kernel_size + ) + + def remove_weight_norm(self): + self.encoder.remove_weight_norm() + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T_audio'), AudioSignal()), + "audio_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), + "encoded_len": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward(self, audio, audio_len): + encoded, encoded_len = self.pre_spec_processor(audio=audio, audio_len=audio_len) + encoded = self.pre_conv(inputs=encoded, input_len=encoded_len) + encoded = self.pre_res_block(inputs=encoded, input_len=encoded_len) + + for stft_res_block in self.stft_res_blocks: + encoded, encoded_len = stft_res_block( + inputs=encoded, input_len=encoded_len, audio=audio, audio_len=audio_len + ) + + for down_sample_res_block in self.down_sample_res_blocks: + encoded, encoded_len = down_sample_res_block(inputs=encoded, input_len=encoded_len) + + encoded = self.post_conv(inputs=encoded, input_len=encoded_len) + + return encoded, encoded_len \ No newline at end of file diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index a4c65f9ed0e5..1379fa169789 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -48,6 +48,7 @@ import librosa import matplotlib.pylab as plt import numpy as np +import seaborn as sns import torch from einops import rearrange from numba import jit, prange @@ -468,6 +469,74 @@ def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vm return data +def plot_alignment_to_numpy_for_speechllm( + alignment, + title='', + info=None, + phoneme_seq=None, + vmin=None, + vmax=None, + phoneme_ver=0, + phone_offset=2, + h_offset=True, +): + alignment = np.clip(alignment, a_min=0, a_max=None) + fig, ax = plt.subplots(figsize=(8, 6)) + im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax) + ax.set_title(title) + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + + if phoneme_seq is not None: + if phoneme_ver == 0: + # for debugging of phonemes and durs in maps. Not used by def in training code + ax.set_yticks(np.arange(len(phoneme_seq))) + ax.set_yticklabels(phoneme_seq) + ax.hlines(np.arange(len(phoneme_seq)), xmin=0.0, xmax=max(ax.get_xticks())) + elif phoneme_ver == 1: + yticks = ax.get_yticks() + new_yticks = [] + for tick in yticks: + if tick < 0 or tick > alignment.shape[0]: + continue + new_yticks.append(tick) + new_yticks += phoneme_seq + ax.set_yticks(new_yticks) + elif phoneme_ver == 2: + phones = phoneme_seq[phone_offset:] + ax.set_yticks(np.arange(len(phones))) + ax.set_yticklabels(phones) + ax.hlines(np.arange(0.5, len(phones) - 0.5, 1.0), xmin=0.0, xmax=alignment.shape[1] - 0.5, colors="black") + + if h_offset: + xticks = ax.get_xticks() + new_xticks = [] + for tick in xticks: + new_xticks.append(f"{tick+phoneme_seq[1]:.0f}") + ax.set_xticklabels(new_xticks) + + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def plot_codec_to_numpy(codes, title=''): + fig, ax = plt.subplots(figsize=(10, 3)) + sns.heatmap(codes, ax=ax) + + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + def plot_pitch_to_numpy(pitch, ylim_range=None): fig, ax = plt.subplots(figsize=(12, 3)) plt.plot(pitch) diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index 5f1185c2c399..8dd8b8ab11e4 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -88,10 +88,11 @@ class BetaBinomialInterpolator: The implementation is taken from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py """ - def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500): + def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500, scaling_factor: float = 1.0): self.round_mel_len_to = round_mel_len_to self.round_text_len_to = round_text_len_to - self.bank = functools.lru_cache(maxsize=cache_size)(beta_binomial_prior_distribution) + cached_func = lambda x, y: beta_binomial_prior_distribution(x, y, scaling_factor=scaling_factor) + self.bank = functools.lru_cache(maxsize=cache_size)(cached_func) @staticmethod def round(val, to): diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index b512bc57cbab..8cd76a5ccf5c 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -51,6 +51,7 @@ from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams from nemo.utils.mcore_logger import add_handlers_to_mcore_logger from nemo.utils.model_utils import uninject_model_parallel_rank +from nemo.utils.timers import NeMoTimerException try: # `ptl_resiliency` is included in `gwe_resiliency_pkg` package @@ -260,7 +261,12 @@ def _on_batch_start(self, name): self.timer.start(name) def _on_batch_end(self, name, pl_module): - self.timer.stop(name) + try: + self.timer.stop(name) + except NeMoTimerException as e: + # Skip the error + pass + # Set the `batch_size=1` as WAR for `dataloader_iter`, which is not used for any metric pl_module.log( name + ' in s', @@ -864,12 +870,13 @@ def check_resume( trainer.ckpt_path = str(checkpoint) logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') + trainer.strategy.barrier() if is_global_rank_zero(): # Check to see if any files exist that need to be moved files_to_move = [] if Path(log_dir).exists(): for child in Path(log_dir).iterdir(): - if child.is_file(): + if child.is_file() and not child.name.startswith("events.out.tfevents"): files_to_move.append(child) if len(files_to_move) > 0: @@ -990,7 +997,7 @@ def get_log_dir( os.environ[NEMO_ENV_VARNAME_VERSION] = "" if version is None else version log_dir = Path(_exp_dir) / Path(str(name)) / Path("" if version is None else str(version)) - return log_dir, str(_exp_dir), name, version + return log_dir, str(_exp_dir), name, "" if version is None else str(version) def get_git_hash(): diff --git a/nemo/utils/timers.py b/nemo/utils/timers.py index a35c257652b9..3c1ebbf1db5e 100644 --- a/nemo/utils/timers.py +++ b/nemo/utils/timers.py @@ -21,9 +21,15 @@ import numpy as np import torch +from nemo.utils.exceptions import NeMoBaseException + __all__ = ["NamedTimer", "SimpleTimer"] +class NeMoTimerException(NeMoBaseException, RuntimeError): + pass + + class NamedTimer(object): """ A timer class that supports multiple named timers. @@ -90,7 +96,7 @@ def start(self, name=""): timer_data = self.timers.get(name, {}) if "start" in timer_data: - raise RuntimeError(f"Cannot start timer = '{name}' since it is already active") + raise NeMoTimerException(f"Cannot start timer = '{name}' since it is already active") # synchronize pytorch cuda execution if supported if self._sync_cuda and torch.cuda.is_initialized(): @@ -109,7 +115,7 @@ def stop(self, name=""): """ timer_data = self.timers.get(name, None) if (timer_data is None) or ("start" not in timer_data): - raise RuntimeError(f"Cannot end timer = '{name}' since it is not active") + raise NeMoTimerException(f"Cannot end timer = '{name}' since it is not active") # synchronize pytorch cuda execution if supported if self._sync_cuda and torch.cuda.is_initialized(): diff --git a/requirements/requirements_tts.txt b/requirements/requirements_tts.txt index 0d499feb3b1f..6d20e0f2250f 100644 --- a/requirements/requirements_tts.txt +++ b/requirements/requirements_tts.txt @@ -11,3 +11,5 @@ nltk pandas pypinyin pypinyin-dict +seaborn + diff --git a/scripts/speechllm_multitask_dataprep.py b/scripts/speechllm_multitask_dataprep.py new file mode 100644 index 000000000000..4859ddc896bf --- /dev/null +++ b/scripts/speechllm_multitask_dataprep.py @@ -0,0 +1,783 @@ +import argparse +import copy +import json +import math +import os +import random +import time +from pathlib import Path + +import numpy as np +import soundfile as sf +import torch +import torchaudio +from encodec import EncodecModel +from omegaconf import OmegaConf +from tqdm import tqdm + +from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation, WhiteNoisePerturbation +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.transformer import mask_from_lens +from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir +from nemo.core.classes import Dataset +from nemo.utils import logging + +try: + from models.soundstream import SoundStream +except: + logging.warning("SoundStream not found, uniaudio cannot be used") + +try: + import dac +except: + logging.warning("DAC not found") + + +class AudioDataset(Dataset): + def __init__( + self, + manifest_paths, + min_duration=0.0, + max_duration=22.0, + sample_rate=24000, + noise_manifest_path=None, + min_snr_db=0, + max_snr_db=5, + max_same_speaker_audios=1, + use_context_as_same_speaker_audio=False, + pad_multiple=320, + audio_type="actual", # actual or noise or silence + ): + self.data = [] + speakerwise_records = {} + for manifest_path in manifest_paths: + with open(manifest_path, "r") as f: + for line in f: + record = json.loads(line) + if 'answer_duration' not in record: + record['answer_duration'] = record['duration'] + + if isinstance(record['speaker'], str) and 'mls_english_' in record['speaker']: + record['speaker'] = record['speaker'].replace('mls_english_', '') + record['speaker'] = int(record['speaker']) + + if record['answer_duration'] < min_duration or record['answer_duration'] > max_duration: + continue + + if ('context_duration' in record) and ( + record['context_duration'] < min_duration or record['context_duration'] > max_duration + ): + continue + + if self._is_record_valid(record): + self.data.append(record) + if record['speaker'] not in speakerwise_records: + speakerwise_records[record['speaker']] = [] + speakerwise_records[record['speaker']].append(record) + + self.speakerwise_records = speakerwise_records + self.speaker_list = list(self.speakerwise_records.keys()) + + self.sample_rate = sample_rate + self.audio_type = audio_type + + # TODO: Using White Noise Perturbation right now (dont have noise manifest) + + # self.noise_perturber = NoisePerturbation( + # manifest_path=noise_manifest_path, + # min_snr_db=min_snr_db, + # max_snr_db=max_snr_db, + # ) + + self.noise_perturber = WhiteNoisePerturbation() + + self.max_same_speaker_audios = max_same_speaker_audios + + # If True, use the 'context' key as the same speaker reference audio, + # otherwise randomly choose from the same speaker audios + + self.use_context_as_same_speaker_audio = use_context_as_same_speaker_audio + self.pad_multiple = pad_multiple + + if self.use_context_as_same_speaker_audio: + logging.info("Using context as same speaker audio") + self.add_context_records_to_manifest() + + self.base_data_dir = get_base_dir([item["audio_filepath"] for item in self.data]) + # self.filter_invalid_records() + # if sup_data_dir is not None: + # self.sup_data_dir = sup_data_dir + # else: + # self.sup_data_dir = os.path.join(self.base_data_dir, "sup_data") + # if not os.path.exists(self.sup_data_dir): + # os.makedirs(self.sup_data_dir) + + def _is_record_valid(self, record): + return True + try: + sf.read(record["audio_filepath"]) + # sf.read(record["context"]) + return True + except: + print("Skipping invalid record", record["audio_filepath"]) + return False + + def filter_invalid_records(self): + filtered_data = [] + for ridx, record in enumerate(self.data): + if ridx % 1000 == 0: + print("Filtering", ridx, "of", len(self.data)) + try: + sf.read(record["audio_filepath"]) + sf.read(record["context"]) + except: + print("Skipping invalid record", record["audio_filepath"]) + continue + filtered_data.append(record) + print("Original data size", len(self.data)) + print("Filtered data size", len(filtered_data)) + self.data = filtered_data + + def add_context_records_to_manifest(self): + # Add dummy records with audio_filepath as context + # to ensure all context file paths have their codes extracted and saved. + context_paths = {} + target_paths = {} + + for record in self.data: + if 'context' in record: + if 'context_duration' not in record: + # Get duration from the context audio file + record['context_duration'] = float(sf.info(record['context']).duration) + + context_paths[record['context']] = { + 'speaker': record['speaker'], + 'duration': record['context_duration'], + } + if 'answer' in record: + target_paths[record['audio_filepath']] = True + + for context_path in context_paths: + if context_path not in target_paths: + self.data.append( + { + "audio_filepath": context_path, + "context": context_path, + "duration": context_paths[context_path]['duration'], + "answer_duration": context_paths[context_path]['duration'], + "context_duration": context_paths[context_path]['duration'], + "text": "", # Indicates that this is a dummy record + "question": "", + "speaker": context_paths[context_path]['speaker'], + } + ) + + def __len__(self): + return len(self.data) + + def _get_wav_from_filepath(self, audio_filepath, perturb=False): + if self.audio_type == "noise" or self.audio_type == "silence": + # Create a 6 second noise audio + if self.audio_type == "noise": + audio_samples = np.random.normal(0, 1, 6 * self.sample_rate) + else: + audio_samples = np.zeros(6 * self.sample_rate) + audio = torch.tensor(audio_samples).float() + audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) + audio_length = torch.tensor(audio.size(0)).long() + + perturbed_audio = None + perturbed_audio_length = None + if perturb: + perturbed_audio = audio * 1.0 + perturbed_audio_length = (audio_length * 1.0).long() + + return audio, audio_length, perturbed_audio, perturbed_audio_length + elif self.audio_type == "actual": + features = AudioSegment.segment_from_file( + audio_filepath, target_sr=self.sample_rate, n_segments=-1, trim=False, + ) + audio_samples = features.samples + audio = torch.tensor(audio_samples) + audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) + audio_length = torch.tensor(audio.size(0)).long() + + perturbed_audio = None + perturbed_audio_length = None + if perturb: + features_copy = copy.deepcopy(features) + self.noise_perturber.perturb(features_copy) + perturbed_audio_samples = features_copy.samples + perturbed_audio = torch.tensor(perturbed_audio_samples) + perturbed_audio = torch.nn.functional.pad( + perturbed_audio, (0, self.pad_multiple - perturbed_audio.size(0) % self.pad_multiple), value=0 + ) + perturbed_audio_length = torch.tensor(perturbed_audio.size(0)).long() + # import ipdb; ipdb.set_trace() + + return audio, audio_length, perturbed_audio, perturbed_audio_length + + else: + raise ValueError("Unknown audio type {}".format(self.audio_type)) + + def pad_collate_fn(self, batch): + final_batch = {} + for row in batch: + for key in row: + if key not in final_batch: + final_batch[key] = [] + final_batch[key].append(row[key]) + + max_audio_len = max([_audio_len.item() for _audio_len in final_batch["audio_len"]]) + + audios_padded = [] + for audio in final_batch["audio"]: + audio_padded = torch.nn.functional.pad(audio, (0, max_audio_len - audio.size(0)), value=0) + audios_padded.append(audio_padded) + + final_batch["audio"] = audios_padded + + perturbed_audios_padded = [] + max_perturbed_audio_len = max([_audio_len.item() for _audio_len in final_batch["perturbed_audio_len"]]) + for audio in final_batch["perturbed_audio"]: + audio_padded = torch.nn.functional.pad(audio, (0, max_perturbed_audio_len - audio.size(0)), value=0) + perturbed_audios_padded.append(audio_padded) + + final_batch["perturbed_audio"] = perturbed_audios_padded + + mixed_audios_padded = [] + max_mixed_audio_len = max([_audio_len.item() for _audio_len in final_batch["mixed_audio_len"]]) + for audio in final_batch["mixed_audio"]: + audio_padded = torch.nn.functional.pad(audio, (0, max_mixed_audio_len - audio.size(0)), value=0) + mixed_audios_padded.append(audio_padded) + + final_batch["mixed_audio"] = mixed_audios_padded + + non_tensor_keys = [ + "audio_filepath", + "question", + "text", + "context", + "old_speaker_id", + "duration", + "context_duration", + "rel_audio_path_as_text_id", + "samespeaker_audioids", + "samespeaker_wavpaths", + "speaker" + ] + + for key in final_batch: + if key not in non_tensor_keys: + final_batch[key] = torch.stack(final_batch[key]) + + return final_batch + + def __getitem__(self, index): + sample = self.data[index] + rel_audio_path = Path(sample["audio_filepath"]).relative_to(self.base_data_dir).with_suffix("") + rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") + # speaker = torch.tensor(sample["speaker"]).long() + speaker = sample['speaker'] + + # Avoid fixed seed + random.seed(time.time()) + alternate_speaker = random.choice(self.speaker_list) + _ctr = 0 + while (alternate_speaker == speaker) and (_ctr < 10): + random.seed(time.time()) + alternate_speaker = random.choice(self.speaker_list) + _ctr += 1 + + random.seed(time.time()) + alternate_wavpath = random.choice(self.speakerwise_records[alternate_speaker])["audio_filepath"] + + if not self.use_context_as_same_speaker_audio: + random.shuffle(self.speakerwise_records[sample["speaker"]]) + samespeaker_wavpaths = [] + context_duration = 0.0 + for _record in self.speakerwise_records[sample["speaker"]][: self.max_same_speaker_audios]: + if _record["audio_filepath"] != sample["audio_filepath"]: + samespeaker_wavpath = _record["audio_filepath"] + samespeaker_wavpaths.append(samespeaker_wavpath) + context_duration += _record["answer_duration"] + + if len(samespeaker_wavpaths) == 0: + # Use the same audio if no other audio is available from the same speaker + samespeaker_wavpaths = [sample["audio_filepath"]] + context_duration = sample["answer_duration"] + else: + samespeaker_wavpaths = [sample["context"]] + context_duration = sample["context_duration"] + + samespeaker_audioids = [] + for samespeaker_wavpath in samespeaker_wavpaths: + samespeaker_rel_audio_path = Path(samespeaker_wavpath).relative_to(self.base_data_dir).with_suffix("") + samespeaker_rel_audio_path_as_text_id = str(samespeaker_rel_audio_path).replace("/", "_") + samespeaker_audioids.append(samespeaker_rel_audio_path_as_text_id) + + alternate_audio, alternate_audio_length, _, _ = self._get_wav_from_filepath(alternate_wavpath, perturb=False) + audio, audio_length, perturbed_audio, perturbed_audio_length = self._get_wav_from_filepath( + sample["audio_filepath"], perturb=True + ) + + # Mix audio and alternate audio + if audio_length > alternate_audio_length: + # Repeat alternate audio + alternate_audio = alternate_audio.repeat(audio_length // alternate_audio_length + 1) + alternate_audio = alternate_audio[:audio_length] + mixed_audio = 0.5 * (audio + alternate_audio) + elif audio_length <= alternate_audio_length: + alternate_audio = alternate_audio[:audio_length] + mixed_audio = 0.5 * (audio + alternate_audio) + + mixed_audio_length = audio_length + + if "question" not in sample: + sample['question'] = "Text to speech this " + sample['text'] + + return { + "audio": audio, + "audio_len": audio_length, + "perturbed_audio": perturbed_audio, + "perturbed_audio_len": perturbed_audio_length, + "mixed_audio": mixed_audio, + "mixed_audio_len": mixed_audio_length, + "rel_audio_path_as_text_id": rel_audio_path_as_text_id, + "samespeaker_audioids": samespeaker_audioids, + "samespeaker_wavpaths": samespeaker_wavpaths, + "audio_filepath": sample["audio_filepath"], + "question": sample["question"], + "text": sample["text"], + "context": sample.get("context", None), + "old_speaker_id": sample.get("old_speaker_id", None), + "duration": sample["answer_duration"], + "context_duration": context_duration, + "speaker": speaker, + } + + +def save_batch_audios(batch, bidx, temp_dir, codec_model, codec_model_type='encodec', codec_model_sample_rate=24000): + for sidx in range(batch["audio"].shape[0]): + sample_audio = batch["audio"][sidx] + sample_audio_len = batch["audio_len"][sidx].item() + sample_audio = sample_audio[:sample_audio_len] + + # Save sample_audio + sample_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_sample.wav") + torchaudio.save(sample_audio_path, sample_audio[None].cpu(), codec_model_sample_rate) + + # Save perturbed_audio + perturbed_audio = batch["perturbed_audio"][sidx] + perturbed_audio_len = batch["perturbed_audio_len"][sidx].item() + perturbed_audio = perturbed_audio[:perturbed_audio_len] + perturbed_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_perturbed.wav") + torchaudio.save(perturbed_audio_path, perturbed_audio[None].cpu(), codec_model_sample_rate) + + # Save mixed_audio + mixed_audio = batch["mixed_audio"][sidx] + mixed_audio_len = batch["mixed_audio_len"][sidx].item() + mixed_audio = mixed_audio[:mixed_audio_len] + mixed_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_mixed.wav") + torchaudio.save(mixed_audio_path, mixed_audio[None].cpu(), codec_model_sample_rate) + + with torch.no_grad(): + for key in batch: + if "CODEC" in key: + codec = batch[key][sidx] # (8, T) + if codec_model_type == 'encodec': + codec_decoded_audio = codec_model.decode([[codec.unsqueeze(0), None]])[0][0] + elif codec_model_type == 'uniaudio_codec': + codec_decoded_audio = codec_model.decode(codec.unsqueeze(0))[0][0] + elif codec_model_type == 'dac': + _z = codec_model.quantizer.from_codes(codec.unsqueeze(0))[0] + codec_decoded_audio = codec_model.decoder(_z)[0][0] + elif codec_model_type in ['nemo_codec', 'nemo_codec21', 'nemo_codec211k', 'nemo_codec214k']: + codec_len = torch.Tensor([codec.shape[1]]).long().cuda() + codec_decoded_audio, _ = codec_model.decode(tokens=codec.unsqueeze(0), tokens_len=codec_len) + codec_decoded_audio = codec_decoded_audio[0] + + codec_decoded_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_{key}_decoded.wav") + torchaudio.save(codec_decoded_audio_path, codec_decoded_audio[None].cpu(), codec_model_sample_rate) + + +def estimate_duration_from_codeclen(codec_len, codec_downsampling_factor=320.0, codec_model_sample_rate=24000): + num_audio_samples = codec_len * codec_downsampling_factor + duration = num_audio_samples / codec_model_sample_rate + return round(duration, 2) + + +def save_manifest(records, manifest_path): + with open(manifest_path, "w") as f: + file_str = "" + for record in records: + file_str += json.dumps(record) + "\n" + file_str = file_str.strip() + f.write(file_str) + print("Saved manifest to {}".format(manifest_path)) + + +def main(): + parser = argparse.ArgumentParser(description='Create multiple tasks') + parser.add_argument("--noise_manifest", type=str, default="/datap/misc/noisedata/train_manifest.json") + parser.add_argument( + '--manifest_paths', + type=str, + default="/Data/manifests_libri_local/train_clean_300_speechlm_ttstasks_with3sec_ref_all_random.json", + ) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--out_dir', type=str, default='/Data/CodecDatasets/speechllm_codecdatasets/') + parser.add_argument('--dataset_name', type=str, default='LibriTTSCorrectContext_train') + parser.add_argument('--codec_model_path', type=str, default='/Data/Checkpoints/rlang_codec/SpeechCodec.nemo') + parser.add_argument('--codec_bw', type=float, default=6.0) # 6 for 8 codebooks, 1.5 for 3 codebooks + parser.add_argument('--codec_model', type=str, default='nemo_codec') # encodec, uniaudio_codec, dac, nemo_codec, nemo_codec21, nemo_codec211k, nemo_codec214k + parser.add_argument('--use_context_as_same_speaker_audio', action='store_true') + parser.add_argument('--save_only_tts_records', action='store_true') + parser.add_argument('--shuffle', action='store_true') + parser.add_argument('--split_into_train_val', action='store_true') + parser.add_argument('--num_val_records', type=int, default=500) + parser.add_argument('--audio_type', type=str, default='actual') # actual, noise or silence + args = parser.parse_args() + + if args.codec_model == 'encodec': + codec_model = EncodecModel.encodec_model_24khz() + codec_model.set_target_bandwidth(6.0) + codec_model.cuda() + codec_model.eval() + codec_model_sample_rate = 24000 + codec_model_downsampling_factor = 320.0 + elif args.codec_model == 'uniaudio_codec': + codec_config_path = os.path.join(os.path.dirname(args.codec_model_path), 'config.yaml') + codec_config = OmegaConf.load(codec_config_path) + codec_model = eval(codec_config.generator.name)(**codec_config.generator.config) + codec_parameter_dict = torch.load(args.codec_model_path) + codec_model.load_state_dict(codec_parameter_dict['codec_model']) # load model + codec_model = codec_model.cuda() + # codec_model.eval() + codec_model_sample_rate = 16000 + codec_model_downsampling_factor = 320.0 + elif args.codec_model == 'dac': + model_path = args.codec_model_path + codec_model = dac.DAC.load(model_path) + codec_model.to('cuda') + codec_model_sample_rate = 44100 + codec_model_downsampling_factor = 512.0 + elif args.codec_model == 'nemo_codec': + model_path = args.codec_model_path + codec_model = AudioCodecModel.restore_from(model_path) + codec_model.to('cuda') + codec_model.eval() + codec_model_sample_rate = 22050 + codec_model_downsampling_factor = 256.0 + elif args.codec_model in ['nemo_codec21', 'nemo_codec211k', 'nemo_codec214k']: + model_path = args.codec_model_path + codec_model = AudioCodecModel.restore_from(model_path) + codec_model.to('cuda') + codec_model.eval() + codec_model_sample_rate = 22050 + codec_model_downsampling_factor = 1024.0 + else: + raise ValueError("Unknown codec model {}".format(args.codec_model)) + + dataset = AudioDataset( + manifest_paths=[args.manifest_paths], + sample_rate=codec_model_sample_rate, + noise_manifest_path=args.noise_manifest, + use_context_as_same_speaker_audio=args.use_context_as_same_speaker_audio, + pad_multiple=int(codec_model_downsampling_factor), + audio_type=args.audio_type, + ) + + dataloader = torch.utils.data.DataLoader( + dataset=dataset, batch_size=args.batch_size, collate_fn=dataset.pad_collate_fn, shuffle=False, num_workers=8, + ) + + _exp_name = "{}_{}_bw_{}".format(args.dataset_name, args.codec_model, args.codec_bw) + temp_dir = os.path.join(args.out_dir, "temp_{}".format(_exp_name)) + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + codec_base_dir = os.path.join(args.out_dir, "codecs") + manifest_dir = os.path.join(args.out_dir, "manifests") + + audiocodec_out_dir = os.path.join(codec_base_dir, _exp_name) + + if not os.path.exists(audiocodec_out_dir): + os.makedirs(audiocodec_out_dir) + + if not os.path.exists(manifest_dir): + os.makedirs(manifest_dir) + + all_tasks_records = [] + phoneme_tts_records = [] + sentencepiece_tts_records = [] + phoneme_plus_sentencepiece_tts_records = [] + + for bidx, batch in enumerate(tqdm(dataloader)): + # print("bidx", bidx+1, "of", len(dataloader)) + + audio_len_mask = mask_from_lens(batch["audio_len"]) + + cuda_keys = ['audio', 'perturbed_audio', 'mixed_audio', 'audio_len', 'perturbed_audio_len', 'mixed_audio_len'] + for key in cuda_keys: + batch[key] = batch[key].cuda() + with torch.no_grad(): + if args.codec_model == 'encodec': + original_codec_codes = codec_model.encode(batch["audio"].unsqueeze(1))[0][0] + if not args.save_only_tts_records: + perturbed_codec_codes = codec_model.encode(batch["perturbed_audio"].unsqueeze(1))[0][0] + mixed_codec_codes = codec_model.encode(batch["mixed_audio"].unsqueeze(1))[0][0] + elif args.codec_model == 'uniaudio_codec': + original_codec_codes = codec_model.encode( + batch["audio"].unsqueeze(1) * codec_config.audio_norm_scale, target_bw=args.codec_bw + ).permute(1, 0, 2) + if not args.save_only_tts_records: + perturbed_codec_codes = codec_model.encode( + batch["perturbed_audio"].unsqueeze(1) * codec_config.audio_norm_scale, target_bw=args.codec_bw + ).permute(1, 0, 2) + mixed_codec_codes = codec_model.encode( + batch["mixed_audio"].unsqueeze(1) * codec_config.audio_norm_scale, target_bw=args.codec_bw + ).permute(1, 0, 2) + elif args.codec_model == 'dac': + # z, codes, latents, _, _ = model.encode(x) + _, original_codec_codes, _, _, _ = codec_model.encode(batch["audio"].unsqueeze(1)) + if not args.save_only_tts_records: + _, perturbed_codec_codes, _, _, _ = codec_model.encode(batch["perturbed_audio"].unsqueeze(1)) + _, mixed_codec_codes, _, _, _ = codec_model.encode(batch["mixed_audio"].unsqueeze(1)) + elif args.codec_model in ['nemo_codec', 'nemo_codec21', 'nemo_codec211k', 'nemo_codec214k']: + original_codec_codes, _ = codec_model.encode(audio=batch["audio"], audio_len=batch["audio_len"]) + if not args.save_only_tts_records: + perturbed_codec_codes, _ = codec_model.encode( + audio=batch["perturbed_audio"], audio_len=batch["perturbed_audio_len"] + ) + mixed_codec_codes, _ = codec_model.encode( + audio=batch["mixed_audio"], audio_len=batch["mixed_audio_len"] + ) + else: + raise ValueError("Unknown codec model {}".format(args.codec_model)) + + if args.save_only_tts_records: + perturbed_codec_codes = original_codec_codes # Dummy values to not break the code + mixed_codec_codes = original_codec_codes # Dummy values to not break the code + + # codec_codes = transformer_encodec_model.encode(batch["audio"].unsqueeze(1), audio_len_mask, bandwidth=6.0) + target_codecs = [] + mixed_codecs = [] + perturbed_codecs = [] + for sidx in range(batch['audio'].shape[0]): + + codec_len = math.ceil(batch['audio_len'][sidx].item() / codec_model_downsampling_factor) + sample_codec_codes = original_codec_codes[sidx][:, :codec_len] + target_codecs.append(sample_codec_codes) + + perturbed_codec_len = math.ceil( + batch['perturbed_audio_len'][sidx].item() / codec_model_downsampling_factor + ) + perturbed_sample_codec_codes = perturbed_codec_codes[sidx][:, :perturbed_codec_len] + perturbed_codecs.append(perturbed_sample_codec_codes) + + mixed_codec_len = math.ceil(batch['mixed_audio_len'][sidx].item() / codec_model_downsampling_factor) + mixed_sample_codec_codes = mixed_codec_codes[sidx][:, :mixed_codec_len] + mixed_codecs.append(mixed_sample_codec_codes) + + example_name = batch['rel_audio_path_as_text_id'][sidx] + + target_codec_filepath = os.path.join(audiocodec_out_dir, "target_codes_{}.pt".format(example_name)) + torch.save(sample_codec_codes.cpu().type(torch.int16), target_codec_filepath) + + if batch['text'][sidx] == "": + # Only save target codes for dummy records + # Don't need to add dummy records to manifest + continue + + perturbed_codec_filepath = os.path.join(audiocodec_out_dir, "perturbed_codes_{}.pt".format(example_name)) + mixed_codec_filepath = os.path.join(audiocodec_out_dir, "mixed_codes_{}.pt".format(example_name)) + if not args.save_only_tts_records: + torch.save(perturbed_sample_codec_codes.cpu().type(torch.int16), perturbed_codec_filepath) + torch.save(mixed_sample_codec_codes.cpu().type(torch.int16), mixed_codec_filepath) + + tts_contextpath = "" + for samespeaker_audioid in batch['samespeaker_audioids'][sidx]: + tts_contextpath += os.path.join(audiocodec_out_dir, "target_codes_{}.pt".format(samespeaker_audioid)) + tts_contextpath += ";" + tts_contextpath = tts_contextpath[:-1] + + tts_record = { + "audio_filepath": batch['audio_filepath'][sidx], + "text": batch['text'][sidx], + "question": batch['question'][sidx].replace("Phoneme TTS", "Text to speech this"), + "answer": target_codec_filepath, + "context": tts_contextpath, + "question_type": "TEXT", + "answer_type": "AUDIOCODEC", + "context_type": "REFSPEAKERCODEC", + "context_duration": batch['context_duration'][sidx], + "answer_duration": batch['duration'][sidx], + "taskname": "squad", + "speaker": batch['speaker'][sidx].item() if torch.is_tensor(batch['speaker'][sidx]) else batch['speaker'][sidx], + } + + phoneme_tts_record = {key: value for key, value in tts_record.items()} + phoneme_tts_record["question"] = phoneme_tts_record["question"].replace( + "Text to speech this", "Phoneme TTS" + ) + + speechenhancement_record = { + "audio_filepath": batch['audio_filepath'][sidx], + "text": batch['text'][sidx], + "question": "Remove Noise", + "answer": target_codec_filepath, + "context": perturbed_codec_filepath, + "question_type": "TEXT", + "answer_type": "AUDIOCODEC", + "context_type": "AUDIOCODEC", + "context_duration": estimate_duration_from_codeclen( + perturbed_codec_len, codec_model_downsampling_factor, codec_model_sample_rate + ), + "answer_duration": batch['duration'][sidx], + "taskname": "squad", + } + + speechseparation_record = { + "audio_filepath": batch['audio_filepath'][sidx], + "text": batch['text'][sidx], + "question": "Extract Speaker Audio", + "answer": target_codec_filepath, + "context": "{},{}".format(mixed_codec_filepath, tts_contextpath), + "question_type": "TEXT", + "answer_type": "AUDIOCODEC", + "context_type": "SEPARATIONCODECS", + "context_duration": estimate_duration_from_codeclen( + mixed_codec_len, codec_model_downsampling_factor, codec_model_sample_rate + ), + "answer_duration": batch['duration'][sidx], + "taskname": "squad", + } + + speechediting_record = { + "audio_filepath": batch['audio_filepath'][sidx], + "text": batch['text'][sidx], + "question": batch['question'][sidx].replace("Text to speech this", "Edit Speech"), + "answer": target_codec_filepath, + "context": target_codec_filepath, + "question_type": "TEXT", + "answer_type": "AUDIOCODEC", + "context_type": "EDITINGCODECS", + "context_duration": batch['duration'][sidx] + 3, # 3 sec for speaker context + "answer_duration": batch['duration'][sidx], + "taskname": "squad", + } + + phoneme_tts_records.append(phoneme_tts_record) + sentencepiece_tts_records.append(tts_record) + + phoneme_plus_sentencepiece_tts_records.append(phoneme_tts_record) + phoneme_plus_sentencepiece_tts_records.append(tts_record) + + all_tasks_records.append(tts_record) + all_tasks_records.append(phoneme_tts_record) + all_tasks_records.append(speechenhancement_record) + all_tasks_records.append(speechseparation_record) + all_tasks_records.append(speechediting_record) + + batch['target_CODEC'] = target_codecs + batch['perturbed_CODEC'] = perturbed_codecs + batch['mixed_CODEC'] = mixed_codecs + + if bidx == 0: + save_batch_audios(batch, bidx, temp_dir, codec_model, args.codec_model, codec_model_sample_rate) + + if args.shuffle: + # To ensure same split for encodec and uniaudio_codec + random.seed(21) + random.shuffle(all_tasks_records) + random.shuffle(phoneme_tts_records) + random.shuffle(sentencepiece_tts_records) + random.shuffle(phoneme_plus_sentencepiece_tts_records) + + if args.split_into_train_val: + # Shuffle compulsory for splitting into train and val + # To ensure same split for encodec and uniaudio_codec + random.seed(21) + random.shuffle(all_tasks_records) + random.shuffle(phoneme_tts_records) + random.shuffle(sentencepiece_tts_records) + # random.shuffle(phoneme_plus_sentencepiece_tts_records) + phoneme_plus_sentencepiece_tts_records = [] + for idx in range(len(phoneme_tts_records)): + phoneme_plus_sentencepiece_tts_records.append(phoneme_tts_records[idx]) + phoneme_plus_sentencepiece_tts_records.append(sentencepiece_tts_records[idx]) + + num_val_records = args.num_val_records + train_phoneme_tts_records = phoneme_tts_records[num_val_records:] + val_phoneme_tts_records = phoneme_tts_records[:num_val_records] + + train_sentencepiece_tts_records = sentencepiece_tts_records[num_val_records:] + val_sentencepiece_tts_records = sentencepiece_tts_records[:num_val_records] + + train_phoneme_plus_sentencepiece_tts_records = phoneme_plus_sentencepiece_tts_records[num_val_records:] + val_phoneme_plus_sentencepiece_tts_records = phoneme_plus_sentencepiece_tts_records[:num_val_records] + # Shuffle train mixed records + random.shuffle(train_phoneme_plus_sentencepiece_tts_records) + + train_all_tasks_records = all_tasks_records[num_val_records:] + val_all_tasks_records = all_tasks_records[:num_val_records] + + manifest_base_name = _exp_name + phoneme_tts_train_manifest_path = os.path.join( + manifest_dir, "{}_train_phoneme_tts.json".format(manifest_base_name) + ) + phoneme_tts_val_manifest_path = os.path.join( + manifest_dir, "{}_val_phoneme_tts.json".format(manifest_base_name) + ) + save_manifest(train_phoneme_tts_records, phoneme_tts_train_manifest_path) + save_manifest(val_phoneme_tts_records, phoneme_tts_val_manifest_path) + + sentencepiece_tts_train_manifest_path = os.path.join( + manifest_dir, "{}_train_sentencepiece_tts.json".format(manifest_base_name) + ) + sentencepiece_tts_val_manifest_path = os.path.join( + manifest_dir, "{}_val_sentencepiece_tts.json".format(manifest_base_name) + ) + save_manifest(train_sentencepiece_tts_records, sentencepiece_tts_train_manifest_path) + save_manifest(val_sentencepiece_tts_records, sentencepiece_tts_val_manifest_path) + + sp_plus_phoneme_tts_train_manifest_path = os.path.join( + manifest_dir, "{}_train_phoneme_plus_sentencepiece_tts.json".format(manifest_base_name) + ) + sp_plus_phoneme_tts_val_manifest_path = os.path.join( + manifest_dir, "{}_val_phoneme_plus_sentencepiece_tts.json".format(manifest_base_name) + ) + save_manifest(train_phoneme_plus_sentencepiece_tts_records, sp_plus_phoneme_tts_train_manifest_path) + save_manifest(val_phoneme_plus_sentencepiece_tts_records, sp_plus_phoneme_tts_val_manifest_path) + + if not args.save_only_tts_records: + all_tasks_train_manifest_path = os.path.join( + manifest_dir, "{}_train_all_tasks.json".format(args.dataset_name) + ) + all_tasks_val_manifest_path = os.path.join(manifest_dir, "{}_val_all_tasks.json".format(args.dataset_name)) + save_manifest(train_all_tasks_records, all_tasks_train_manifest_path) + save_manifest(val_all_tasks_records, all_tasks_val_manifest_path) + else: + manifest_base_name = _exp_name + phoneme_tts_manifest_path = os.path.join(manifest_dir, "{}_phoneme_tts.json".format(manifest_base_name)) + save_manifest(phoneme_tts_records, phoneme_tts_manifest_path) + + sentencepiece_tts_manifest_path = os.path.join( + manifest_dir, "{}_sentencepiece_tts.json".format(manifest_base_name) + ) + save_manifest(sentencepiece_tts_records, sentencepiece_tts_manifest_path) + + phoneme_plus_sentencepiece_tts_manifest_path = os.path.join( + manifest_dir, "{}_phoneme_plus_sentencepiece_tts.json".format(manifest_base_name) + ) + save_manifest(phoneme_plus_sentencepiece_tts_records, phoneme_plus_sentencepiece_tts_manifest_path) + + if not args.save_only_tts_records: + all_manifest_path = os.path.join(manifest_dir, "{}_all_tasks.json".format(args.dataset_name)) + save_manifest(all_tasks_records, all_manifest_path) + + +if __name__ == '__main__': + main() \ No newline at end of file From c9e8839e8fcc30dfc7fa775660f996c6ae4e7f9c Mon Sep 17 00:00:00 2001 From: blisc Date: Wed, 6 Nov 2024 20:32:29 +0000 Subject: [PATCH 02/18] Apply isort and black reformatting Signed-off-by: blisc --- .../tts/speechllm/megatron_t5_speechllm.py | 4 +- .../megatron/base_prompt_learning_dataset.py | 8 +- .../nlp/modules/common/megatron/attention.py | 18 +- .../megatron/megatron_encoder_decoder.py | 25 +- .../common/megatron/megatron_encoders.py | 5 +- .../megatron/megatron_transformer_decoder.py | 29 +- .../megatron/megatron_transformer_encoder.py | 23 +- .../nlp/modules/common/megatron/module.py | 5 +- .../megatron/token_level_encoder_decoder.py | 52 +- .../modules/common/megatron/transformer.py | 11 +- .../data/speechllm/t5_speechllm_dataset.py | 152 ++++-- .../speechllm/t5_speechllm_tarred_dataset.py | 76 +-- .../megatron_base_speechllm_prompt_model.py | 3 +- .../speechllm/megatron_t5_speechllm_model.py | 461 ++++++++++++------ .../tts/modules/audio_codec_modules.py | 145 ++---- .../tts/parts/utils/tts_dataset_utils.py | 9 +- nemo/utils/timers.py | 1 + scripts/speechllm_multitask_dataprep.py | 45 +- 18 files changed, 676 insertions(+), 396 deletions(-) diff --git a/examples/tts/speechllm/megatron_t5_speechllm.py b/examples/tts/speechllm/megatron_t5_speechllm.py index 3f0f7a2e76b1..1b438d8c1fc4 100644 --- a/examples/tts/speechllm/megatron_t5_speechllm.py +++ b/examples/tts/speechllm/megatron_t5_speechllm.py @@ -16,9 +16,9 @@ import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict -from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel -from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager diff --git a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py index cb43408478e4..826e139fe6ba 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import omegaconf +import torch from nemo.collections.nlp.modules.common import VirtualPromptSource from nemo.core import Dataset @@ -74,7 +74,7 @@ def __init__( dataset = open(path, 'r', encoding='utf-8') dataset_examples = self.load_data(dataset) self.examples.extend(dataset_examples) - elif (isinstance(datasets[0], omegaconf.ListConfig) or isinstance(datasets[0], list)): + elif isinstance(datasets[0], omegaconf.ListConfig) or isinstance(datasets[0], list): # Dataset is a list of tuples with the first element being the probability of sampling from the dataset # This code repeates the smaller datasets to approximately match the target probabilities total_examples = 0 @@ -106,7 +106,9 @@ def __init__( final_dataset_lengths = [] for dataset_idx in range(len(datasets)): num_samples_required = int(new_total_examples * target_probs[dataset_idx]) - num_repeat = max(int(round(num_samples_required // dataset_lengths[dataset_idx])), 1) # At least 1 repeat + num_repeat = max( + int(round(num_samples_required // dataset_lengths[dataset_idx])), 1 + ) # At least 1 repeat logging.info("dataset idx {}, num_repeat {}".format(dataset_idx, num_repeat)) dataset_examples_repeated = datasets_examples_list[dataset_idx] * num_repeat final_dataset_lengths.append(len(dataset_examples_repeated)) diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index da1005ebbdb8..46da533186c1 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -517,7 +517,9 @@ def forward( # If we are in cross attention (inference_current_sequence_len == inference_max_sequence_len == inference_key_memory.size(0)) # We only need to cache this once if inference_max_sequence_len and self.inference_current_sequence_len < inference_max_sequence_len: - logging.debug(f"inference_current_sequence_len={self.inference_current_sequence_len} | key_layer.shape={key_layer.shape} | inference_key_memory={self.inference_key_memory.size()} | inference_value_memory={self.inference_value_memory.size()}") + logging.debug( + f"inference_current_sequence_len={self.inference_current_sequence_len} | key_layer.shape={key_layer.shape} | inference_key_memory={self.inference_key_memory.size()} | inference_value_memory={self.inference_value_memory.size()}" + ) # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) @@ -954,7 +956,12 @@ def forward( f"not returning scores: attn_type={self.attention_type} | attn_fn={self.attn_fn} | return_scores={return_scores}" ) context_layer = self.attn_fn( - query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode, + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, ) else: # SpeechLLM TTS modifications @@ -977,7 +984,12 @@ def forward( f"attn_fn: {self.attn_fn}, return_scores: {return_scores}, relative_position_bias is not None: {relative_position_bias is not None}" ) context_layer = self.attn_fn( - query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode, + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, ) if headscale_tensor is not None: diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index cc7072be0c40..67753f08775e 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -46,8 +46,7 @@ class MegatronTransformerEncoderDecoderModule(MegatronModule): - """Transformer encoder-decoder model. - """ + """Transformer encoder-decoder model.""" def __init__( self, @@ -144,7 +143,11 @@ def encode( # apply hidden transformations if needed if self.hiddens_module is not None: enc_output = self.hiddens_module.apply_hidden_transforms( - {"hiddens": enc_output, "hiddens_mask": self.get_hiddens_mask(enc_attn_mask),}, batch_data=batch_data, + { + "hiddens": enc_output, + "hiddens_mask": self.get_hiddens_mask(enc_attn_mask), + }, + batch_data=batch_data, ) return enc_output @@ -163,7 +166,7 @@ def decode( set_inference_key_value_memory=False, decoder_max_sequence_len=None, encoder_max_sequence_len=None, - enc_output_to_layers=None + enc_output_to_layers=None, ): if self.decoder is None: raise ValueError(f"Cannot call .decode(...) when self.decoder is None.") @@ -181,7 +184,7 @@ def decode( set_inference_key_value_memory=set_inference_key_value_memory, decoder_max_sequence_len=decoder_max_sequence_len, encoder_max_sequence_len=encoder_max_sequence_len, - enc_output_to_layers=enc_output_to_layers + enc_output_to_layers=enc_output_to_layers, ) return dec_output @@ -207,7 +210,7 @@ def forward( set_inference_key_value_memory=False, decoder_max_sequence_len=None, encoder_max_sequence_len=None, - enc_output_to_layers=None + enc_output_to_layers=None, ): # encoder if enc_output is None: @@ -236,9 +239,11 @@ def forward( dec_output = self.decode( dec_input=dec_input, dec_attn_mask=dec_attn_mask, - enc_output=enc_output["enc_output"] # enc_output is a dict if we used hidden transformations - if self.hiddens_module is not None - else enc_output, + enc_output=( + enc_output["enc_output"] # enc_output is a dict if we used hidden transformations + if self.hiddens_module is not None + else enc_output + ), # Adjust encoder attention mask if encoder is a perceiver. enc_attn_mask=self.get_hiddens_mask(enc_attn_mask), dec_layer_past=dec_layer_past, @@ -249,7 +254,7 @@ def forward( set_inference_key_value_memory=set_inference_key_value_memory, decoder_max_sequence_len=decoder_max_sequence_len, encoder_max_sequence_len=encoder_max_sequence_len, - enc_output_to_layers=enc_output_to_layers + enc_output_to_layers=enc_output_to_layers, ) # if self.hiddens_module is not None enc_output is a dict, else it is a torch.tensor diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py index e0e14e024629..3d2b2c1ecc13 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py @@ -14,7 +14,10 @@ """Transformer based language model.""" from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule -from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import MegatronTransformerEncoderModule, MultiMegatronTransformerEncoderModule +from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import ( + MegatronTransformerEncoderModule, + MultiMegatronTransformerEncoderModule, +) from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( MegatronRetrievalTransformerEncoderModule, ) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py index 3fdff2a7068c..14677552492b 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py @@ -52,8 +52,7 @@ class MegatronTransformerDecoderModule(MegatronModule, Exportable, MegatronDecoderModule): - """Transformer decoder model. - """ + """Transformer decoder model.""" def __init__( self, @@ -166,7 +165,7 @@ def __init__( self._model_key = 'model' def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" self.model.set_input_tensor(input_tensor) def forward( @@ -187,7 +186,9 @@ def forward( ): # convert to Megatron mask dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=dec_attn_mask, attn_mask_type=self.model_attn_mask_type, + source_mask=dec_attn_mask, + target_mask=dec_attn_mask, + attn_mask_type=self.model_attn_mask_type, ) if isinstance(enc_output, list): @@ -195,14 +196,22 @@ def forward( enc_dec_attn_mask_3d = [] for i in range(len(enc_output)): enc_dec_attn_mask_3d.append( - attn_mask_postprocess(build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=enc_attn_mask[i], attn_mask_type=AttnMaskType.padding, - )) + attn_mask_postprocess( + build_attention_mask_3d( + source_mask=dec_attn_mask, + target_mask=enc_attn_mask[i], + attn_mask_type=AttnMaskType.padding, + ) + ) ) else: - enc_dec_attn_mask_3d = attn_mask_postprocess(build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=enc_attn_mask, attn_mask_type=AttnMaskType.padding, - )) + enc_dec_attn_mask_3d = attn_mask_postprocess( + build_attention_mask_3d( + source_mask=dec_attn_mask, + target_mask=enc_attn_mask, + attn_mask_type=AttnMaskType.padding, + ) + ) # transformer decoder dec_output = self.model( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index 67c4d071c279..1f1d962f2c4a 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -13,6 +13,8 @@ # limitations under the License. """Transformer based language model.""" +import torch + from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_encoder_module import MegatronEncoderModule from nemo.collections.nlp.modules.common.megatron.module import MegatronModule @@ -23,7 +25,6 @@ build_attention_mask_3d, ) from nemo.core.classes.exportable import Exportable -import torch try: from apex.transformer.enums import AttnMaskType, ModelType @@ -164,7 +165,7 @@ def __init__( self._model_key = 'model' def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" self.model.set_input_tensor(input_tensor) def forward( @@ -182,7 +183,9 @@ def forward( else: enc_attn_mask_3d = attn_mask_postprocess( build_attention_mask_3d( - source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type, + source_mask=enc_attn_mask, + target_mask=enc_attn_mask, + attn_mask_type=self.model_attn_mask_type, ) ) @@ -352,13 +355,13 @@ def __init__( use_flash_attention=use_flash_attention, ) self.model.append(transformer) - + self.model = torch.nn.ModuleList(self.model) self._model_key = 'model' def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" for mi in range(len(self.model)): self.model[mi].set_input_tensor(input_tensor) @@ -379,7 +382,7 @@ def forward( enc_self_attention_relative_position_bias=None, set_inference_key_value_memory=False, ): - + assert isinstance(enc_input, list) assert len(enc_input) == len(self.model) assert isinstance(enc_attn_mask, list) @@ -391,13 +394,15 @@ def forward( enc_input_ = enc_input[encoder_number] enc_attn_mask_ = enc_attn_mask[encoder_number] enc_self_attention_relative_position_bias_ = enc_self_attention_relative_position_bias[encoder_number] - + if self.use_flash_attention: enc_attn_mask_3d = enc_attn_mask_ < 0.5 else: enc_attn_mask_3d = attn_mask_postprocess( build_attention_mask_3d( - source_mask=enc_attn_mask_, target_mask=enc_attn_mask_, attn_mask_type=self.model_attn_mask_type, + source_mask=enc_attn_mask_, + target_mask=enc_attn_mask_, + attn_mask_type=self.model_attn_mask_type, ) ) @@ -413,7 +418,7 @@ def forward( ) enc_outputs.append(enc_output) - + return enc_outputs def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index c311f63e15de..a4efb2992166 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -140,7 +140,10 @@ def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - vocab_size, hidden_size, init_method=init_method, config=self.config, + vocab_size, + hidden_size, + init_method=init_method, + config=self.config, ) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index b4875a2ffa41..4e73c2615cf3 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -473,7 +473,7 @@ def _validate_config(self): return encoder_kv_channels, decoder_kv_channels def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" # This is usually handled in schedules.py but some inference code still # gives us non-lists or None @@ -570,7 +570,8 @@ def forward( if self.add_encoder and self.encoder_relative_position_embedding is not None: encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( - query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, + query_seq_length=enc_seq_length, + key_seq_length=enc_seq_length, ) if output_enc_hidden_only: @@ -608,8 +609,11 @@ def forward( query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) ) if not self.decoder_cfg.relative_position_bias_self_attention_only: - decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( - query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, + decoder_cross_attention_relative_position_bias = ( + self.decoder_cross_attention_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), + key_seq_length=enc_seq_length, + ) ) else: decoder_cross_attention_relative_position_bias = None @@ -660,7 +664,8 @@ def forward( # check if hiddens is used if self.hiddens_cfg is not None: loss_dict = self.enc_dec_model.hiddens_module.apply_loss_transforms( - outputs=enc_output, batch_data=batch_data, + outputs=enc_output, + batch_data=batch_data, ) loss_dict["tokens_loss"] = tokens_loss # We need to store default output in a known key, so that we can mimic default behaviour @@ -846,7 +851,8 @@ def forward( if self.add_encoder and self.encoder_relative_position_embedding is not None: assert False, "Not implemented for speech models yet." encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( - query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, + query_seq_length=enc_seq_length, + key_seq_length=enc_seq_length, ) if output_enc_hidden_only: @@ -895,8 +901,11 @@ def forward( query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) ) if not self.decoder_cfg.relative_position_bias_self_attention_only: - decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( - query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, + decoder_cross_attention_relative_position_bias = ( + self.decoder_cross_attention_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), + key_seq_length=enc_seq_length, + ) ) else: decoder_cross_attention_relative_position_bias = None @@ -907,7 +916,6 @@ def forward( single_encoder = True cross_attention_prior = [cross_attention_prior] - decoder_cross_attention_relative_position_bias = [] for _cross_attention_prior in cross_attention_prior: _decoder_cross_attention_relative_position_bias = None @@ -929,15 +937,19 @@ def forward( curr_cross_attention_prior = _cross_attention_prior + ( (1.0 - _cross_attention_prior) * curr_annealing_step / total_annealing_steps ) - _decoder_cross_attention_relative_position_bias = curr_cross_attention_prior.unsqueeze(1).repeat( - 1, num_attention_heads, 1, 1 + _decoder_cross_attention_relative_position_bias = curr_cross_attention_prior.unsqueeze( + 1 + ).repeat(1, num_attention_heads, 1, 1) + _decoder_cross_attention_relative_position_bias = torch.log( + _decoder_cross_attention_relative_position_bias + 1e-8 ) - _decoder_cross_attention_relative_position_bias = torch.log(_decoder_cross_attention_relative_position_bias + 1e-8) else: _decoder_cross_attention_relative_position_bias = _cross_attention_prior.unsqueeze(1).repeat( 1, num_attention_heads, 1, 1 ) - _decoder_cross_attention_relative_position_bias = torch.log(_decoder_cross_attention_relative_position_bias + 1e-8) + _decoder_cross_attention_relative_position_bias = torch.log( + _decoder_cross_attention_relative_position_bias + 1e-8 + ) decoder_cross_attention_relative_position_bias.append(_decoder_cross_attention_relative_position_bias) return_all_crossattention_probs = return_all_crossattention_probs or self.logging_step @@ -964,7 +976,7 @@ def forward( set_inference_key_value_memory=set_inference_key_value_memory, decoder_max_sequence_len=decoder_max_sequence_len, encoder_max_sequence_len=encoder_max_sequence_len, - enc_output_to_layers=self.enc_output_to_layers + enc_output_to_layers=self.enc_output_to_layers, ) alignment_loss = None @@ -972,7 +984,11 @@ def forward( dec_output, enc_output = output # [s, b, h] if return_all_crossattention_probs: dec_output, attention_scores = dec_output - attention_probs = [torch.softmax(attention_score, dim=-1) for lidx, attention_score in enumerate(attention_scores) if lidx in self.alignment_decoder_layerids] + attention_probs = [ + torch.softmax(attention_score, dim=-1) + for lidx, attention_score in enumerate(attention_scores) + if lidx in self.alignment_decoder_layerids + ] if text_limits is not None and self.use_alignment_loss and hasattr(self, "forward_sum_loss"): attention_scores_filtered = [ @@ -987,10 +1003,10 @@ def forward( # align_every_n_head: eg if set to 2, will skip every other head # if set to 12, will select 1 head from every layer align_every_n_head = self.align_every_n_head - dec_start_idx = self.decoder_context_len + 1 # +1 to remove bos + dec_start_idx = self.decoder_context_len + 1 # +1 to remove bos attention_scores_sliced = attention_scores_combined[ - :,::align_every_n_head,dec_start_idx:,text_start_idx:-(2 + end_offset) - ] # -2 to remove eos and pad + :, ::align_every_n_head, dec_start_idx:, text_start_idx : -(2 + end_offset) + ] # -2 to remove eos and pad attention_logprobs = ( attention_scores_sliced # not taking log_softmax, since we will do that in loss function ) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 3203b03deca9..3d5d36d05cc5 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1633,7 +1633,7 @@ def forward( for encoder_idx in range(len(encoder_output)): for layer_idx in enc_output_to_layers[encoder_idx]: layer_to_encoder_num_mapping[layer_idx] = encoder_idx - + for index in range(self.num_layers): layer = self._get_layer(index) past = None @@ -1645,10 +1645,13 @@ def forward( if index in layer_to_encoder_num_mapping: _encoder_output = encoder_output[layer_to_encoder_num_mapping[index]] _enc_dec_attn_mask = enc_dec_attn_mask[layer_to_encoder_num_mapping[index]] - _cross_attention_relative_position_bias = cross_attention_relative_position_bias[layer_to_encoder_num_mapping[index]] + _cross_attention_relative_position_bias = cross_attention_relative_position_bias[ + layer_to_encoder_num_mapping[index] + ] if encoder_max_sequence_len is not None: - _encoder_max_sequence_len = encoder_max_sequence_len[layer_to_encoder_num_mapping[index]] - + _encoder_max_sequence_len = encoder_max_sequence_len[ + layer_to_encoder_num_mapping[index] + ] if layer_past is not None: past = layer_past[index] diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py index 6ebf64d7c17c..cd8abe647990 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -185,9 +185,9 @@ def __init__( context_duration_max: Optional[float] = 5.0, skip_datasets: Optional[List[str]] = [], # substrings of dataset names to skip english_only_model: Optional[bool] = False, - context_conditioning: Optional[str] = "decoder", # encoder or decoder - use_beta_binomial_interpolator: Optional[str] = False, # encoder or decoder - context_slice_method: Optional[str] = "random", # random or fixed + context_conditioning: Optional[str] = "decoder", # encoder or decoder + use_beta_binomial_interpolator: Optional[str] = False, # encoder or decoder + context_slice_method: Optional[str] = "random", # random or fixed phoneme_probability: Optional[float] = 0.5, encoder_type: Optional[str] = "single_transformer", use_ipa: bool = False, @@ -248,7 +248,9 @@ def __init__( self.english_only_model = english_only_model self.phoneme_tokenizer = None if english_only_model: - self.phoneme_tokenizer = instantiate(_get_default_text_tokenizer_conf(phoneme_probability=phoneme_probability, use_ipa=use_ipa)).text_tokenizer + self.phoneme_tokenizer = instantiate( + _get_default_text_tokenizer_conf(phoneme_probability=phoneme_probability, use_ipa=use_ipa) + ).text_tokenizer else: self.g2p = {"fr": lambda x: x} if kwargs.get("g2p", None): @@ -267,8 +269,12 @@ def __init__( self.context_conditioning = context_conditioning if self.context_conditioning == "decoder": - assert self.context_duration_min == self.context_duration_max, "For decoder conditioning, context_duration_min and context_duration_max should be same" - self.decoder_context_len = int(self.context_duration_min * self.codebook_fps) #TODO: Just take from model var? + assert ( + self.context_duration_min == self.context_duration_max + ), "For decoder conditioning, context_duration_min and context_duration_max should be same" + self.decoder_context_len = int( + self.context_duration_min * self.codebook_fps + ) # TODO: Just take from model var? # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type\ self.sup_data_path = None @@ -290,7 +296,11 @@ def __init__( self.transformer_type = kwargs.pop('transformer_type', 'T5') self.skip_datasets = skip_datasets - self.beta_binomial_interpolator = BetaBinomialInterpolator(scaling_factor=self.attention_prior_scaling_factor) if use_beta_binomial_interpolator else None + self.beta_binomial_interpolator = ( + BetaBinomialInterpolator(scaling_factor=self.attention_prior_scaling_factor) + if use_beta_binomial_interpolator + else None + ) self.context_slice_method = context_slice_method self.encoder_type = encoder_type super().__init__( @@ -392,7 +402,9 @@ def load_data(self, dataset): ) # 400 is the max ref speaker audio elif ("Text to speech this" in question_in_manifest) or ('Phoneme TTS' in question_in_manifest): # approx_context_len = 400 - approx_context_len = 5 * (self.codebook_fps + 1) # better than 400. TODO: pneekhara: Need to change things for multi-encoder vs single encoder based filtering. + approx_context_len = 5 * ( + self.codebook_fps + 1 + ) # better than 400. TODO: pneekhara: Need to change things for multi-encoder vs single encoder based filtering. elif "Edit Speech" in question_in_manifest: approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) else: @@ -405,7 +417,7 @@ def load_data(self, dataset): if doc["answer_type"] in ["SPEECH", "AUDIOCODEC", "CONTEXTANSWER"]: assert "answer_duration" in doc, f"answer_duration key not in document {doc}" - approx_answer_len = doc["answer_duration"] * (self.codebook_fps + 1) + 3 # +3 for EOS, BOS padding + approx_answer_len = doc["answer_duration"] * (self.codebook_fps + 1) + 3 # +3 for EOS, BOS padding if self.seq_pattern == "delay_parallel": # In delay parallel, there is padding so add 8 frames approx_answer_len = approx_answer_len + self.num_speech_codebooks @@ -473,7 +485,7 @@ def __getitem__(self, idx): instructions = ["Phoneme TTS", "Text to speech this"] for prefix in instructions: if doc["question"].startswith(prefix): - question_text = doc["question"][len(prefix):].strip() + question_text = doc["question"][len(prefix) :].strip() break input_dict = self._insert_data_in_template(prompt_template_fields, doc, answer_field) @@ -492,9 +504,11 @@ def __getitem__(self, idx): total_context_len = context_tokens[0].size()[1] reduced_len = min( 400, - int(total_context_len * 0.2) - if total_context_len > 600 - else int(total_context_len * random.uniform(0.2, 0.5)), + ( + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)) + ), ) start_token_index = random.randint( 0, total_context_len - reduced_len @@ -658,9 +672,7 @@ def __getitem__(self, idx): text_len = question_tokens_len.item() - start_of_question_offset - end_of_question_offset audio_len = prior_dec_len if self.beta_binomial_interpolator is not None: - cross_attention_question_prior = torch.from_numpy( - self.beta_binomial_interpolator(audio_len, text_len) - ) + cross_attention_question_prior = torch.from_numpy(self.beta_binomial_interpolator(audio_len, text_len)) else: cross_attention_question_prior = torch.from_numpy( beta_binomial_prior_distribution( @@ -675,7 +687,8 @@ def __getitem__(self, idx): ] = cross_attention_question_prior else: cross_attention_prior[ - prior_dec_start_idx:, virtual_tokens_len + context_tokens_len + start_of_question_offset : -end_of_question_offset + prior_dec_start_idx:, + virtual_tokens_len + context_tokens_len + start_of_question_offset : -end_of_question_offset, ] = cross_attention_question_prior if self.encoder_type == "multi_transformer": @@ -890,12 +903,12 @@ def _get_tokens(self, doc, field, field_data): if _text.startswith("Phoneme TTS"): lang = doc.get("lang", "en") instruction_tokens = self._get_text_tokens("Phoneme TTS") - field_tokens = self._get_phoneme_tokens(_text[len("Phoneme TTS"):].strip(), lang=lang) + field_tokens = self._get_phoneme_tokens(_text[len("Phoneme TTS") :].strip(), lang=lang) field_tokens = instruction_tokens + field_tokens elif _text.startswith("Edit Speech"): # Always use phoneme tokenizer for edit speech instruction_tokens = self._get_text_tokens("Edit Speech") - field_tokens = self._get_phoneme_tokens(_text[len("Edit Speech"):].strip()) + field_tokens = self._get_phoneme_tokens(_text[len("Edit Speech") :].strip()) field_tokens = instruction_tokens + field_tokens elif _text.startswith("TEXT CONTEXT:"): # Speaker id conditioning @@ -965,7 +978,9 @@ def _get_tokens(self, doc, field, field_data): if context_info.startswith("TEXT CONTEXT:"): context_tokens = self._get_text_tokens(context_info.strip(" ")) # pad field tokens to fixed length - assert self.context_duration_min == self.context_duration_max, "TEXT CONTEXT only supports fixed context duration" + assert ( + self.context_duration_min == self.context_duration_max + ), "TEXT CONTEXT only supports fixed context duration" _fixed_context_len = int(self.context_duration_min * self.codebook_fps) context_tokens = context_tokens + [self.tokenizer.pad_id] * (_fixed_context_len - len(context_tokens)) @@ -975,15 +990,23 @@ def _get_tokens(self, doc, field, field_data): else: context_tokens = torch.load(context_codec_path).long() context_tokens[0] = (context_tokens[0] + self.speech_offset).long() - assert self.context_duration_min == self.context_duration_max, "CONTEXTANSWER only supports fixed context duration" + assert ( + self.context_duration_min == self.context_duration_max + ), "CONTEXTANSWER only supports fixed context duration" reference_codec_len = int(self.context_duration_min * self.codebook_fps) if context_tokens.shape[1] < reference_codec_len: # Repeat the context to match the reference_codec_len - context_tokens = torch.cat([context_tokens] * (reference_codec_len // context_tokens.shape[1] + 1), dim=1) - assert context_tokens.shape[1] >= reference_codec_len, "CONTEXTANSWER context duration is less than min duration {} {} {}".format(context_tokens.shape[1], reference_codec_len, context_codec_path) + context_tokens = torch.cat( + [context_tokens] * (reference_codec_len // context_tokens.shape[1] + 1), dim=1 + ) + assert ( + context_tokens.shape[1] >= reference_codec_len + ), "CONTEXTANSWER context duration is less than min duration {} {} {}".format( + context_tokens.shape[1], reference_codec_len, context_codec_path + ) si = rng.randint(0, context_tokens.shape[1] - reference_codec_len) - context_tokens = context_tokens[:, si:si+reference_codec_len] - + context_tokens = context_tokens[:, si : si + reference_codec_len] + answer_tokens = torch.load(answer_codec_path).long() answer_tokens[0] = (answer_tokens[0] + self.speech_offset).long() pad_tokens = torch.zeros(self.num_speech_codebooks, 1).long() @@ -1023,7 +1046,7 @@ def _get_tokens(self, doc, field, field_data): return field_tokens def _insert_data_in_template(self, prompt_template_fields, doc, answer_field): - """ Format the input example according to the template """ + """Format the input example according to the template""" out_dict = {} for field in prompt_template_fields: # discard the last one, {label} / {answer} @@ -1056,12 +1079,15 @@ def get_position_ids(self, virtual_token, context_and_qquestion): return build_position_ids(enc_input_p).contiguous() def collate_fn(self, batch): - """ Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch """ + """Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch""" data_dict = self.pad_batch_and_build_loss_mask(batch) if self.encoder_type == "multi_transformer": - position_ids = [self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][0]), self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][1])] + position_ids = [ + self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][0]), + self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][1]), + ] else: position_ids = self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens']) @@ -1084,7 +1110,7 @@ def collate_fn(self, batch): ) def pad_batch_and_build_loss_mask(self, batch): - """ Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask """ + """Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask""" ( taskname_ids, _, @@ -1110,8 +1136,16 @@ def pad_batch_and_build_loss_mask(self, batch): virtual_mask = get_mask_from_lengths(virtual_tokens_len) if self.encoder_type == "multi_transformer": - max_context_len = max(_c[0] for _c in context_and_question_tokens_len) if context_and_question_tokens_len is not None else 0 - max_question_len = max(_c[1] for _c in context_and_question_tokens_len) if context_and_question_tokens_len is not None else 0 + max_context_len = ( + max(_c[0] for _c in context_and_question_tokens_len) + if context_and_question_tokens_len is not None + else 0 + ) + max_question_len = ( + max(_c[1] for _c in context_and_question_tokens_len) + if context_and_question_tokens_len is not None + else 0 + ) max_context_and_question_tokens_len = [max_context_len, max_question_len] context_len = torch.stack([_c[0] for _c in context_and_question_tokens_len]) question_len = torch.stack([_c[1] for _c in context_and_question_tokens_len]) @@ -1119,7 +1153,10 @@ def pad_batch_and_build_loss_mask(self, batch): question_mask = get_mask_from_lengths(question_len) context_and_question_tokens_len = [context_len, question_len] context_and_question_mask = [context_mask, question_mask] - enc_mask = [torch.cat([virtual_mask, context_and_question_mask[0]], dim=1), torch.cat([virtual_mask, context_and_question_mask[1]], dim=1)] + enc_mask = [ + torch.cat([virtual_mask, context_and_question_mask[0]], dim=1), + torch.cat([virtual_mask, context_and_question_mask[1]], dim=1), + ] # import ipdb; ipdb.set_trace() else: max_context_and_question_tokens_len = ( @@ -1268,7 +1305,10 @@ def pad_batch_and_build_loss_mask(self, batch): ) cross_attention_prior_padded = torch.nn.functional.pad( - cross_attention_prior, pad=(0, _p1, 0, _p0), mode="constant", value=1, + cross_attention_prior, + pad=(0, _p1, 0, _p0), + mode="constant", + value=1, ) cross_attention_prior_list.append(cross_attention_prior_padded) @@ -1288,7 +1328,7 @@ def pad_batch_and_build_loss_mask(self, batch): dec_labels_mask = torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None if dec_labels_mask is not None and self.context_conditioning == 'decoder': # Mask out context tokens from loss computation. +1 for bos/pad in the beginning - dec_labels_mask[:,:self.decoder_context_len + 1] = 0 + dec_labels_mask[:, : self.decoder_context_len + 1] = 0 if self.encoder_type == "multi_transformer": context_batch = torch.stack([c[0] for c in context_question_tokens_list]) @@ -1308,10 +1348,12 @@ def pad_batch_and_build_loss_mask(self, batch): "dec_labels_mask": dec_labels_mask, "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, "context_and_question_tokens_lens": context_and_question_tokens_len, - "cross_attention_prior": torch.stack(cross_attention_prior_list) - if len(cross_attention_prior_list) > 0 - else None, - "text_limits": torch.stack(text_limits) if len(text_limits) > 0 else None, # tensor, valid range of answer transcripts without virtual/instruction/end tokens. + "cross_attention_prior": ( + torch.stack(cross_attention_prior_list) if len(cross_attention_prior_list) > 0 else None + ), + "text_limits": ( + torch.stack(text_limits) if len(text_limits) > 0 else None + ), # tensor, valid range of answer transcripts without virtual/instruction/end tokens. "lang": torch.stack(lang_list), "question_texts": question_texts, } @@ -1427,9 +1469,7 @@ def __getitem__(self, idx): if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: taskname_id = self.tokenizer.text_to_ids(taskname) - elif ( - self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT - ): + elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT: taskname_id = -1 else: raise ValueError("Invalid virtual prompt source specified") @@ -1473,7 +1513,14 @@ def __getitem__(self, idx): ) def collate_fn(self, batch): - (_, context_tokens_len, _, question_tokens_len, _, input_ids_len,) = zip(*batch) + ( + _, + context_tokens_len, + _, + question_tokens_len, + _, + input_ids_len, + ) = zip(*batch) decoder_input_len = ( torch.stack(context_tokens_len) + torch.stack(question_tokens_len) + torch.stack(input_ids_len) @@ -1484,7 +1531,10 @@ def collate_fn(self, batch): decoder_mask = get_mask_from_lengths(decoder_input_len - 1) speech_mask = get_mask_from_lengths(decoder_input_len - 1) context_question_mask = torch.ones(speech_mask.shape) - (decoder_input_list, decoder_labels_list,) = ( + ( + decoder_input_list, + decoder_labels_list, + ) = ( [], [], ) @@ -1510,11 +1560,17 @@ def collate_fn(self, batch): complete_input = torch.cat([context_tokens_input, question_tokens, input_ids_shifted], dim=1) complete_input_padded = general_padding( - complete_input, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + complete_input, + decoder_input_len[i].item(), + max_decoder_input_len, + pad_value=self.tokenizer.pad_id, ) complete_output = torch.cat([context_tokens, question_tokens, input_ids], dim=1) complete_output_padded = general_padding( - complete_output, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + complete_output, + decoder_input_len[i].item(), + max_decoder_input_len, + pad_value=self.tokenizer.pad_id, ) decoder_labels = complete_output_padded[:, 1:].contiguous() decoder_input = complete_input_padded[:, :-1].contiguous() @@ -1524,9 +1580,9 @@ def collate_fn(self, batch): decoder_mask[i, : context_tokens_len + question_tokens_len - 1] = 0 # Mask out context and question # TODO: jasoli, the speech_mask looks wrong. I shouldn't be masking out the context - speech_mask[ - i, context_tokens_len : context_tokens_len + question_tokens_len - ] = 0 # Mask out context and question + speech_mask[i, context_tokens_len : context_tokens_len + question_tokens_len] = ( + 0 # Mask out context and question + ) context_question_mask[i, : context_tokens_len + question_tokens_len] = 0 if self.spec_aug: diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py index 7755b9d9bdbf..940c5d2eaab6 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py @@ -34,10 +34,7 @@ from nemo.collections.nlp.modules.common import VirtualPromptSource from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths -from nemo.collections.tts.parts.utils.tts_dataset_utils import ( - beta_binomial_prior_distribution, - general_padding, -) +from nemo.collections.tts.parts.utils.tts_dataset_utils import beta_binomial_prior_distribution, general_padding from nemo.core.classes import IterableDataset from nemo.utils import logging @@ -237,8 +234,7 @@ def __next__(self): return TarredAudioFilter(self.manifest_processor.collection) def _loop_offsets(self, iterator): - """This function is used to iterate through utterances with different offsets for each file. - """ + """This function is used to iterate through utterances with different offsets for each file.""" class TarredAudioLoopOffsets: def __init__(self, collection): @@ -272,8 +268,7 @@ def _collate_fn(self, batch): return _speech_collate_fn(batch) def _build_sample(self, tup): - """Builds the training sample by combining the data from the WebDataset with the manifest info. - """ + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" audio_filename, encodec, ref_encodec, offset_id = tup return audio_filename, encodec, ref_encodec, offset_id @@ -435,7 +430,7 @@ def __init__( self.encodec, self.ref_encodec = None, None def _insert_virtual_token_placeholders(self, input_example, virtual_token_splits): - """ Insert the correct number of pseudo tokens at the <|VIRTUAL_PROMPT_n|> markers """ + """Insert the correct number of pseudo tokens at the <|VIRTUAL_PROMPT_n|> markers""" total_inserted_tokens = 0 for idx in range(len(virtual_token_splits)): @@ -505,9 +500,11 @@ def _build_sample(self, tup): total_context_len = context_tokens[0].size()[1] reduced_len = min( 400, - int(total_context_len * 0.2) - if total_context_len > 600 - else int(total_context_len * random.uniform(0.2, 0.5)), + ( + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)) + ), ) start_token_index = random.randint( 0, total_context_len - reduced_len @@ -648,9 +645,9 @@ def _build_sample(self, tup): scaling_factor=self.attention_prior_scaling_factor, ) ) - cross_attention_prior[ - :, virtual_tokens_len + context_tokens_len + num_question_offset : - ] = cross_attention_question_prior + cross_attention_prior[:, virtual_tokens_len + context_tokens_len + num_question_offset :] = ( + cross_attention_question_prior + ) return ( taskname_id, @@ -794,7 +791,7 @@ def _get_tokens(self, doc, field, field_data): return field_tokens def _insert_data_in_template(self, input_example, prompt_template_fields, doc, answer_field): - """ Format the input example according to the template """ + """Format the input example according to the template""" out_dict = {} for field in prompt_template_fields: # discard the last one, {label} / {answer} @@ -827,7 +824,7 @@ def get_position_ids(self, virtual_token, context_and_qquestion): return build_position_ids(enc_input_p).contiguous() def collate_fn(self, batch): - """ Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch """ + """Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch""" data_dict = self.pad_batch_and_build_loss_mask(batch) @@ -849,7 +846,7 @@ def collate_fn(self, batch): ) def pad_batch_and_build_loss_mask(self, batch): - """ Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask """ + """Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask""" ( taskname_ids, _, @@ -972,7 +969,10 @@ def pad_batch_and_build_loss_mask(self, batch): ) cross_attention_prior_padded = torch.nn.functional.pad( - cross_attention_prior, pad=(0, _p1, 0, _p0), mode="constant", value=1, + cross_attention_prior, + pad=(0, _p1, 0, _p0), + mode="constant", + value=1, ) cross_attention_prior_list.append(cross_attention_prior_padded) @@ -987,16 +987,16 @@ def pad_batch_and_build_loss_mask(self, batch): "dec_labels_mask": torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None, "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, "context_and_question_tokens_lens": context_and_question_tokens_len, - "cross_attention_prior": torch.stack(cross_attention_prior_list) - if len(cross_attention_prior_list) > 0 - else None, + "cross_attention_prior": ( + torch.stack(cross_attention_prior_list) if len(cross_attention_prior_list) > 0 else None + ), } return data_dict class GPTSpeechLMTarredDataset(T5SpeechLMTarredDataset): - """ No support for cross attention here yet""" + """No support for cross attention here yet""" def _build_sample(self, tup): audio_filename, self.encodec, self.ref_encodec, offset_id = tup @@ -1135,7 +1135,14 @@ def _build_sample(self, tup): ) def collate_fn(self, batch): - (_, context_tokens_len, _, question_tokens_len, _, input_ids_len,) = zip(*batch) + ( + _, + context_tokens_len, + _, + question_tokens_len, + _, + input_ids_len, + ) = zip(*batch) decoder_input_len = ( torch.stack(context_tokens_len) + torch.stack(question_tokens_len) + torch.stack(input_ids_len) @@ -1145,7 +1152,10 @@ def collate_fn(self, batch): decoder_mask = get_mask_from_lengths(decoder_input_len - 1) speech_mask = get_mask_from_lengths(decoder_input_len - 1) context_question_mask = torch.ones(speech_mask.shape) - (decoder_input_list, decoder_labels_list,) = ( + ( + decoder_input_list, + decoder_labels_list, + ) = ( [], [], ) @@ -1168,11 +1178,17 @@ def collate_fn(self, batch): complete_input = torch.cat([context_tokens_input, question_tokens, input_ids_shifted], dim=1) complete_input_padded = general_padding( - complete_input, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + complete_input, + decoder_input_len[i].item(), + max_decoder_input_len, + pad_value=self.tokenizer.pad_id, ) complete_output = torch.cat([context_tokens, question_tokens, input_ids], dim=1) complete_output_padded = general_padding( - complete_output, decoder_input_len[i].item(), max_decoder_input_len, pad_value=self.tokenizer.pad_id, + complete_output, + decoder_input_len[i].item(), + max_decoder_input_len, + pad_value=self.tokenizer.pad_id, ) decoder_labels = complete_output_padded[:, 1:].contiguous() decoder_input = complete_input_padded[:, :-1].contiguous() @@ -1181,9 +1197,9 @@ def collate_fn(self, batch): decoder_labels_list.append(decoder_labels) decoder_mask[i, : context_tokens_len + question_tokens_len - 1] = 0 # Mask out context and question - speech_mask[ - i, context_tokens_len : context_tokens_len + question_tokens_len - ] = 0 # Mask out context and question + speech_mask[i, context_tokens_len : context_tokens_len + question_tokens_len] = ( + 0 # Mask out context and question + ) context_question_mask[i, : context_tokens_len + question_tokens_len] = 0 # Using causal attention mask for whole input diff --git a/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py index eb917f0d7af3..aedc1b07d92f 100644 --- a/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py +++ b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py @@ -207,8 +207,7 @@ def init_prompt_encoder(self): ) def freeze_existing_word_embeddings(self): - """Freeze params of existing virtual prompts that should not be tuned further - """ + """Freeze params of existing virtual prompts that should not be tuned further""" # Make sure word embeddings are frozen for params in self.word_embeddings.parameters(): params.requires_grad = False diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py index 5e47db21795f..5bd6b8993525 100644 --- a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import random import json import os +import random import string -from typing import Any, List from functools import partial +from typing import Any, List import editdistance +import imageio import numpy as np import soundfile as sf import torch @@ -50,7 +51,6 @@ from nemo.collections.tts.models.speechllm.megatron_base_speechllm_prompt_model import MegatronBaseSpeechLM from nemo.collections.tts.parts.utils.helpers import plot_alignment_to_numpy_for_speechllm, plot_codec_to_numpy from nemo.utils import AppState, logging -import imageio try: from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches @@ -74,23 +74,25 @@ import time + +import librosa import torchaudio from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector -import librosa __all__ = ['MegatronT5SpeechLMModel'] class MegatronT5OverrideModel(MegatronT5Model): def _build_tokenizer(self): - if self._cfg.tokenizer.library == "sentencepiece": + if self._cfg.tokenizer.library == "sentencepiece": if hasattr(self._cfg.tokenizer, "sentencepiece_legacy"): legacy = self._cfg.tokenizer.sentencepiece_legacy else: legacy = True if self._cfg.tokenizer.library == 'sentencepiece' else False self.tokenizer = SentencePieceSpeechLLMTTSTokenizer( - model_path=self.register_artifact("tokenizer.model", self._cfg.tokenizer.get('model', None)), legacy=legacy + model_path=self.register_artifact("tokenizer.model", self._cfg.tokenizer.get('model', None)), + legacy=legacy, ) if self._cfg.tokenizer.get('additional_special_tokens', None) is not None: @@ -194,16 +196,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.enc_output_to_layers = cfg.get('enc_output_to_layers', None) if self.enc_output_to_layers is not None: # Convert from listconfig to list - self.enc_output_to_layers = [ [l for l in encoder_layer] for encoder_layer in self.enc_output_to_layers ] - + self.enc_output_to_layers = [[l for l in encoder_layer] for encoder_layer in self.enc_output_to_layers] + self.frozen_model.enc_dec_model.speech_offset = speech_offset self.frozen_model.enc_dec_model.speech_codebook_size = speech_codebook_size self.frozen_model.enc_dec_model.num_speech_codebooks = num_speech_codebooks self.frozen_model.enc_dec_model.seq_pattern = cfg.get('seq_pattern', 'parallel') self.frozen_model.enc_dec_model.attn_prior_scaledown_start_step = attn_prior_scaledown_start_step self.frozen_model.enc_dec_model.attn_prior_end_step = attn_prior_end_step - self.frozen_model.enc_dec_model.alignment_decoder_layerids = cfg.get('alignment_decoder_layerids', list(range(0, 12))) - self.frozen_model.enc_dec_model.return_all_crossattention_probs = cfg.get('return_all_crossattention_probs', False) + self.frozen_model.enc_dec_model.alignment_decoder_layerids = cfg.get( + 'alignment_decoder_layerids', list(range(0, 12)) + ) + self.frozen_model.enc_dec_model.return_all_crossattention_probs = cfg.get( + 'return_all_crossattention_probs', False + ) self.frozen_model.enc_dec_model.num_cross_attention_heads = num_cross_attention_heads self.frozen_model.enc_dec_model.context_conditioning = self.context_conditioning self.frozen_model.enc_dec_model.decoder_context_len = self.decoder_context_len @@ -310,11 +316,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # requires to specify a non-matching high-quality and clean reference audio file. It is used to estimate MOS. self.non_matching_ref_audio_filepath = cfg.get('non_matching_ref_audio_filepath', None) if self.non_matching_ref_audio_filepath is None: - raise ValueError(f"Please provide a high-quality reference audio to estimate the MOS. Alternatively, " - f"set `model.estimate_mos=False` to disable MOS estimation.") + raise ValueError( + f"Please provide a high-quality reference audio to estimate the MOS. Alternatively, " + f"set `model.estimate_mos=False` to disable MOS estimation." + ) if not os.path.exists(self.non_matching_ref_audio_filepath): - raise FileNotFoundError(f"Please provide a valid file path for a high-quality reference audio to estimate" - f" the MOS. Alternatively, set `model.estimate_mos=False` to disable MOS estimation.") + raise FileNotFoundError( + f"Please provide a valid file path for a high-quality reference audio to estimate" + f" the MOS. Alternatively, set `model.estimate_mos=False` to disable MOS estimation." + ) def decode_wav_from_codec_model(self, codes): codec_model = self.additional_models['codec'] @@ -372,9 +382,10 @@ def forward( position_ids = [position_ids] cross_attention_prior = [cross_attention_prior] - enc_output = None - logging.debug(f"self.first_stage_of_pipeline()={self.first_stage_of_pipeline()}\tinference_step={inference_step}") + logging.debug( + f"self.first_stage_of_pipeline()={self.first_stage_of_pipeline()}\tinference_step={inference_step}" + ) if self.first_stage_of_pipeline() and inference_step == 0: # Get embeddings for text tokens and insert virtual token embeddings encoder_input_list = [] @@ -544,8 +555,8 @@ def load_frozen_model(self, cfg, trainer): def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): """ - Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ # Get seq length of batch batch = next(dataloader_iter) @@ -630,7 +641,7 @@ def fwd_output_and_loss_func(dataloader_iter, model): _, # TODO: text limit and lang not in tarred dataset _, ) = batch - + if self.trainer.global_step % self.train_check_interval == 0 and not validation_step and self.is_rank_zero: self.frozen_model.enc_dec_model.logging_step = True @@ -683,7 +694,9 @@ def fwd_output_and_loss_func(dataloader_iter, model): logging.info(f"wer score : {score}") self.logger.experiment.add_scalar('WER', score, self.global_step) else: - audio_len = self.decoder_context_len + (labels[0][0][self.decoder_context_len:] != 0).sum().item() + audio_len = ( + self.decoder_context_len + (labels[0][0][self.decoder_context_len :] != 0).sum().item() + ) labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) label_wav = self.decode_wav_from_codec_model(labels_to_1024) dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) @@ -698,8 +711,7 @@ def fwd_output_and_loss_func(dataloader_iter, model): context_tokens = context_and_question_tokens[0] question_tokens = context_and_question_tokens[1] input_token_list_all = [ - question_tokens[0, 0, i].item() - for i in range(question_tokens.shape[2]) + question_tokens[0, 0, i].item() for i in range(question_tokens.shape[2]) ] input_token_list = [ (ti, t) @@ -722,7 +734,7 @@ def fwd_output_and_loss_func(dataloader_iter, model): _context_tokens = context_and_question_tokens[0, :, :context_end_step] if context_end_step > 1: - is_speech_context = _context_tokens[1,:].sum().item() > 0 + is_speech_context = _context_tokens[1, :].sum().item() > 0 if is_speech_context: _context_tokens = self.convert_tokens_to_range( _context_tokens, pattern=self.context_pattern @@ -732,11 +744,13 @@ def fwd_output_and_loss_func(dataloader_iter, model): "train_context_wav", _context_wav, self.global_step, self.sample_rate ) else: - _context_token_list = [ v.item() for v in _context_tokens[0, :] ] + _context_token_list = [v.item() for v in _context_tokens[0, :]] _context_text = self.frozen_model.tokenizer.ids_to_text( [v for v in _context_token_list if v < self.lm_vocab_size] ) - self.logger.experiment.add_text("train_context_text", _context_text, self.global_step) + self.logger.experiment.add_text( + "train_context_text", _context_text, self.global_step + ) question_si = text_limits[0, 0].item() - virtual_tokens.shape[1] question_ei = text_limits[0, 1].item() - virtual_tokens.shape[1] @@ -780,10 +794,15 @@ def fwd_output_and_loss_func(dataloader_iter, model): phoneme_seq=None if self.plot_alignments_sliced else [text_si], ) self.logger.experiment.add_image( - name, alignment_image, self.global_step, dataformats="HWC", + name, + alignment_image, + self.global_step, + dataformats="HWC", ) attention_sliced_list.append( - attention_probs[0, _i, self.decoder_context_len:audio_len, text_si:text_ei] + attention_probs[ + 0, _i, self.decoder_context_len : audio_len, text_si:text_ei + ] ) attention_sliced = torch.stack(attention_sliced_list) attention_sliced = torch.mean(attention_sliced, 0) @@ -856,7 +875,7 @@ def loss_func(loss_args): return fwd_output_and_loss_func def get_forward_output_only_func(self): - """ Used in inference / predict """ + """Used in inference / predict""" def fwd_output_only_func(dataloader_iter, model): batch = next(dataloader_iter) @@ -882,7 +901,6 @@ def fwd_output_only_func(dataloader_iter, model): speech_mask, ) = batch - output_logits, _, token_and_speech_logits = model( context_and_question_tokens, context_and_question_tokens, @@ -908,15 +926,15 @@ def id_func(output_tensor): return fwd_output_only_func def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. """ return def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ return @@ -963,25 +981,46 @@ def training_step(self, dataloader_iter, batch_idx): # if under "Phoneme TTS" instruction, so existing no overlaps between instruction and question token IDs. # question token IDs are bpe token IDs without any offset # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. - context_and_question_tokens = batch[1] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + context_and_question_tokens = batch[ + 1 + ] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) text_limits = batch[12] virtual_tokens = batch[0] - question_limits = text_limits - virtual_tokens.size(1) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. context_tokens, question_tokens = context_and_question_tokens question_tokens_unconditioned = question_tokens.clone() - time_range = torch.arange(question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device).unsqueeze(0) - question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. - question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. batch[1] = [context_tokens, question_tokens_unconditioned] else: - context_and_question_tokens_unconditioned = context_and_question_tokens.clone() # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) - time_range = torch.arange(context_and_question_tokens_unconditioned.size(2), device=context_and_question_tokens_unconditioned.device).unsqueeze(0) # (1, max_context_question_tokens_len) - question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. - context_and_question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + context_and_question_tokens_unconditioned = ( + context_and_question_tokens.clone() + ) # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. batch[1] = context_and_question_tokens_unconditioned del question_limits, question_start, question_end, time_range, question_mask @@ -994,18 +1033,24 @@ def training_step(self, dataloader_iter, batch_idx): if self._rng.random() < self.train_audio_cfg_prob: logging.info(f"Audio Classifier-Free Guidance is triggered for the {batch_idx}-th batch.") - context_and_question_tokens = batch[1] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + context_and_question_tokens = batch[ + 1 + ] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. context_tokens, question_tokens = context_and_question_tokens context_tokens_unconditioned = context_tokens.clone() - context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + context_tokens_unconditioned[:, :, :] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. batch[1] = [context_tokens_unconditioned, question_tokens] else: # dec_input dec_input = batch[3] dec_input_unconditioned = dec_input.clone() - dec_input_unconditioned[:, :, 1:self.decoder_context_len + 1] = self.tokenizer.unk_id # TODO @xueyang: switch to other token id if this one is conflict with text unk. + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. batch[3] = dec_input_unconditioned loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=False) @@ -1033,9 +1078,9 @@ def get_predictions(self, input_ids, enc_mask, encoder_input, labels): enc_mask=enc_mask, num_tokens_to_generate=self.decoder_seq_length, encoder_input=encoder_input, - bos_id=self.tokenizer.pad_id - if self.cfg.data.get('decoder_starts_with_pad', False) - else self.tokenizer.bos_id, + bos_id=( + self.tokenizer.pad_id if self.cfg.data.get('decoder_starts_with_pad', False) else self.tokenizer.bos_id + ), ) # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) @@ -1091,7 +1136,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): ) = batch # loss_mask (b, t) # does not use dataloader_iter due to device placement issues arising from PTL - + mode = self.training self.eval() gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) @@ -1115,11 +1160,11 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): ) labels_original = labels.clone() # (b, 8, t) - + _cross_attention_prior = cross_attention_prior if isinstance(context_and_question_tokens, list): _cross_attention_prior = [None, cross_attention_prior] - + output_loss, _, output_logits = self.forward( virtual_tokens, context_and_question_tokens, @@ -1158,7 +1203,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): logging.info(f"wer score : {score}") self.logger.experiment.add_scalar('WER', score, self.global_step) else: - audio_len = self.decoder_context_len + (labels[0][0][self.decoder_context_len:] != 0).sum().item() + audio_len = self.decoder_context_len + (labels[0][0][self.decoder_context_len :] != 0).sum().item() labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) label_wav = self.decode_wav_from_codec_model(labels_to_1024) dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) @@ -1191,9 +1236,11 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): context_end_step = input_token_list[0][0] _context_tokens = context_and_question_tokens[0, :, :context_end_step] if context_end_step > 1: - is_speech_context = _context_tokens[1,:].sum().item() > 0 + is_speech_context = _context_tokens[1, :].sum().item() > 0 if is_speech_context: - _context_tokens = self.convert_tokens_to_range(_context_tokens, pattern=self.context_pattern) + _context_tokens = self.convert_tokens_to_range( + _context_tokens, pattern=self.context_pattern + ) _context_wav = self.decode_wav_from_codec_model(_context_tokens) self.logger.experiment.add_audio( "val_context_wav", _context_wav, self.global_step, self.sample_rate @@ -1235,7 +1282,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): for lidx in range(len(attention_probs_list)): attention_probs = attention_probs_list[lidx] for _i in range(attention_probs.shape[1]): - attention_sliced_list.append(attention_probs[0, _i, self.decoder_context_len:audio_len, text_si:text_ei]) + attention_sliced_list.append( + attention_probs[0, _i, self.decoder_context_len : audio_len, text_si:text_ei] + ) attention_sliced = torch.stack(attention_sliced_list) attention_sliced = torch.mean(attention_sliced, 0) text = None @@ -1351,7 +1400,11 @@ def on_validation_epoch_end(self): averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() self.log( - 'val_loss_total_check', averaged_loss_total_check, prog_bar=False, rank_zero_only=True, batch_size=1 + 'val_loss_total_check', + averaged_loss_total_check, + prog_bar=False, + rank_zero_only=True, + batch_size=1, ) self.log( 'val_first_layer_accuracy', @@ -1394,7 +1447,11 @@ def on_validation_epoch_end(self): logging.info(f'Validation loss: {averaged_loss}') self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) self.log( - 'val_loss_total_check', averaged_loss_total_check, prog_bar=False, rank_zero_only=True, batch_size=1 + 'val_loss_total_check', + averaged_loss_total_check, + prog_bar=False, + rank_zero_only=True, + batch_size=1, ) averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() @@ -1446,7 +1503,8 @@ def on_validation_epoch_end(self): gather_results_dedup = list(set(itertools.chain(*gather_results))) val_metric_dict = self.validation_metric.get_score( - [i[2] for i in gather_results_dedup], [i[1] for i in gather_results_dedup], + [i[2] for i in gather_results_dedup], + [i[1] for i in gather_results_dedup], ) for metric, val in val_metric_dict.items(): @@ -1566,9 +1624,9 @@ def build_virtual_prompt_dataset( drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, - persistent_workers=True - if num_workers > 0 - else False, # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + persistent_workers=( + True if num_workers > 0 else False + ), # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True ) logging.info(f'build success: {len(dataloader)} {dataset_paths}') if self.phoneme_tokenizer is None: @@ -1623,9 +1681,9 @@ def build_virtual_prompt_tarred_dataset( drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, - persistent_workers=True - if num_workers > 0 - else False, # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + persistent_workers=( + True if num_workers > 0 else False + ), # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True ) logging.info(f'build success: {len(dataloader)} {dataset_paths}') @@ -1652,7 +1710,9 @@ def process_text(self, input_text): return single_space_text - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_scalars=True, global_step=None) -> Any: + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_scalars=True, global_step=None + ) -> Any: with torch.no_grad(): ( @@ -1674,7 +1734,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ ) = batch batch_size = virtual_tokens.size(0) - dec_input = dec_input_raw * 1 # (B, 8, T) # TODO @xueyang: apply clone() method bypasses this unnecessary computation. + dec_input = ( + dec_input_raw * 1 + ) # (B, 8, T) # TODO @xueyang: apply clone() method bypasses this unnecessary computation. dec_input_mask = dec_input_mask_raw * 1 # (B, T) dec_input_mask[:, :] = 1 # Does not really matter output_token_list = [] @@ -1684,14 +1746,16 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ max_inference_timesteps = self.cfg.get('max_inference_timesteps', 2000) # TODO @xueyang: potential bug when max_inference_timesteps < dec_input.shape[2], then dec_input is clipped. dec_input = torch.nn.functional.pad(dec_input, (0, max_inference_timesteps - dec_input.shape[2]), value=0) - dec_input[:, :, self.decoder_context_len + 1:].zero_() + dec_input[:, :, self.decoder_context_len + 1 :].zero_() # TODO @xueyang: why not just declare torch.ones(dec_input_raw.size(0), max_inference_timesteps)? dec_input_mask = torch.nn.functional.pad( dec_input_mask, (0, max_inference_timesteps - dec_input_mask.shape[1]), value=1 ) if self.inference_apply_text_cfg and self.inference_apply_audio_cfg: - question_limits = text_limits - virtual_tokens.size(1) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) @@ -1706,34 +1770,58 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # text question_tokens_unconditioned = question_tokens.clone() - time_range = torch.arange(question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device).unsqueeze(0) - question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. - question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. # audio context_tokens_unconditioned = context_tokens.clone() context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id # concatenate both conditioned and unconditioned batches as a single one. - context_and_question_tokens = [torch.cat((context_tokens, context_tokens_unconditioned), dim=0), torch.cat((question_tokens, question_tokens_unconditioned), dim=0)] + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens_unconditioned), dim=0), + torch.cat((question_tokens, question_tokens_unconditioned), dim=0), + ] enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] dec_input = torch.cat((dec_input, dec_input), dim=0) position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] else: - assert self.context_conditioning == "decoder", f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" # text context_and_question_tokens_unconditioned = context_and_question_tokens.clone() - time_range = torch.arange(context_and_question_tokens_unconditioned.size(2), device=context_and_question_tokens_unconditioned.device).unsqueeze(0) # (1, max_context_question_tokens_len) - question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. - context_and_question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. # audio dec_input_unconditioned = dec_input.clone() - dec_input_unconditioned[:, :, 1:self.decoder_context_len + 1] = self.tokenizer.unk_id # TODO @xueyang: switch to other token id if this one is conflict with text unk. + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. # concatenate both conditioned and unconditioned batches as a single one. - context_and_question_tokens = torch.cat((context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0) + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0 + ) enc_mask = torch.cat((enc_mask, enc_mask), dim=0) dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) position_ids = torch.cat((position_ids, position_ids), dim=0) @@ -1747,7 +1835,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # if under "Phoneme TTS" instruction, so exising no overlaps between instruction and question token IDs. # question token IDs are bpe token IDs without any offset # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. - question_limits = text_limits - virtual_tokens.size(1) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) @@ -1761,24 +1851,46 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ context_tokens, question_tokens = context_and_question_tokens question_tokens_unconditioned = question_tokens.clone() - time_range = torch.arange(question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device).unsqueeze(0) - question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. - question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. # concatenate both conditioned and unconditioned batches as a single one. - context_and_question_tokens = [torch.cat((context_tokens, context_tokens), dim=0), torch.cat((question_tokens, question_tokens_unconditioned), dim=0)] + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens), dim=0), + torch.cat((question_tokens, question_tokens_unconditioned), dim=0), + ] enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] dec_input = torch.cat((dec_input, dec_input), dim=0) position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] else: - assert self.context_conditioning == "decoder", f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" context_and_question_tokens_unconditioned = context_and_question_tokens.clone() - time_range = torch.arange(context_and_question_tokens_unconditioned.size(2), device=context_and_question_tokens_unconditioned.device).unsqueeze(0) # (1, max_context_question_tokens_len) - question_mask = (time_range >= question_start) & (time_range < question_end) # create a mask for question only tokens. - context_and_question_tokens_unconditioned[:, 0][question_mask] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. # concatenate both conditioned and unconditioned batches as a single one. - context_and_question_tokens = torch.cat((context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0) + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0 + ) enc_mask = torch.cat((enc_mask, enc_mask), dim=0) dec_input = torch.cat((dec_input, dec_input), dim=0) position_ids = torch.cat((position_ids, position_ids), dim=0) @@ -1792,23 +1904,36 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ speech_mask = torch.cat((speech_mask, speech_mask), dim=0) dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) - if isinstance(context_and_question_tokens, list): # indicate that self.encoder_type = "multi_transformers" + if isinstance( + context_and_question_tokens, list + ): # indicate that self.encoder_type = "multi_transformers" context_tokens, question_tokens = context_and_question_tokens context_tokens_unconditioned = context_tokens.clone() - context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + context_tokens_unconditioned[:, :, :] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. # concatenate both conditioned and unconditioned batches as a single one. - context_and_question_tokens = [torch.cat((context_tokens, context_tokens_unconditioned), dim=0), torch.cat((question_tokens, question_tokens), dim=0)] + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens_unconditioned), dim=0), + torch.cat((question_tokens, question_tokens), dim=0), + ] enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] dec_input = torch.cat((dec_input, dec_input), dim=0) position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] else: - assert self.context_conditioning == "decoder", f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" dec_input_unconditioned = dec_input.clone() - dec_input_unconditioned[:, :, 1:self.decoder_context_len + 1] = self.tokenizer.unk_id # TODO @xueyang: switch to other token id if this one is conflict with text unk. + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. # concatenate both conditioned and unconditioned batches as a single one. - context_and_question_tokens = torch.cat((context_and_question_tokens, context_and_question_tokens), dim=0) + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens), dim=0 + ) enc_mask = torch.cat((enc_mask, enc_mask), dim=0) dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) position_ids = torch.cat((position_ids, position_ids), dim=0) @@ -1831,7 +1956,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ if t == end_inference_loop_at: print("All ends detected") break - + if isinstance(enc_mask, list): encoder_max_sequence_len = [e.size(1) for e in enc_mask] else: @@ -1845,7 +1970,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ virtual_tokens, context_and_question_tokens, enc_mask, - dec_input[:, :, : t + 1], # tensors representing [CLS] + context audio tokens + [PAD] if context_condition is decoder, otherwise, tensors representing [CLS]. + dec_input[ + :, :, : t + 1 + ], # tensors representing [CLS] + context audio tokens + [PAD] if context_condition is decoder, otherwise, tensors representing [CLS]. dec_input_mask[:, : t + 1], # doesn't matter because of all ones. position_ids, taskname_ids, @@ -1854,15 +1981,15 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ inference=True, inference_step=0, decoder_max_sequence_len=max_inference_timesteps, - encoder_max_sequence_len=encoder_max_sequence_len + encoder_max_sequence_len=encoder_max_sequence_len, ) encoder_output = token_and_speech_logits[-1] - + if isinstance(encoder_output, list): encoder_output = [e.transpose(0, 1) for e in encoder_output] else: encoder_output = encoder_output.transpose(0, 1) - + else: # Prepare batch batch = [ @@ -1876,10 +2003,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ taskname_ids, speech_mask, ] - + output_tensor = fwd_bwd_function( forward_step_func=self.get_forward_output_only_func(), - data_iterator=iter([batch,]), + data_iterator=iter( + [ + batch, + ] + ), model=[self], num_microbatches=get_num_microbatches(), forward_only=True, @@ -1892,13 +2023,19 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # when return_all_crossattention is False, attention_probs is None. if self.frozen_model.enc_dec_model.return_all_crossattention_probs: attention_probs = token_and_speech_logits[2] - attention_probs_mean = torch.stack(attention_probs).mean(dim=0) # B, 12, 1, enc_timesteps + attention_probs_mean = torch.stack(attention_probs).mean(dim=0) # B, 12, 1, enc_timesteps attention_probs_all.append(attention_probs_mean) if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: # interpolate conditioned and unconditioned logits - token_logits = self.inference_cfg_interpolation_scale * token_and_speech_logits[0][:batch_size] + (1 - self.inference_cfg_interpolation_scale) * token_and_speech_logits[0][batch_size:] - output_speech_logits = self.inference_cfg_interpolation_scale * output_logits[:batch_size] + (1 - self.inference_cfg_interpolation_scale) * output_logits[batch_size:] + token_logits = ( + self.inference_cfg_interpolation_scale * token_and_speech_logits[0][:batch_size] + + (1 - self.inference_cfg_interpolation_scale) * token_and_speech_logits[0][batch_size:] + ) + output_speech_logits = ( + self.inference_cfg_interpolation_scale * output_logits[:batch_size] + + (1 - self.inference_cfg_interpolation_scale) * output_logits[batch_size:] + ) else: token_logits = token_and_speech_logits[0] # (B, T, V) output_speech_logits = output_logits @@ -1908,13 +2045,22 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ if torch.count_nonzero(speech_mask) > 0: output_logits_currtimestep = ( - output_speech_logits[:, -1, :, :].permute(0, 2, 1).contiguous().view(-1, self.speech_codebook_size) + output_speech_logits[:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) ) # (B*8, V) output_logits_currtimestep_conditioned = ( - output_logits[:batch_size][:, -1, :, :].permute(0, 2, 1).contiguous().view(-1, self.speech_codebook_size) + output_logits[:batch_size][:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) ) output_logits_currtimestep_unconditioned = ( - output_logits[batch_size:][:, -1, :, :].permute(0, 2, 1).contiguous().view(-1, self.speech_codebook_size) + output_logits[batch_size:][:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) ) else: output_logits_currtimestep = token_logits_currtimestep # (B, V) @@ -1938,8 +2084,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ output_logits_currtimestep_rescored[indices_to_remove] = -float('Inf') # logits interpolation between conditioned and unconditioned logits. - if (self.inference_apply_text_cfg or self.inference_apply_audio_cfg) and self.inference_apply_cfg_filter: - output_logits_currtimestep_rescored = self.inference_cfg_filter_interpolation_scale * output_logits_currtimestep_rescored + (1 - self.inference_cfg_filter_interpolation_scale) * output_logits_currtimestep_unconditioned + if ( + self.inference_apply_text_cfg or self.inference_apply_audio_cfg + ) and self.inference_apply_cfg_filter: + output_logits_currtimestep_rescored = ( + self.inference_cfg_filter_interpolation_scale * output_logits_currtimestep_rescored + + (1 - self.inference_cfg_filter_interpolation_scale) + * output_logits_currtimestep_unconditioned + ) temperature = self.cfg.get('temperature', 0.85) # Set temp 0.01 for greedy decoding output_logits_currtimestep_rescored = output_logits_currtimestep_rescored / temperature @@ -1969,7 +2121,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # duplicate to 2b dim as input for the next iteration if enabling cfg. if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: - output_tokens_curr_timestep = torch.cat((output_tokens_curr_timestep, output_tokens_curr_timestep), dim=0) + output_tokens_curr_timestep = torch.cat( + (output_tokens_curr_timestep, output_tokens_curr_timestep), dim=0 + ) if torch.count_nonzero(speech_mask) > 0: dec_input_next_timestep = output_tokens_curr_timestep * 1 # (B,8) @@ -2004,7 +2158,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ if 'nemo_sv_model' not in self.additional_models: nemo_sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') nemo_sv_model = nemo_sv_model.to(device) - nemo_sv_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + nemo_sv_model.encoder.disable_torch_distributed = True # For multi-gpu training validation nemo_sv_model.eval() self.additional_models['nemo_sv_model'] = nemo_sv_model logging.info(f"Loaded SV Model: {nemo_sv_model}") @@ -2020,7 +2174,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ model = nemo_asr.models.EncDecRNNTBPEModel asr_model = model.from_pretrained(model_name=asr_model) asr_model = asr_model.to(device) - asr_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + asr_model.encoder.disable_torch_distributed = True # For multi-gpu training validation asr_model.eval() self.additional_models['asr_model'] = asr_model logging.info(f"Loaded ASR Model: {asr_model}") @@ -2038,7 +2192,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ self.additional_models['asr_model_zh'] = asr_model_zh else: asr_model_zh = self.additional_models['asr_model_zh'] - + if 'wavlm_sv_model' not in self.additional_models: wavlm_sv_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv') wavlm_sv_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv') @@ -2065,8 +2219,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # prepare MOS estimator by taking a single audio example as an input. squim_mos_model = partial( - squim_mos_model_full, - reference=torch.from_numpy(ref_16khz_wav).to(device).unsqueeze(0) + squim_mos_model_full, reference=torch.from_numpy(ref_16khz_wav).to(device).unsqueeze(0) ) _exp_dir_path = self.logger.log_dir @@ -2095,16 +2248,17 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # logging attention maps. # empty attention_probs_all indicates self.frozen_model.enc_dec_model.return_all_crossattention_probs is False. if len(attention_probs_all) != 0: - attention_probs_all = torch.cat(attention_probs_all, dim=2) # B, 12, dec_timesteps, enc_timesteps - attention_probs_all = attention_probs_all.mean(dim=1) # B, dec_timesteps, enc_timesteps + attention_probs_all = torch.cat(attention_probs_all, dim=2) # B, 12, dec_timesteps, enc_timesteps + attention_probs_all = attention_probs_all.mean(dim=1) # B, dec_timesteps, enc_timesteps for i in range(batch_size): text_end_step = text_limits[i, 1].item() text_start_step = text_limits[i, 0].item() end_index = end_indices.get(i, output_tokens_combined.shape[2]) if len(attention_probs_all) != 0: - attention_probs_example = attention_probs_all[i][:end_index - (1 + self.decoder_context_len), - text_start_step:text_end_step] # T, enc_timesteps + attention_probs_example = attention_probs_all[i][ + : end_index - (1 + self.decoder_context_len), text_start_step:text_end_step + ] # T, enc_timesteps attention_map = attention_probs_example.float().cpu().numpy().T alignment_image = plot_alignment_to_numpy_for_speechllm( attention_map, @@ -2126,7 +2280,10 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # print("Ctc Loss: ", step, ctc_loss.item()) self.logger.experiment.add_image( - "Inf Attention Map", alignment_image, step, dataformats="HWC", + "Inf Attention Map", + alignment_image, + step, + dataformats="HWC", ) # Save attention image to file alignment_fp = os.path.join(_exp_dir_path, f'attention_map_{step}.png') @@ -2144,11 +2301,11 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # During inference, step is the index of the sample step = batch_idx * test_dataloader_batch_size + i - audio_len = self.decoder_context_len + (labels[i][0][self.decoder_context_len:] != 0).sum().item() + audio_len = self.decoder_context_len + (labels[i][0][self.decoder_context_len :] != 0).sum().item() if torch.count_nonzero(speech_mask) > 0: dec_input_to_1024 = self.convert_tokens_to_range(dec_input_raw[i, :, 0:audio_len]) - dec_input_to_1024_answer = dec_input_to_1024[:,self.decoder_context_len+1:] + dec_input_to_1024_answer = dec_input_to_1024[:, self.decoder_context_len + 1 :] dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024_answer) self.logger.experiment.add_audio("Inf Dec Input Wav", dec_input_wav, step, self.sample_rate) @@ -2156,9 +2313,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ if i in end_indices: logging.info(f"Clipping until end index for audio {i}") if self.cfg.get('seq_pattern', 'parallel') == 'delay_parallel': - predicted_tokens = predicted_tokens[:, 0 : end_indices[i] - (1 + self.decoder_context_len) + self.num_speech_codebooks] # trim to audio length + predicted_tokens = predicted_tokens[ + :, 0 : end_indices[i] - (1 + self.decoder_context_len) + self.num_speech_codebooks + ] # trim to audio length else: - predicted_tokens = predicted_tokens[:, 0 : end_indices[i] - (1 + self.decoder_context_len)] # trim to audio length + predicted_tokens = predicted_tokens[ + :, 0 : end_indices[i] - (1 + self.decoder_context_len) + ] # trim to audio length pred_img = predicted_tokens.data.cpu().float().numpy() dec_inp_img = dec_input_to_1024.data.cpu().float().numpy() @@ -2171,10 +2332,16 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ self.logger.experiment.add_audio("Inf Pred Wav", predicted_wav, step, self.sample_rate) self.logger.experiment.add_image( - "Inf Pred Tokens", plot_codec_to_numpy(pred_img), step, dataformats="HWC", + "Inf Pred Tokens", + plot_codec_to_numpy(pred_img), + step, + dataformats="HWC", ) self.logger.experiment.add_image( - "Inf Dec Input Tokens", plot_codec_to_numpy(dec_inp_img), step, dataformats="HWC", + "Inf Dec Input Tokens", + plot_codec_to_numpy(dec_inp_img), + step, + dataformats="HWC", ) # save predicted_wav and gt_wav to a wav files in dir_path @@ -2205,7 +2372,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # speaker verification evaluation using wavlm model gt_16khz_wav, _ = librosa.load(audio_fp_gt, sr=16000) pred_16khz_wav, _ = librosa.load(audio_fp_pred, sr=16000) - inputs_wavlm = wavlm_sv_extractor([pred_16khz_wav, gt_16khz_wav], padding=True, return_tensors="pt", sampling_rate=16000) + inputs_wavlm = wavlm_sv_extractor( + [pred_16khz_wav, gt_16khz_wav], padding=True, return_tensors="pt", sampling_rate=16000 + ) for key in inputs_wavlm.keys(): inputs_wavlm[key] = inputs_wavlm[key].to(device) @@ -2252,21 +2421,21 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ spk_embedding_context = spk_embedding_gt spk_embedding_context_wavlm = spk_embedding_gt_wavlm if self.decoder_context_len > 0: - context_tokens = dec_input_to_1024[:, :self.decoder_context_len+1] + context_tokens = dec_input_to_1024[:, : self.decoder_context_len + 1] context_wav = self.decode_wav_from_codec_model(context_tokens) elif context_end_step > 1: - is_speech_context = context_tokens[1,:].sum().item() > 0 + is_speech_context = context_tokens[1, :].sum().item() > 0 if is_speech_context: context_tokens = self.convert_tokens_to_range(context_tokens, pattern=self.context_pattern) context_wav = self.decode_wav_from_codec_model(context_tokens) else: context_wav = None - _context_token_list = [ v.item() for v in context_tokens[0, :] ] + _context_token_list = [v.item() for v in context_tokens[0, :]] _context_text = self.frozen_model.tokenizer.ids_to_text( [v for v in _context_token_list if v < self.lm_vocab_size] ) self.logger.experiment.add_text("Context Text", _context_text, self.global_step) - + else: context_wav = None # raise NotImplementedError("During prediction, there was no context found.") @@ -2279,10 +2448,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ spk_embedding_context = spk_embedding_context.cpu().detach().numpy().flatten() # wavlm context_wavlm_wav, _ = librosa.load(context_wav_fp, sr=16000) - inputs_wavlm = wavlm_sv_extractor([context_wavlm_wav], padding=True, return_tensors="pt", sampling_rate=16000) + inputs_wavlm = wavlm_sv_extractor( + [context_wavlm_wav], padding=True, return_tensors="pt", sampling_rate=16000 + ) for key in inputs_wavlm.keys(): inputs_wavlm[key] = inputs_wavlm[key].to(device) - + with torch.no_grad(): wavlm_embeddings = wavlm_sv_model(**inputs_wavlm).embeddings wavlm_embeddings = torch.nn.functional.normalize(wavlm_embeddings, dim=-1).cpu() @@ -2336,8 +2507,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ # estimate MOS scores. if self.estimate_mos: - squim_mos_score_pred = squim_mos_model(torch.from_numpy(pred_16khz_wav).to(device).unsqueeze(0)).item() - squim_mos_score_gt = squim_mos_model(torch.from_numpy(gt_16khz_wav).to(device).unsqueeze(0)).item() + squim_mos_score_pred = squim_mos_model( + torch.from_numpy(pred_16khz_wav).to(device).unsqueeze(0) + ).item() + squim_mos_score_gt = squim_mos_model( + torch.from_numpy(gt_16khz_wav).to(device).unsqueeze(0) + ).item() if context_wav is not None: squim_mos_score_context = squim_mos_model(context_wav.to(device).unsqueeze(0)).item() squim_mos_list_context.append(squim_mos_score_context) @@ -2396,15 +2571,19 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ assert all_audio_to_pred[i]["step"] == all_audio_to_pred[i + 1]["step"] # step = batch_idx * self.test_dataloader().batch_size + all_audio_to_pred[i]["step"] step = batch_idx * test_dataloader_batch_size + all_audio_to_pred[i]["step"] - question_text = question_texts[i//2] + question_text = question_texts[i // 2] # No need to process text since both are ASR outputs cer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=True) wer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=False) # Processing text since one is ASR output and the other is the GT text - cer_gt = word_error_rate([self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=True) - wer_gt = word_error_rate([self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=False) + cer_gt = word_error_rate( + [self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=True + ) + wer_gt = word_error_rate( + [self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=False + ) self.logger.experiment.add_text("Inf Predicted Text", greedy_transcripts[i], step) self.logger.experiment.add_text("Inf GT Text", greedy_transcripts[i + 1], step) @@ -2473,7 +2652,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_ } ) - #TODO @xueyang: PTL 2.0+ patch. Signature of method `on_predict_epoch_end` does not match signature of the base method in PTL class 'ModelHooks'. + # TODO @xueyang: PTL 2.0+ patch. Signature of method `on_predict_epoch_end` does not match signature of the base method in PTL class 'ModelHooks'. # Remove the `outputs` param and choose `self.predict_step_output` instead. def on_predict_epoch_end(self, outputs: List[Any]) -> None: @@ -2498,7 +2677,7 @@ def on_predict_epoch_end(self, outputs: List[Any]) -> None: input_prediction_pair = [] correct = 0 - for (input, pred, label) in gather_results_dedup: + for input, pred, label in gather_results_dedup: input_prediction_pair.append((input, pred)) if label: if pred == label: diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index f25d5da9cbb9..88c7204070cd 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -204,7 +204,7 @@ def __init__( stride: int = 1, dilation: int = 1, padding: Optional[int] = None, - activation: Optional[str] = None + activation: Optional[str] = None, ): super().__init__() if not padding: @@ -441,6 +441,7 @@ def forward(self, audio_real, audio_gen): return scores_real, scores_gen, fmaps_real, fmaps_gen + class SSLModel(NeuralModule): def __init__(self, slm_model_name): super().__init__() @@ -454,15 +455,17 @@ class SLMDiscriminator(NeuralModule): """SLM Discriminator as in StyleTTS2 paper. Adapted from https://github.com/yl4579/StyleTTS2/blob/5cedc71c333f8d8b8551ca59378bdcc7af4c9529/losses.py#L193""" - def __init__(self, - slm_model_name="microsoft/wavlm-base-plus", - slm_sr=16000, - input_sr=22050, - slm_hidden=768, - slm_layers=13, - initial_channel=64, - use_spectral_norm=False, - lrelu_slope=0.1): + def __init__( + self, + slm_model_name="microsoft/wavlm-base-plus", + slm_sr=16000, + input_sr=22050, + slm_hidden=768, + slm_layers=13, + initial_channel=64, + use_spectral_norm=False, + lrelu_slope=0.1, + ): super().__init__() self.lrelu_slope = lrelu_slope @@ -479,11 +482,13 @@ def __init__(self, norm_f = nn.utils.weight_norm if use_spectral_norm == False else nn.utils.spectral_norm self.pre = norm_f(nn.Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)) - self.convs = nn.ModuleList([ - norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)), - norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)), - norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)), - ]) + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)), + norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)), + norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)), + ] + ) self.conv_post = norm_f(nn.Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) @@ -1403,9 +1408,7 @@ def __init__( self.down_sample_conv = None self.down_sample_activation = None - self.input_conv = Conv1dNorm( - in_channels=channels, out_channels=filters, kernel_size=kernel_size - ) + self.input_conv = Conv1dNorm(in_channels=channels, out_channels=filters, kernel_size=kernel_size) self.skip_activation = CodecActivation(activation=activation, channels=filters) self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) self.output_activation = CodecActivation(activation=activation, channels=channels) @@ -1416,16 +1419,13 @@ def remove_weight_norm(self): @property def input_types(self): - return { - "inputs": NeuralType(('B', 'C', 'T'), VoidType()), - "input_len": NeuralType(tuple('B'), LengthsType()) - } + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} @property def output_types(self): return { "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()) + "out_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() @@ -1858,7 +1858,7 @@ def forward(self, audio, audio_len): win_length=self.win_length, window=self.window, return_complex=True, - center=False + center=False, ) fft_mag = torch.abs(fft) fft_mag_log = torch.log(fft_mag + self.log_guard) @@ -1940,37 +1940,21 @@ def forward(self, inputs, input_len): class ResNetEncoderV2(NeuralModule): def __init__( - self, - in_channels, - out_channels, - num_layers, - hidden_channels, - filters, - kernel_size=3, - activation="lrelu" + self, in_channels, out_channels, num_layers, hidden_channels, filters, kernel_size=3, activation="lrelu" ): super(ResNetEncoderV2, self).__init__() - self.pre_conv = Conv1dNorm( - in_channels=in_channels, - out_channels=hidden_channels, - kernel_size=kernel_size - ) + self.pre_conv = Conv1dNorm(in_channels=in_channels, out_channels=hidden_channels, kernel_size=kernel_size) self.pre_act = CodecActivation(activation, channels=hidden_channels) - self.res_blocks = nn.ModuleList([ - ResidualBlockV2( - channels=hidden_channels, - filters=filters, - kernel_size=kernel_size, - activation=activation - ) - for _ in range(num_layers) - ]) - self.post_conv = Conv1dNorm( - in_channels=hidden_channels, - out_channels=out_channels, - kernel_size=kernel_size + self.res_blocks = nn.ModuleList( + [ + ResidualBlockV2( + channels=hidden_channels, filters=filters, kernel_size=kernel_size, activation=activation + ) + for _ in range(num_layers) + ] ) + self.post_conv = Conv1dNorm(in_channels=hidden_channels, out_channels=out_channels, kernel_size=kernel_size) def remove_weight_norm(self): self.pre_conv.remove_weight_norm() @@ -2003,41 +1987,25 @@ def forward(self, inputs, input_len): class ResNetEncoderV3(NeuralModule): - def __init__( - self, - in_channels, - out_channels, - filter_list, - stride_list, - kernel_size=3, - activation="lrelu" - ): + def __init__(self, in_channels, out_channels, filter_list, stride_list, kernel_size=3, activation="lrelu"): super(ResNetEncoderV3, self).__init__() input_dim = filter_list[0] - self.pre_conv = Conv1dNorm( - in_channels=in_channels, - out_channels=input_dim, - kernel_size=kernel_size - ) + self.pre_conv = Conv1dNorm(in_channels=in_channels, out_channels=input_dim, kernel_size=kernel_size) self.pre_act = CodecActivation(activation, channels=input_dim) self.res_blocks = nn.ModuleList([]) - for (filters, stride) in zip(filter_list, stride_list): + for filters, stride in zip(filter_list, stride_list): res_block = ResidualBlockV3( channels=input_dim, filters=filters, down_sample_rate=stride, kernel_size=kernel_size, - activation=activation + activation=activation, ) self.res_blocks.append(res_block) input_dim = filters - self.post_conv = Conv1dNorm( - in_channels=input_dim, - out_channels=out_channels, - kernel_size=kernel_size - ) + self.post_conv = Conv1dNorm(in_channels=input_dim, out_channels=out_channels, kernel_size=kernel_size) def remove_weight_norm(self): self.pre_conv.remove_weight_norm() @@ -2235,7 +2203,7 @@ def __init__( out_channels=filters, kernel_size=down_sample_kernel_size, stride=self.down_sample_rate, - activation=activation + activation=activation, ) self.res_block = ResidualBlockV2( channels=filters, filters=filters, kernel_size=kernel_size, activation=activation @@ -2247,16 +2215,13 @@ def remove_weight_norm(self): @property def input_types(self): - return { - "inputs": NeuralType(('B', 'C', 'T'), VoidType()), - "input_len": NeuralType(tuple('B'), LengthsType()) - } + return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} @property def output_types(self): return { "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()) + "out_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() @@ -2298,7 +2263,7 @@ def __init__( out_channels=filters, kernel_size=down_sample_kernel_size, stride=self.down_sample_rate, - activation=activation + activation=activation, ) n_fft, hop_length, win_length = resolution @@ -2328,7 +2293,7 @@ def input_types(self): def output_types(self): return { "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()) + "out_len": NeuralType(tuple('B'), LengthsType()), } @typecheck() @@ -2354,7 +2319,7 @@ def __init__( out_dim, kernel_size=3, down_sample_kernel_size=5, - activation="lrelu" + activation="lrelu", ): super(MultiResolutionSTFTEncoder, self).__init__() assert len(resolutions) == len(filter_list) @@ -2364,16 +2329,10 @@ def __init__( input_dim = n_fft // 2 + 1 self.pre_spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) self.pre_conv = Conv1dNorm( - in_channels=input_dim, - out_channels=input_filters, - kernel_size=kernel_size, - activation=activation + in_channels=input_dim, out_channels=input_filters, kernel_size=kernel_size, activation=activation ) self.pre_res_block = ResidualBlockV2( - channels=input_filters, - filters=input_filters, - kernel_size=kernel_size, - activation=activation + channels=input_filters, filters=input_filters, kernel_size=kernel_size, activation=activation ) input_dim = input_filters self.stft_res_blocks = nn.ModuleList([]) @@ -2398,16 +2357,12 @@ def __init__( down_sample_rate=2, kernel_size=kernel_size, down_sample_kernel_size=down_sample_kernel_size, - activation=activation + activation=activation, ) self.down_sample_res_blocks.append(down_sample_res_block) input_dim = filters - self.post_conv = Conv1dNorm( - in_channels=input_dim, - out_channels=out_dim, - kernel_size=kernel_size - ) + self.post_conv = Conv1dNorm(in_channels=input_dim, out_channels=out_dim, kernel_size=kernel_size) def remove_weight_norm(self): self.encoder.remove_weight_norm() @@ -2442,4 +2397,4 @@ def forward(self, audio, audio_len): encoded = self.post_conv(inputs=encoded, input_len=encoded_len) - return encoded, encoded_len \ No newline at end of file + return encoded, encoded_len diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index 8dd8b8ab11e4..96806f633a54 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -67,8 +67,7 @@ def get_audio_filepaths(manifest_entry: Dict[str, Any], audio_dir: Path) -> Tupl def normalize_volume(audio: np.array, volume_level: float = 0.95) -> np.array: - """Apply peak normalization to the input audio. - """ + """Apply peak normalization to the input audio.""" if not (0.0 <= volume_level <= 1.0): raise ValueError(f"Volume must be in range [0.0, 1.0], received {volume_level}") @@ -316,7 +315,11 @@ def load_audio( def sample_audio( - manifest_entry: Dict[str, Any], audio_dir: Path, sample_rate: int, n_samples: int, volume_norm: bool = False, + manifest_entry: Dict[str, Any], + audio_dir: Path, + sample_rate: int, + n_samples: int, + volume_norm: bool = False, ) -> Tuple[np.ndarray, Path, Path]: """ Randomly sample an audio segment from a manifest entry. diff --git a/nemo/utils/timers.py b/nemo/utils/timers.py index 3c1ebbf1db5e..4197c7be6337 100644 --- a/nemo/utils/timers.py +++ b/nemo/utils/timers.py @@ -1,6 +1,7 @@ """ This module support timing of code blocks. """ + # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/scripts/speechllm_multitask_dataprep.py b/scripts/speechllm_multitask_dataprep.py index 4859ddc896bf..82757fa76e0b 100644 --- a/scripts/speechllm_multitask_dataprep.py +++ b/scripts/speechllm_multitask_dataprep.py @@ -47,7 +47,7 @@ def __init__( max_same_speaker_audios=1, use_context_as_same_speaker_audio=False, pad_multiple=320, - audio_type="actual", # actual or noise or silence + audio_type="actual", # actual or noise or silence ): self.data = [] speakerwise_records = {} @@ -57,7 +57,7 @@ def __init__( record = json.loads(line) if 'answer_duration' not in record: record['answer_duration'] = record['duration'] - + if isinstance(record['speaker'], str) and 'mls_english_' in record['speaker']: record['speaker'] = record['speaker'].replace('mls_english_', '') record['speaker'] = int(record['speaker']) @@ -69,7 +69,7 @@ def __init__( record['context_duration'] < min_duration or record['context_duration'] > max_duration ): continue - + if self._is_record_valid(record): self.data.append(record) if record['speaker'] not in speakerwise_records: @@ -122,7 +122,7 @@ def _is_record_valid(self, record): except: print("Skipping invalid record", record["audio_filepath"]) return False - + def filter_invalid_records(self): filtered_data = [] for ridx, record in enumerate(self.data): @@ -144,7 +144,7 @@ def add_context_records_to_manifest(self): # to ensure all context file paths have their codes extracted and saved. context_paths = {} target_paths = {} - + for record in self.data: if 'context' in record: if 'context_duration' not in record: @@ -192,11 +192,14 @@ def _get_wav_from_filepath(self, audio_filepath, perturb=False): if perturb: perturbed_audio = audio * 1.0 perturbed_audio_length = (audio_length * 1.0).long() - + return audio, audio_length, perturbed_audio, perturbed_audio_length elif self.audio_type == "actual": features = AudioSegment.segment_from_file( - audio_filepath, target_sr=self.sample_rate, n_segments=-1, trim=False, + audio_filepath, + target_sr=self.sample_rate, + n_segments=-1, + trim=False, ) audio_samples = features.samples audio = torch.tensor(audio_samples) @@ -217,7 +220,7 @@ def _get_wav_from_filepath(self, audio_filepath, perturb=False): # import ipdb; ipdb.set_trace() return audio, audio_length, perturbed_audio, perturbed_audio_length - + else: raise ValueError("Unknown audio type {}".format(self.audio_type)) @@ -265,7 +268,7 @@ def pad_collate_fn(self, batch): "rel_audio_path_as_text_id", "samespeaker_audioids", "samespeaker_wavpaths", - "speaker" + "speaker", ] for key in final_batch: @@ -431,7 +434,9 @@ def main(): parser.add_argument('--dataset_name', type=str, default='LibriTTSCorrectContext_train') parser.add_argument('--codec_model_path', type=str, default='/Data/Checkpoints/rlang_codec/SpeechCodec.nemo') parser.add_argument('--codec_bw', type=float, default=6.0) # 6 for 8 codebooks, 1.5 for 3 codebooks - parser.add_argument('--codec_model', type=str, default='nemo_codec') # encodec, uniaudio_codec, dac, nemo_codec, nemo_codec21, nemo_codec211k, nemo_codec214k + parser.add_argument( + '--codec_model', type=str, default='nemo_codec' + ) # encodec, uniaudio_codec, dac, nemo_codec, nemo_codec21, nemo_codec211k, nemo_codec214k parser.add_argument('--use_context_as_same_speaker_audio', action='store_true') parser.add_argument('--save_only_tts_records', action='store_true') parser.add_argument('--shuffle', action='store_true') @@ -490,7 +495,11 @@ def main(): ) dataloader = torch.utils.data.DataLoader( - dataset=dataset, batch_size=args.batch_size, collate_fn=dataset.pad_collate_fn, shuffle=False, num_workers=8, + dataset=dataset, + batch_size=args.batch_size, + collate_fn=dataset.pad_collate_fn, + shuffle=False, + num_workers=8, ) _exp_name = "{}_{}_bw_{}".format(args.dataset_name, args.codec_model, args.codec_bw) @@ -556,10 +565,10 @@ def main(): ) else: raise ValueError("Unknown codec model {}".format(args.codec_model)) - + if args.save_only_tts_records: - perturbed_codec_codes = original_codec_codes # Dummy values to not break the code - mixed_codec_codes = original_codec_codes # Dummy values to not break the code + perturbed_codec_codes = original_codec_codes # Dummy values to not break the code + mixed_codec_codes = original_codec_codes # Dummy values to not break the code # codec_codes = transformer_encodec_model.encode(batch["audio"].unsqueeze(1), audio_len_mask, bandwidth=6.0) target_codecs = [] @@ -615,7 +624,11 @@ def main(): "context_duration": batch['context_duration'][sidx], "answer_duration": batch['duration'][sidx], "taskname": "squad", - "speaker": batch['speaker'][sidx].item() if torch.is_tensor(batch['speaker'][sidx]) else batch['speaker'][sidx], + "speaker": ( + batch['speaker'][sidx].item() + if torch.is_tensor(batch['speaker'][sidx]) + else batch['speaker'][sidx] + ), } phoneme_tts_record = {key: value for key, value in tts_record.items()} @@ -780,4 +793,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() From 75ed036cb7d6fb0e64683e0cfe8bb6bf55a7bb27 Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 6 Nov 2024 12:34:40 -0800 Subject: [PATCH 03/18] undo some changes Signed-off-by: Jason --- Dockerfile.t5tts_from2407 | 194 ----- T5InferenceClean.ipynb | 389 --------- .../asr/modules/conformer_encoder.py | 2 +- nemo/collections/asr/modules/conv_asr.py | 2 +- nemo/utils/exp_manager.py | 13 +- nemo/utils/timers.py | 11 +- scripts/speechllm_multitask_dataprep.py | 796 ------------------ 7 files changed, 7 insertions(+), 1400 deletions(-) delete mode 100644 Dockerfile.t5tts_from2407 delete mode 100644 T5InferenceClean.ipynb delete mode 100644 scripts/speechllm_multitask_dataprep.py diff --git a/Dockerfile.t5tts_from2407 b/Dockerfile.t5tts_from2407 deleted file mode 100644 index 1f6a69364400..000000000000 --- a/Dockerfile.t5tts_from2407 +++ /dev/null @@ -1,194 +0,0 @@ -# syntax=docker/dockerfile:experimental - -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -ARG BASE_IMAGE=nvcr.io/nvidia/nemo:24.07 - -# build an image that includes only the nemo dependencies, ensures that dependencies -# are included first for optimal caching, and useful for building a development -# image (by specifying build target as `nemo-deps`) -FROM ${BASE_IMAGE} as nemo-deps - -# dependency flags; should be declared after FROM -# torchaudio: not required by default -ARG REQUIRE_TORCHAUDIO=false -# k2: not required by default -ARG REQUIRE_K2=false -# ais cli: not required by default, install only if required -ARG REQUIRE_AIS_CLI=false - -# Ensure apt-get won't prompt for selecting options -ENV DEBIAN_FRONTEND=noninteractive -RUN cd - -# libavdevice-dev required for latest torchaudio -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install -y \ - libsndfile1 sox \ - libfreetype6 \ - swig \ - ffmpeg \ - libavdevice-dev && \ - rm -rf /var/lib/apt/lists/* - -# libtool, ... , libgts-dev are required for graphviz -# graphviz is required for k2 and pynini visualization -RUN apt-get update && \ - apt-get install -y \ - libtool \ - libltdl-dev \ - automake \ - autoconf \ - bison \ - flex \ - tcl \ - ghostscript \ - libgd-dev \ - fontconfig \ - libcairo2-dev \ - libpango1.0-dev \ - libgts-dev && \ - rm -rf /var/lib/apt/lists/* - -WORKDIR /workspace/ - -ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea -ARG MCORE_TAG=3f90b989c477ba9be5d6011866641eda9d91f588 -ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c -# Install megatron core, this can be removed once 0.3 pip package is released -# We leave it here in case we need to work off of a specific commit in main -RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ - cd Megatron-LM && \ - git checkout ${MCORE_TAG} && \ - pip install . - -# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 -RUN git clone https://github.com/NVIDIA/apex.git && \ - cd apex && \ - git checkout ${APEX_TAG} && \ - pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir \ - --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ - -# Transformer Engine 1.2.0 -RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ - cd TransformerEngine && \ - git fetch origin ${TE_TAG} && \ - git checkout FETCH_HEAD && \ - git submodule init && git submodule update && \ - NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . - -WORKDIR /tmp/ - -# uninstall stuff from base container -RUN pip3 uninstall -y sacrebleu torchtext - -# build torchaudio -WORKDIR /tmp/torchaudio_build -COPY scripts/installers /tmp/torchaudio_build/scripts/installers/ -RUN INSTALL_MSG=$(/bin/bash /tmp/torchaudio_build/scripts/installers/install_torchaudio_latest.sh); INSTALL_CODE=$?; \ - echo ${INSTALL_MSG}; \ - if [ ${INSTALL_CODE} -ne 0 ]; then \ - echo "torchaudio installation failed"; \ - if [ "${REQUIRE_TORCHAUDIO}" = true ]; then \ - exit ${INSTALL_CODE}; \ - else echo "Skipping failed torchaudio installation"; fi \ - else echo "torchaudio installed successfully"; fi - -COPY scripts /tmp/nemo/scripts/ -# install correct graphviz version (k2 and pynini visualization tool), skip if installation fails -RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_graphviz.sh --docker); INSTALL_CODE=$?; \ - echo ${INSTALL_MSG}; \ - if [ ${INSTALL_CODE} -ne 0 ]; then \ - echo "graphviz installation failed"; \ - if [ "${REQUIRE_K2}" = true ]; then \ - exit ${INSTALL_CODE}; \ - else echo "Skipping failed graphviz installation"; fi \ - else echo "graphviz installed successfully"; fi - -# # install k2, skip if installation fails -# COPY scripts /tmp/nemo/scripts/ -# RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_k2.sh); INSTALL_CODE=$?; \ -# echo ${INSTALL_MSG}; \ -# if [ ${INSTALL_CODE} -ne 0 ]; then \ -# echo "k2 installation failed"; \ -# if [ "${REQUIRE_K2}" = true ]; then \ -# exit ${INSTALL_CODE}; \ -# else echo "Skipping failed k2 installation"; fi \ -# else echo "k2 installed successfully"; fi - -# install nemo dependencies -WORKDIR /tmp/nemo -ENV LHOTSE_REQUIRE_TORCHAUDIO=0 -COPY requirements . -# exclude requirements_vllm.txt, since `vllm==0.5.x` breaks the container due to hardcoded requirements `torch==2.3.0` -RUN for f in $(ls requirements*.txt | grep -v 'requirements_vllm.txt'); do \ - pip3 install --disable-pip-version-check --no-cache-dir -r $f; done - -# install flash attention -RUN pip install flash-attn -# install numba for latest containers -RUN pip install numba>=0.57.1 -# Extra t5 libraries -RUN pip install ipdb seaborn gradio - -# copy nemo source into a scratch image -FROM scratch as nemo-src -COPY . . - -# start building the final container -FROM nemo-deps as nemo -ARG NEMO_VERSION=2.0.0 - -# Check that NEMO_VERSION is set. Build will fail without this. Expose NEMO and base container -# version information as runtime environment variable for introspection purposes -RUN /usr/bin/test -n "$NEMO_VERSION" && \ - /bin/echo "export NEMO_VERSION=${NEMO_VERSION}" >> /root/.bashrc && \ - /bin/echo "export BASE_IMAGE=${BASE_IMAGE}" >> /root/.bashrc - -# Install NeMo -RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]" - -# Check install -# NB: adjusting LD_LIBRARY_PATH (only here, should not be persistent!) is a temporary hack -# to avoid failure if CUDA is unavailable (`docker build` does not expose GPUs) -# The error is raised in NeMo Core, and the main reason is reinstalled Transformer-Engine; -RUN export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CUDA_HOME}/compat/lib.real && \ - python -c "import nemo.collections.asr as nemo_asr" && \ - python -c "import nemo.collections.nlp as nemo_nlp" && \ - python -c "import nemo.collections.tts as nemo_tts" && \ - python -c "import nemo_text_processing.text_normalization as text_normalization" - - -# copy scripts/examples/tests into container for end user -WORKDIR /workspace/nemo -# COPY scripts /workspace/nemo/scripts -# COPY examples /workspace/nemo/examples -# COPY tests /workspace/nemo/tests -# COPY tutorials /workspace/nemo/tutorials -# COPY README.rst LICENSE /workspace/nemo/ - -RUN printf "#!/bin/bash\njupyter lab --no-browser --allow-root --ip=0.0.0.0" >> start-jupyter.sh && \ - chmod +x start-jupyter.sh - -# If required, install AIS CLI and Python AIS SDK -RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_ais_cli_latest.sh && pip install aistore); INSTALL_CODE=$?; \ - echo ${INSTALL_MSG}; \ - if [ ${INSTALL_CODE} -ne 0 ]; then \ - echo "AIS CLI installation failed"; \ - if [ "${REQUIRE_AIS_CLI}" = true ]; then \ - exit ${INSTALL_CODE}; \ - else echo "Skipping AIS CLI installation"; fi \ - else echo "AIS CLI installed successfully"; fi diff --git a/T5InferenceClean.ipynb b/T5InferenceClean.ipynb deleted file mode 100644 index af43d6c8409e..000000000000 --- a/T5InferenceClean.ipynb +++ /dev/null @@ -1,389 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "7554757b", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", - "import torch\n", - "import json\n", - "from omegaconf.omegaconf import OmegaConf, open_dict\n", - "import shutil\n", - "\n", - "from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel\n", - "from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder\n", - "from nemo.collections.asr.parts.preprocessing.segment import AudioSegment\n", - "from nemo.core.config import hydra_runner\n", - "from nemo.utils import logging\n", - "from nemo.utils.exp_manager import exp_manager\n", - "from IPython.display import Audio, display\n", - "import torchaudio\n", - "\n", - "# CHANGE THIS TO A LOCAL DIRECTORY\n", - "EXP_DIR = \"/datap/misc/NotebookInference\"\n", - "\n", - "if not os.path.exists(EXP_DIR):\n", - " os.makedirs(EXP_DIR)" - ] - }, - { - "cell_type": "markdown", - "id": "5fdfa55a", - "metadata": {}, - "source": [ - "## Save a dummy manifest to setup Model Test Step" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1a0c40ac", - "metadata": {}, - "outputs": [], - "source": [ - "def write_records(fp, records):\n", - " with open(fp, \"w\") as f:\n", - " for record in records:\n", - " f.write(json.dumps(record) + \"\\n\")\n", - "\n", - "dummy_codes = torch.ones(8, 300).cpu().type(torch.int16)\n", - "dummy_codes_fp = os.path.join(EXP_DIR, \"dummy_codes.pt\")\n", - "torch.save(dummy_codes, dummy_codes_fp)\n", - "\n", - "\n", - "dummy_record = {\n", - " \"question\" : \"Phoneme TTS Sample Text\",\n", - " \"answer\" : dummy_codes_fp,\n", - " \"context\" : dummy_codes_fp,\n", - " \"context_type\" : \"REFSPEAKERCODEC\",\n", - " \"question_type\" : \"TEXT\",\n", - " \"answer_type\" : \"AUDIOCODEC\",\n", - " \"context_duration\" : 5.0,\n", - " \"answer_duration\" : 5.0,\n", - " \"taskname\" : \"squad\"\n", - "}\n", - "\n", - "dummy_val_file = os.path.join(EXP_DIR, \"dummy_val.json\")\n", - "\n", - "write_records(dummy_val_file, [dummy_record])" - ] - }, - { - "cell_type": "markdown", - "id": "c9cd90c5", - "metadata": {}, - "source": [ - "## Load and setup the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ee1df6bf", - "metadata": {}, - "outputs": [], - "source": [ - "# CHANGE THESE PATHS TO RELEVANT MOUNTED PATHS IN DOCKER\n", - "config_path = \"/home/pneekhara/2023/NeMo/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml\"\n", - "# checkpoint_path = \"/datap/misc/temp_checkpoints_new/desta_less_sophia_highLR_step159600.ckpt\"\n", - "checkpoint_path = \"/datap/misc/checkpoints/desta_less_sophia_213850.ckpt\"\n", - "codecmodel_path = \"/datap/misc/checkpoints/SpeechCodec_2402.nemo\"\n", - "vocab_file = \"/datap/misc/checkpoints/9a77f10c2793465e8e8a3fa5fcbef8b0_vocab.txt\"\n", - "\n", - "cfg = OmegaConf.load(config_path)\n", - "\n", - "if \"gradient_as_bucket_view\" not in cfg.model:\n", - " with open_dict(cfg):\n", - " cfg.model.gradient_as_bucket_view=False\n", - "\n", - "trainer = MegatronTrainerBuilder(cfg).create_trainer()\n", - "exp_manager(trainer, cfg.exp_manager)\n", - "\n", - "with open_dict(cfg):\n", - " cfg.exp_manager.exp_dir=EXP_DIR\n", - " cfg.checkpoint_path = checkpoint_path\n", - " cfg.model.data.sup_data_path=\"/datap/misc/speechllm_codecdatasets/\"\n", - " cfg.model.global_batch_size=1\n", - " cfg.model.micro_batch_size=1\n", - " cfg.model.data.speech_offset=30128\n", - " cfg.model.lm_vocab_size=30000\n", - " cfg.model.data.add_special_tokens_to_only_first_codebook=True\n", - " cfg.model.data.train_task=\"all\"\n", - " cfg.model.freeze_model=False\n", - " cfg.model.data.max_seq_length=2048\n", - " cfg.model.max_inference_timesteps=2000\n", - " cfg.model.data.context_duration_min=20.0\n", - " cfg.model.data.context_duration_max=20.0\n", - " cfg.model.top_k=80\n", - " cfg.model.temperature=0.85\n", - " cfg.model.data.speech_offset=30128\n", - " cfg.model.lm_vocab_size=30000\n", - " cfg.model.codecmodel_path=codecmodel_path\n", - " cfg.trainer.devices=1\n", - " cfg.trainer.precision=\"bf16\"\n", - " cfg.model.precision = cfg.trainer.precision\n", - " cfg.model.override_tokenizer_vocab_file=vocab_file\n", - " cfg.model.english_only_model=True\n", - " cfg.model.asr_model_name=\"stt_en_conformer_transducer_large\"\n", - " cfg.model.frozen_model.decoder.layer_type=[1,1,1,2,2,2,2,2,2,2,1,1]\n", - " cfg.model.alignment_decoder_layerids=[0,1,2,3,4]\n", - " cfg.model.enc_output_to_layers=[[8,9],[3,4,5,6,7]]\n", - " cfg.model.data.test_ds=[dummy_val_file]\n", - " cfg.model.data.num_workers = 0\n", - "\n", - "\n", - "checkpoint_path = cfg.get('checkpoint_path', None)\n", - "assert checkpoint_path is not None, \"checkpoint path needs to be valid\"\n", - "\n", - "model = MegatronT5SpeechLMModel.load_from_checkpoint(\n", - " checkpoint_path=checkpoint_path, trainer=trainer, cfg=cfg.model\n", - " )\n", - "model.eval()\n", - "model = model.cuda()\n", - "\n", - "codec_model = model.additional_models['codec']\n", - "trainer.test(model)\n" - ] - }, - { - "cell_type": "markdown", - "id": "e5461918", - "metadata": {}, - "source": [ - "## Helper functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "227b07f2", - "metadata": {}, - "outputs": [], - "source": [ - "out_dir = os.path.join( model.trainer.logger.save_dir, model.trainer.logger.name, model.trainer.logger.version, \"Sample_Audios\")\n", - "out_path = os.path.join(out_dir, 'predicted_wav_0.wav')\n", - "\n", - "\n", - "def encode(wav_path):\n", - " # Convert an audio file to nemo codec codes\n", - " features = AudioSegment.segment_from_file(\n", - " wav_path, target_sr=codec_model.sample_rate, n_segments=-1, trim=False,\n", - " )\n", - " audio_samples = features.samples\n", - " audio = torch.tensor(audio_samples).cuda()\n", - " audio_length = torch.tensor(audio.size(0)).long().cuda()\n", - " print(f\"audio {audio.size()} audio_length {audio_length}\")\n", - " print(f\"audio {audio.device} audio_length {audio_length.device} codec_model {codec_model.device}\")\n", - "\n", - " original_codec_codes, _ = codec_model.encode(audio=audio.unsqueeze(0), audio_len=audio_length.unsqueeze(0))\n", - " original_codec_codes = original_codec_codes[0]\n", - " print(f\"original_codec_codes {original_codec_codes.size()} audio {audio.size()} audio_length {audio_length}\")\n", - " duration = original_codec_codes.size()[1] / 86\n", - " \n", - " target_codec_filepath = wav_path[:-4] + \"_codes.pt\"\n", - " torch.save(original_codec_codes.cpu().type(torch.int16), target_codec_filepath)\n", - " return original_codec_codes, target_codec_filepath, duration\n", - " \n", - " \n", - " \n", - "def play_codec(codec_path):\n", - " # Convert nemo codecs to audio and play it\n", - " codec = torch.load(codec_path)\n", - " codec = codec.to('cuda')\n", - " codec = codec.unsqueeze(0)\n", - " codec_lens = torch.Tensor([codec.shape[2]]).long().cuda()\n", - " codec_decoded_audios, _ = codec_model.decode(tokens=codec.long(), tokens_len=codec_lens)\n", - " codec_decoded_audio = codec_decoded_audios[0]\n", - " temp_wav_path = os.path.join(EXP_DIR, \"temp.wav\")\n", - " torchaudio.save(temp_wav_path, codec_decoded_audio[None].cpu(), 22050)\n", - " display(Audio(temp_wav_path))\n", - "\n", - "def generate_new_audio(\n", - " text,\n", - " context,\n", - " context_duration=4.0,\n", - " context_type=\"REFSPEAKERCODEC\",\n", - " temperature=0.85,\n", - " top_k=80,\n", - " text_task=\"Phoneme TTS \"\n", - " ):\n", - " # Prepare data in speechllm format\n", - " model.cfg.temperature = temperature\n", - " model.cfg.top_k = top_k\n", - " dummy_answer = dummy_codes_fp\n", - " json_in = {}\n", - " json_in[\"question\"] = text_task + text\n", - " json_in[\"question_type\"] = \"TEXT\"\n", - " json_in[\"answer\"] = dummy_answer \n", - " json_in[\"context\"] = context \n", - " json_in[\"answer_type\"] = \"AUDIOCODEC\"\n", - " json_in[\"context_type\"] = context_type\n", - " json_in[\"context_duration\"] = context_duration\n", - " json_in[\"answer_duration\"] = 2.0\n", - " json_in[\"taskname\"] = \"squad\"\n", - " json_in[\"lang\"] = \"en\"\n", - " json_in = [json_in]\n", - " \n", - " # Prepare dataloader\n", - " model._test_ds.examples = []\n", - " model._test_ds.examples = model._test_ds.load_data(json_in)\n", - " \n", - " sampler = torch.utils.data.distributed.DistributedSampler(\n", - " model._test_ds, num_replicas=1, rank=0, shuffle=False, seed=1\n", - " )\n", - "\n", - " model._test_dl = torch.utils.data.DataLoader(\n", - " model._test_ds,\n", - " collate_fn=model._test_ds.collate_fn,\n", - " sampler=sampler,\n", - " batch_size=1,\n", - " drop_last=False,\n", - " num_workers=1,\n", - " pin_memory=False,\n", - " persistent_workers=True\n", - " )\n", - " \n", - " # Run inference\n", - " model.cfg.data.test_ds = None\n", - " trainer.test(model, model._test_dl)\n", - " print(\"Out path:\", out_path)\n", - " print(\"Inference done\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7f449a2", - "metadata": {}, - "outputs": [], - "source": [ - "text_contexts = [\n", - " \"TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Lindy_CMU_FEARFUL |\",\n", - " \"TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Lindy_CMU_HAPPY |\",\n", - " \"TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Rodney_CMU_HAPPY |\",\n", - " \"TEXT CONTEXT: | Language:en Dataset:PromptTTS Gender:female SpeakingRate:2. Slow emotion:neutral Pitch:4. High SNR:5. Clean REVERB:5. Very close-sounding |\"\n", - "]" - ] - }, - { - "cell_type": "markdown", - "id": "a3d5a467", - "metadata": {}, - "source": [ - "## Generate audio from a text context" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf9660d9", - "metadata": {}, - "outputs": [], - "source": [ - "text = \"As I closed my laptop for the night, my reflection in the screen continued to smile back at me.\"\n", - "text_task = \"Phoneme TTS \" # Can be \"Text to speech this \" (for sentence-piece tokenizer) or \"Phoneme TTS \" (for phoneme tokenizer)\n", - "context = text_contexts[1] # Sample Text Context\n", - "context_type = \"TEXT\" # Can be REFSPEAKERCODEC (for audio context), TEXT (for text context)\n", - "generate_new_audio(\n", - " text, \n", - " context, \n", - " context_type=context_type, \n", - " context_duration=5.0, # Does not matter, should just be > 3 so that dataset does not filter it out.\n", - " top_k=80, # Can play around with this to check roubstness\n", - " temperature=0.8, # Can play around with this. temperature < 0.85 can be more robust\n", - " text_task=text_task\n", - ")\n", - "display(Audio(out_path))" - ] - }, - { - "cell_type": "markdown", - "id": "f1a964c3", - "metadata": {}, - "source": [ - "## Listen to some ground-truth context audios" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f7c966e", - "metadata": {}, - "outputs": [], - "source": [ - "context_paths = [\n", - " \"/datap/misc/speechllm_codecdatasets/codecs/RivattsAllLanguagesUpdated_train_nemo_codec_bw_6.0/target_codes_en_Lindy_44khz_CMU_HAPPY_LINDY_CMU_HAPPY_000570.pt\",\n", - "]\n", - "\n", - "for cidx, context_path in enumerate(context_paths):\n", - " print(cidx, context_path)\n", - " play_codec(context_path)" - ] - }, - { - "cell_type": "markdown", - "id": "45f8bc28", - "metadata": {}, - "source": [ - "## Generate audio from an audio context" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7dd9deb2", - "metadata": {}, - "outputs": [], - "source": [ - "text = \"As I closed my laptop for the night, my reflection in the screen continued to smile back at me.\"\n", - "text_task = \"Text to speech this \" # Can be \"Text to speech this \" (for sentence-piece tokenizer) or \"Phoneme TTS \" (for phoneme tokenizer)\n", - "context = context_paths[0] # Sample Text Context\n", - "context_type = \"REFSPEAKERCODEC\" # Can be REFSPEAKERCODEC (for audio context), TEXT (for text context)\n", - "generate_new_audio(\n", - " text, \n", - " context, \n", - " context_type=context_type, \n", - " context_duration=5.0, # Does not matter, should just be > 3 so that dataset does not filter it out.\n", - " temperature=0.8, # Can play around with this. temperature < 0.85 can be more robust\n", - " text_task=text_task\n", - ")\n", - "display(Audio(out_path))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f40d2450", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index d081f5034eed..27d0cde33f8c 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -672,7 +672,7 @@ def forward_internal( def update_max_seq_length(self, seq_length: int, device): # Find global max audio length across all nodes - if torch.distributed.is_initialized() and (not getattr(self, 'disable_torch_distributed', False)): + if torch.distributed.is_initialized(): global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) # Update across all ranks in the distributed system diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index b4b9a6735013..3cb9ec13109b 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -200,7 +200,7 @@ def forward(self, audio_signal, length): def update_max_sequence_length(self, seq_length: int, device): # Find global max audio length across all nodes - if torch.distributed.is_initialized() and (not getattr(self, 'disable_torch_distributed', False)): + if torch.distributed.is_initialized(): global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) # Update across all ranks in the distributed system diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 8cd76a5ccf5c..b512bc57cbab 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -51,7 +51,6 @@ from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams from nemo.utils.mcore_logger import add_handlers_to_mcore_logger from nemo.utils.model_utils import uninject_model_parallel_rank -from nemo.utils.timers import NeMoTimerException try: # `ptl_resiliency` is included in `gwe_resiliency_pkg` package @@ -261,12 +260,7 @@ def _on_batch_start(self, name): self.timer.start(name) def _on_batch_end(self, name, pl_module): - try: - self.timer.stop(name) - except NeMoTimerException as e: - # Skip the error - pass - + self.timer.stop(name) # Set the `batch_size=1` as WAR for `dataloader_iter`, which is not used for any metric pl_module.log( name + ' in s', @@ -870,13 +864,12 @@ def check_resume( trainer.ckpt_path = str(checkpoint) logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') - trainer.strategy.barrier() if is_global_rank_zero(): # Check to see if any files exist that need to be moved files_to_move = [] if Path(log_dir).exists(): for child in Path(log_dir).iterdir(): - if child.is_file() and not child.name.startswith("events.out.tfevents"): + if child.is_file(): files_to_move.append(child) if len(files_to_move) > 0: @@ -997,7 +990,7 @@ def get_log_dir( os.environ[NEMO_ENV_VARNAME_VERSION] = "" if version is None else version log_dir = Path(_exp_dir) / Path(str(name)) / Path("" if version is None else str(version)) - return log_dir, str(_exp_dir), name, "" if version is None else str(version) + return log_dir, str(_exp_dir), name, version def get_git_hash(): diff --git a/nemo/utils/timers.py b/nemo/utils/timers.py index 4197c7be6337..a35c257652b9 100644 --- a/nemo/utils/timers.py +++ b/nemo/utils/timers.py @@ -1,7 +1,6 @@ """ This module support timing of code blocks. """ - # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,15 +21,9 @@ import numpy as np import torch -from nemo.utils.exceptions import NeMoBaseException - __all__ = ["NamedTimer", "SimpleTimer"] -class NeMoTimerException(NeMoBaseException, RuntimeError): - pass - - class NamedTimer(object): """ A timer class that supports multiple named timers. @@ -97,7 +90,7 @@ def start(self, name=""): timer_data = self.timers.get(name, {}) if "start" in timer_data: - raise NeMoTimerException(f"Cannot start timer = '{name}' since it is already active") + raise RuntimeError(f"Cannot start timer = '{name}' since it is already active") # synchronize pytorch cuda execution if supported if self._sync_cuda and torch.cuda.is_initialized(): @@ -116,7 +109,7 @@ def stop(self, name=""): """ timer_data = self.timers.get(name, None) if (timer_data is None) or ("start" not in timer_data): - raise NeMoTimerException(f"Cannot end timer = '{name}' since it is not active") + raise RuntimeError(f"Cannot end timer = '{name}' since it is not active") # synchronize pytorch cuda execution if supported if self._sync_cuda and torch.cuda.is_initialized(): diff --git a/scripts/speechllm_multitask_dataprep.py b/scripts/speechllm_multitask_dataprep.py deleted file mode 100644 index 82757fa76e0b..000000000000 --- a/scripts/speechllm_multitask_dataprep.py +++ /dev/null @@ -1,796 +0,0 @@ -import argparse -import copy -import json -import math -import os -import random -import time -from pathlib import Path - -import numpy as np -import soundfile as sf -import torch -import torchaudio -from encodec import EncodecModel -from omegaconf import OmegaConf -from tqdm import tqdm - -from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation, WhiteNoisePerturbation -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.modules.transformer import mask_from_lens -from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir -from nemo.core.classes import Dataset -from nemo.utils import logging - -try: - from models.soundstream import SoundStream -except: - logging.warning("SoundStream not found, uniaudio cannot be used") - -try: - import dac -except: - logging.warning("DAC not found") - - -class AudioDataset(Dataset): - def __init__( - self, - manifest_paths, - min_duration=0.0, - max_duration=22.0, - sample_rate=24000, - noise_manifest_path=None, - min_snr_db=0, - max_snr_db=5, - max_same_speaker_audios=1, - use_context_as_same_speaker_audio=False, - pad_multiple=320, - audio_type="actual", # actual or noise or silence - ): - self.data = [] - speakerwise_records = {} - for manifest_path in manifest_paths: - with open(manifest_path, "r") as f: - for line in f: - record = json.loads(line) - if 'answer_duration' not in record: - record['answer_duration'] = record['duration'] - - if isinstance(record['speaker'], str) and 'mls_english_' in record['speaker']: - record['speaker'] = record['speaker'].replace('mls_english_', '') - record['speaker'] = int(record['speaker']) - - if record['answer_duration'] < min_duration or record['answer_duration'] > max_duration: - continue - - if ('context_duration' in record) and ( - record['context_duration'] < min_duration or record['context_duration'] > max_duration - ): - continue - - if self._is_record_valid(record): - self.data.append(record) - if record['speaker'] not in speakerwise_records: - speakerwise_records[record['speaker']] = [] - speakerwise_records[record['speaker']].append(record) - - self.speakerwise_records = speakerwise_records - self.speaker_list = list(self.speakerwise_records.keys()) - - self.sample_rate = sample_rate - self.audio_type = audio_type - - # TODO: Using White Noise Perturbation right now (dont have noise manifest) - - # self.noise_perturber = NoisePerturbation( - # manifest_path=noise_manifest_path, - # min_snr_db=min_snr_db, - # max_snr_db=max_snr_db, - # ) - - self.noise_perturber = WhiteNoisePerturbation() - - self.max_same_speaker_audios = max_same_speaker_audios - - # If True, use the 'context' key as the same speaker reference audio, - # otherwise randomly choose from the same speaker audios - - self.use_context_as_same_speaker_audio = use_context_as_same_speaker_audio - self.pad_multiple = pad_multiple - - if self.use_context_as_same_speaker_audio: - logging.info("Using context as same speaker audio") - self.add_context_records_to_manifest() - - self.base_data_dir = get_base_dir([item["audio_filepath"] for item in self.data]) - # self.filter_invalid_records() - # if sup_data_dir is not None: - # self.sup_data_dir = sup_data_dir - # else: - # self.sup_data_dir = os.path.join(self.base_data_dir, "sup_data") - # if not os.path.exists(self.sup_data_dir): - # os.makedirs(self.sup_data_dir) - - def _is_record_valid(self, record): - return True - try: - sf.read(record["audio_filepath"]) - # sf.read(record["context"]) - return True - except: - print("Skipping invalid record", record["audio_filepath"]) - return False - - def filter_invalid_records(self): - filtered_data = [] - for ridx, record in enumerate(self.data): - if ridx % 1000 == 0: - print("Filtering", ridx, "of", len(self.data)) - try: - sf.read(record["audio_filepath"]) - sf.read(record["context"]) - except: - print("Skipping invalid record", record["audio_filepath"]) - continue - filtered_data.append(record) - print("Original data size", len(self.data)) - print("Filtered data size", len(filtered_data)) - self.data = filtered_data - - def add_context_records_to_manifest(self): - # Add dummy records with audio_filepath as context - # to ensure all context file paths have their codes extracted and saved. - context_paths = {} - target_paths = {} - - for record in self.data: - if 'context' in record: - if 'context_duration' not in record: - # Get duration from the context audio file - record['context_duration'] = float(sf.info(record['context']).duration) - - context_paths[record['context']] = { - 'speaker': record['speaker'], - 'duration': record['context_duration'], - } - if 'answer' in record: - target_paths[record['audio_filepath']] = True - - for context_path in context_paths: - if context_path not in target_paths: - self.data.append( - { - "audio_filepath": context_path, - "context": context_path, - "duration": context_paths[context_path]['duration'], - "answer_duration": context_paths[context_path]['duration'], - "context_duration": context_paths[context_path]['duration'], - "text": "", # Indicates that this is a dummy record - "question": "", - "speaker": context_paths[context_path]['speaker'], - } - ) - - def __len__(self): - return len(self.data) - - def _get_wav_from_filepath(self, audio_filepath, perturb=False): - if self.audio_type == "noise" or self.audio_type == "silence": - # Create a 6 second noise audio - if self.audio_type == "noise": - audio_samples = np.random.normal(0, 1, 6 * self.sample_rate) - else: - audio_samples = np.zeros(6 * self.sample_rate) - audio = torch.tensor(audio_samples).float() - audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) - audio_length = torch.tensor(audio.size(0)).long() - - perturbed_audio = None - perturbed_audio_length = None - if perturb: - perturbed_audio = audio * 1.0 - perturbed_audio_length = (audio_length * 1.0).long() - - return audio, audio_length, perturbed_audio, perturbed_audio_length - elif self.audio_type == "actual": - features = AudioSegment.segment_from_file( - audio_filepath, - target_sr=self.sample_rate, - n_segments=-1, - trim=False, - ) - audio_samples = features.samples - audio = torch.tensor(audio_samples) - audio = torch.nn.functional.pad(audio, (0, self.pad_multiple - audio.size(0) % self.pad_multiple), value=0) - audio_length = torch.tensor(audio.size(0)).long() - - perturbed_audio = None - perturbed_audio_length = None - if perturb: - features_copy = copy.deepcopy(features) - self.noise_perturber.perturb(features_copy) - perturbed_audio_samples = features_copy.samples - perturbed_audio = torch.tensor(perturbed_audio_samples) - perturbed_audio = torch.nn.functional.pad( - perturbed_audio, (0, self.pad_multiple - perturbed_audio.size(0) % self.pad_multiple), value=0 - ) - perturbed_audio_length = torch.tensor(perturbed_audio.size(0)).long() - # import ipdb; ipdb.set_trace() - - return audio, audio_length, perturbed_audio, perturbed_audio_length - - else: - raise ValueError("Unknown audio type {}".format(self.audio_type)) - - def pad_collate_fn(self, batch): - final_batch = {} - for row in batch: - for key in row: - if key not in final_batch: - final_batch[key] = [] - final_batch[key].append(row[key]) - - max_audio_len = max([_audio_len.item() for _audio_len in final_batch["audio_len"]]) - - audios_padded = [] - for audio in final_batch["audio"]: - audio_padded = torch.nn.functional.pad(audio, (0, max_audio_len - audio.size(0)), value=0) - audios_padded.append(audio_padded) - - final_batch["audio"] = audios_padded - - perturbed_audios_padded = [] - max_perturbed_audio_len = max([_audio_len.item() for _audio_len in final_batch["perturbed_audio_len"]]) - for audio in final_batch["perturbed_audio"]: - audio_padded = torch.nn.functional.pad(audio, (0, max_perturbed_audio_len - audio.size(0)), value=0) - perturbed_audios_padded.append(audio_padded) - - final_batch["perturbed_audio"] = perturbed_audios_padded - - mixed_audios_padded = [] - max_mixed_audio_len = max([_audio_len.item() for _audio_len in final_batch["mixed_audio_len"]]) - for audio in final_batch["mixed_audio"]: - audio_padded = torch.nn.functional.pad(audio, (0, max_mixed_audio_len - audio.size(0)), value=0) - mixed_audios_padded.append(audio_padded) - - final_batch["mixed_audio"] = mixed_audios_padded - - non_tensor_keys = [ - "audio_filepath", - "question", - "text", - "context", - "old_speaker_id", - "duration", - "context_duration", - "rel_audio_path_as_text_id", - "samespeaker_audioids", - "samespeaker_wavpaths", - "speaker", - ] - - for key in final_batch: - if key not in non_tensor_keys: - final_batch[key] = torch.stack(final_batch[key]) - - return final_batch - - def __getitem__(self, index): - sample = self.data[index] - rel_audio_path = Path(sample["audio_filepath"]).relative_to(self.base_data_dir).with_suffix("") - rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") - # speaker = torch.tensor(sample["speaker"]).long() - speaker = sample['speaker'] - - # Avoid fixed seed - random.seed(time.time()) - alternate_speaker = random.choice(self.speaker_list) - _ctr = 0 - while (alternate_speaker == speaker) and (_ctr < 10): - random.seed(time.time()) - alternate_speaker = random.choice(self.speaker_list) - _ctr += 1 - - random.seed(time.time()) - alternate_wavpath = random.choice(self.speakerwise_records[alternate_speaker])["audio_filepath"] - - if not self.use_context_as_same_speaker_audio: - random.shuffle(self.speakerwise_records[sample["speaker"]]) - samespeaker_wavpaths = [] - context_duration = 0.0 - for _record in self.speakerwise_records[sample["speaker"]][: self.max_same_speaker_audios]: - if _record["audio_filepath"] != sample["audio_filepath"]: - samespeaker_wavpath = _record["audio_filepath"] - samespeaker_wavpaths.append(samespeaker_wavpath) - context_duration += _record["answer_duration"] - - if len(samespeaker_wavpaths) == 0: - # Use the same audio if no other audio is available from the same speaker - samespeaker_wavpaths = [sample["audio_filepath"]] - context_duration = sample["answer_duration"] - else: - samespeaker_wavpaths = [sample["context"]] - context_duration = sample["context_duration"] - - samespeaker_audioids = [] - for samespeaker_wavpath in samespeaker_wavpaths: - samespeaker_rel_audio_path = Path(samespeaker_wavpath).relative_to(self.base_data_dir).with_suffix("") - samespeaker_rel_audio_path_as_text_id = str(samespeaker_rel_audio_path).replace("/", "_") - samespeaker_audioids.append(samespeaker_rel_audio_path_as_text_id) - - alternate_audio, alternate_audio_length, _, _ = self._get_wav_from_filepath(alternate_wavpath, perturb=False) - audio, audio_length, perturbed_audio, perturbed_audio_length = self._get_wav_from_filepath( - sample["audio_filepath"], perturb=True - ) - - # Mix audio and alternate audio - if audio_length > alternate_audio_length: - # Repeat alternate audio - alternate_audio = alternate_audio.repeat(audio_length // alternate_audio_length + 1) - alternate_audio = alternate_audio[:audio_length] - mixed_audio = 0.5 * (audio + alternate_audio) - elif audio_length <= alternate_audio_length: - alternate_audio = alternate_audio[:audio_length] - mixed_audio = 0.5 * (audio + alternate_audio) - - mixed_audio_length = audio_length - - if "question" not in sample: - sample['question'] = "Text to speech this " + sample['text'] - - return { - "audio": audio, - "audio_len": audio_length, - "perturbed_audio": perturbed_audio, - "perturbed_audio_len": perturbed_audio_length, - "mixed_audio": mixed_audio, - "mixed_audio_len": mixed_audio_length, - "rel_audio_path_as_text_id": rel_audio_path_as_text_id, - "samespeaker_audioids": samespeaker_audioids, - "samespeaker_wavpaths": samespeaker_wavpaths, - "audio_filepath": sample["audio_filepath"], - "question": sample["question"], - "text": sample["text"], - "context": sample.get("context", None), - "old_speaker_id": sample.get("old_speaker_id", None), - "duration": sample["answer_duration"], - "context_duration": context_duration, - "speaker": speaker, - } - - -def save_batch_audios(batch, bidx, temp_dir, codec_model, codec_model_type='encodec', codec_model_sample_rate=24000): - for sidx in range(batch["audio"].shape[0]): - sample_audio = batch["audio"][sidx] - sample_audio_len = batch["audio_len"][sidx].item() - sample_audio = sample_audio[:sample_audio_len] - - # Save sample_audio - sample_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_sample.wav") - torchaudio.save(sample_audio_path, sample_audio[None].cpu(), codec_model_sample_rate) - - # Save perturbed_audio - perturbed_audio = batch["perturbed_audio"][sidx] - perturbed_audio_len = batch["perturbed_audio_len"][sidx].item() - perturbed_audio = perturbed_audio[:perturbed_audio_len] - perturbed_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_perturbed.wav") - torchaudio.save(perturbed_audio_path, perturbed_audio[None].cpu(), codec_model_sample_rate) - - # Save mixed_audio - mixed_audio = batch["mixed_audio"][sidx] - mixed_audio_len = batch["mixed_audio_len"][sidx].item() - mixed_audio = mixed_audio[:mixed_audio_len] - mixed_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_mixed.wav") - torchaudio.save(mixed_audio_path, mixed_audio[None].cpu(), codec_model_sample_rate) - - with torch.no_grad(): - for key in batch: - if "CODEC" in key: - codec = batch[key][sidx] # (8, T) - if codec_model_type == 'encodec': - codec_decoded_audio = codec_model.decode([[codec.unsqueeze(0), None]])[0][0] - elif codec_model_type == 'uniaudio_codec': - codec_decoded_audio = codec_model.decode(codec.unsqueeze(0))[0][0] - elif codec_model_type == 'dac': - _z = codec_model.quantizer.from_codes(codec.unsqueeze(0))[0] - codec_decoded_audio = codec_model.decoder(_z)[0][0] - elif codec_model_type in ['nemo_codec', 'nemo_codec21', 'nemo_codec211k', 'nemo_codec214k']: - codec_len = torch.Tensor([codec.shape[1]]).long().cuda() - codec_decoded_audio, _ = codec_model.decode(tokens=codec.unsqueeze(0), tokens_len=codec_len) - codec_decoded_audio = codec_decoded_audio[0] - - codec_decoded_audio_path = os.path.join(temp_dir, f"{bidx}_{sidx}_{key}_decoded.wav") - torchaudio.save(codec_decoded_audio_path, codec_decoded_audio[None].cpu(), codec_model_sample_rate) - - -def estimate_duration_from_codeclen(codec_len, codec_downsampling_factor=320.0, codec_model_sample_rate=24000): - num_audio_samples = codec_len * codec_downsampling_factor - duration = num_audio_samples / codec_model_sample_rate - return round(duration, 2) - - -def save_manifest(records, manifest_path): - with open(manifest_path, "w") as f: - file_str = "" - for record in records: - file_str += json.dumps(record) + "\n" - file_str = file_str.strip() - f.write(file_str) - print("Saved manifest to {}".format(manifest_path)) - - -def main(): - parser = argparse.ArgumentParser(description='Create multiple tasks') - parser.add_argument("--noise_manifest", type=str, default="/datap/misc/noisedata/train_manifest.json") - parser.add_argument( - '--manifest_paths', - type=str, - default="/Data/manifests_libri_local/train_clean_300_speechlm_ttstasks_with3sec_ref_all_random.json", - ) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--out_dir', type=str, default='/Data/CodecDatasets/speechllm_codecdatasets/') - parser.add_argument('--dataset_name', type=str, default='LibriTTSCorrectContext_train') - parser.add_argument('--codec_model_path', type=str, default='/Data/Checkpoints/rlang_codec/SpeechCodec.nemo') - parser.add_argument('--codec_bw', type=float, default=6.0) # 6 for 8 codebooks, 1.5 for 3 codebooks - parser.add_argument( - '--codec_model', type=str, default='nemo_codec' - ) # encodec, uniaudio_codec, dac, nemo_codec, nemo_codec21, nemo_codec211k, nemo_codec214k - parser.add_argument('--use_context_as_same_speaker_audio', action='store_true') - parser.add_argument('--save_only_tts_records', action='store_true') - parser.add_argument('--shuffle', action='store_true') - parser.add_argument('--split_into_train_val', action='store_true') - parser.add_argument('--num_val_records', type=int, default=500) - parser.add_argument('--audio_type', type=str, default='actual') # actual, noise or silence - args = parser.parse_args() - - if args.codec_model == 'encodec': - codec_model = EncodecModel.encodec_model_24khz() - codec_model.set_target_bandwidth(6.0) - codec_model.cuda() - codec_model.eval() - codec_model_sample_rate = 24000 - codec_model_downsampling_factor = 320.0 - elif args.codec_model == 'uniaudio_codec': - codec_config_path = os.path.join(os.path.dirname(args.codec_model_path), 'config.yaml') - codec_config = OmegaConf.load(codec_config_path) - codec_model = eval(codec_config.generator.name)(**codec_config.generator.config) - codec_parameter_dict = torch.load(args.codec_model_path) - codec_model.load_state_dict(codec_parameter_dict['codec_model']) # load model - codec_model = codec_model.cuda() - # codec_model.eval() - codec_model_sample_rate = 16000 - codec_model_downsampling_factor = 320.0 - elif args.codec_model == 'dac': - model_path = args.codec_model_path - codec_model = dac.DAC.load(model_path) - codec_model.to('cuda') - codec_model_sample_rate = 44100 - codec_model_downsampling_factor = 512.0 - elif args.codec_model == 'nemo_codec': - model_path = args.codec_model_path - codec_model = AudioCodecModel.restore_from(model_path) - codec_model.to('cuda') - codec_model.eval() - codec_model_sample_rate = 22050 - codec_model_downsampling_factor = 256.0 - elif args.codec_model in ['nemo_codec21', 'nemo_codec211k', 'nemo_codec214k']: - model_path = args.codec_model_path - codec_model = AudioCodecModel.restore_from(model_path) - codec_model.to('cuda') - codec_model.eval() - codec_model_sample_rate = 22050 - codec_model_downsampling_factor = 1024.0 - else: - raise ValueError("Unknown codec model {}".format(args.codec_model)) - - dataset = AudioDataset( - manifest_paths=[args.manifest_paths], - sample_rate=codec_model_sample_rate, - noise_manifest_path=args.noise_manifest, - use_context_as_same_speaker_audio=args.use_context_as_same_speaker_audio, - pad_multiple=int(codec_model_downsampling_factor), - audio_type=args.audio_type, - ) - - dataloader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=args.batch_size, - collate_fn=dataset.pad_collate_fn, - shuffle=False, - num_workers=8, - ) - - _exp_name = "{}_{}_bw_{}".format(args.dataset_name, args.codec_model, args.codec_bw) - temp_dir = os.path.join(args.out_dir, "temp_{}".format(_exp_name)) - if not os.path.exists(temp_dir): - os.makedirs(temp_dir) - - codec_base_dir = os.path.join(args.out_dir, "codecs") - manifest_dir = os.path.join(args.out_dir, "manifests") - - audiocodec_out_dir = os.path.join(codec_base_dir, _exp_name) - - if not os.path.exists(audiocodec_out_dir): - os.makedirs(audiocodec_out_dir) - - if not os.path.exists(manifest_dir): - os.makedirs(manifest_dir) - - all_tasks_records = [] - phoneme_tts_records = [] - sentencepiece_tts_records = [] - phoneme_plus_sentencepiece_tts_records = [] - - for bidx, batch in enumerate(tqdm(dataloader)): - # print("bidx", bidx+1, "of", len(dataloader)) - - audio_len_mask = mask_from_lens(batch["audio_len"]) - - cuda_keys = ['audio', 'perturbed_audio', 'mixed_audio', 'audio_len', 'perturbed_audio_len', 'mixed_audio_len'] - for key in cuda_keys: - batch[key] = batch[key].cuda() - with torch.no_grad(): - if args.codec_model == 'encodec': - original_codec_codes = codec_model.encode(batch["audio"].unsqueeze(1))[0][0] - if not args.save_only_tts_records: - perturbed_codec_codes = codec_model.encode(batch["perturbed_audio"].unsqueeze(1))[0][0] - mixed_codec_codes = codec_model.encode(batch["mixed_audio"].unsqueeze(1))[0][0] - elif args.codec_model == 'uniaudio_codec': - original_codec_codes = codec_model.encode( - batch["audio"].unsqueeze(1) * codec_config.audio_norm_scale, target_bw=args.codec_bw - ).permute(1, 0, 2) - if not args.save_only_tts_records: - perturbed_codec_codes = codec_model.encode( - batch["perturbed_audio"].unsqueeze(1) * codec_config.audio_norm_scale, target_bw=args.codec_bw - ).permute(1, 0, 2) - mixed_codec_codes = codec_model.encode( - batch["mixed_audio"].unsqueeze(1) * codec_config.audio_norm_scale, target_bw=args.codec_bw - ).permute(1, 0, 2) - elif args.codec_model == 'dac': - # z, codes, latents, _, _ = model.encode(x) - _, original_codec_codes, _, _, _ = codec_model.encode(batch["audio"].unsqueeze(1)) - if not args.save_only_tts_records: - _, perturbed_codec_codes, _, _, _ = codec_model.encode(batch["perturbed_audio"].unsqueeze(1)) - _, mixed_codec_codes, _, _, _ = codec_model.encode(batch["mixed_audio"].unsqueeze(1)) - elif args.codec_model in ['nemo_codec', 'nemo_codec21', 'nemo_codec211k', 'nemo_codec214k']: - original_codec_codes, _ = codec_model.encode(audio=batch["audio"], audio_len=batch["audio_len"]) - if not args.save_only_tts_records: - perturbed_codec_codes, _ = codec_model.encode( - audio=batch["perturbed_audio"], audio_len=batch["perturbed_audio_len"] - ) - mixed_codec_codes, _ = codec_model.encode( - audio=batch["mixed_audio"], audio_len=batch["mixed_audio_len"] - ) - else: - raise ValueError("Unknown codec model {}".format(args.codec_model)) - - if args.save_only_tts_records: - perturbed_codec_codes = original_codec_codes # Dummy values to not break the code - mixed_codec_codes = original_codec_codes # Dummy values to not break the code - - # codec_codes = transformer_encodec_model.encode(batch["audio"].unsqueeze(1), audio_len_mask, bandwidth=6.0) - target_codecs = [] - mixed_codecs = [] - perturbed_codecs = [] - for sidx in range(batch['audio'].shape[0]): - - codec_len = math.ceil(batch['audio_len'][sidx].item() / codec_model_downsampling_factor) - sample_codec_codes = original_codec_codes[sidx][:, :codec_len] - target_codecs.append(sample_codec_codes) - - perturbed_codec_len = math.ceil( - batch['perturbed_audio_len'][sidx].item() / codec_model_downsampling_factor - ) - perturbed_sample_codec_codes = perturbed_codec_codes[sidx][:, :perturbed_codec_len] - perturbed_codecs.append(perturbed_sample_codec_codes) - - mixed_codec_len = math.ceil(batch['mixed_audio_len'][sidx].item() / codec_model_downsampling_factor) - mixed_sample_codec_codes = mixed_codec_codes[sidx][:, :mixed_codec_len] - mixed_codecs.append(mixed_sample_codec_codes) - - example_name = batch['rel_audio_path_as_text_id'][sidx] - - target_codec_filepath = os.path.join(audiocodec_out_dir, "target_codes_{}.pt".format(example_name)) - torch.save(sample_codec_codes.cpu().type(torch.int16), target_codec_filepath) - - if batch['text'][sidx] == "": - # Only save target codes for dummy records - # Don't need to add dummy records to manifest - continue - - perturbed_codec_filepath = os.path.join(audiocodec_out_dir, "perturbed_codes_{}.pt".format(example_name)) - mixed_codec_filepath = os.path.join(audiocodec_out_dir, "mixed_codes_{}.pt".format(example_name)) - if not args.save_only_tts_records: - torch.save(perturbed_sample_codec_codes.cpu().type(torch.int16), perturbed_codec_filepath) - torch.save(mixed_sample_codec_codes.cpu().type(torch.int16), mixed_codec_filepath) - - tts_contextpath = "" - for samespeaker_audioid in batch['samespeaker_audioids'][sidx]: - tts_contextpath += os.path.join(audiocodec_out_dir, "target_codes_{}.pt".format(samespeaker_audioid)) - tts_contextpath += ";" - tts_contextpath = tts_contextpath[:-1] - - tts_record = { - "audio_filepath": batch['audio_filepath'][sidx], - "text": batch['text'][sidx], - "question": batch['question'][sidx].replace("Phoneme TTS", "Text to speech this"), - "answer": target_codec_filepath, - "context": tts_contextpath, - "question_type": "TEXT", - "answer_type": "AUDIOCODEC", - "context_type": "REFSPEAKERCODEC", - "context_duration": batch['context_duration'][sidx], - "answer_duration": batch['duration'][sidx], - "taskname": "squad", - "speaker": ( - batch['speaker'][sidx].item() - if torch.is_tensor(batch['speaker'][sidx]) - else batch['speaker'][sidx] - ), - } - - phoneme_tts_record = {key: value for key, value in tts_record.items()} - phoneme_tts_record["question"] = phoneme_tts_record["question"].replace( - "Text to speech this", "Phoneme TTS" - ) - - speechenhancement_record = { - "audio_filepath": batch['audio_filepath'][sidx], - "text": batch['text'][sidx], - "question": "Remove Noise", - "answer": target_codec_filepath, - "context": perturbed_codec_filepath, - "question_type": "TEXT", - "answer_type": "AUDIOCODEC", - "context_type": "AUDIOCODEC", - "context_duration": estimate_duration_from_codeclen( - perturbed_codec_len, codec_model_downsampling_factor, codec_model_sample_rate - ), - "answer_duration": batch['duration'][sidx], - "taskname": "squad", - } - - speechseparation_record = { - "audio_filepath": batch['audio_filepath'][sidx], - "text": batch['text'][sidx], - "question": "Extract Speaker Audio", - "answer": target_codec_filepath, - "context": "{},{}".format(mixed_codec_filepath, tts_contextpath), - "question_type": "TEXT", - "answer_type": "AUDIOCODEC", - "context_type": "SEPARATIONCODECS", - "context_duration": estimate_duration_from_codeclen( - mixed_codec_len, codec_model_downsampling_factor, codec_model_sample_rate - ), - "answer_duration": batch['duration'][sidx], - "taskname": "squad", - } - - speechediting_record = { - "audio_filepath": batch['audio_filepath'][sidx], - "text": batch['text'][sidx], - "question": batch['question'][sidx].replace("Text to speech this", "Edit Speech"), - "answer": target_codec_filepath, - "context": target_codec_filepath, - "question_type": "TEXT", - "answer_type": "AUDIOCODEC", - "context_type": "EDITINGCODECS", - "context_duration": batch['duration'][sidx] + 3, # 3 sec for speaker context - "answer_duration": batch['duration'][sidx], - "taskname": "squad", - } - - phoneme_tts_records.append(phoneme_tts_record) - sentencepiece_tts_records.append(tts_record) - - phoneme_plus_sentencepiece_tts_records.append(phoneme_tts_record) - phoneme_plus_sentencepiece_tts_records.append(tts_record) - - all_tasks_records.append(tts_record) - all_tasks_records.append(phoneme_tts_record) - all_tasks_records.append(speechenhancement_record) - all_tasks_records.append(speechseparation_record) - all_tasks_records.append(speechediting_record) - - batch['target_CODEC'] = target_codecs - batch['perturbed_CODEC'] = perturbed_codecs - batch['mixed_CODEC'] = mixed_codecs - - if bidx == 0: - save_batch_audios(batch, bidx, temp_dir, codec_model, args.codec_model, codec_model_sample_rate) - - if args.shuffle: - # To ensure same split for encodec and uniaudio_codec - random.seed(21) - random.shuffle(all_tasks_records) - random.shuffle(phoneme_tts_records) - random.shuffle(sentencepiece_tts_records) - random.shuffle(phoneme_plus_sentencepiece_tts_records) - - if args.split_into_train_val: - # Shuffle compulsory for splitting into train and val - # To ensure same split for encodec and uniaudio_codec - random.seed(21) - random.shuffle(all_tasks_records) - random.shuffle(phoneme_tts_records) - random.shuffle(sentencepiece_tts_records) - # random.shuffle(phoneme_plus_sentencepiece_tts_records) - phoneme_plus_sentencepiece_tts_records = [] - for idx in range(len(phoneme_tts_records)): - phoneme_plus_sentencepiece_tts_records.append(phoneme_tts_records[idx]) - phoneme_plus_sentencepiece_tts_records.append(sentencepiece_tts_records[idx]) - - num_val_records = args.num_val_records - train_phoneme_tts_records = phoneme_tts_records[num_val_records:] - val_phoneme_tts_records = phoneme_tts_records[:num_val_records] - - train_sentencepiece_tts_records = sentencepiece_tts_records[num_val_records:] - val_sentencepiece_tts_records = sentencepiece_tts_records[:num_val_records] - - train_phoneme_plus_sentencepiece_tts_records = phoneme_plus_sentencepiece_tts_records[num_val_records:] - val_phoneme_plus_sentencepiece_tts_records = phoneme_plus_sentencepiece_tts_records[:num_val_records] - # Shuffle train mixed records - random.shuffle(train_phoneme_plus_sentencepiece_tts_records) - - train_all_tasks_records = all_tasks_records[num_val_records:] - val_all_tasks_records = all_tasks_records[:num_val_records] - - manifest_base_name = _exp_name - phoneme_tts_train_manifest_path = os.path.join( - manifest_dir, "{}_train_phoneme_tts.json".format(manifest_base_name) - ) - phoneme_tts_val_manifest_path = os.path.join( - manifest_dir, "{}_val_phoneme_tts.json".format(manifest_base_name) - ) - save_manifest(train_phoneme_tts_records, phoneme_tts_train_manifest_path) - save_manifest(val_phoneme_tts_records, phoneme_tts_val_manifest_path) - - sentencepiece_tts_train_manifest_path = os.path.join( - manifest_dir, "{}_train_sentencepiece_tts.json".format(manifest_base_name) - ) - sentencepiece_tts_val_manifest_path = os.path.join( - manifest_dir, "{}_val_sentencepiece_tts.json".format(manifest_base_name) - ) - save_manifest(train_sentencepiece_tts_records, sentencepiece_tts_train_manifest_path) - save_manifest(val_sentencepiece_tts_records, sentencepiece_tts_val_manifest_path) - - sp_plus_phoneme_tts_train_manifest_path = os.path.join( - manifest_dir, "{}_train_phoneme_plus_sentencepiece_tts.json".format(manifest_base_name) - ) - sp_plus_phoneme_tts_val_manifest_path = os.path.join( - manifest_dir, "{}_val_phoneme_plus_sentencepiece_tts.json".format(manifest_base_name) - ) - save_manifest(train_phoneme_plus_sentencepiece_tts_records, sp_plus_phoneme_tts_train_manifest_path) - save_manifest(val_phoneme_plus_sentencepiece_tts_records, sp_plus_phoneme_tts_val_manifest_path) - - if not args.save_only_tts_records: - all_tasks_train_manifest_path = os.path.join( - manifest_dir, "{}_train_all_tasks.json".format(args.dataset_name) - ) - all_tasks_val_manifest_path = os.path.join(manifest_dir, "{}_val_all_tasks.json".format(args.dataset_name)) - save_manifest(train_all_tasks_records, all_tasks_train_manifest_path) - save_manifest(val_all_tasks_records, all_tasks_val_manifest_path) - else: - manifest_base_name = _exp_name - phoneme_tts_manifest_path = os.path.join(manifest_dir, "{}_phoneme_tts.json".format(manifest_base_name)) - save_manifest(phoneme_tts_records, phoneme_tts_manifest_path) - - sentencepiece_tts_manifest_path = os.path.join( - manifest_dir, "{}_sentencepiece_tts.json".format(manifest_base_name) - ) - save_manifest(sentencepiece_tts_records, sentencepiece_tts_manifest_path) - - phoneme_plus_sentencepiece_tts_manifest_path = os.path.join( - manifest_dir, "{}_phoneme_plus_sentencepiece_tts.json".format(manifest_base_name) - ) - save_manifest(phoneme_plus_sentencepiece_tts_records, phoneme_plus_sentencepiece_tts_manifest_path) - - if not args.save_only_tts_records: - all_manifest_path = os.path.join(manifest_dir, "{}_all_tasks.json".format(args.dataset_name)) - save_manifest(all_tasks_records, all_manifest_path) - - -if __name__ == '__main__': - main() From 20f707fedb2366703a5c7145c8e4502c4dcda9fa Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 6 Nov 2024 12:40:20 -0800 Subject: [PATCH 04/18] undo some more changes Signed-off-by: Jason --- .../tts/modules/audio_codec_modules.py | 785 +----------------- 1 file changed, 2 insertions(+), 783 deletions(-) diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 88c7204070cd..057d9f49546d 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torchaudio from einops import rearrange from transformers import AutoModel @@ -204,7 +203,6 @@ def __init__( stride: int = 1, dilation: int = 1, padding: Optional[int] = None, - activation: Optional[str] = None, ): super().__init__() if not padding: @@ -219,10 +217,6 @@ def __init__( padding_mode="reflect", ) self.conv = nn.utils.weight_norm(conv) - if activation is not None: - self.activation = CodecActivation(activation=activation, channels=out_channels) - else: - self.activation = None @property def input_types(self): @@ -243,8 +237,6 @@ def remove_weight_norm(self): @typecheck() def forward(self, inputs, input_len): out = self.conv(inputs) - if self.activation is not None: - out = self.activation(out) out = mask_sequence_tensor(out, input_len) return out @@ -442,259 +434,6 @@ def forward(self, audio_real, audio_gen): return scores_real, scores_gen, fmaps_real, fmaps_gen -class SSLModel(NeuralModule): - def __init__(self, slm_model_name): - super().__init__() - self.ssl_model = AutoModel.from_pretrained(slm_model_name) - - def forward(self, *args, **kwargs): - return self.ssl_model(*args, **kwargs) - - -class SLMDiscriminator(NeuralModule): - """SLM Discriminator as in StyleTTS2 paper. - Adapted from https://github.com/yl4579/StyleTTS2/blob/5cedc71c333f8d8b8551ca59378bdcc7af4c9529/losses.py#L193""" - - def __init__( - self, - slm_model_name="microsoft/wavlm-base-plus", - slm_sr=16000, - input_sr=22050, - slm_hidden=768, - slm_layers=13, - initial_channel=64, - use_spectral_norm=False, - lrelu_slope=0.1, - ): - super().__init__() - - self.lrelu_slope = lrelu_slope - - # define slm model - self.slm_model = SSLModel(slm_model_name) - self.slm_model.ssl_model.feature_extractor._requires_grad = False - - # Freeze slm model - self.slm_model.freeze() - - self.resample = torchaudio.transforms.Resample(input_sr, slm_sr) - - norm_f = nn.utils.weight_norm if use_spectral_norm == False else nn.utils.spectral_norm - self.pre = norm_f(nn.Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)) - - self.convs = nn.ModuleList( - [ - norm_f(nn.Conv1d(initial_channel, initial_channel * 2, kernel_size=5, padding=2)), - norm_f(nn.Conv1d(initial_channel * 2, initial_channel * 4, kernel_size=5, padding=2)), - norm_f(nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)), - ] - ) - - self.conv_post = norm_f(nn.Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) - - def _forward(self, x): - x = self.slm_model(input_values=self.resample(x), output_hidden_states=True).hidden_states - x = torch.stack(x, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2) - - x = self.pre(x) - fmap = [] - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, self.lrelu_slope) - fmap.append(x.unsqueeze(-1)) - - x = self.conv_post(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - def forward(self, audio_real, audio_gen): - - y_d_r, fmap_r = self._forward(audio_real) - y_d_g, fmap_g = self._forward(audio_gen) - - return [y_d_r.unsqueeze(1)], [y_d_g.unsqueeze(1)], [fmap_r], [fmap_g] - - -class DiscriminatorSTFT(NeuralModule): - """ - Discriminator network from EnCodec for Complex STFT input, but without dilations. - - Args: - filters: number of filters to use in Conv2d layers - lrelu_slope: Slope to use for activations. Leaky relu with slope of 0.1 or 0.2 is recommended for the - stability of the feature matching loss - """ - - def __init__(self, filters: int = 32, lrelu_slope: float = 0.1): - super().__init__() - - self.activation = nn.LeakyReLU(lrelu_slope) - self.conv_layers = nn.ModuleList( - [ - Conv2dNorm(2, filters, kernel_size=(3, 9)), - Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), - Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), - Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)), - Conv2dNorm(filters, filters, kernel_size=(3, 3)), - ] - ) - self.conv_post = Conv2dNorm(filters, 1, kernel_size=(3, 3)) - - @property - def input_types(self): - return { - "spec": NeuralType(('B', 'C', 'T_spec', 'D'), VoidType()), - } - - @property - def output_types(self): - return { - "scores": NeuralType(('B', 'C', 'T_spec'), VoidType()), - "fmap": [NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())], - } - - @typecheck() - def forward(self, spec): - fmap = [] - - # [batch, 2, T_spec, fft] - out = spec - for conv in self.conv_layers: - # [batch, filters, T_spec, fft // strides] - out = conv(inputs=out) - out = self.activation(out) - fmap.append(out) - # [batch, 1, T_spec, fft // 8] - scores = self.conv_post(inputs=out) - fmap.append(scores) - scores = rearrange(scores, "B 1 T C -> B C T") - - return scores, fmap - - -class MultiBandDiscriminatorSTFT(NeuralModule): - """ - Multi-band STFT discriminator proposed in DAC (https://arxiv.org/abs/2306.06546). - - Computes the complex STFT for a given resolution and splits it into sub-bands, - which are given to separate discriminator networks. - - Args: - resolution: STFT resolution, provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length) - stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). - The floats are in the range [0, 1] representing the fraction of all stft bands. - For example for n_fft=1024, the stft output has 513 dimensions. - For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. - """ - - def __init__(self, resolution: Tuple[int], stft_bands: Iterable[Tuple[int]]): - super().__init__() - - self.n_fft, self.hop_length, self.win_length = resolution - self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) - self.discriminators = nn.ModuleList([DiscriminatorSTFT() for _ in stft_bands]) - n_stft = self.n_fft // 2 + 1 - self.stft_bands = [(int(band[0] * n_stft), int(band[1] * n_stft)) for band in stft_bands] - - def compute_stft(self, audio): - # [B, fft, T_spec] - fft = torch.stft( - audio, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - normalized=True, - center=True, - return_complex=True, - ) - fft = rearrange(fft, "B fft T -> B T fft") - # [batch, 2, T_spec, fft] - out = torch.stack([fft.real, fft.imag], dim=1) - return out - - @property - def input_types(self): - return { - "audio": NeuralType(('B', 'T_audio'), AudioSignal()), - } - - @property - def output_types(self): - return { - "scores_list": [NeuralType(('B', 'C', 'T_spec'), VoidType())], - "fmaps_list": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], - } - - @typecheck() - def forward(self, audio): - scores_list = [] - fmap_list = [] - spec = self.compute_stft(audio) - for band, disc in zip(self.stft_bands, self.discriminators): - spec_band = spec[:, :, :, band[0] : band[1]] - score, fmap = disc(spec=spec_band) - scores_list.append(score) - fmap_list.append(fmap) - - return scores_list, fmap_list - - -class MultiResolutionDiscriminatorSTFT(NeuralModule): - """ - Multi-resolution discriminator which creates a multi-band discriminator for each input resolution. - - Args: - resolutions: List of STFT resolutions, each resolution provided as a tuple of 3 integers ordered - (num_fft, hop_length, window_length) - stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end). - The floats are in the range [0, 1] representing the fraction of all stft bands. - For example for n_fft=1024, the stft output has 513 dimensions. - For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512]. - """ - - def __init__(self, resolutions: Iterable[Tuple[int]], stft_bands: Iterable[Tuple[int]]): - super().__init__() - self.discriminators = nn.ModuleList( - [MultiBandDiscriminatorSTFT(resolution=resolution, stft_bands=stft_bands) for resolution in resolutions] - ) - - @property - def input_types(self): - return { - "audio_real": NeuralType(('B', 'T_audio'), AudioSignal()), - "audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()), - } - - @property - def output_types(self): - return { - "scores_real": [NeuralType(('B', 'C', 'T_spec'), VoidType())], - "scores_gen": [NeuralType(('B', 'C', 'T_spec'), VoidType())], - "fmaps_real": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], - "fmaps_gen": [[NeuralType(('B', 'D', 'T_spec', 'C'), VoidType())]], - } - - @typecheck() - def forward(self, audio_real, audio_gen): - scores_real = [] - scores_gen = [] - fmaps_real = [] - fmaps_gen = [] - - for disc in self.discriminators: - score_real_i, fmap_real_i = disc(audio=audio_real) - scores_real = scores_real + score_real_i - fmaps_real = fmaps_real + fmap_real_i - - score_gen_i, fmap_gen_i = disc(audio=audio_gen) - scores_gen = scores_gen + score_gen_i - fmaps_gen = fmaps_gen + fmap_gen_i - - return scores_real, scores_gen, fmaps_real, fmaps_gen - - class DiscriminatorSTFT(NeuralModule): """ Discriminator network from EnCodec for Complex STFT input, but without dilations. @@ -1325,127 +1064,9 @@ def forward(self, inputs, input_len): return out -class ResidualBlockV2(NeuralModule): +class HiFiGANResBlock(NeuralModule): """ - - Args: - channels: Input dimension. - filters: Number of channels in the residual convolutions. - kernel_size: Kernel size of the residual convolutions. - dilation: Dilation of the residual convolutions. - dropout_rate: Dropout to apply to residuals. - activation: Activation to apply in between residual convolutions. - """ - - def __init__( - self, - channels: int, - filters: int, - kernel_size: int = 3, - activation: str = "lrelu", - ): - super(ResidualBlockV2, self).__init__() - - self.input_conv = Conv1dNorm( - in_channels=channels, out_channels=filters, kernel_size=kernel_size, activation=activation - ) - self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) - self.output_activation = CodecActivation(activation=activation, channels=channels) - - def remove_weight_norm(self): - self.input_conv.remove_weight_norm() - self.skip_conv.remove_weight_norm() - - @property - def input_types(self): - return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} - - @property - def output_types(self): - return {"out": NeuralType(('B', 'C', 'T'), EncodedRepresentation())} - - @typecheck() - def forward(self, inputs, input_len): - res = self.input_conv(inputs=inputs, input_len=input_len) - res = self.skip_conv(inputs=res, input_len=input_len) - out = inputs + res - out = self.output_activation(out) - return out - - -class ResidualBlockV3(NeuralModule): - """ - The residual block structure defined by the HiFi-GAN V1 and V2 configurations. - - Args: - channels: Input dimension. - filters: Number of channels in the residual convolutions. - kernel_size: Kernel size of the residual convolutions. - dilation: Dilation of the residual convolutions. - dropout_rate: Dropout to apply to residuals. - activation: Activation to apply in between residual convolutions. - """ - - def __init__( - self, - channels: int, - filters: int, - down_sample_rate: int, - kernel_size: int = 3, - activation: str = "lrelu", - ): - super(ResidualBlockV3, self).__init__() - - if down_sample_rate > 1: - self.down_sample_rate = down_sample_rate - self.down_sample_conv = Conv1dNorm( - in_channels=channels, out_channels=filters, kernel_size=kernel_size, stride=self.down_sample_rate - ) - self.down_sample_activation = CodecActivation(activation=activation, channels=filters) - channels = filters - else: - self.down_sample_rate = None - self.down_sample_conv = None - self.down_sample_activation = None - - self.input_conv = Conv1dNorm(in_channels=channels, out_channels=filters, kernel_size=kernel_size) - self.skip_activation = CodecActivation(activation=activation, channels=filters) - self.skip_conv = Conv1dNorm(in_channels=filters, out_channels=channels, kernel_size=kernel_size) - self.output_activation = CodecActivation(activation=activation, channels=channels) - - def remove_weight_norm(self): - self.input_conv.remove_weight_norm() - self.skip_conv.remove_weight_norm() - - @property - def input_types(self): - return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} - - @property - def output_types(self): - return { - "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, inputs, input_len): - if self.down_sample_rate is not None: - inputs = self.down_sample_conv(inputs=inputs, input_len=input_len) - inputs = self.down_sample_activation(inputs) - input_len = input_len // self.down_sample_rate - - skip_input = self.input_conv(inputs=inputs, input_len=input_len) - skip_input = self.skip_activation(skip_input) - res = self.skip_conv(inputs=skip_input, input_len=input_len) - out = inputs + res - out = self.output_activation(out) - return out, input_len - - -class HiFiGANResBlock(NeuralModule): - """ - Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. + Residual block wrapper for HiFi-GAN which creates a block for multiple dilations. Args: channels: Input dimension. @@ -1821,51 +1442,6 @@ def forward(self, audio, audio_len): return spec, spec_len -class STFTProcessor(NeuralModule): - def __init__(self, n_fft, win_length, hop_length, log_guard=1.0): - super().__init__() - - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) - self.log_guard = log_guard - self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 - - @property - def input_types(self): - return { - "audio": NeuralType(('B', 'T_audio'), AudioSignal()), - "audio_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), - "spec_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, audio, audio_len): - spec_len = audio_len // self.hop_length - audio_padded = torch.nn.functional.pad(audio, (self.stft_pad_amount, self.stft_pad_amount), "reflect") - # [B, n_fft, T_spec] - fft = torch.stft( - audio_padded, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - return_complex=True, - center=False, - ) - fft_mag = torch.abs(fft) - fft_mag_log = torch.log(fft_mag + self.log_guard) - fft_mag_log = mask_sequence_tensor(fft_mag_log, spec_len) - return fft_mag_log, spec_len - - class ResNetEncoder(NeuralModule): """ Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing @@ -1938,136 +1514,6 @@ def forward(self, inputs, input_len): return encoded -class ResNetEncoderV2(NeuralModule): - def __init__( - self, in_channels, out_channels, num_layers, hidden_channels, filters, kernel_size=3, activation="lrelu" - ): - super(ResNetEncoderV2, self).__init__() - - self.pre_conv = Conv1dNorm(in_channels=in_channels, out_channels=hidden_channels, kernel_size=kernel_size) - self.pre_act = CodecActivation(activation, channels=hidden_channels) - self.res_blocks = nn.ModuleList( - [ - ResidualBlockV2( - channels=hidden_channels, filters=filters, kernel_size=kernel_size, activation=activation - ) - for _ in range(num_layers) - ] - ) - self.post_conv = Conv1dNorm(in_channels=hidden_channels, out_channels=out_channels, kernel_size=kernel_size) - - def remove_weight_norm(self): - self.pre_conv.remove_weight_norm() - self.post_conv.remove_weight_norm() - for res_layer in self.res_layers: - res_layer.remove_weight_norm() - - @property - def input_types(self): - return { - "inputs": NeuralType(('B', 'D', 'T'), VoidType()), - "input_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "encoded": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "encoded_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, inputs, input_len): - encoded = self.pre_conv(inputs=inputs, input_len=input_len) - encoded = self.pre_act(encoded) - for res_block in self.res_blocks: - encoded = res_block(inputs=encoded, input_len=input_len) - encoded = self.post_conv(inputs=encoded, input_len=input_len) - return encoded, input_len - - -class ResNetEncoderV3(NeuralModule): - def __init__(self, in_channels, out_channels, filter_list, stride_list, kernel_size=3, activation="lrelu"): - super(ResNetEncoderV3, self).__init__() - - input_dim = filter_list[0] - self.pre_conv = Conv1dNorm(in_channels=in_channels, out_channels=input_dim, kernel_size=kernel_size) - self.pre_act = CodecActivation(activation, channels=input_dim) - self.res_blocks = nn.ModuleList([]) - for filters, stride in zip(filter_list, stride_list): - res_block = ResidualBlockV3( - channels=input_dim, - filters=filters, - down_sample_rate=stride, - kernel_size=kernel_size, - activation=activation, - ) - self.res_blocks.append(res_block) - input_dim = filters - - self.post_conv = Conv1dNorm(in_channels=input_dim, out_channels=out_channels, kernel_size=kernel_size) - - def remove_weight_norm(self): - self.pre_conv.remove_weight_norm() - self.post_conv.remove_weight_norm() - for res_layer in self.res_layers: - res_layer.remove_weight_norm() - - @property - def input_types(self): - return { - "inputs": NeuralType(('B', 'D', 'T'), VoidType()), - "input_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "encoded": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "encoded_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, inputs, input_len): - encoded = self.pre_conv(inputs=inputs, input_len=input_len) - encoded = self.pre_act(encoded) - encoded_len = input_len - for res_block in self.res_blocks: - encoded, encoded_len = res_block(inputs=encoded, input_len=encoded_len) - encoded = self.post_conv(inputs=encoded, input_len=encoded_len) - return encoded, encoded_len - - -class SpectrogramEncoder(NeuralModule): - def __init__(self, spec_processor, encoder): - super(SpectrogramEncoder, self).__init__() - self.spec_processor = spec_processor - self.encoder = encoder - - def remove_weight_norm(self): - self.encoder.remove_weight_norm() - - @property - def input_types(self): - return { - "audio": NeuralType(('B', 'T_audio'), AudioSignal()), - "audio_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), - "encoded_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, audio, audio_len): - spec, spec_len = self.spec_processor(audio=audio, audio_len=audio_len) - encoded, encoded_len = self.encoder(inputs=spec, input_len=spec_len) - return encoded, encoded_len - - class FullBandMelEncoder(NeuralModule): """ Encoder which encodes the entire mel spectrogram with a single encoder network. @@ -2171,230 +1617,3 @@ def forward(self, audio, audio_len): # [B, C, T] encoded = torch.cat(outputs, dim=1) return encoded, spec_len - - -class DownSampleResidualBlock(NeuralModule): - """ - The residual block structure defined by the HiFi-GAN V1 and V2 configurations. - - Args: - channels: Input dimension. - filters: Number of channels in the residual convolutions. - kernel_size: Kernel size of the residual convolutions. - dilation: Dilation of the residual convolutions. - dropout_rate: Dropout to apply to residuals. - activation: Activation to apply in between residual convolutions. - """ - - def __init__( - self, - channels: int, - filters: int, - kernel_size: int, - down_sample_rate: int, - down_sample_kernel_size: int, - activation: str = "lrelu", - ): - super(DownSampleResidualBlock, self).__init__() - - self.down_sample_rate = down_sample_rate - self.down_sample_conv = Conv1dNorm( - in_channels=channels, - out_channels=filters, - kernel_size=down_sample_kernel_size, - stride=self.down_sample_rate, - activation=activation, - ) - self.res_block = ResidualBlockV2( - channels=filters, filters=filters, kernel_size=kernel_size, activation=activation - ) - - def remove_weight_norm(self): - self.input_conv.remove_weight_norm() - self.skip_conv.remove_weight_norm() - - @property - def input_types(self): - return {"inputs": NeuralType(('B', 'C', 'T'), VoidType()), "input_len": NeuralType(tuple('B'), LengthsType())} - - @property - def output_types(self): - return { - "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, inputs, input_len): - output_len = input_len // self.down_sample_rate - out = self.down_sample_conv(inputs=inputs, input_len=output_len) - out = self.res_block(inputs=out, input_len=output_len) - return out, output_len - - -class STFTResidualBlock(NeuralModule): - """ - The residual block structure defined by the HiFi-GAN V1 and V2 configurations. - - Args: - channels: Input dimension. - filters: Number of channels in the residual convolutions. - kernel_size: Kernel size of the residual convolutions. - dilation: Dilation of the residual convolutions. - dropout_rate: Dropout to apply to residuals. - activation: Activation to apply in between residual convolutions. - """ - - def __init__( - self, - resolution, - input_dim, - filters, - kernel_size, - down_sample_rate, - down_sample_kernel_size, - activation, - ): - super(STFTResidualBlock, self).__init__() - - self.down_sample_rate = down_sample_rate - self.down_sample_conv = Conv1dNorm( - in_channels=input_dim, - out_channels=filters, - kernel_size=down_sample_kernel_size, - stride=self.down_sample_rate, - activation=activation, - ) - - n_fft, hop_length, win_length = resolution - stft_dim = n_fft // 2 + 1 - self.spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) - self.spec_conv = Conv1dNorm(in_channels=stft_dim, out_channels=filters, kernel_size=kernel_size) - self.spec_act = CodecActivation(activation=activation, channels=filters) - - self.res_block = ResidualBlockV2( - channels=filters, filters=filters, kernel_size=kernel_size, activation=activation - ) - - def remove_weight_norm(self): - self.input_conv.remove_weight_norm() - self.skip_conv.remove_weight_norm() - - @property - def input_types(self): - return { - "inputs": NeuralType(('B', 'C', 'T'), VoidType()), - "input_len": NeuralType(tuple('B'), LengthsType()), - "audio": NeuralType(('B', 'T_audio'), AudioSignal()), - "audio_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "out": NeuralType(('B', 'C', 'T'), EncodedRepresentation()), - "out_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, inputs, input_len, audio, audio_len): - out_len = input_len // self.down_sample_rate - out = self.down_sample_conv(inputs=inputs, input_len=out_len) - - spec, _ = self.spec_processor(audio=audio, audio_len=audio_len) - spec_res = self.spec_conv(inputs=spec, input_len=out_len) - out = out + spec_res - out = self.spec_act(out) - - out = self.res_block(inputs=out, input_len=out_len) - return out, out_len - - -class MultiResolutionSTFTEncoder(NeuralModule): - def __init__( - self, - resolutions, - filter_list, - down_sample_filter_list, - out_dim, - kernel_size=3, - down_sample_kernel_size=5, - activation="lrelu", - ): - super(MultiResolutionSTFTEncoder, self).__init__() - assert len(resolutions) == len(filter_list) - - n_fft, hop_length, win_length = resolutions[0] - input_filters = filter_list[0] - input_dim = n_fft // 2 + 1 - self.pre_spec_processor = STFTProcessor(n_fft=n_fft, win_length=win_length, hop_length=hop_length) - self.pre_conv = Conv1dNorm( - in_channels=input_dim, out_channels=input_filters, kernel_size=kernel_size, activation=activation - ) - self.pre_res_block = ResidualBlockV2( - channels=input_filters, filters=input_filters, kernel_size=kernel_size, activation=activation - ) - input_dim = input_filters - self.stft_res_blocks = nn.ModuleList([]) - for resolution, filters in zip(resolutions[1:], filter_list[1:]): - stft_res_block = STFTResidualBlock( - resolution=resolution, - input_dim=input_dim, - down_sample_rate=2, - filters=filters, - kernel_size=kernel_size, - down_sample_kernel_size=down_sample_kernel_size, - activation=activation, - ) - self.stft_res_blocks.append(stft_res_block) - input_dim = filters - - self.down_sample_res_blocks = nn.ModuleList([]) - for filters in down_sample_filter_list: - down_sample_res_block = DownSampleResidualBlock( - channels=input_dim, - filters=input_dim, - down_sample_rate=2, - kernel_size=kernel_size, - down_sample_kernel_size=down_sample_kernel_size, - activation=activation, - ) - self.down_sample_res_blocks.append(down_sample_res_block) - input_dim = filters - - self.post_conv = Conv1dNorm(in_channels=input_dim, out_channels=out_dim, kernel_size=kernel_size) - - def remove_weight_norm(self): - self.encoder.remove_weight_norm() - - @property - def input_types(self): - return { - "audio": NeuralType(('B', 'T_audio'), AudioSignal()), - "audio_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()), - "encoded_len": NeuralType(tuple('B'), LengthsType()), - } - - @typecheck() - def forward(self, audio, audio_len): - encoded, encoded_len = self.pre_spec_processor(audio=audio, audio_len=audio_len) - encoded = self.pre_conv(inputs=encoded, input_len=encoded_len) - encoded = self.pre_res_block(inputs=encoded, input_len=encoded_len) - - for stft_res_block in self.stft_res_blocks: - encoded, encoded_len = stft_res_block( - inputs=encoded, input_len=encoded_len, audio=audio, audio_len=audio_len - ) - - for down_sample_res_block in self.down_sample_res_blocks: - encoded, encoded_len = down_sample_res_block(inputs=encoded, input_len=encoded_len) - - encoded = self.post_conv(inputs=encoded, input_len=encoded_len) - - return encoded, encoded_len From 19650b5ecccbc0cce0a990d7b57ba3325964495e Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 7 Nov 2024 07:42:56 -0800 Subject: [PATCH 05/18] fix some attention errors Signed-off-by: Jason --- .../nlp/modules/common/megatron/attention.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index 46da533186c1..5932936a2782 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -389,7 +389,6 @@ def forward( # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: - logging.debug(f"Initializing KV Cache.") assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype, hidden_states.device @@ -417,7 +416,6 @@ def forward( # ===================== if self.attention_type == AttnType.self_attn: - logging.debug(f"Start Self-Attention!") # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) if self.is_adapter_available(): @@ -453,12 +451,6 @@ def forward( if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']: lora_mixed_kv_layer = lora_kv_adapter(encoder_output) mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer - mixed_kv_layer, _ = self.key_value(encoder_output) - if self.is_adapter_available(): - lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER) - if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']: - lora_mixed_kv_layer = lora_kv_adapter(encoder_output) - mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + ( @@ -517,9 +509,6 @@ def forward( # If we are in cross attention (inference_current_sequence_len == inference_max_sequence_len == inference_key_memory.size(0)) # We only need to cache this once if inference_max_sequence_len and self.inference_current_sequence_len < inference_max_sequence_len: - logging.debug( - f"inference_current_sequence_len={self.inference_current_sequence_len} | key_layer.shape={key_layer.shape} | inference_key_memory={self.inference_key_memory.size()} | inference_value_memory={self.inference_value_memory.size()}" - ) # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) @@ -901,7 +890,6 @@ def forward( key_layer.size(0), query_layer.size(3), ) - logging.debug(f"query_layer.shape={query_layer.size()}\tkey_layer.shape={key_layer.size()}") # ================================================== # Update attention mask for inference. [b, np, sq, sk] @@ -952,9 +940,6 @@ def forward( # context_layer [b, np, sq, hn] # ================================================== if not return_scores: - logging.debug( - f"not returning scores: attn_type={self.attention_type} | attn_fn={self.attn_fn} | return_scores={return_scores}" - ) context_layer = self.attn_fn( query_layer, key_layer, @@ -966,9 +951,6 @@ def forward( else: # SpeechLLM TTS modifications if return_scores or relative_position_bias is not None: - logging.debug( - f"torch a: return_scores: {return_scores}, relative_position_bias is not None: {relative_position_bias is not None}" - ) context_layer = self.torch_attention_with_prior( query_layer, key_layer, @@ -980,9 +962,6 @@ def forward( ) context_layer, attention_probs = context_layer else: - logging.debug( - f"attn_fn: {self.attn_fn}, return_scores: {return_scores}, relative_position_bias is not None: {relative_position_bias is not None}" - ) context_layer = self.attn_fn( query_layer, key_layer, @@ -1043,9 +1022,6 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a attention_scores += attention_bias attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - logging.debug(f"attention_type={self.attention_type}") - logging.debug(f"attention_scores.shape={attention_scores.shape}") - logging.debug(f"attention_mask.shape={attention_mask.shape}") # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. From ab8945eac9e8984475e7ea26829de15c1109b02a Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 7 Nov 2024 10:54:44 -0800 Subject: [PATCH 06/18] more ci fixes Signed-off-by: Jason --- .../megatron/base_prompt_learning_dataset.py | 8 ++--- .../nlp/modules/common/megatron/attention.py | 30 +++++++------------ .../megatron/token_level_encoder_decoder.py | 5 ---- .../modules/common/megatron/transformer.py | 2 ++ .../data/speechllm/t5_speechllm_dataset.py | 10 ++----- .../speechllm/t5_speechllm_tarred_dataset.py | 3 -- 6 files changed, 18 insertions(+), 40 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py index 826e139fe6ba..ea5f8c5a930b 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py @@ -71,8 +71,8 @@ def __init__( # Datasets are a list of file path strings to .json or .jsonl files elif isinstance(datasets[0], str): for path in datasets: - dataset = open(path, 'r', encoding='utf-8') - dataset_examples = self.load_data(dataset) + with open(path, 'r', encoding='utf-8') as dataset: + dataset_examples = self.load_data(dataset) self.examples.extend(dataset_examples) elif isinstance(datasets[0], omegaconf.ListConfig) or isinstance(datasets[0], list): # Dataset is a list of tuples with the first element being the probability of sampling from the dataset @@ -84,8 +84,8 @@ def __init__( for prob_and_path in datasets: prob = prob_and_path[0] path = prob_and_path[1] - dataset = open(path, 'r', encoding='utf-8') - dataset_examples = self.load_data(dataset) + with open(path, 'r', encoding='utf-8') as dataset: + dataset_examples = self.load_data(dataset) datasets_examples_list.append(dataset_examples) dataset_lengths.append(len(dataset_examples)) total_examples += len(dataset_examples) diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index 5932936a2782..d2cccfbb218c 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -950,26 +950,16 @@ def forward( ) else: # SpeechLLM TTS modifications - if return_scores or relative_position_bias is not None: - context_layer = self.torch_attention_with_prior( - query_layer, - key_layer, - value_layer, - attention_mask, - relative_position_bias, - inference_mode, - return_scores=return_scores, - ) - context_layer, attention_probs = context_layer - else: - context_layer = self.attn_fn( - query_layer, - key_layer, - value_layer, - attention_mask, - relative_position_bias, - inference_mode, - ) + context_layer = self.torch_attention_with_prior( + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, + return_scores=return_scores, + ) + context_layer, attention_probs = context_layer if headscale_tensor is not None: context_layer = context_layer * headscale_tensor diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index 4e73c2615cf3..fb4404495e89 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -819,7 +819,6 @@ def forward( enc_seq_length = enc_input.size(0) # Only need to run encoder embedding and position ids if enc_input or enc_output is not provided. elif enc_input_ids is not None: - assert False, "This should not be reached for speech models" enc_seq_length = enc_input_ids.size(1) if self.pre_process and self.add_encoder: # We don't need position ids for RPE, because the embedding layer does not have position embeddings. @@ -844,19 +843,16 @@ def forward( else: enc_input = None else: - assert False, "This should not be reached for speech models" # This should only happen with PP > 1 for enc-dec prompt learning models enc_seq_length = enc_attn_mask.size(1) if self.add_encoder and self.encoder_relative_position_embedding is not None: - assert False, "Not implemented for speech models yet." encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, ) if output_enc_hidden_only: - assert False, "Not implemented for speech models yet." # When pipeline parallel > 1 we need to make sure encoder exist (will be missing in decoder) # Speecht5 should not go here for inference if enc_output is None and self.enc_dec_model.encoder is not None: @@ -896,7 +892,6 @@ def forward( dec_input = None if self.add_decoder and self.decoder_relative_position_embedding is not None: - assert False, "This should not be reached." decoder_self_attention_relative_position_bias = self.decoder_relative_position_embedding( query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) ) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 3d5d36d05cc5..89613a392dcf 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -583,6 +583,8 @@ def forward( enc_dec_attn_mask, encoder_output=encoder_output, rotary_pos_emb=cross_attention_pos_emb, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, checkpoint_core_attention=checkpoint_core_attention, ) else: diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py index cd8abe647990..9997aa598782 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -866,7 +866,6 @@ def _get_speech_tokens(self, audio_filepath, dur=-1): audio, audio_length = self._load_audio(audio_filepath, dur) # Convert to codes - codec_codes, codec_codes_length = None, None # Codes codec_path = self.codec_folder / f"{rel_audio_path_as_text_id}.pt" if codec_path.exists(): @@ -880,8 +879,6 @@ def _get_speech_tokens(self, audio_filepath, dur=-1): codec_codes = self.get_codec(audio).long() torch.save(codec_codes, codec_path) - codec_codes_length = torch.tensor(codec_codes.shape[1]).long() - # Convert codes to codes corresponding to megatron embedding layer codec_codes[0] = (codec_codes[0] + self.speech_offset).long() @@ -1400,11 +1397,10 @@ def __getitem__(self, idx): # of the reference audio. # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be # predicted by the decoder. - start_token_index = 0 - end_token_index = -1 if ("Text to speech this" in question_in_manifest or "Phoneme TTS" in question_in_manifest) and ( doc["context_type"] == "SPEECH" ): + start_token_index = 0 total_context_len = context_tokens[0].size()[1] # Redo of this logic 11/29 @@ -1474,9 +1470,7 @@ def __getitem__(self, idx): else: raise ValueError("Invalid virtual prompt source specified") - input_ids = answer_text_ids - - input_ids, input_ids_len = self.list_to_tensor(input_ids, True) + input_ids, input_ids_len = self.list_to_tensor(answer_text_ids, True) is_speech = True if doc["answer_type"] != "TEXT" else False if is_speech: assert input_ids.dim() == 2 diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py index 940c5d2eaab6..2f653eb05d6a 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py @@ -777,9 +777,6 @@ def _get_tokens(self, doc, field, field_data): else: field_tokens = self._get_text_tokens(_text) # list of ids elif doc[f"{field}_type"] == 'SPEECH': - dur = -1 - if f"{field}_duration" in doc: - dur = doc[f"{field}_duration"] field_tokens = self._get_speech_tokens(field) # list of ids if not isinstance(field_tokens, list): field_tokens = [field_tokens] From c7de50e54ae42e392f2c0c0afce1e30dc939c7f2 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Date: Thu, 7 Nov 2024 14:51:49 -0800 Subject: [PATCH 07/18] remove commented codes. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> --- .../conf/megatron_t5_speechllm_inference_model.yaml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml index bd65eb956bbc..42d59b91abc5 100644 --- a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml @@ -72,13 +72,6 @@ model: lm_vocab_size: 30000 frozen_model: - # micro_batch_size: null - # global_batch_size: null - # megatron_amp_O2: true - # seq_length: 512 - # max_position_embeddings: 512 - # precision: bf16 - # Above is overridden in code tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 0 @@ -218,4 +211,4 @@ model: weight_decay: 0.01 betas: - 0.9 - - 0.98 \ No newline at end of file + - 0.98 From d1251472e8a9bfa772d38a9cabfeeb6a97750ea0 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:56:34 -0800 Subject: [PATCH 08/18] removed unused codes. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- ...megatron_t5_speechllm_inference_model.yaml | 1 - ...n_t5_speechllm_inference_multiencoder.yaml | 10 +- .../megatron_t5_speechllm_multiencoder.yaml | 10 +- .../conf/megatron_t5_speechlm_model.yaml | 8 - .../tts/speechllm/megatron_t5_speechllm.py | 2 - .../megatron_t5_speechllm_inference.py | 2 - .../megatron/megatron_transformer_encoder.py | 8 - .../data/speechllm/t5_speechllm_dataset.py | 283 ------------------ .../speechllm/megatron_t5_speechllm_model.py | 15 +- 9 files changed, 3 insertions(+), 336 deletions(-) diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml index 42d59b91abc5..1858edf9e667 100644 --- a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml @@ -91,7 +91,6 @@ model: model: null vocab_file: null merge_file: null - # num_sentinel_tokens: 100 optim: name: null data: diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml index 761268ca6fa1..8ad967d20538 100644 --- a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml @@ -72,13 +72,6 @@ model: enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] frozen_model: - # micro_batch_size: null - # global_batch_size: null - # megatron_amp_O2: true - # seq_length: 512 - # max_position_embeddings: 512 - # precision: bf16 - # Above is overridden in code tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 0 @@ -98,7 +91,6 @@ model: model: null vocab_file: null merge_file: null - # num_sentinel_tokens: 100 optim: name: null data: @@ -223,4 +215,4 @@ model: weight_decay: 0.01 betas: - 0.9 - - 0.98 \ No newline at end of file + - 0.98 diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml index c121c8f9a510..bf3f65ff9e00 100644 --- a/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml @@ -74,13 +74,6 @@ model: enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] frozen_model: - # micro_batch_size: null - # global_batch_size: null - # megatron_amp_O2: true - # seq_length: 512 - # max_position_embeddings: 512 - # precision: bf16 - # Above is overridden in code tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 0 @@ -100,7 +93,6 @@ model: model: null vocab_file: null merge_file: null - # num_sentinel_tokens: 100 optim: name: null data: @@ -228,4 +220,4 @@ model: constant_steps: 0 min_lr: 1e-5 monitor: val_loss - reduce_on_plateau: false \ No newline at end of file + reduce_on_plateau: false diff --git a/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml index 5210254a2e7d..d69bfb979182 100644 --- a/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml +++ b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml @@ -73,13 +73,6 @@ model: lm_vocab_size: 30000 frozen_model: - # micro_batch_size: null - # global_batch_size: null - # megatron_amp_O2: true - # seq_length: 512 - # max_position_embeddings: 512 - # precision: bf16 - # Above is overridden in code tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 0 @@ -99,7 +92,6 @@ model: model: null vocab_file: null merge_file: null - # num_sentinel_tokens: 100 optim: name: null data: diff --git a/examples/tts/speechllm/megatron_t5_speechllm.py b/examples/tts/speechllm/megatron_t5_speechllm.py index 1b438d8c1fc4..755b72f3b322 100644 --- a/examples/tts/speechllm/megatron_t5_speechllm.py +++ b/examples/tts/speechllm/megatron_t5_speechllm.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder diff --git a/examples/tts/speechllm/megatron_t5_speechllm_inference.py b/examples/tts/speechllm/megatron_t5_speechllm_inference.py index 27e5deb1f81a..65f0f79988af 100644 --- a/examples/tts/speechllm/megatron_t5_speechllm_inference.py +++ b/examples/tts/speechllm/megatron_t5_speechllm_inference.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index 1f1d962f2c4a..a9b80868558f 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -365,14 +365,6 @@ def set_input_tensor(self, input_tensor): for mi in range(len(self.model)): self.model[mi].set_input_tensor(input_tensor) - # def set_input_tensor(self, input_tensor): - # """ See megatron.model.transformer.set_input_tensor()""" - # import ipdb; ipdb.set_trace() - # assert isinstance(input_tensor, list) - # assert len(input_tensor) == len(self.model) - # for _input_tensor in input_tensor: - # self.model.set_input_tensor(_input_tensor) - def forward( self, enc_input, diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py index 9997aa598782..9252e92e028b 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -449,7 +449,6 @@ def load_data(self, dataset): logging.debug(f"skipped for {doc['answer']} as it is in skip_datasets") skipped += 1 - # logging.info(f"After Process len(self.examples) {len(self.examples)} TTS = {tts} ASR = {asr}") logging.info(f'Skipped {skipped} sentences, sequence length too short or too long even after truncation') return examples @@ -524,10 +523,8 @@ def __getitem__(self, idx): # Get virtual tokens # `virtual_tokens` is "". virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) - # print("virtual_tokens", virtual_tokens) # a trick to align with the data format in t5 pretraining - # new virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) if self.add_sentinel_to_input: question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) @@ -1356,283 +1353,3 @@ def pad_batch_and_build_loss_mask(self, batch): } return data_dict - - -class GPTSpeechLMDataset(T5SpeechLMDataset): - def __init__(self, *args, **kwargs): - kwargs["transformer_type"] = "GPT" - super().__init__(*args, **kwargs) - - def __getitem__(self, idx): - doc = self.examples[idx] - taskname = doc["taskname"] - prompt_template = self.task_templates[taskname]["prompt_template"] - prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] - total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] - virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] - truncation_field = self.task_templates[taskname]['truncate_field'] - answer_field = self.task_templates[taskname]["answer_field"] - - input_example = prompt_template - - self._input_sanity_checks( - total_virtual_tokens=total_virtual_tokens, - virtual_token_splits=virtual_token_splits, - prompt_template=prompt_template, - prompt_template_fields=prompt_template_fields, - truncation_field=truncation_field, - answer_field=answer_field, - doc=doc, - ) - question_in_manifest = doc['question'] - - # Format the input example according to the template - # Get context, question and answer codes in a dict. - input_dict = self._insert_data_in_template(prompt_template_fields, doc, answer_field) - context_tokens = input_dict['context'] - question_tokens = input_dict['question'] - - # Logic to prune context - # In case of TTS task, the entire reference speech is not required, so we randomly select a portion - # of the reference audio. - # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be - # predicted by the decoder. - if ("Text to speech this" in question_in_manifest or "Phoneme TTS" in question_in_manifest) and ( - doc["context_type"] == "SPEECH" - ): - start_token_index = 0 - total_context_len = context_tokens[0].size()[1] - - # Redo of this logic 11/29 - # logging.debug(f"total_context_len: {total_context_len}") - context_3s = 3 * self.codebook_fps - if total_context_len > context_3s: - start_token_index = random.randint(0, total_context_len - context_3s) - # logging.debug(f"start_token_index: {start_token_index}") - end_token_index = start_token_index + min(context_3s, total_context_len) - # logging.debug(f"end_token_index: {end_token_index}") - context_tokens[0] = context_tokens[0][:, start_token_index:end_token_index] - # logging.debug(f"context_tokens: {context_tokens[0].shape}") - - # Get virtual tokens - virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) - - # a trick to align with the data format in t5 pretraining - virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) - if self.add_sentinel_to_input: - question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) - - # Add BOS/EOS to the input of encoder if desired, adds EOS by default - if self.ul2_prompt_token is not None: - ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) - assert len(ul2_prompt_token_id) == 1 - context_tokens = ul2_prompt_token_id + context_tokens - if self.add_bos: - context_tokens = [self.tokenizer.bos_id] + context_tokens - if self.add_eos: - # question_tokens = question_tokens + [self.tokenizer.eos_id] - question_tokens = [self.tokenizer.pad_id] + question_tokens + [self.tokenizer.pad_id] - - virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) - context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) - question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) - - if doc["question_type"] == "TEXT" and doc["context_type"] != "TEXT": - question_tokens = pad_text_to_speech_dims( - question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 - ) - if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT": - context_tokens = pad_text_to_speech_dims( - context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 - ) - if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT": - context_tokens = pad_text_to_speech_dims( - context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 - ) - question_tokens = pad_text_to_speech_dims( - question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 - ) - - # get answer ids - if answer_field in doc.keys(): # training and validation - answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) - answer_text_ids = answer_ids - - if self.add_eos_to_decoder_output: - answer_text_ids += [self.tokenizer.eos_id] - else: - answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) - - if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: - taskname_id = self.tokenizer.text_to_ids(taskname) - elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT: - taskname_id = -1 - else: - raise ValueError("Invalid virtual prompt source specified") - - input_ids, input_ids_len = self.list_to_tensor(answer_text_ids, True) - is_speech = True if doc["answer_type"] != "TEXT" else False - if is_speech: - assert input_ids.dim() == 2 - if self.seq_pattern == "delay_parallel": - - num_codebooks = input_ids.shape[0] - dinput_ids_padded = torch.cat( - [ - torch.zeros_like(input_ids[:, 0:num_codebooks]), - input_ids, - torch.zeros_like(input_ids[:, 0:num_codebooks]), - ], - dim=1, - ) - dec_input_new = [] - for _c in range(self.num_speech_codebooks): - st = num_codebooks - _c - et_decoder_input = dinput_ids_padded.shape[1] - _c - 1 - dec_input_new.append(dinput_ids_padded[_c, st:et_decoder_input]) - input_ids = torch.stack(dec_input_new, dim=0) - input_ids_len = torch.tensor(input_ids.shape[1]).long() - - # logging.debug( - # f"Return from getitem: \ncontext_tokens:{context_tokens.shape}\ncontext_tokens_len:{context_tokens_len}\n" - # f"question_tokens:{question_tokens.shape}\nquestion_tokens_len:{question_tokens_len}\ninput_ids:{input_ids.shape}\ninput_ids_len{input_ids_len}" - # ) - return ( - context_tokens, - context_tokens_len, - question_tokens, - question_tokens_len, - input_ids, - input_ids_len, - ) - - def collate_fn(self, batch): - ( - _, - context_tokens_len, - _, - question_tokens_len, - _, - input_ids_len, - ) = zip(*batch) - - decoder_input_len = ( - torch.stack(context_tokens_len) + torch.stack(question_tokens_len) + torch.stack(input_ids_len) - ) - max_decoder_input_len = max(decoder_input_len).item() if decoder_input_len is not None else 0 - max_decoder_input_len_1 = max_decoder_input_len - 1 - - decoder_mask = get_mask_from_lengths(decoder_input_len - 1) - speech_mask = get_mask_from_lengths(decoder_input_len - 1) - context_question_mask = torch.ones(speech_mask.shape) - ( - decoder_input_list, - decoder_labels_list, - ) = ( - [], - [], - ) - cross_attention_prior = torch.zeros(len(batch), max_decoder_input_len_1, max_decoder_input_len_1) - start_of_question_offset = 5 # For "Text to Speech this" - Only used in attention prior computation - end_of_question_offset = 3 # "" - Only used in attention prior computation - for i, sample_tuple in enumerate(batch): - ( - context_tokens, - context_tokens_len, - question_tokens, - question_tokens_len, - input_ids, - input_ids_len, - ) = sample_tuple - - context_tokens_input = context_tokens.clone().contiguous().detach() - for l in range(1, context_tokens_input.shape[0]): - context_tokens_input[l] += self.speech_offset + 1024 * l # TODO: fix hardcode - input_ids_shifted = input_ids.clone().contiguous().detach() - for l in range(1, input_ids_shifted.shape[0]): - input_ids_shifted[l] += self.speech_offset + 1024 * l # TODO: fix hardcode - - complete_input = torch.cat([context_tokens_input, question_tokens, input_ids_shifted], dim=1) - complete_input_padded = general_padding( - complete_input, - decoder_input_len[i].item(), - max_decoder_input_len, - pad_value=self.tokenizer.pad_id, - ) - complete_output = torch.cat([context_tokens, question_tokens, input_ids], dim=1) - complete_output_padded = general_padding( - complete_output, - decoder_input_len[i].item(), - max_decoder_input_len, - pad_value=self.tokenizer.pad_id, - ) - decoder_labels = complete_output_padded[:, 1:].contiguous() - decoder_input = complete_input_padded[:, :-1].contiguous() - - decoder_input_list.append(decoder_input) - decoder_labels_list.append(decoder_labels) - - decoder_mask[i, : context_tokens_len + question_tokens_len - 1] = 0 # Mask out context and question - # TODO: jasoli, the speech_mask looks wrong. I shouldn't be masking out the context - speech_mask[i, context_tokens_len : context_tokens_len + question_tokens_len] = ( - 0 # Mask out context and question - ) - context_question_mask[i, : context_tokens_len + question_tokens_len] = 0 - - if self.spec_aug: - # Derive time width, sometimes based percentage of input length. - time_max_width = max(1, int(input_ids_len.item() * self.time_width)) - time_start_upper_bound = max(1, input_ids_len.item() - time_max_width) - time_start = context_tokens_len.item() + question_tokens_len.item() - time_start_upper_bound += time_start - - # Set time masking - for _ in range(self.time_masks): - start = self._rng.randint(time_start, time_start_upper_bound) - width = self._rng.randint(0, time_max_width) - speech_mask[i, start : start + width] = 0 - - if self.use_attention_prior: - cross_attention_question_prior = torch.from_numpy( - beta_binomial_prior_distribution( - question_tokens_len.item() - start_of_question_offset - end_of_question_offset, - input_ids_len.item() - 1, - scaling_factor=self.attention_prior_scaling_factor, - ) - ) - cross_attention_prior[ - i, - context_tokens_len - + question_tokens_len : context_tokens_len - + question_tokens_len - + input_ids_len - - 1, - context_tokens_len - + start_of_question_offset : context_tokens_len - + question_tokens_len - - end_of_question_offset, - ] = cross_attention_question_prior - # Using causal attention mask for whole input - batch_size = len(decoder_input_list) - attention_mask = torch.tril(torch.ones((batch_size, max_decoder_input_len_1, max_decoder_input_len_1))).view( - batch_size, 1, max_decoder_input_len_1, max_decoder_input_len_1 - ) - - # Convert attention mask from float to bool - attention_mask = attention_mask < 0.5 # Currently not used, not sure if correct either - - decoder_input = torch.stack(decoder_input_list) - decoder_input_p = decoder_input[:, 0, :] if decoder_input.dim() == 3 else decoder_input - position_ids = build_position_ids(decoder_input_p) - data_dict = { - "tokens": decoder_input, - "position_ids": position_ids, - "attention_mask": attention_mask, - "labels": torch.stack(decoder_labels_list), - "speech_mask": speech_mask, # For TTS, can just be loss_mask since answer will always be speech - "loss_mask": decoder_mask, # Mask out context and question and padding - "attention_prior": cross_attention_prior, - "context_question_mask": context_question_mask, - } - - return data_dict diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py index 5bd6b8993525..02097f507e11 100644 --- a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -168,7 +168,6 @@ class MegatronT5SpeechLMModel(MegatronBaseSpeechLM): """ def __init__(self, cfg: DictConfig, trainer: Trainer): - # torch.autograd.set_detect_anomaly(True) super().__init__(cfg, trainer) self.model_type = ModelType.encoder_and_decoder speech_codebook_size = cfg.data.get('speech_codebook_size', 1024) @@ -254,8 +253,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): gather_output=not self.frozen_model.enc_dec_model.parallel_output, init_method=init_method_normal(init_method_std), config=self.model_parallel_config, - # use_cpu_initialization=False, - # params_dtype=self.frozen_model.enc_dec_model.dtype, ) list_of_speech_heads.append(_speech_head) @@ -2133,9 +2130,6 @@ def predict_step( dec_input[:, :, t + 1] = dec_input_next_timestep * 1 else: dec_input[:, 0, t + 1] = output_tokens_curr_timestep.squeeze(1) - # # TF - # if t+1 < 10: - # dec_input[:, :, t + 1] = dec_input_raw[:, :, t+1] # end of for loop output_tokens_combined = torch.stack(output_token_list) # (T, B, 8) if speech else (T, B) @@ -2265,11 +2259,6 @@ def predict_step( phoneme_ver=1, phoneme_seq=None, ) - # ctc_loss = self.frozen_model.enc_dec_model.forward_sum_loss( - # attn_logprob=attention_probs_example[None,None,:,:], - # in_lens=torch.tensor([attention_probs_example.shape[1]]).to(device), - # out_lens=torch.tensor([attention_probs_example.shape[0]]).to(device) - # ) if global_step is not None: # During validation, step is simply global_step + i @@ -2278,7 +2267,6 @@ def predict_step( # During inference, step is the index of the sample step = batch_idx * test_dataloader_batch_size + i - # print("Ctc Loss: ", step, ctc_loss.item()) self.logger.experiment.add_image( "Inf Attention Map", alignment_image, @@ -2438,7 +2426,7 @@ def predict_step( else: context_wav = None - # raise NotImplementedError("During prediction, there was no context found.") + if context_wav is not None: self.logger.experiment.add_audio("Context Wav", context_wav, step, self.sample_rate) context_wav_fp = os.path.join(_exp_dir_path, f'context_wav_{wav_num}.wav') @@ -2569,7 +2557,6 @@ def predict_step( for i in range(0, len(greedy_transcripts) - 1, 2): assert all_audio_to_pred[i]["step"] == all_audio_to_pred[i + 1]["step"] - # step = batch_idx * self.test_dataloader().batch_size + all_audio_to_pred[i]["step"] step = batch_idx * test_dataloader_batch_size + all_audio_to_pred[i]["step"] question_text = question_texts[i // 2] From 3645a6aa5d8bc3086182f2826c7aa0745989a21d Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:10:44 -0800 Subject: [PATCH 09/18] remove unused imports Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- nemo/collections/nlp/modules/common/megatron/attention.py | 1 - .../nlp/modules/common/megatron/megatron_decoders.py | 1 - .../nlp/modules/common/megatron/megatron_encoder_decoder.py | 1 - nemo/collections/nlp/modules/common/megatron/utils.py | 1 - .../tts/models/speechllm/megatron_t5_speechllm_model.py | 3 +-- nemo/collections/tts/parts/utils/helpers.py | 1 - 6 files changed, 1 insertion(+), 7 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index d2cccfbb218c..d5784081f6f0 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -39,7 +39,6 @@ attention_mask_func, ) from nemo.core import adapter_mixins -from nemo.utils import logging try: from apex.transformer.enums import AttnMaskType, AttnType diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py index e70b26e5cb08..d2945a061584 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py @@ -13,7 +13,6 @@ # limitations under the License. """Transformer based language model.""" -from ast import Mod from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_transformer_decoder import MegatronTransformerDecoderModule from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index 67753f08775e..744a6e18c8b1 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -13,7 +13,6 @@ # limitations under the License. """Transformer based language model.""" -from ast import Mod import torch diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 1540170a8dc1..b0a6f755a9cc 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -18,7 +18,6 @@ from typing import Dict, Iterator, List, Optional, Tuple, Union import torch -import torch.nn as nn from torch import Tensor from nemo.utils import logging, logging_mode diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py index 02097f507e11..786c8c0e6e75 100644 --- a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -76,8 +76,7 @@ import time import librosa -import torchaudio -from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE +from torchaudio.pipelines import SQUIM_SUBJECTIVE from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector __all__ = ['MegatronT5SpeechLMModel'] diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index 1379fa169789..85d12a4261da 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -50,7 +50,6 @@ import numpy as np import seaborn as sns import torch -from einops import rearrange from numba import jit, prange from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens From d019909b6267d75eb8dfbf0dff0e18e81f5fa40f Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:18:54 -0800 Subject: [PATCH 10/18] add copyright headers. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- examples/tts/speechllm/megatron_t5_speechllm.py | 2 +- .../speechllm/megatron_t5_speechllm_inference.py | 2 +- nemo/collections/tts/data/speechllm/__init__.py | 13 +++++++++++++ .../tts/data/speechllm/t5_speechllm_dataset.py | 2 +- .../data/speechllm/t5_speechllm_tarred_dataset.py | 2 +- nemo/collections/tts/models/speechllm/__init__.py | 13 +++++++++++++ .../megatron_base_speechllm_prompt_model.py | 2 +- .../models/speechllm/megatron_t5_speechllm_model.py | 3 ++- 8 files changed, 33 insertions(+), 6 deletions(-) diff --git a/examples/tts/speechllm/megatron_t5_speechllm.py b/examples/tts/speechllm/megatron_t5_speechllm.py index 755b72f3b322..c4ec1a77f944 100644 --- a/examples/tts/speechllm/megatron_t5_speechllm.py +++ b/examples/tts/speechllm/megatron_t5_speechllm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/tts/speechllm/megatron_t5_speechllm_inference.py b/examples/tts/speechllm/megatron_t5_speechllm_inference.py index 65f0f79988af..48d46952a993 100644 --- a/examples/tts/speechllm/megatron_t5_speechllm_inference.py +++ b/examples/tts/speechllm/megatron_t5_speechllm_inference.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/tts/data/speechllm/__init__.py b/nemo/collections/tts/data/speechllm/__init__.py index e69de29bb2d1..9df65818d226 100644 --- a/nemo/collections/tts/data/speechllm/__init__.py +++ b/nemo/collections/tts/data/speechllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py index 9252e92e028b..2b459e4345b9 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py index 2f653eb05d6a..e33d7c7d6507 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/tts/models/speechllm/__init__.py b/nemo/collections/tts/models/speechllm/__init__.py index e69de29bb2d1..9df65818d226 100644 --- a/nemo/collections/tts/models/speechllm/__init__.py +++ b/nemo/collections/tts/models/speechllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py index aedc1b07d92f..f5c3ba720224 100644 --- a/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py +++ b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py index 786c8c0e6e75..b452427d610e 100644 --- a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import itertools import json import os From 982908584980e49cbe8bd51043adda19db641d26 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:32:15 -0800 Subject: [PATCH 11/18] added TODO and detail error log info. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- nemo/collections/tts/g2p/models/zh_cn_pinyin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py index 82810f8736be..9368ae6749d5 100644 --- a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py +++ b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py @@ -202,9 +202,12 @@ def __call__(self, text: str) -> List[str]: tone_hyp = pinyin[-1] if tone_hyp in self.tone_dict: syllable = pinyin[:-1] + # TODO: skipping the syllable that does not exist in the dictionary will lead to deletion errors in the + # synthesized speech. Even though this case is uncommon, it should be fixed in future. if syllable not in self.phoneme_dict: err = True - logging.error(f"Syllable <{syllable}> does not exist in the dictionary.") + logging.error(f"Syllable <{syllable}> does not exist in the dictionary. You should expect symbol " + f"deletion risks!!") continue phoneme_seq += self.phoneme_dict[syllable] phoneme_seq.append(self.tone_dict[tone_hyp]) From 8062cf3b372028bb6a6b1decca4155315c817b02 Mon Sep 17 00:00:00 2001 From: XuesongYang Date: Fri, 8 Nov 2024 23:33:02 +0000 Subject: [PATCH 12/18] Apply isort and black reformatting Signed-off-by: XuesongYang --- nemo/collections/tts/g2p/models/zh_cn_pinyin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py index 9368ae6749d5..2fe0ac3f6077 100644 --- a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py +++ b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py @@ -206,8 +206,10 @@ def __call__(self, text: str) -> List[str]: # synthesized speech. Even though this case is uncommon, it should be fixed in future. if syllable not in self.phoneme_dict: err = True - logging.error(f"Syllable <{syllable}> does not exist in the dictionary. You should expect symbol " - f"deletion risks!!") + logging.error( + f"Syllable <{syllable}> does not exist in the dictionary. You should expect symbol " + f"deletion risks!!" + ) continue phoneme_seq += self.phoneme_dict[syllable] phoneme_seq.append(self.tone_dict[tone_hyp]) From 6191fc8bf8c416dcc733b88f75872195512613bd Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:40:55 -0800 Subject: [PATCH 13/18] fixed missing a corner case. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- nemo/collections/common/parts/preprocessing/collections.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index d4e96875be6b..915f406a3e88 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -421,15 +421,17 @@ def __init__( def _get_len(self, field_type, data, duration_data): if field_type == "SPEECH": - return duration_data * 76 + return duration_data * 76 # TODO: add explanation for the hardcoded value. elif field_type == "TEXT": if self.use_phoneme_tokenizer: # Approx len is number of characters return len(data) else: - return len(data.split(' ')) + 3 + return len(data.split(' ')) + 3 # # TODO: add explanation for the hardcoded value. elif field_type == "TOKENS": return len(data) + 3 + else: + raise ValueError(f"Unknown field type {field_type}.") class ASRAudioText(AudioText): From f194f72c136e3e3b3d299536782ca5e8f7118e36 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:27:47 -0800 Subject: [PATCH 14/18] removed unused codes. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- .../megatron/token_level_encoder_decoder.py | 2 - .../speechllm/t5_speechllm_tarred_dataset.py | 247 +----------------- 2 files changed, 4 insertions(+), 245 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index fb4404495e89..e0f0ca69c628 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -1023,8 +1023,6 @@ def forward( # For flat seq_pattern we need all the logits token_logits = token_logits[:, :, :first_layer_vocabsize] speech_layers = self.num_speech_codebooks - 1 - last_layer_output = dec_output - last_layer_logits = token_logits # speech_logits_list will be used in loss calculation (parallel output) speech_logits_list = [] diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py index e33d7c7d6507..9b0a4f8d06c2 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py @@ -38,7 +38,7 @@ from nemo.core.classes import IterableDataset from nemo.utils import logging -__all__ = ['T5SpeechLMTarredDataset', 'GPTSpeechLMTarredDataset'] +__all__ = ['T5SpeechLMTarredDataset'] @dataclass @@ -77,10 +77,6 @@ def pad_text_to_speech_dims(text_tensor, pad_id): return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0) -# tokenizer_config = _get_default_text_tokenizer_conf() -# phoneme_tokenizer = instantiate(tokenizer_config).text_tokenizer - - class InstructionTuningManifestProcessor: """ Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds). @@ -474,9 +470,7 @@ def _build_sample(self, tup): taskname = "squad" prompt_template = self.task_templates[taskname]["prompt_template"] prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] - total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] - truncation_field = self.task_templates[taskname]['truncate_field'] answer_field = self.task_templates[taskname]["answer_field"] input_example = prompt_template @@ -563,9 +557,7 @@ def _build_sample(self, tup): answer_text_ids = [self.tokenizer.pad_id] else: answer_text_ids = [self.tokenizer.bos_id] - # a trick to align with the data format in t5 pretraining - # if self.add_sentinel_to_input: - # answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + answer_text_ids += answer_ids if self.add_eos_to_decoder_output: @@ -662,6 +654,8 @@ def _build_sample(self, tup): is_speech, cross_attention_prior, ) + else: + return None def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens): total_len = self._get_len(context_tokens, question_tokens, virtual_tokens) @@ -990,236 +984,3 @@ def pad_batch_and_build_loss_mask(self, batch): } return data_dict - - -class GPTSpeechLMTarredDataset(T5SpeechLMTarredDataset): - """No support for cross attention here yet""" - - def _build_sample(self, tup): - audio_filename, self.encodec, self.ref_encodec, offset_id = tup - - file_id, _ = os.path.splitext(os.path.basename(audio_filename)) - manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id] - manifest_entry = self.manifest_processor.collection[manifest_idx] - doc = {} - doc['context'] = manifest_entry.context - doc['context_type'] = manifest_entry.context_type - doc['context_duration'] = manifest_entry.context_duration - doc['answer'] = manifest_entry.answer - doc['answer_type'] = manifest_entry.answer_type - doc['answer_duration'] = manifest_entry.answer_duration - doc['question'] = manifest_entry.question - doc['question_type'] = manifest_entry.question_type - - taskname = "squad" - prompt_template = self.task_templates[taskname]["prompt_template"] - prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] - virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] - answer_field = self.task_templates[taskname]["answer_field"] - - input_example = prompt_template - - # Format the input example according to the template - # Get context, question and answer codes in a dict. - input_dict = self._insert_data_in_template(input_example, prompt_template_fields, doc, answer_field) - context_tokens = input_dict['context'] - question_tokens = input_dict['question'] - - # Logic to prune context - # In case of TTS task, the entire reference speech is not required, so we randomly select a portion - # of the reference audio. - # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be - # predicted by the decoder. - start_token_index = 0 - end_token_index = -1 - - total_context_len = context_tokens[0].size()[1] - context_3s = 3 * 75 - if total_context_len > context_3s: - start_token_index = random.randint(0, total_context_len - context_3s) - # logging.debug(f"start_token_index: {start_token_index}") - end_token_index = start_token_index + min(context_3s, total_context_len) - # logging.debug(f"end_token_index: {end_token_index}") - context_tokens[0] = context_tokens[0][:, start_token_index:end_token_index] - - # Get virtual tokens - virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) - - # a trick to align with the data format in t5 pretraining - # new - virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) - if self.add_sentinel_to_input: - question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) - - # Add BOS/EOS to the input of encoder if desired, adds EOS by default - if self.ul2_prompt_token is not None: - ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) - assert len(ul2_prompt_token_id) == 1 - context_tokens = ul2_prompt_token_id + context_tokens - if self.add_bos: - context_tokens = [self.tokenizer.bos_id] + context_tokens - if self.add_eos: - question_tokens = [self.tokenizer.pad_id] + question_tokens + [self.tokenizer.pad_id] - - virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) - context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) - question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) - - if doc["question_type"] != "SPEECH" and doc["context_type"] == "SPEECH": - question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id) - if doc["context_type"] != "SPEECH" and doc["question_type"] == "SPEECH": - context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id) - context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) - - # get answer ids - if answer_field in doc.keys(): # training and validation - answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) - answer_text_ids = answer_ids - - if self.add_eos_to_decoder_output: - answer_text_ids += [self.tokenizer.eos_id] - else: - answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) - - # Skip example if the final length doesn't fit length requirements even after truncation - input_ids = answer_text_ids - input_ids, input_ids_len = self.list_to_tensor(input_ids, True) - input_len = self._get_element_len(context_and_question_tokens) + self._get_element_len(answer_text_ids) - 1 - if input_len > self.max_seq_length: - # logging.debug(f"Overflow. input_len:{input_len}. self.max_seq_length:{self.max_seq_length}. overflow_len:{self.max_seq_length - input_len}.") - overflow_len = self.max_seq_length - input_len - # truncate context if context after truncation is at least 1s - # else truncate answer as final option - if context_tokens_len - overflow_len > 75: - # logging.debug(f"Cutting context. context_tokens:{context_tokens.shape}. context_tokens_len:{context_tokens_len}.") - context_tokens = context_tokens[:, : context_tokens_len - overflow_len] - context_tokens_len = context_tokens_len - overflow_len - # logging.debug(f"Cut context. context_tokens:{context_tokens.shape}. context_tokens_len:{context_tokens_len}.") - else: - # logging.debug(f"Cutting answer. input_ids:{input_ids.shape}. input_ids_len:{input_ids_len}.") - input_ids = input_ids[:, : input_ids_len - overflow_len] - input_ids_len = input_ids_len - overflow_len - # logging.debug(f"Cut answer. input_ids:{input_ids.shape}. input_ids_len:{input_ids_len}.") - - is_speech = True if doc["answer_type"] == "SPEECH" else False - if is_speech: - assert input_ids.dim() == 2 - if self.seq_pattern == "delay_parallel": - num_codebooks = input_ids.shape[0] - dec_input_padded = torch.cat( - [ - torch.zeros_like(input_ids[:, 0:num_codebooks]), - input_ids, - torch.zeros_like(input_ids[:, 0:num_codebooks]), - ], - dim=1, - ) - dec_input_new = [] - for _c in range(self.num_speech_codebooks): - st = num_codebooks - _c - et_decoder_input = dec_input_padded.shape[1] - _c - dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) - input_ids = torch.stack(dec_input_new, dim=0) - input_ids_len = torch.tensor(input_ids.shape[1]).long() - - return ( - context_tokens, - context_tokens_len, - question_tokens, - question_tokens_len, - input_ids, - input_ids_len, - ) - - def collate_fn(self, batch): - ( - _, - context_tokens_len, - _, - question_tokens_len, - _, - input_ids_len, - ) = zip(*batch) - - decoder_input_len = ( - torch.stack(context_tokens_len) + torch.stack(question_tokens_len) + torch.stack(input_ids_len) - ) - max_decoder_input_len = max(decoder_input_len).item() if decoder_input_len is not None else 0 - - decoder_mask = get_mask_from_lengths(decoder_input_len - 1) - speech_mask = get_mask_from_lengths(decoder_input_len - 1) - context_question_mask = torch.ones(speech_mask.shape) - ( - decoder_input_list, - decoder_labels_list, - ) = ( - [], - [], - ) - for i, sample_tuple in enumerate(batch): - ( - context_tokens, - context_tokens_len, - question_tokens, - question_tokens_len, - input_ids, - input_ids_len, - ) = sample_tuple - - context_tokens_input = context_tokens.clone().contiguous().detach() - for l in range(1, context_tokens_input.shape[0]): - context_tokens_input[l] += self.speech_offset + 1024 * l # TODO: fix hardcode - input_ids_shifted = input_ids.clone().contiguous().detach() - for l in range(1, input_ids_shifted.shape[0]): - input_ids_shifted[l] += self.speech_offset + 1024 * l # TODO: fix hardcode - - complete_input = torch.cat([context_tokens_input, question_tokens, input_ids_shifted], dim=1) - complete_input_padded = general_padding( - complete_input, - decoder_input_len[i].item(), - max_decoder_input_len, - pad_value=self.tokenizer.pad_id, - ) - complete_output = torch.cat([context_tokens, question_tokens, input_ids], dim=1) - complete_output_padded = general_padding( - complete_output, - decoder_input_len[i].item(), - max_decoder_input_len, - pad_value=self.tokenizer.pad_id, - ) - decoder_labels = complete_output_padded[:, 1:].contiguous() - decoder_input = complete_input_padded[:, :-1].contiguous() - - decoder_input_list.append(decoder_input) - decoder_labels_list.append(decoder_labels) - - decoder_mask[i, : context_tokens_len + question_tokens_len - 1] = 0 # Mask out context and question - speech_mask[i, context_tokens_len : context_tokens_len + question_tokens_len] = ( - 0 # Mask out context and question - ) - context_question_mask[i, : context_tokens_len + question_tokens_len] = 0 - - # Using causal attention mask for whole input - batch_size = len(decoder_input_list) - attention_mask = torch.tril( - torch.ones((batch_size, max_decoder_input_len - 1, max_decoder_input_len - 1)) - ).view(batch_size, 1, max_decoder_input_len - 1, max_decoder_input_len - 1) - - # Convert attention mask from float to bool - attention_mask = attention_mask < 0.5 - - decoder_input = torch.stack(decoder_input_list) - decoder_input_p = decoder_input[:, 0, :] if decoder_input.dim() == 3 else decoder_input - position_ids = build_position_ids(decoder_input_p) - data_dict = { - "tokens": decoder_input, - "position_ids": position_ids, - "attention_mask": attention_mask, - "labels": torch.stack(decoder_labels_list), - "speech_mask": speech_mask, # For TTS, can just be loss_mask since answer will always be speech - "loss_mask": decoder_mask, # Mask out context and question and padding - "attention_prior": None, - "context_question_mask": context_question_mask, - } - - return data_dict From 14a76ca9901efe6ce6af54de0aed6d1ae1a74b54 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Sat, 9 Nov 2024 23:13:45 -0800 Subject: [PATCH 15/18] added classes to __all__ Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- .../nlp/modules/common/megatron/token_level_encoder_decoder.py | 2 +- nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index e0f0ca69c628..185c24946cde 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -68,7 +68,7 @@ HAVE_MEGATRON_CORE = False -__all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule"] +__all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule", "MegatronTokenLevelEncoderDecoderSpeechLLMModule"] class MegatronTokenLevelHead(MegatronModule): diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py index 2b459e4345b9..32f0a14f5e65 100644 --- a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -42,7 +42,7 @@ ) from nemo.utils import logging -__all__ = ['T5SpeechLMDataset'] +__all__ = ['T5SpeechLMDataset', "Lang"] def get_full_list_puncts(): From e62299065e31dea6558af23271e09112e23368c9 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Sat, 9 Nov 2024 23:19:31 -0800 Subject: [PATCH 16/18] removed unused lines Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- .../tts/models/speechllm/megatron_t5_speechllm_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py index b452427d610e..b3b5d857155d 100644 --- a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -363,7 +363,6 @@ def forward( Special forward method for p-tuning/prompt-tuning pretrained T5 style models. """ - multi_encoder = False if isinstance(context_and_question_tokens, list): multi_encoder = True assert isinstance(enc_mask, list) @@ -414,9 +413,7 @@ def forward( _encoder_input = encoder_input_list if not multi_encoder: - context_and_question_tokens = context_and_question_tokens[0] enc_mask = enc_mask[0] - position_ids = position_ids[0] cross_attention_prior = cross_attention_prior[0] _encoder_input = encoder_input_list[0] if encoder_input_list is not None else None From ca138504847eee17030c8848ea8a60506ddd24bb Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Sun, 10 Nov 2024 00:24:19 -0800 Subject: [PATCH 17/18] fixed to return either self-attention scores or cross-attention scores in ParallelTransformerLayer_ class. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- .../common/megatron/token_level_encoder_decoder.py | 2 +- .../nlp/modules/common/megatron/transformer.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index 185c24946cde..10adde4c66bb 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -854,7 +854,7 @@ def forward( if output_enc_hidden_only: # When pipeline parallel > 1 we need to make sure encoder exist (will be missing in decoder) - # Speecht5 should not go here for inference + # SpeechT5 should not go here for inference if enc_output is None and self.enc_dec_model.encoder is not None: enc_output = self.enc_dec_model.encode( enc_input=enc_input, diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 89613a392dcf..c5108d8e3801 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -494,6 +494,12 @@ def forward( self_attention_pos_emb = None cross_attention_pos_emb = None + if return_crossattention_scores and return_selfattention_scores: + raise NotImplementedError( + "We can only return 1 of cross attention scores or self attention scores. Not both yet." + ) + attention_probs = None + if self.layer_type != LayerType.retrieval_decoder_after_self_attn: # hidden_states: [b, s, h] @@ -535,7 +541,7 @@ def forward( attention_bias = None # jit scripting for a nn.module (with dropout) is not - # trigerring the fusion kernel. For now, we use two + # triggering the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. @@ -562,6 +568,9 @@ def forward( elif self.transformer_block_type in ['pre_ln', 'normformer']: # Layer norm post the self attention. normalization_output = self.post_attention_layernorm(layernorm_input) + else: + normalization_output = None + logging.warning(f"This is a rare case since `normalization_output=None`") else: layernorm_input, normalization_output = hidden_states @@ -646,7 +655,7 @@ def forward( if get_key_value: output = [output, presents] - if return_crossattention_scores or return_selfattention_scores: + if attention_probs is not None: output = [output, attention_probs] return output From 22709038820bb63d036710c55beac2309ba2713e Mon Sep 17 00:00:00 2001 From: XuesongYang Date: Sun, 10 Nov 2024 08:25:46 +0000 Subject: [PATCH 18/18] Apply isort and black reformatting Signed-off-by: XuesongYang --- .../modules/common/megatron/token_level_encoder_decoder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index 10adde4c66bb..e68113949aa7 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -68,7 +68,11 @@ HAVE_MEGATRON_CORE = False -__all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule", "MegatronTokenLevelEncoderDecoderSpeechLLMModule"] +__all__ = [ + "MegatronTokenLevelHead", + "MegatronTokenLevelEncoderDecoderModule", + "MegatronTokenLevelEncoderDecoderSpeechLLMModule", +] class MegatronTokenLevelHead(MegatronModule):