diff --git a/.github/workflows/mcore-tag-bump-bot.yml b/.github/workflows/mcore-tag-bump-bot.yml index 13f4059a3a6b..1b0712924101 100644 --- a/.github/workflows/mcore-tag-bump-bot.yml +++ b/.github/workflows/mcore-tag-bump-bot.yml @@ -6,54 +6,15 @@ on: - cron: 0 0 * * * jobs: - main: - runs-on: ubuntu-latest - environment: main - steps: - - name: Checkout NVIDIA/Megatron-LM - uses: actions/checkout@v4 - with: - repository: NVIDIA/Megatron-LM - ref: main - path: ${{ github.run_id }} - - - name: Get latest mcore commit - id: ref - run: | - cd ${{ github.run_id }} - sha=$(git rev-parse origin/main) - echo "sha=${sha}" >> "$GITHUB_OUTPUT" - echo "short_sha=${sha:0:7}" >> "$GITHUB_OUTPUT" - echo "date=$(date +%F)" >> "$GITHUB_OUTPUT" - - - name: Checkout ${{ github.repository }} - uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - token: ${{ secrets.PAT }} - - - name: Bump MCORE_TAG - run: | - cd ${{ github.run_id }} - sed -i 's/^ARG MCORE_TAG=.*$/ARG MCORE_TAG=${{ steps.ref.outputs.sha }}/' Dockerfile.ci - - - name: Create Bump PR - uses: peter-evans/create-pull-request@v6 - id: create-pull-request - with: - path: ${{ github.run_id }} - branch: bump-ci-container-${{ steps.ref.outputs.date }} - base: main - title: 'Bump `Dockerfile.ci` (${{ steps.ref.outputs.date }})' - token: ${{ secrets.PAT }} - body: | - 🚀 PR to Bump `Dockerfile.ci`. - - 📝 Please remember the following to-do's before merge: - - [ ] Verify the presubmit CI - - 🙏 Please merge this PR only if the CI workflow completed successfully. - commit-message: "[🤠]: Howdy folks, let's bump `Dockerfile.ci` to ${{ steps.ref.outputs.short_sha }} !" - signoff: true - reviewers: 'pablo-garay' - labels: 'Run CICD' + mcore: + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_bump_dockerfile.yml@v0.11.0 + with: + source-repository: NVIDIA/Megatron-LM + source-ref: main + build-arg: MCORE_TAG + dockerfile: Dockerfile.ci + base-branch: main + cicd-label: Run CICD + pr-reviewers: 'pablo-garay' + secrets: + PAT: ${{ secrets.PAT }} \ No newline at end of file diff --git a/.github/workflows/secrets-detector.yml b/.github/workflows/secrets-detector.yml index cf8ccc189ab6..d81b5638e31f 100644 --- a/.github/workflows/secrets-detector.yml +++ b/.github/workflows/secrets-detector.yml @@ -25,13 +25,24 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 with: - path: ${{ github.run_id }} + # setup repository and ref for PRs, see + # https://github.com/EndBug/add-and-commit?tab=readme-ov-file#working-with-prs + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.ref }} + # custom token is required to trigger actions after reformatting + pushing fetch-depth: 0 + token: ${{ secrets.NEMO_REFORMAT_TOKEN }} - name: Install secrets detector run: pip install detect-secrets - name: Run on change-set run: | - cd ${{ github.run_id }} - git diff --name-only --diff-filter=d --merge-base origin/main -z | xargs -0 detect-secrets-hook --baseline .secrets.baseline \ No newline at end of file + git diff --name-only --diff-filter=d --merge-base origin/main -z | xargs -0 detect-secrets-hook --baseline .secrets.baseline + + - uses: EndBug/add-and-commit@v9 + # Commit changes. Nothing is committed if no changes. + if: always() + with: + message: Update baseline + commit: --signoff diff --git a/Dockerfile.ci b/Dockerfile.ci index dddae1a8ec9f..ee04c79cd2ba 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -54,7 +54,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.19.0 -ARG MCORE_TAG=bc8c4f356240ea4ccadce426251171e6e430c9d3 +ARG MCORE_TAG=47ff44e5b98061bf81295ce7df899ee62529d5e3 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ diff --git a/docs/source/asr/intro.rst b/docs/source/asr/intro.rst index aae372765a8a..ade767e541a0 100644 --- a/docs/source/asr/intro.rst +++ b/docs/source/asr/intro.rst @@ -16,10 +16,39 @@ After :ref:`installing NeMo`, you can transcribe an audio file as asr_model = nemo_asr.models.ASRModel.from_pretrained("stt_en_fastconformer_transducer_large") transcript = asr_model.transcribe(["path/to/audio_file.wav"]) -Obtain word/segment timestamps -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Obtain timestamps +^^^^^^^^^^^^^^^^^ -You can also obtain timestamps for each word or segment in the transcription as follows: +Obtaining char(token), word or segment timestamps is also possible with NeMo ASR Models. + +Currently, timestamps are available for Parakeet Models with all types of decoders (CTC/RNNT/TDT). Support for AED models would be added soon. + +There are two ways to obtain timestamps: +1. By using the `timestamps=True` flag in the `transcribe` method. +2. For more control over the timestamps, you can update the decoding config to mention type of timestamps (char, word, segment) and also specify the segment seperators or word seperator for segment and word level timestamps. + +With the `timestamps=True` flag, you can obtain timestamps for each character in the transcription as follows: + +.. code-block:: python + + # import nemo_asr and instantiate asr_model as above + import nemo.collections.asr as nemo_asr + asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt_ctc-110m") + + # specify flag `timestamps=True` + hypotheses = asr_model.transcribe(["path/to/audio_file.wav"], timestamps=True) + + # by default, timestamps are enabled for char, word and segment level + word_timestamps = hypotheses[0][0].timestep['word'] # word level timestamps for first sample + segment_timestamps = hypotheses[0][0].timestep['segment'] # segment level timestamps + char_timestamps = hypotheses[0][0].timestep['char'] # char level timestamps + + for stamp in segment_timestamps: + print(f"{stamp['start']}s - {stamp['end']}s : {stamp['segment']}") + + # segment level timestamps (if model supports Punctuation and Capitalization, segment level timestamps are displayed based on punctuation otherwise complete transcription is considered as a single segment) + +For more control over the timestamps, you can update the decoding config to mention type of timestamps (char, word, segment) and also specify the segment seperators or word seperator for segment and word level timestamps as follows: .. code-block:: python diff --git a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py index 0417522885b9..8188bcced14d 100644 --- a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py +++ b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py @@ -13,11 +13,13 @@ # limitations under the License. """ -This script chunks long audios into non-overlapping segments of `chunk_len_in_secs` seconds and performs inference on each +This script chunks long audios into non-overlapping segments of `chunk_len_in_secs` +seconds and performs inference on each segment individually. The results are then concatenated to form the final output. Below is an example of how to run this script with the Canary-1b model. -It's recommended to use manifest input, otherwise the model will perform English ASR with punctuations and capitalizations. +It's recommended to use manifest input, otherwise the model will perform English ASR +with punctuations and capitalizations. An example manifest line: { "audio_filepath": "/path/to/audio.wav", # path to the audio file @@ -41,11 +43,10 @@ """ -import contextlib import copy import glob import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from typing import Optional import pytorch_lightning as pl @@ -67,6 +68,10 @@ @dataclass class TranscriptionConfig: + """ + Transcription config + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -116,6 +121,10 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: + """ + Transcribes the input audio and can be used to infer long audio files by chunking + them into smaller segments. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) @@ -160,7 +169,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: if model_cfg.preprocessor.normalize != "per_feature": logging.error( - "Only EncDecMultiTaskModel models trained with per_feature normalization are supported currently" + "Only EncDecMultiTaskModel models trained with per_feature normalization are supported \ + currently" ) # Disable config overwriting @@ -206,7 +216,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) output_filename, pred_text_attr_name = write_transcription( - hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False + hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py index 77b97e0ab516..87370d278f98 100644 --- a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py +++ b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py @@ -35,12 +35,11 @@ You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the predictions of the model, and ground-truth text if presents in manifest. """ -import contextlib import copy import glob import math import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from typing import Optional import pytorch_lightning as pl @@ -65,6 +64,10 @@ @dataclass class TranscriptionConfig: + """ + Transcription Configuration for buffered inference. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -114,6 +117,10 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: + """ + Transcribes the input audio and can be used to infer long audio files by chunking + them into smaller segments. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) @@ -221,7 +228,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: filepaths, ) output_filename, pred_text_attr_name = write_transcription( - hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False + hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py index 501ca525c1ed..e6e84cdfa6c4 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py @@ -61,7 +61,7 @@ import glob import math import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from typing import Optional import pytorch_lightning as pl @@ -87,6 +87,10 @@ @dataclass class TranscriptionConfig: + """ + Transcription Configuration for buffered inference. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -143,6 +147,10 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: + """ + Transcribes the input audio and can be used to infer long audio files by chunking + them into smaller segments. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) @@ -274,7 +282,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) output_filename, pred_text_attr_name = write_transcription( - hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False + hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/examples/asr/speech_translation/translate_speech.py b/examples/asr/speech_translation/translate_speech.py index 47717f562774..53599e1b3511 100644 --- a/examples/asr/speech_translation/translate_speech.py +++ b/examples/asr/speech_translation/translate_speech.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import json import os from dataclasses import dataclass, is_dataclass @@ -65,13 +64,19 @@ @dataclass class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ - # Sub-config for changes specific to the Conformer Encoder conformer: ConformerChangeConfig = ConformerChangeConfig() @dataclass class TranslationConfig: + """ + Translation Configuration for audio to text translation. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -106,6 +111,9 @@ class TranslationConfig: @hydra_runner(config_name="TranslationConfig", schema=TranslationConfig) def main(cfg: TranslationConfig) -> Union[TranslationConfig, List[str]]: + """ + Main function to translate audio to text using a pretrained/finetuned model. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index e529c988779a..a543fcf5e252 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import json import os import time @@ -48,14 +47,9 @@ model_path: path to .nemo ASR checkpoint pretrained_name: name of pretrained ASR model (from NGC registry) audio_dir: path to directory with audio files - dataset_manifest: path to dataset JSON manifest file (in NeMo format) - - compute_timestamps: Bool to request greedy time stamp information (if the model supports it) + dataset_manifest: path to dataset JSON manifest file (in NeMo formats compute_langs: Bool to request language ID information (if the model supports it) - - (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) + timestamps: Bool to request greedy time stamp information (if the model supports it) by default None (Optionally: You can limit the type of timestamp computations using below overrides) ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) @@ -98,7 +92,7 @@ clean_groundtruth_text=True \ langid='en' \ batch_size=32 \ - compute_timestamps=False \ + timestamps=False \ compute_langs=False \ cuda=0 \ amp=True \ @@ -109,13 +103,19 @@ @dataclass class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ - # Sub-config for changes specific to the Conformer Encoder conformer: ConformerChangeConfig = field(default_factory=ConformerChangeConfig) @dataclass class TranscriptionConfig: + """ + Transcription Configuration for audio to text transcription. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -136,10 +136,11 @@ class TranscriptionConfig: pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. random_seed: Optional[int] = None # seed number going to be used in seed_everything() - # Set to True to output greedy timestamp information (only supported models) - compute_timestamps: bool = False - # set to True if need to return full alignment information - preserve_alignment: bool = False + # Set to True to output greedy timestamp information (only supported models) and returns full alignment hypotheses + timestamps: Optional[bool] = None + + # Set to True to return hypotheses instead of text from the transcribe function + return_hypotheses: bool = False # Set to True to output language ID information compute_langs: bool = False @@ -171,7 +172,8 @@ class TranscriptionConfig: # Implicit single-turn assuming default role='user' (works with Canary-1B) # +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes # Explicit single-turn prompt: - # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes + # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es + # +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes # Explicit multi-turn prompt: # +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]' prompt: dict = field(default_factory=dict) @@ -194,9 +196,6 @@ class TranscriptionConfig: # if True, will also skip writing anything to the output file return_transcriptions: bool = False - # Set to False to return text instead of hypotheses from the transcribe function, so as to save memory - return_hypotheses: bool = True - # key for groundtruth text in manifest gt_text_attr_name: str = "text" gt_lang_attr_name: str = "lang" @@ -208,6 +207,9 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: + """ + Transcribes the input audio and can be used to infer with Encoder-Decoder models. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: @@ -272,10 +274,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis asr_model.to(getattr(torch, cfg.compute_dtype)) # we will adjust this flag if the model does not support it - compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs - # has to be True if timestamps are required - preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): @@ -295,7 +294,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): cfg.multitask_decoding.compute_langs = cfg.compute_langs - cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment if cfg.extract_nbest: cfg.multitask_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True @@ -309,9 +307,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if cfg.extract_nbest: decoding_cfg.beam.return_best_hypothesis = False cfg.return_hypotheses = True - decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it - if 'preserve_alignments' in decoding_cfg: - decoding_cfg.preserve_alignments = preserve_alignment if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs if hasattr(asr_model, 'cur_decoder'): @@ -325,16 +320,12 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis cfg.rnnt_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True cfg.rnnt_decoding.fused_batch_size = -1 - cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps cfg.rnnt_decoding.compute_langs = cfg.compute_langs - if 'preserve_alignments' in cfg.rnnt_decoding: - cfg.rnnt_decoding.preserve_alignments = preserve_alignment asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: if cfg.compute_langs: raise ValueError("CTC models do not support `compute_langs` at the moment.") - cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps if cfg.extract_nbest: cfg.ctc_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True @@ -379,7 +370,8 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis item = json.loads(line) if "duration" not in item: raise ValueError( - f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." + f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} \ + lacks a 'duration' field." ) total_duration += item["duration"] @@ -396,6 +388,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis override_cfg.augmentor = augmentor override_cfg.text_field = cfg.gt_text_attr_name override_cfg.lang_field = cfg.gt_lang_attr_name + override_cfg.timestamps = cfg.timestamps if hasattr(override_cfg, "prompt"): override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) @@ -433,7 +426,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis model_name, filepaths=filepaths, compute_langs=compute_langs, - compute_timestamps=compute_timestamps, + timestamps=cfg.timestamps, ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 39efe87de368..7d4cde7866a2 100644 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -22,7 +22,7 @@ class SquadDataModuleWithPthDataloader(llm.SquadDataModule): - def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: return DataLoader( dataset, num_workers=self.num_workers, diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 268438c2e09d..f18fe02d2ed8 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -40,7 +40,6 @@ from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier -from nemo.collections.asr.parts.utils import manifest_utils from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common import tokenizers from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config @@ -68,6 +67,9 @@ def lens_to_mask(lens, max_length): + """ + Create a mask from a tensor of lengths. + """ batch_size = lens.shape[0] mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None] return mask @@ -222,7 +224,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.val_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) - # TODO: PytorchMetrics lets you join two metrics together to save compute. But need to make wer and bleu have same outputs first + # TODO: PytorchMetrics lets you join two metrics together to save compute. + # But need to make wer and bleu have same outputs first self.wer = WER(self.decoding, log_prediction=self.cfg.get("log_prediction")) self.bleu = BLEU( self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False @@ -270,13 +273,15 @@ def change_vocabulary( prompt_format: Optional[str] = None, ): """ - Changes vocabulary used during AED decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during AED decoding process. Use this method when fine-tuning on + from pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when + fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: - new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer + (if the tokenizer type is `agg`) new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`. decoding_cfg: A config for the decoding, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. @@ -291,7 +296,8 @@ def change_vocabulary( new_tokenizer_cfg = new_tokenizer_dir else: raise ValueError( - f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this\ + tokenizer type is: {new_tokenizer_type}' ) else: new_tokenizer_cfg = None @@ -457,13 +463,15 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[MultiTaskTranscriptionConfig] = None, **prompt, ) -> Union[List[str], List[Hypothesis]]: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path to a manifest file. + audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path + to a manifest file. Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. @@ -472,15 +480,30 @@ def transcribe( return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels + from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis + object (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class + for more details. Default is None and would retain the previous state set by using + self.change_decoding_strategy(). + Note: Currently its not supported for AED models. verbose: (bool) whether to display tqdm progress bar - override_config: (Optional[MultiTaskTranscriptionConfig]) A config to override the default config. - **prompt: Optional input to construct the prompts for the model. Accepted formats are: 1) legacy Canary-1B API source_lang=, target_lang=, etc. 2) explicit single-turn role=, slots={: , ...} 3) explicit multi-turn: turns=[{"role": , "slots": {: , ...}}] + override_config: (Optional[MultiTaskTranscriptionConfig]) A config to override the + default config. + **prompt: Optional input to construct the prompts for the model. Accepted formats are: + 1) legacy Canary-1B API source_lang=, target_lang=, etc. + 2) explicit single-turn role=, slots={: , ...} + 3) explicit multi-turn: turns=[{"role": , "slots": {: , ...}}] Returns: - A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order + as paths2audio_files """ + if timestamps: + raise NotImplementedError("Computing timestamps are not supported for this model yet.") + if override_config is None: trcfg = MultiTaskTranscriptionConfig( batch_size=batch_size, @@ -889,7 +912,8 @@ def _transcribe_forward( ) @deprecated( - explanation='The return type of args will be updated in the upcoming release to ensure a consistent output format across all decoder types, such that a Hypothesis object is always returned.' + explanation='The return type of args will be updated in the upcoming release to ensure a consistent \ + output format across all decoder types, such that a Hypothesis object is always returned.' ) def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType: """ diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 2e313ce3c928..79c22794de01 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -209,12 +209,14 @@ def change_vocabulary( """ Changes vocabulary of the tokenizer used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. + For example, you would use it if you want to use pretrained encoder when fine-tuning on a + data in another language, or when you'd need model to learn capitalization, punctuation + and/or special characters. Args: - new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer + (if the tokenizer type is `agg`) new_tokenizer_type: Either `agg`, `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, whereas `wpe` is used for `BertTokenizer`. new_tokenizer_cfg: A config for the new tokenizer. if provided, pre-empts the dir and type @@ -227,7 +229,8 @@ def change_vocabulary( new_tokenizer_cfg = new_tokenizer_dir else: raise ValueError( - f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer \ + type is: {new_tokenizer_type}' ) else: new_tokenizer_cfg = None @@ -307,13 +310,14 @@ def change_vocabulary( logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose: bool = True): """ Changes decoding strategy used during CTC decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: Whether to print the new config or not. """ if decoding_cfg is None: # Assume same decoding config as before @@ -343,7 +347,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: @@ -378,7 +383,7 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: model = PretrainedModelInfo( pretrained_model_name="stt_en_citrinet_256_gamma_0_25", - description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_256_gamma_0_25", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:\nemo:stt_en_citrinet_256_gamma_0_25", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_256_gamma_0_25/versions/1.0.0/files/stt_en_citrinet_256_gamma_0_25.nemo", ) results.append(model) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index edf4f84a9f9b..993c7dc6b298 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import json import os -import tempfile from math import ceil from typing import Any, Dict, List, Optional, Union @@ -23,7 +21,6 @@ from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from torch.utils.data import DataLoader -from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset @@ -37,6 +34,7 @@ from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler +from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes.common import PretrainedModelInfo, typecheck @@ -45,6 +43,7 @@ from nemo.utils import logging from nemo.utils.decorators import deprecated + __all__ = ['EncDecCTCModel'] @@ -128,13 +127,15 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[TranscribeConfig] = None, ) -> TranscriptionReturnType: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path to a manifest file. + audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or + path to a manifest file. Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. @@ -143,16 +144,41 @@ def transcribe( return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels + from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis + object (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class + for more details. Default is None and would retain the previous state set by + using self.change_decoding_strategy(). verbose: (bool) whether to display tqdm progress bar override_config: (Optional[TranscribeConfig]) override transcription config pre-defined by the user. **Note**: All other arguments in the function will be ignored if override_config is passed. You should call this argument as `model.transcribe(audio, override_config=TranscribeConfig(...))`. Returns: - A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as + paths2audio_files """ + if timestamps is not None: + # else retain the decoder state (users can set it using change_decoding_strategy) + if timestamps or (override_config is not None and override_config.timestamps): + logging.info( + "Timestamps requested, setting decoding timestamps to True. Capture them in Hypothesis object, \ + with output[idx].timestep['word'/'segment'/'char']" + ) + return_hypotheses = True + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = True + self.cfg.decoding.preserve_alignments = True + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + else: # This is done to ensure the state is preserved when decoding_strategy is set outside + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = self.cfg.decoding.get('compute_timestamps', False) + self.cfg.decoding.preserve_alignments = self.cfg.decoding.get('preserve_alignments', False) + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + return super().transcribe( audio=audio, batch_size=batch_size, @@ -161,6 +187,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, override_config=override_config, ) @@ -235,13 +262,14 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose: bool = True): """ Changes decoding strategy used during CTC decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: (bool) whether to display logging information """ if decoding_cfg is None: # Assume same decoding config as before @@ -270,7 +298,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") def _setup_dataloader_from_config(self, config: Optional[Dict]): # Automatically inject args from model config to dataloader config @@ -670,7 +699,8 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): return output @deprecated( - explanation='The return type of args will be updated in the upcoming release to ensure a consistent output format across all decoder types, such that a Hypothesis object is always returned.' + explanation='The return type of args will be updated in the upcoming release to ensure a consistent output \ + format across all decoder types, such that a Hypothesis object is always returned.' ) def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> GenericTranscriptionType: logits = outputs.pop('logits') @@ -705,6 +735,14 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen # cleanup memory del logits, logits_len + if trcfg.timestamps: + current_hypotheses = process_timestamp_outputs( + current_hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + hypotheses = [] if all_hyp is None: hypotheses += current_hypotheses @@ -767,7 +805,11 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: model = PretrainedModelInfo( pretrained_model_name="QuartzNet15x5Base-En", - description="QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.", + description="QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice \ + (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. \ + It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of \ + 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit \ + https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo", ) results.append(model) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 089c34d98884..1d437a19a86b 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -253,10 +253,11 @@ def change_vocabulary( ctc_decoding_cfg: Optional[DictConfig] = None, ): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on + from pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when + fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) @@ -415,7 +416,9 @@ def change_vocabulary( logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): + def change_decoding_strategy( + self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True + ): """ Changes decoding strategy used during RNNT decoding process. Args: @@ -424,6 +427,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoder_type: (str) Can be set to 'rnnt' or 'ctc' to switch between appropriate decoder in a model having both RNN-T and CTC decoders. Defaults to None, in which case RNN-T decoder is used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model. + verbose: bool whether to display change of decoder config or not. """ if decoder_type is None or decoder_type == 'rnnt': if decoding_cfg is None: @@ -466,7 +470,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cfg.decoding = decoding_cfg self.cur_decoder = "rnnt" - logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info( + f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}" + ) elif decoder_type == 'ctc': if not hasattr(self, 'ctc_decoding'): @@ -497,9 +504,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cfg.aux_ctc.decoding = decoding_cfg self.cur_decoder = "ctc" - logging.info( - f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" - ) + if verbose: + logging.info( + f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" + ) else: raise ValueError(f"decoder_type={decoder_type} is not supported. Supported values: [ctc,rnnt]") diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index c14265325985..028073d7ca7f 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -31,6 +31,7 @@ from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig +from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.mixins import AccessMixin from nemo.utils import logging, model_utils @@ -104,6 +105,7 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: bool = False, override_config: Optional[TranscribeConfig] = None, ) -> TranscriptionReturnType: """ @@ -120,8 +122,13 @@ def transcribe( return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of + channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis object + (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class for more details. + Default is None and would retain the previous state set by using self.change_decoding_strategy(). verbose: (bool) whether to display tqdm progress bar logprobs: (bool) whether to return ctc logits insted of hypotheses @@ -130,10 +137,29 @@ def transcribe( * A list of greedy transcript texts / Hypothesis * An optional list of beam search transcript texts / Hypothesis / NBestHypothesis. """ - if self.cur_decoder not in ["ctc", "rnnt"]: - raise ValueError( - f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']" - ) + + if timestamps is not None: + if self.cur_decoder not in ["ctc", "rnnt"]: + raise ValueError( + f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']" + ) + decoding_cfg = self.cfg.aux_ctc.decoding if self.cur_decoder == "ctc" else self.cfg.decoding + if timestamps or (override_config is not None and override_config.timestamps): + logging.info( + "Timestamps requested, setting decoding timestamps to True. Capture them in Hypothesis object, \ + with output[idx].timestep['word'/'segment'/'char']" + ) + return_hypotheses = True + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = True + decoding_cfg.preserve_alignments = True + self.change_decoding_strategy(decoding_cfg, decoder_type=self.cur_decoder, verbose=False) + else: + return_hypotheses = False + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = False + decoding_cfg.preserve_alignments = False + self.change_decoding_strategy(decoding_cfg, decoder_type=self.cur_decoder, verbose=False) return super().transcribe( audio=audio, @@ -144,6 +170,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, override_config=override_config, ) @@ -201,6 +228,14 @@ def _transcribe_output_processing( # for logit, elen in zip(logits, encoded_len): # logits_list.append(logit[:elen]) + if trcfg.timestamps: + best_hyp = process_timestamp_outputs( + best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + del logits, encoded_len hypotheses = [] @@ -221,10 +256,11 @@ def change_vocabulary( ctc_decoding_cfg: Optional[DictConfig] = None, ): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a + pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder + when fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ @@ -295,7 +331,9 @@ def change_vocabulary( logging.info(f"Changed the tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): + def change_decoding_strategy( + self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True + ): """ Changes decoding strategy used during RNNT decoding process. @@ -305,10 +343,11 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoder_type: (str) Can be set to 'rnnt' or 'ctc' to switch between appropriate decoder in a model having RNN-T and CTC decoders. Defaults to None, in which case RNN-T decoder is used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model. + verbose: (bool) whether to display logging information """ if decoder_type is None or decoder_type == 'rnnt': self.cur_decoder = "rnnt" - return super().change_decoding_strategy(decoding_cfg=decoding_cfg) + return super().change_decoding_strategy(decoding_cfg=decoding_cfg, verbose=verbose) assert decoder_type == 'ctc' and hasattr(self, 'ctc_decoder') if decoding_cfg is None: @@ -337,7 +376,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cfg.aux_ctc.decoding = decoding_cfg self.cur_decoder = "ctc" - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}") + + return None # PTL-specific methods def training_step(self, batch, batch_nb): diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 9e09acd21a5d..25890ec716c8 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -344,13 +344,15 @@ def change_vocabulary( decoding_cfg: Optional[DictConfig] = None, ): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning + on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning + on data in another language, or when you'd need model to learn capitalization, punctuation + and/or special characters. Args: - new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer + (if the tokenizer type is `agg`) new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`. decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. @@ -363,7 +365,8 @@ def change_vocabulary( new_tokenizer_cfg = new_tokenizer_dir else: raise ValueError( - f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer \ + type is: {new_tokenizer_type}' ) else: new_tokenizer_cfg = None @@ -451,13 +454,14 @@ def change_vocabulary( logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose: bool = True): """ Changes decoding strategy used during RNNT decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: A flag to enable/disable logging. """ if decoding_cfg is None: # Assume same decoding config as before @@ -498,7 +502,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 2b319a3c7dec..ce3b6bc89bce 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -40,6 +40,7 @@ from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecoding, RNNTDecodingConfig from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler +from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes.common import PretrainedModelInfo, typecheck @@ -247,13 +248,15 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[TranscribeConfig] = None, ) -> TranscriptionReturnType: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path to a manifest file. + audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path + to a manifest file. Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. @@ -265,9 +268,14 @@ def transcribe( decoding. This is useful for streaming rnnt decoding. If this is not None, then the length of this list should be equal to the length of the audio list. num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels + from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. verbose: (bool) whether to display tqdm progress bar + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis object + (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class for more details. + Default is None and would retain the previous state set by using self.change_decoding_strategy(). override_config: (Optional[TranscribeConfig]) override transcription config pre-defined by the user. **Note**: All other arguments in the function will be ignored if override_config is passed. You should call this argument as `model.transcribe(audio, override_config=TranscribeConfig(...))`. @@ -277,6 +285,25 @@ def transcribe( * A list of greedy transcript texts / Hypothesis * An optional list of beam search transcript texts / Hypothesis / NBestHypothesis. """ + + if timestamps is not None: + if timestamps or (override_config is not None and override_config.timestamps): + logging.info( + "Timestamps requested, setting decoding timestamps to True. Capture them in Hypothesis object, \ + with output[0][idx].timestep['word'/'segment'/'char']" + ) + return_hypotheses = True + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = True + self.cfg.decoding.preserve_alignments = True + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + else: + return_hypotheses = False + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = False + self.cfg.decoding.preserve_alignments = False + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + return super().transcribe( audio=audio, batch_size=batch_size, @@ -285,6 +312,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, override_config=override_config, # Additional arguments partial_hypothesis=partial_hypothesis, @@ -292,10 +320,11 @@ def transcribe( def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[DictConfig] = None): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a + pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when + fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ @@ -381,13 +410,14 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose=True): """ Changes decoding strategy used during RNNT decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: (bool) whether to display logging information """ if decoding_cfg is None: # Assume same decoding config as before @@ -428,7 +458,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") def _setup_dataloader_from_config(self, config: Optional[Dict]): # Automatically inject args from model config to dataloader config @@ -901,7 +932,8 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): return output @deprecated( - explanation='The return type of args will be updated in the upcoming release to ensure a consistent output format across all decoder types, such that a "Hypothesis" object is always returned.' + explanation='The return type of args will be updated in the upcoming release to ensure a consistent \ + output format across all decoder types, such that a "Hypothesis" object is always returned.' ) def _transcribe_output_processing( self, outputs, trcfg: TranscribeConfig @@ -915,10 +947,17 @@ def _transcribe_output_processing( return_hypotheses=trcfg.return_hypotheses, partial_hypotheses=trcfg.partial_hypothesis, ) - # cleanup memory del encoded, encoded_len + if trcfg.timestamps: + best_hyp = process_timestamp_outputs( + best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + hypotheses = [] all_hypotheses = [] diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 3cb9ec13109b..e48d76a9b7a3 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -133,6 +133,7 @@ def __init__( residual_panes = [] encoder_layers = [] self.dense_residual = False + self._subsampling_factor = 1 for layer_idx, lcfg in enumerate(jasper): dense_res = [] if lcfg.get('residual_dense', False): @@ -181,6 +182,9 @@ def __init__( ) ) feat_in = lcfg['filters'] + self._subsampling_factor *= ( + int(lcfg['stride'][0]) if isinstance(lcfg['stride'], List) else int(lcfg['stride']) + ) self._feat_out = feat_in @@ -199,7 +203,9 @@ def forward(self, audio_signal, length): return s_input[-1], length def update_max_sequence_length(self, seq_length: int, device): - # Find global max audio length across all nodes + """ + Find global max audio length across all nodes in distributed training and update the max_audio_length + """ if torch.distributed.is_initialized(): global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) @@ -229,6 +235,10 @@ def update_max_sequence_length(self, seq_length: int, device): elif isinstance(m, SqueezeExcite): m.set_max_len(self.max_audio_length, seq_range=self.seq_range) + @property + def subsampling_factor(self) -> int: + return self._subsampling_factor + class ParallelConvASREncoder(NeuralModule, Exportable): """ @@ -426,7 +436,8 @@ def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary= if vocabulary is not None: if num_classes != len(vocabulary): raise ValueError( - f"If vocabulary is specified, it's length should be equal to the num_classes. Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" + f"If vocabulary is specified, it's length should be equal to the num_classes. \ + Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" ) self.__vocabulary = vocabulary self._feat_in = feat_in @@ -765,8 +776,8 @@ class SpeakerDecoder(NeuralModule, Exportable): Args: feat_in (int): Number of channels being input to this module num_classes (int): Number of unique speakers in dataset - emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers) - Defaults to [1024,1024] + emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings + from 1st of this layers). Defaults to [1024,1024] pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention' Defaults to 'xvector (mean and variance)' tap (temporal average pooling: just mean) diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 104e6bff81af..ac928fe99272 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -16,8 +16,7 @@ import os import tempfile from abc import ABC, abstractmethod -from collections.abc import Iterable -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -61,6 +60,7 @@ class TranscribeConfig: num_workers: Optional[int] = None channel_selector: ChannelSelectorType = None augmentor: Optional[DictConfig] = None + timestamps: Optional[bool] = None # returns timestamps for each word and segments if model supports punctuations verbose: bool = True # Utility @@ -86,7 +86,8 @@ def get_value_from_transcription_config(trcfg, key, default): return getattr(trcfg, key) else: logging.debug( - f"Using default value of {default} for {key} because it is not present in the transcription config {trcfg}." + f"Using default value of {default} for {key} because it is not present \ + in the transcription config {trcfg}." ) return default @@ -179,6 +180,7 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[TranscribeConfig] = None, **config_kwargs, ) -> GenericTranscriptionType: @@ -200,6 +202,9 @@ def transcribe( to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. verbose: (bool) whether to display tqdm progress bar + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis object + (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class for more details. + Default is None and would retain the previous state set by using self.change_decoding_strategy(). override_config: (Optional[TranscribeConfig]) override transcription config pre-defined by the user. **Note**: All other arguments in the function will be ignored if override_config is passed. You should call this argument as `model.transcribe(audio, override_config=TranscribeConfig(...))`. @@ -229,6 +234,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, **config_kwargs, ) else: diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index e52c3f46423e..da280a0c6b3c 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -67,13 +67,15 @@ class AbstractRNNTDecoding(ConfidenceMixin): rnnt_timestamp_type: A str value, which represents the types of timestamps that should be calculated. Can take the following values - "char" for character/subword time stamps, "word" for word level - time stamps, "segment" for segment level time stamps and "all" (default), for character, word and segment level time stamps. + time stamps, "segment" for segment level time stamps and "all" (default), for character, + word and segment level time stamps. word_seperator: Str token representing the seperator between words. segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary + for forming the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -102,10 +104,10 @@ class AbstractRNNTDecoding(ConfidenceMixin): The length of the list corresponds to the number of recognized words. exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. - aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. - Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and - attached to the regular frame confidence, + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word + confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated + and attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -177,22 +179,23 @@ class AbstractRNNTDecoding(ConfidenceMixin): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 - in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the next step. - - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. - Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed - but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, - thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally - tuned on a validation set. + and affects the speed of inference since large values will perform large beam search in the + next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the + expansions. The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set + and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin + of additional tokens which can be potential candidates for expansion apart from the "most likely" + candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, + thereby improving speed but hurting accuracy). Higher values will increase the number of expansions + (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). + This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -887,7 +890,7 @@ def _compute_offsets( # Construct the start and end indices brackets end_indices = np.asarray(token_repetitions).cumsum() - start_indices = np.concatenate(([start_index], end_indices[:-1])) + start_indices = np.concatenate(([int(start_index)], end_indices[:-1])) # Process the TxU dangling alignment tensor, containing pairs of (logits, label) alignment_labels = [al_logits_labels for al_logits_labels in hypothesis.text[1]] @@ -950,7 +953,8 @@ def _refine_timestamps_tdt( # Check if token is a punctuation mark # If so, set its start and end offset as start and end of the previous token - # This is done because there was observed a behaviour, when punctuation marks are predicted long after preceding token (i.e. after silence) + # This is done because there was observed a behaviour, when punctuation marks are predicted long + # after preceding token (i.e. after silence) if offset['char'][0] in supported_punctuation and i > 0: encoded_char_offsets[i]['start_offset'] = offset['start_offset'] = char_offsets[i - 1]['end_offset'] encoded_char_offsets[i]['end_offset'] = offset['end_offset'] = offset['start_offset'] @@ -1237,10 +1241,10 @@ class RNNTDecoding(AbstractRNNTDecoding): The length of the list corresponds to the number of recognized words. exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. - aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. - Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and - attached to the regular frame confidence, + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word + confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated + and attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1313,8 +1317,8 @@ class RNNTDecoding(AbstractRNNTDecoding): per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, at increased cost to execution time. - alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. - If an integer is provided, it can decode sequences of that particular maximum length. + alsd_max_target_len: optional int or float, determines the potential maximum target sequence + length. If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). @@ -1326,22 +1330,24 @@ class RNNTDecoding(AbstractRNNTDecoding): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 - in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the next step. - - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. - Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed - but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, - thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally - tuned on a validation set. + and affects the speed of inference since large values will perform large beam search in the + next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the + expansions. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" + candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, + thereby improving speed but hurting accuracy). Higher values will increase the number of + expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving + accuracy). This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -1492,7 +1498,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for + forming the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -1521,10 +1528,10 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): The length of the list corresponds to the number of recognized words. exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. - aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. - Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and - attached to the regular frame confidence, + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word + confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be + calculated and attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1594,8 +1601,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, at increased cost to execution time. - alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. - If an integer is provided, it can decode sequences of that particular maximum length. + alsd_max_target_len: optional int or float, determines the potential maximum target sequence + length. If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). @@ -1607,22 +1614,24 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 - in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the next step. - - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. - Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed - but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, - thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally - tuned on a validation set. + and affects the speed of inference since large values will perform large beam search in the + next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when + computing the expansions. The default (2.3) is selected from the paper. It performs a + comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the + Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore + provides a margin of additional tokens which can be potential candidates for expansion + apart from the "most likely" candidate. Lower values will reduce the number of expansions + (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher + values will increase the number of expansions (by reducing pruning-by-value, thereby + reducing speed but potentially improving accuracy). This is a hyper parameter to be + experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -1750,7 +1759,8 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp hypotheses[ind].langs_chars = self.decode_ids_to_langs(prediction) else: logging.warning( - "Ignoring request for lang output in hypotheses since the model does not use an aggregate tokenizer" + "Ignoring request for lang output in hypotheses since the model does not use an aggregate\ + tokenizer" ) return hypotheses diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 415096a0c9d5..cb272e3d0462 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -22,7 +22,7 @@ from torch.utils.data import DataLoader from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextMiniBatch -from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE +from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.preprocessing.features import normalize_batch from nemo.collections.asr.parts.preprocessing.segment import get_samples @@ -79,8 +79,8 @@ def longest_common_subsequence_merge(X, Y, filepath=None): Assumption is that the two chunks are consecutive chunks, and there exists at least small overlap acoustically. - It is a sub-word token merge algorithm, operating on the abstract notion of integer ids representing the subword ids. - It is independent of text or character encoding. + It is a sub-word token merge algorithm, operating on the abstract notion of integer ids representing + the subword ids. It is independent of text or character encoding. Since the algorithm is merge based, and depends on consecutive buffers, the very first buffer is processes using the "middle tokens" algorithm. @@ -292,8 +292,8 @@ def lcs_alignment_merge_buffer(buffer, data, delay, model, max_steps_per_timeste Merges the new text from the current frame with the previous text contained in the buffer. The alignment is based on a Longest Common Subsequence algorithm, with some additional heuristics leveraging - the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of the merge - will be incorrect (or at least obtain worse WER overall). + the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of the + merge will be incorrect (or at least obtain worse WER overall). """ # If delay timesteps is 0, that means no future context was used. Simply concatenate the buffer with new data. if delay < 1: @@ -327,8 +327,8 @@ def inplace_buffer_merge(buffer, data, timesteps, model): Merges the new text from the current frame with the previous text contained in the buffer. The alignment is based on a Longest Common Subsequence algorithm, with some additional heuristics leveraging - the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of the merge - will be incorrect (or at least obtain worse WER overall). + the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of + the merge will be incorrect (or at least obtain worse WER overall). """ # If delay timesteps is 0, that means no future context was used. Simply concatenate the buffer with new data. if timesteps < 1: @@ -391,7 +391,7 @@ def __init__(self, asr_model, chunk_size, buffer_size): cfg.preprocessor.dither = 0.0 cfg.preprocessor.pad_to = 0 cfg.preprocessor.normalize = "None" - self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor) + self.raw_preprocessor = ASRModel.from_config_dict(cfg.preprocessor) self.raw_preprocessor.to(asr_model.device) def reset(self): @@ -756,7 +756,7 @@ def __init__( cfg.preprocessor.dither = 0.0 cfg.preprocessor.pad_to = 0 cfg.preprocessor.normalize = "None" - self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor) + self.raw_preprocessor = ASRModel.from_config_dict(cfg.preprocessor) self.raw_preprocessor.to(asr_model.device) self.preprocessor = self.raw_preprocessor @@ -1091,12 +1091,15 @@ def _get_batch_preds(self): - For all samples, determine if signal has finished. - If so, skip calculation of mel-specs. - If not, compute mel spec and length - - Perform Encoder forward over this sub-batch of samples. Maintain the indices of samples that were processed. - - If performing stateful decoding, prior to decoder forward, remove the states of samples that were not processed. + - Perform Encoder forward over this sub-batch of samples. Maintain the indices of samples that + were processed. + - If performing stateful decoding, prior to decoder forward, remove the states of samples that + were not processed. - Perform Decoder + Joint forward for samples that were processed. - For all output RNNT alignment matrix of the joint do: - If signal has ended previously (this was last buffer of padding), skip alignment - - Otherwise, recalculate global index of this sample from the sub-batch index, and preserve alignment. + - Otherwise, recalculate global index of this sample from the sub-batch index, and preserve + alignment. - Same for preds - Update indices of sub-batch with global index map. - Redo steps until all samples were processed (sub-batch size == 0). @@ -1362,15 +1365,17 @@ def transcribe( class CacheAwareStreamingAudioBuffer: """ - A buffer to be used for cache-aware streaming. It can load a single or multiple audio files/processed signals, split them in chunks and return one on one. - It can be used to simulate streaming audio or audios. + A buffer to be used for cache-aware streaming. It can load a single or multiple audio + files/processed signals, split them in chunks and return one on one. It can be used to + simulate streaming audio or audios. """ def __init__(self, model, online_normalization=None, pad_and_drop_preencoded=False): ''' Args: model: An ASR model. - online_normalization (bool): whether to perform online normalization per chunk or normalize the whole audio before chunking + online_normalization (bool): whether to perform online normalization per chunk or + normalize the whole audio before chunking pad_and_drop_preencoded (bool): if true pad first audio chunk and always drop preencoded ''' self.model = model @@ -1430,7 +1435,8 @@ def __iter__(self): audio_chunk = self.buffer[:, :, self.buffer_idx : self.buffer_idx + chunk_size] if self.sampling_frames is not None: - # checking to make sure the audio chunk has enough frames to produce at least one output after downsampling + # checking to make sure the audio chunk has enough frames to produce at least one output after + # downsampling if self.buffer_idx == 0 and isinstance(self.sampling_frames, list): cur_sampling_frames = self.sampling_frames[0] else: diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 0d4f4c895bcf..189d98537d3f 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -199,7 +199,8 @@ def get_buffered_pred_feat_multitaskAED( if filepaths: logging.info( - "Deteced audio files as input, default to English ASR with Punctuation and Capitalization output. Please use manifest input for other options." + "Deteced audio files as input, default to English ASR with Punctuation and Capitalization output. \ + Please use manifest input for other options." ) for audio_file in tqdm(filepaths, desc="Transcribing:", total=len(filepaths), ncols=80): meta = { @@ -281,12 +282,16 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: - append_pred (bool): Flag indicating whether to append predictions to an existing dataset. - audio_type (str): Type of audio files to consider. - dataset_manifest (str): Path to the dataset manifest file. - - audio_key (str, optional): Key in the manifest file specifying the audio file path. Defaults to 'audio_filepath'. - - presort_manifest (bool, optional): Flag indicating whether to presort the manifest file. Defaults to True. + - audio_key (str, optional): Key in the manifest file specifying the audio file path. + Defaults to 'audio_filepath'. + - presort_manifest (bool, optional): Flag indicating whether to presort the manifest file. + Defaults to True. Returns: Tuple[List[str], bool]: A tuple containing the following: - - filepaths (List[str]): List of filepaths to the audio files if path to the directory containing audio files is provided. - - sorted_manifest_path (bool): Path to the sorted manifest file if path to the dataset manifest file is provided. + - filepaths (List[str]): List of filepaths to the audio files if path to the directory + containing audio files is provided. + - sorted_manifest_path (bool): Path to the sorted manifest file if path to the dataset + manifest file is provided. """ filepaths = None @@ -308,7 +313,8 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: item[audio_key] = get_full_path(item[audio_key], cfg.dataset_manifest) if item.get("duration") is None and cfg.presort_manifest: raise ValueError( - f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." + f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} \ + lacks a 'duration' field." ) with NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: @@ -388,7 +394,7 @@ def write_transcription( model_name: str, filepaths: List[str] = None, compute_langs: bool = False, - compute_timestamps: bool = False, + timestamps: bool = False, ) -> Tuple[str, str]: """Write generated transcription to output file.""" if cfg.append_pred: @@ -433,7 +439,7 @@ def write_transcription( else: # transcription is Hypothesis item = {'audio_filepath': filepaths[idx], pred_text_attr_name: transcription.text} - if compute_timestamps: + if timestamps: timestamps = transcription.timestep if timestamps is not None and isinstance(timestamps, dict): timestamps.pop( @@ -441,7 +447,7 @@ def write_transcription( ) # Pytorch tensor calculating index of each token, not needed. for key in timestamps.keys(): values = normalize_timestamp_output(timestamps[key]) - item[f'timestamps_{key}'] = values + item[f'{key}'] = values if compute_langs: item['pred_lang'] = transcription.langs @@ -458,7 +464,7 @@ def write_transcription( else: # transcription is Hypothesis item[pred_text_attr_name] = best_hyps[idx].text - if compute_timestamps: + if timestamps: timestamps = best_hyps[idx].timestep if timestamps is not None and isinstance(timestamps, dict): timestamps.pop( @@ -466,7 +472,7 @@ def write_transcription( ) # Pytorch tensor calculating index of each token, not needed. for key in timestamps.keys(): values = normalize_timestamp_output(timestamps[key]) - item[f'timestamps_{key}'] = values + item[f'{key}'] = values if compute_langs: item['pred_lang'] = best_hyps[idx].langs @@ -492,10 +498,14 @@ def compute_metrics_per_sample( Args: manifest_path: str, Required - path to dataset JSON manifest file (in NeMo format) - reference_field: str, Optional - name of field in .json manifest with the reference text ("text" by default). - hypothesis_field: str, Optional - name of field in .json manifest with the hypothesis text ("pred_text" by default). - metrics: list[str], Optional - list of metrics to be computed (currently supported "wer", "cer", "punct_er") - punctuation_marks: list[str], Optional - list of punctuation marks for computing punctuation error rate ([".", ",", "?"] by default). + reference_field: str, Optional - name of field in .json manifest with the reference text + ("text" by default). + hypothesis_field: str, Optional - name of field in .json manifest with the hypothesis text + ("pred_text" by default). + metrics: list[str], Optional - list of metrics to be computed + (currently supported "wer", "cer", "punct_er") + punctuation_marks: list[str], Optional - list of punctuation marks for computing + punctuation error rate ([".", ",", "?"] by default). output_manifest_path: str, Optional - path where .json manifest with calculated metrics will be saved. Returns: @@ -568,6 +578,61 @@ def compute_metrics_per_sample( return samples_with_metrics +def process_timestamp_outputs(outputs, subsampling_factor: int = 1, window_stride: float = 0.01): + """ + Process the timestamps from list of hypothesis to user friendly format. + Converts the start and end duration from frames to seconds. + Args: + outputs: List of Hypothesis objects. + subsampling_factor: int, Subsampling factor used in the model. + window_stride: float, Window stride used in the model. (sometimes referred to as hop length/shift) + Returns: + List of Hypothesis objects with processed timestamps + + """ + + if outputs is None: + return outputs + + if isinstance(outputs, rnnt_utils.Hypothesis): + outputs = [outputs] + + if not isinstance(outputs[0], rnnt_utils.Hypothesis): + raise ValueError(f"Expected Hypothesis object, got {type(outputs[0])}") + + def process_timestamp(timestamp, subsampling_factor, window_stride): + """ + Process the timestamp for a single hypothesis. + return the start and end duration in seconds. + """ + for idx, val in enumerate(timestamp): + start_offset = val['start_offset'] + end_offset = val['end_offset'] + start = start_offset * window_stride * subsampling_factor + end = end_offset * window_stride * subsampling_factor + val['start'] = start + val['end'] = end + + return timestamp + + for idx, hyp in enumerate(outputs): + if not hasattr(hyp, 'timestep'): + raise ValueError( + f"Expected Hypothesis object to have 'timestep' attribute, when compute_timestamps is \ + enabled but got {hyp}" + ) + timestep = hyp.timestep + if 'word' in timestep: + outputs[idx].timestep['word'] = process_timestamp(timestep['word'], subsampling_factor, window_stride) + if 'char' in timestep: + outputs[idx].timestep['char'] = process_timestamp(timestep['char'], subsampling_factor, window_stride) + if 'segment' in timestep: + outputs[idx].timestep['segment'] = process_timestamp( + timestep['segment'], subsampling_factor, window_stride + ) + return outputs + + class PunctuationCapitalization: def __init__(self, punctuation_marks: str): """ diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 8f37ce24a23a..fb115faade2f 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -567,6 +567,72 @@ def generate( inference_params: Optional["CommonInferenceParams"] = None, text_only: bool = False, ) -> list[Union["InferenceRequest", str]]: + """ + Generates text using a NeMo LLM model. + + This function takes a checkpoint path and a list of prompts, + and generates text based on the loaded model and parameters. + It returns a list of generated text, either as a string or as an InferenceRequest object. + + Python Usage: + ```python + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + setup_optimizers=False, + store_optimizer_states=False, + ) + + trainer = nl.Trainer( + accelerator="gpu", + devices=2, + num_nodes=1, + strategy=strategy, + plugins=nl.MegatronMixedPrecision( + precision="bf16-mixed", + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + autocast_enabled=False, + grad_reduce_in_fp32=False, + ), + ) + prompts = [ + "Hello, how are you?", + "How many r's are in the word 'strawberry'?", + "Which number is bigger? 10.119 or 10.19?", + ] + + if __name__ == "__main__": + results = api.generate( + path=os.path.join(os.environ["NEMO_HOME"], "models", "meta-llama/Meta-Llama-3-8B"), + prompts=prompts, + trainer=trainer, + inference_params=CommonInferenceParams(temperature=0.1, top_k=10, num_tokens_to_generate=512), + text_only=True, + ) + ``` + + Args: + path (Union[Path, str]): The path to the model checkpoint. + prompts (list[str]): The list of prompts to generate text for. + trainer (nl.Trainer): The trainer object. + encoder_prompts (Optional[list[str]], optional): The list of encoder prompts. Defaults to None. + params_dtype (torch.dtype, optional): The data type of the model parameters. Defaults to torch.bfloat16. + add_BOS (bool, optional): Whether to add the beginning of sequence token. Defaults to False. + max_batch_size (int, optional): The maximum batch size. Defaults to 4. + random_seed (Optional[int], optional): The random seed. Defaults to None. + inference_batch_times_seqlen_threshold (int, optional): If batch-size times sequence-length is smaller than + this threshold then we will not use pipelining, otherwise we will. Defaults to 1000. + inference_params (Optional["CommonInferenceParams"], optional): The inference parameters defined in + Mcore's CommonInferenceParams. Defaults to None. + text_only (bool, optional): Whether to return only the generated text as a string. Defaults to False. + + Returns: + list[Union["InferenceRequest", str]]: A list of generated text, + either as a string or as an InferenceRequest object. + """ from nemo.collections.llm import inference inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer( diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index d7ed08a01ed4..9d16ea8aa021 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -22,6 +22,7 @@ from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.llm.gpt.data.core import create_sft_dataset +from nemo.lightning.data import WrappedDataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils import logging @@ -34,22 +35,26 @@ class FineTuningDataModule(pl.LightningDataModule): """Base class for fine-tuning an LLM. This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from - `pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch creation - for training, validation, and testing. + `pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch + creation for training, validation, and testing. Args: dataset_root (Union[str, Path]): The root directory containing the training, validation, and test data. seq_length (int, optional): The maximum sequence length for the input and output text. Defaults to 2048. - tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. Defaults to None. + tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. If not provided, a Megatron GPT2 BPE tokenizer will be used. micro_batch_size (int, optional): The micro batch size for training. Defaults to 4. global_batch_size (int, optional): The global batch size for training. Defaults to 8. - rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. Defaults to None. + rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. + Defaults to None. seed (int, optional): The random seed for data shuffling. Defaults to 1234. - memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. Defaults to 1. + memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. + Defaults to 1. num_workers (int, optional): The number of worker processes for data loading. Defaults to 8. - pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. Defaults to True. - persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False. + pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. + Defaults to True. + persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. + Defaults to False. packed_sequence_specs (PackedSequenceSpecs, optional): See PackedSequenceSpecs for details dataset_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments to pass into the GPTSFTDataset class """ @@ -90,18 +95,28 @@ def __init__( self.dataset_kwargs = dataset_kwargs or {} def validate_batch_size_for_packed_sequence(self): + """ + Validate that micro batch size must be 1 when using packed sequence. + """ if self.packed_sequence_size > 0 and self.micro_batch_size > 1: raise ValueError( "Micro batch size should be 1 when training with packed sequence, but your micro batch size " f"is {self.micro_batch_size}. \nThe following config is equivalent to your current setting for " f"a packed dataset. Please update your config to the following: \n" f"Set micro batch size to 1 (currently {self.micro_batch_size})\n" - f"Set global batch size to {self.global_batch_size // self.micro_batch_size} (currently {self.global_batch_size}) \n" - f"Set packed sequence length to {self.packed_sequence_size*self.micro_batch_size} (currently {self.packed_sequence_size}) \n" - f"For details please visit https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/sequence_packing.html" + f"Set global batch size to {self.global_batch_size // self.micro_batch_size} " + f"(currently {self.global_batch_size}) \n" + f"Set packed sequence length to {self.packed_sequence_size*self.micro_batch_size} " + f"(currently {self.packed_sequence_size}) \n" + f"For details please visit " + f"https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/" + f"sequence_packing.html" ) def prepare_data(self) -> None: + """ + Prepare packed sequence data + """ if self.packed_sequence_size > 0 and not self.train_path_packed.is_file(): from nemo.collections.llm.gpt.data.packed_sequence import prepare_packed_sequence_data @@ -115,6 +130,9 @@ def prepare_data(self) -> None: ) def setup(self, stage: str): + """Called by pytorch lightning in datamodule setup""" + + # data_sampler is used in `setup_data_sampler` in MegatronStrategy.setup self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, micro_batch_size=self.micro_batch_size, @@ -127,36 +145,78 @@ def setup(self, stage: str): # base_dataset_utils.get_datasets_weights_and_num_samples self.max_train_samples = int(math.ceil(self.global_batch_size * self.trainer.max_steps * 1.005)) + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save datamodule state. + + Returns: + A dictionary containing datamodule state. + + """ + consumed_samples = self.data_sampler.compute_consumed_samples( + self.trainer.global_step - self.data_sampler.init_global_step + ) + return {"consumed_samples": consumed_samples} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat + + Args: + state_dict: the datamodule state returned by ``state_dict``. + + """ + try: + from megatron.core.num_microbatches_calculator import update_num_microbatches + + except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import update_num_microbatches + consumed_samples = state_dict["consumed_samples"] + self.data_sampler.init_consumed_samples = consumed_samples + self.data_sampler.prev_consumed_samples = consumed_samples + + update_num_microbatches( + consumed_samples=consumed_samples, + consistency_check=False, + ) + self.data_sampler.if_first_step = 1 + def train_dataloader(self) -> DataLoader: + # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed, max_num_samples=self.max_train_samples, **self.dataset_kwargs, - ) + ), + mode="train", ) def val_dataloader(self) -> DataLoader: + # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( self.validation_path, is_test=True, **self.dataset_kwargs, ), + mode="validation", ) def test_dataloader(self) -> DataLoader: + # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( self.test_path, tokens_to_generate=32, is_test=True, **self.dataset_kwargs, - ) + ), + mode="test", ) @lru_cache def _create_dataset(self, path, is_test=False, **kwargs): + # pylint: disable=C0115,C0116 return create_sft_dataset( path, tokenizer=self.tokenizer, @@ -167,9 +227,11 @@ def _create_dataset(self, path, is_test=False, **kwargs): **kwargs, ) - def _create_dataloader(self, dataset, **kwargs) -> DataLoader: - return DataLoader( - dataset, + def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: + # pylint: disable=C0115,C0116 + return WrappedDataLoader( + mode=mode, + dataset=dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -179,10 +241,13 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader: @property def train_path(self) -> Path: + """Path to training dataset file""" return self.dataset_root / "training.jsonl" @property def train_path_packed(self) -> Path: + """Path to training dataset file for packed sequence. The file path contains a reference to the + tokenizer/model name since packed sequence dataset consists of tokenized indices.""" if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_data_path is not None: return self.packed_sequence_specs.packed_data_path @@ -195,13 +260,16 @@ def train_path_packed(self) -> Path: @property def validation_path(self) -> Path: + """Path to validation dataset file""" return self.dataset_root / "validation.jsonl" @property def test_path(self) -> Path: + """Path to test dataset file""" return self.dataset_root / "test.jsonl" def _extract_tokenizer_model_name(self) -> str: + """Automatically get the model name from model path.""" if self.packed_sequence_specs.tokenizer_model_name is not None: tokenizer_model_name = self.packed_sequence_specs.tokenizer_model_name elif isinstance(self.tokenizer, AutoTokenizer): diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index f3d202451c60..55d865ec238b 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -36,46 +36,105 @@ import nemo.lightning as nl from nemo.collections.llm.peft import LoRA from nemo.lightning import io -from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir, ckpt_to_weights_subdir +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir +from nemo.lightning.io.pl import ckpt_to_weights_subdir from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig -# We need this wrapper since mcore generate uses methods/properties such as tokenizer.detokenize, tokenizer.tokenize, tokenizer.bos, tokenizer.pad, etc. to encode and decode prompts class MCoreTokenizerWrappper: + """ + We need this wrapper since mcore generate uses methods/properties such as + tokenizer.detokenize, tokenizer.tokenize, tokenizer.bos, tokenizer.pad, etc. to encode and decode prompts + """ + def __init__(self, tokenizer): self.tokenizer = tokenizer self.eod = tokenizer.eod self.vocab_size = tokenizer.vocab_size def detokenize(self, tokens, remove_special_tokens=False): + """ + Detokenizes a list of tokens into a string. + + Args: + tokens (list): The list of tokens to detokenize. + remove_special_tokens (bool, optional): Whether to remove special tokens. Defaults to False. + + Returns: + str: The detokenized string. + """ return self.tokenizer.ids_to_text(tokens, remove_special_tokens) def tokenize(self, prompt): + """ + Tokenizes a prompt into a list of tokens. + + Args: + prompt (str): The prompt to tokenize. + + Returns: + list: The list of tokens. + """ return self.tokenizer.text_to_ids(prompt) @property def additional_special_tokens_ids(self): + """ + Gets the IDs of additional special tokens. + + Returns: + list: The IDs of additional special tokens. + """ return self.tokenizer.additional_special_tokens_ids @property def bos(self): + """ + Gets the ID of the beginning of sequence token. + + Returns: + int: The ID of the beginning of sequence token. + """ return self.tokenizer.bos_id @property def pad(self): + """ + Gets the ID of the padding token. + + Returns: + int: The ID of the padding token. + """ return self.tokenizer.pad_id # TODO: Move to lightning Fabric API. def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule): + """ + Sets up the trainer and restores the model from the given checkpoint path. + + It does the following: + - Defines a RestoreConfig to restore only model weights + - Disables setting up optimizers in the Trainer + - Calls strategy.setup_environment(), model.configure_model() and strategy.setup_megatron_parallel(trainer=trainer) + - Finally loads the model weights + + Args: + path (Path): The path to the checkpoint file. + trainer (nl.Trainer): The trainer object. + model (pl.LightningModule): The model object. + + Returns: + None + """ assert isinstance(trainer.strategy, MegatronStrategy), "Only MegatronStrategy is supported for trainer.strategy." assert trainer.strategy.context_parallel_size <= 1, "Context parallelism is not supported for inference." - if (adapter_meta_path := ckpt_to_weights_subdir(path) / ADAPTER_META_FILENAME).exists(): + if (adapter_meta_path := ckpt_to_weights_subdir(path, is_saving=False) / ADAPTER_META_FILENAME).exists(): with open(adapter_meta_path, "r") as f: metadata = json.load(f) restore_config = RestoreConfig( - path=metadata['model_ckpt_path'], + path=metadata["model_ckpt_path"], load_model_state=True, load_optim_state=False, ) @@ -107,7 +166,7 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl. model = lora(model) adapter_sharded_state_dict = {k: v for k, v in model.sharded_state_dict().items() if ".adapter." in k} adapter_state = trainer.strategy.checkpoint_io.load_checkpoint( - ckpt_to_weights_subdir(path), sharded_state_dict=adapter_sharded_state_dict + ckpt_to_weights_subdir(path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict ) trainer.strategy.load_model_state_dict(adapter_state, strict=False) @@ -118,6 +177,24 @@ def setup_model_and_tokenizer( params_dtype: torch.dtype = torch.bfloat16, inference_batch_times_seqlen_threshold: int = 1000, ) -> tuple[MegatronModule, MCoreTokenizerWrappper]: + """ + Sets up the model and tokenizer for inference. + + This function loads the model and tokenizer from the given checkpoint path, + sets up the trainer, and returns the Megatron inference-wrapped model and tokenizer. + + Args: + path (Path): The path to the checkpoint file. + trainer (nl.Trainer): The trainer object. + params_dtype (torch.dtype, optional): The data type of the model parameters. + Defaults to torch.bfloat16. + inference_batch_times_seqlen_threshold (int, optional): If batch-size times sequence-length is smaller + than this threshold then we will not use pipelining, otherwise we will. + + Returns: + tuple[MegatronModule, MCoreTokenizerWrappper]: + A tuple containing the inference-wrapped model and Mcore wrapped tokenizer. + """ model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model") _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) @@ -135,6 +212,26 @@ def generate( random_seed: Optional[int] = None, inference_params: Optional[CommonInferenceParams] = None, ) -> dict: + """ + Runs generate on the model with the given prompts. + + This function uses the loaded model, loaded tokenizer, and prompts to generate text. + It returns a dictionary containing the generated text. + + Args: + model (AbstractModelInferenceWrapper): The inference-wrapped model. + tokenizer (MCoreTokenizerWrappper): The tokenizer. + prompts (list[str]): The list of prompts to generate text for. + encoder_prompts (Optional[list[str]], optional): The list of encoder prompts. Defaults to None. + add_BOS (bool, optional): Whether to add the beginning of sequence token. Defaults to False. + max_batch_size (int, optional): The maximum batch size. Defaults to 4. + random_seed (Optional[int], optional): The random seed. Defaults to None. + inference_params (Optional[CommonInferenceParams], optional): The inference parameters defined in + Mcore's CommonInferenceParams. Defaults to None. + + Returns: + dict: A dictionary containing the generated results. + """ if encoder_prompts is not None: text_generation_controller = EncoderDecoderTextGenerationController( inference_wrapped_model=model, tokenizer=tokenizer diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 15367cb25aba..2f3e0e1e986e 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import shutil from dataclasses import dataclass from typing import Optional, Union @@ -22,6 +23,7 @@ from tqdm import tqdm from nemo.collections import llm +from nemo.lightning.ckpt_utils import CONTEXT_PATH from nemo.utils import logging from .utils import get_unwrapped_mcore_model @@ -259,7 +261,7 @@ def loop(model): return loop - def export(self, model: llm.GPTModel) -> None: + def export(self, model: llm.GPTModel, model_dir: str) -> None: assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate # TODO: Support megatron_amp_O2 @@ -277,15 +279,16 @@ def export(self, model: llm.GPTModel) -> None: use_nfs_workspace=use_nfs_workspace, ) - dist.barrier() # Wait until all ranks complete export_model_config step - logging.info(f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible...") + # Save the model context in order to restore its tokenizer later. The destination + # path is "nemo_context" as this name is used in nemo.export to setup tokenizer. + shutil.copytree( + os.path.join(model_dir, CONTEXT_PATH), + os.path.join(export_dir, "nemo_context"), + dirs_exist_ok=True, + ) + logging.info(f"Model context saved.") - if dist.get_rank() == 0: - try: - tokenizer_dst = os.path.join(export_dir, 'tokenizer') - model.tokenizer.tokenizer.save_pretrained(tokenizer_dst) - except Exception as err: - logging.warning("Could not save the tokenizer: " + str(err)) + logging.info(f"Export succeeded, model has been exported to {export_dir}.") def get_calib_data_iter( diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index fb43224d59a9..08b0b822cad4 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -37,12 +37,12 @@ from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import ( build_tokenizer, - get_tokenzier, + get_tokenizer, is_nemo_file, load_nemo_model, ) from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm -from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer +from nemo.export.trt_llm.qnemo.tokenizer_utils import TOKENIZER_CONFIG_FILE, get_nmt_tokenizer from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine from nemo.export.trt_llm.tensorrt_llm_run import ( @@ -294,7 +294,14 @@ def export( else: unpack_tarball(nemo_checkpoint_path, tmp_dir.name) nemo_checkpoint_path = tmp_dir.name - self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path) + + if os.path.exists(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)): + # Instantiate tokenizer for a legacy "Nemo 1" quantized checkpoint from a tokenizer config. + # Note that using the config is deprecated and it will be removed in future releases. + LOGGER.warning("Detected legacy tokenizer_config.yaml, using it to build tokenizer.") + self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path) + else: + self.tokenizer = get_tokenizer(nemo_checkpoint_path) qnemo_to_tensorrt_llm( nemo_checkpoint_path=nemo_checkpoint_path, @@ -1092,7 +1099,7 @@ def _load(self): if len(folders) > 0: try: self._load_config_file() - self.tokenizer = get_tokenzier(Path(os.path.join(self.model_dir))) + self.tokenizer = get_tokenizer(self.model_dir) self.model = load( tokenizer=self.tokenizer, engine_dir=self.model_dir, diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 407a7ce600c9..23d227d32acf 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -283,16 +283,17 @@ def copy_tokenizer_files(config, out_dir): outfile.write(infile.read()) -def get_tokenzier(tokenizer_dir_or_path: Path) -> PreTrainedTokenizer: - """Loads the tokenizer from the decoded NEMO weights dir.""" +def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenizer: + """Loads the tokenizer from the decoded NeMo weights dir.""" + tokenizer_dir_or_path = Path(tokenizer_dir_or_path) if (tokenizer_dir_or_path / "nemo_context").exists(): from nemo.lightning import io tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer") return build_tokenizer(tokenizer_spec) else: - if os.path.isdir(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")): - return AutoTokenizer.from_pretrained(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")) + if (tokenizer_dir_or_path / "huggingface_tokenizer").is_dir(): + return AutoTokenizer.from_pretrained(tokenizer_dir_or_path / "huggingface_tokenizer") model_path = ( tokenizer_dir_or_path / "tokenizer.model" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path diff --git a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py index 7a1f7a6cc31d..eac1ab743849 100644 --- a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +++ b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py @@ -77,8 +77,6 @@ def qnemo_to_tensorrt_llm( use_qdq = quant_algo in ["FP8", "W8A8_SQ_PER_CHANNEL"] - builder_opt = 4 if "RecurrentGemma" not in config.architecture else 0 - speculative_decoding_mode = "medusa" if "Medusa" in config.architecture else None build_cmd = "trtllm-build " @@ -90,7 +88,6 @@ def qnemo_to_tensorrt_llm( build_cmd += f"--max_input_len {max_input_len} " build_cmd += f"--max_beam_width {max_beam_width} " build_cmd += f"--max_prompt_embedding_table_size {max_prompt_embedding_table_size} " - build_cmd += f"--builder_opt {builder_opt} " build_cmd += f"--paged_kv_cache {'enable' if paged_kv_cache else 'disable'} " build_cmd += f"--use_paged_context_fmha {'enable' if paged_context_fmha else 'disable'} " build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} " diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index 36efa9259f9d..beca40bcd3d7 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -29,11 +29,6 @@ def get_nmt_tokenizer(nemo_checkpoint_path: str): """Build tokenizer from Nemo tokenizer config.""" - tokenizer_dir = os.path.join(nemo_checkpoint_path, TOKENIZER_DIR) - if os.path.exists(tokenizer_dir): - print(f"Initializing tokenizer from {TOKENIZER_DIR} directory") - return AutoTokenizer.from_pretrained(tokenizer_dir) - print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)) diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index b417c088b22e..5d7d019c6099 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -724,6 +724,10 @@ def save_artifacts(model, output_dir: str, use_abspath: bool = False) -> None: app_state = AppState() model_file = app_state.model_restore_path model_cfg = copy.deepcopy(model.cfg) + + if model_cfg.tokenizer.library == "huggingface": + model.tokenizer.save_pretrained(os.path.join(output_dir, "huggingface_tokenizer")) + if not hasattr(model, "artifacts"): if hasattr(model_cfg, "tokenizer"): OmegaConf.save(model_cfg.tokenizer, os.path.join(output_dir, "tokenizer_config.yaml")) diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index 0fd2c5682e8a..c04d32290e5f 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -92,7 +92,7 @@ def main(): quantizer = quantization.Quantizer(quantization_config, export_config) model = quantization.load_with_modelopt_layer_spec(args.nemo_checkpoint, args.calib_tp, args.calib_pp) model = quantizer.quantize(model) - quantizer.export(model) + quantizer.export(model, args.nemo_checkpoint) if __name__ == '__main__': diff --git a/tests/collections/asr/conftest.py b/tests/collections/asr/conftest.py index dba29f949fb0..a9bc13153164 100644 --- a/tests/collections/asr/conftest.py +++ b/tests/collections/asr/conftest.py @@ -19,6 +19,8 @@ import pytest import torch +from nemo.collections.asr.models import ASRModel + class RNNTTestHelper: @staticmethod @@ -353,3 +355,18 @@ def rnnt_test_helper() -> Type[RNNTTestHelper]: @pytest.fixture(scope="session") def rnn_loss_sample_data() -> Type[RnntLossSampleData]: return RnntLossSampleData + + +@pytest.fixture(scope='session') +def fast_conformer_transducer_model(): + return ASRModel.from_pretrained("stt_en_fastconformer_transducer_large") + + +@pytest.fixture(scope='session') +def fast_conformer_ctc_model(): + return ASRModel.from_pretrained("stt_en_fastconformer_ctc_large") + + +@pytest.fixture(scope='session') +def fast_conformer_hybrid_model(): + return ASRModel.from_pretrained("parakeet-tdt_ctc-110m") diff --git a/tests/collections/asr/mixins/test_transcription.py b/tests/collections/asr/mixins/test_transcription.py index 1a6f38681d0c..6e2d5fe16c68 100644 --- a/tests/collections/asr/mixins/test_transcription.py +++ b/tests/collections/asr/mixins/test_transcription.py @@ -23,7 +23,6 @@ from torch.utils.data import DataLoader, Dataset from nemo.collections.asr.data.audio_to_text import _speech_collate_fn -from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.mixins import TranscribeConfig, TranscriptionMixin from nemo.collections.asr.parts.mixins.transcription import GenericTranscriptionType from nemo.collections.asr.parts.utils import Hypothesis @@ -44,6 +43,23 @@ def forward(self, x): return out +@pytest.mark.with_downloads() +@pytest.fixture() +def audio_files(test_data_dir): + """ + Returns a list of audio files for testing. + """ + import soundfile as sf + + audio_file1 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") + audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an104-mrcb-b.wav") + + audio1, _ = sf.read(audio_file1, dtype='float32') + audio2, _ = sf.read(audio_file2, dtype='float32') + + return audio1, audio2 + + class TranscribableDummy(DummyModel, TranscriptionMixin): def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): super()._transcribe_on_begin(audio, trcfg) @@ -297,12 +313,11 @@ class OverrideConfig(TranscribeConfig): pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_return_hypothesis(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") + def test_transcribe_return_hypothesis(self, test_data_dir, fast_conformer_ctc_model): audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - # Numpy array test - outputs = model.transcribe(audio_file, batch_size=1, return_hypotheses=True) + # Audio file test + outputs = fast_conformer_ctc_model.transcribe(audio_file, batch_size=1, return_hypotheses=True) assert len(outputs) == 1 assert isinstance(outputs[0], Hypothesis) @@ -313,62 +328,82 @@ def test_transcribe_return_hypothesis(self, test_data_dir): @pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_tensor(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") - - # Load audio file - import soundfile as sf - - audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - audio, sr = sf.read(audio_file, dtype='float32') + def test_transcribe_tensor(self, audio_files, fast_conformer_ctc_model): + audio, _ = audio_files # Numpy array test - outputs = model.transcribe(audio, batch_size=1) + outputs = fast_conformer_ctc_model.transcribe(audio, batch_size=1) assert len(outputs) == 1 assert isinstance(outputs[0], str) @pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_multiple_tensor(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") - - # Load audio file - import soundfile as sf - - audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - audio, sr = sf.read(audio_file, dtype='float32') + def test_transcribe_multiple_tensor(self, audio_files, fast_conformer_ctc_model): - audio_file_2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an104-mrcb-b.wav") - audio_2, sr = sf.read(audio_file_2, dtype='float32') + audio, audio_2 = audio_files # Mix second audio to torch.tensor() audio_2 = torch.tensor(audio_2) # Numpy array test - outputs = model.transcribe([audio, audio_2], batch_size=2) + outputs = fast_conformer_ctc_model.transcribe([audio, audio_2], batch_size=2) assert len(outputs) == 2 assert isinstance(outputs[0], str) assert isinstance(outputs[1], str) @pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_dataloader(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") - - # Load audio file - import soundfile as sf - - audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - audio, sr = sf.read(audio_file, dtype='float32') + def test_transcribe_dataloader(self, audio_files, fast_conformer_ctc_model): - audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an152-mwhw-b.wav") - audio2, sr = sf.read(audio_file2, dtype='float32') + audio, audio2 = audio_files dataset = DummyDataset([audio, audio2]) collate_fn = lambda x: _speech_collate_fn(x, pad_id=0) dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn) # DataLoader test - outputs = model.transcribe(dataloader, batch_size=1) + outputs = fast_conformer_ctc_model.transcribe(dataloader, batch_size=1) assert len(outputs) == 2 assert isinstance(outputs[0], str) assert isinstance(outputs[1], str) + + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_timestamps_with_transcribe(self, audio_files, fast_conformer_ctc_model): + audio1, audio2 = audio_files + + output = fast_conformer_ctc_model.transcribe([audio1, audio2], timestamps=True) + + # check len of output + assert len(output) == 2 + + # check hypothesis object + assert isinstance(output[0], Hypothesis) + # check transcript + assert output[0].text == 'stop' + assert output[1].text == 'start' + + # check timestamp + assert output[0].timestep['segment'][0]['start'] == pytest.approx(0.4) + assert output[0].timestep['segment'][0]['end'] == pytest.approx(0.48) + + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_timestamps_with_transcribe_hybrid(self, audio_files, fast_conformer_hybrid_model): + audio1, audio2 = audio_files + + output = fast_conformer_hybrid_model.transcribe([audio1, audio2], timestamps=True) + + # check len of output + assert len(output) == 2 + + output = output[1] # Transducer returns tuple + + # check hypothesis object + assert isinstance(output[0], Hypothesis) + # check transcript + assert output[0].text == 'Stop?' + assert output[1].text == 'Start.' + + # check timestamp + assert output[0].timestep['segment'][0]['start'] == pytest.approx(0.48) + assert output[0].timestep['segment'][0]['end'] == pytest.approx(0.72)