diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 5f5543fe11e53..d54ee34c18cd8 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -16,6 +16,7 @@ import glob import json import os +import time from dataclasses import dataclass, field, is_dataclass from tempfile import NamedTemporaryFile from typing import List, Optional, Union @@ -84,6 +85,8 @@ langid: Str used for convert_num_to_words during groundtruth cleaning use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) + calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. + # Usage ASR model can be specified by either "model_path" or "pretrained_name". Data for transcription can be defined with either "audio_dir" or "dataset_manifest". @@ -153,6 +156,7 @@ class TranscriptionConfig: allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) amp: bool = False amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp + compute_dtype: str = "float32" matmul_precision: str = "highest" # Literal["highest", "high", "medium"] audio_type: str = "wav" @@ -208,6 +212,8 @@ class TranscriptionConfig: allow_partial_transcribe: bool = False extract_nbest: bool = False # Extract n-best hypotheses from the model + calculate_rtfx: bool = False + @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: @@ -266,6 +272,14 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis asr_model.set_trainer(trainer) asr_model = asr_model.eval() + if cfg.compute_dtype != "float32" and cfg.amp: + raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32") + + amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + + if cfg.compute_dtype != "float32": + asr_model.to(getattr(torch, cfg.compute_dtype)) + # we will adjust this flag if the model does not support it compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs @@ -378,7 +392,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis else: @contextlib.contextmanager - def autocast(dtype=None): + def autocast(dtype=None, enabled=True): yield # Compute output filename @@ -394,10 +408,22 @@ def autocast(dtype=None): # transcribe audio - amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + if cfg.calculate_rtfx: + total_duration = 0.0 + + with open(cfg.dataset_manifest, "rt") as fh: + for line in fh: + item = json.loads(line) + if "duration" not in item: + raise ValueError( + f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." + ) + total_duration += item["duration"] - with autocast(dtype=amp_dtype): + with autocast(dtype=amp_dtype, enabled=cfg.amp): with torch.no_grad(): + if cfg.calculate_rtfx: + start_time = time.time() if partial_audio: transcriptions = transcribe_partial_audio( asr_model=asr_model, @@ -420,10 +446,13 @@ def autocast(dtype=None): override_cfg.lang_field = cfg.gt_lang_attr_name if hasattr(override_cfg, "prompt"): override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) + transcriptions = asr_model.transcribe( audio=filepaths, override_config=override_cfg, ) + if cfg.calculate_rtfx: + transcribe_time = time.time() - start_time if cfg.dataset_manifest is not None: logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}") @@ -475,6 +504,9 @@ def autocast(dtype=None): logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") logging.info(f"{total_res}") + if cfg.calculate_rtfx: + logging.info(f"Dataset RTFx {(total_duration/transcribe_time)}") + return cfg diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 177da81f85f28..093419c3ca0ca 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -668,7 +668,7 @@ def test_dataloader(self): def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): super()._transcribe_on_begin(audio, trcfg) - # Freeze the encoder and decoure_exder modules + # Freeze the encoder and decoder modules self.encoder.freeze() self.decoder.freeze() @@ -706,7 +706,11 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen logits_len = logits_len.cpu() # dump log probs per file for idx in range(logits_cpu.shape[0]): - current_hypotheses[idx].y_sequence = logits_cpu[idx][: logits_len[idx]] + # We clone because we don't want references to the + # cudaMallocHost()-allocated tensor to be floating + # around. Were that to be the case, then the pinned + # memory cache would always miss. + current_hypotheses[idx].y_sequence = logits_cpu[idx, : logits_len[idx]].clone() if current_hypotheses[idx].alignments is None: current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence del logits_cpu diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index d45c0acf314fb..2dca468fab359 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -39,7 +39,7 @@ ) from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ -from nemo.utils import logging +from nemo.utils import logging, logging_mode try: import torchaudio @@ -85,11 +85,27 @@ def __init__(self, win_length, hop_length): None: torch.ones, } + # Normally, when you call to(dtype) on a torch.nn.Module, all + # floating point parameters and buffers will change to that + # dtype, rather than being float32. The AudioPreprocessor + # classes, uniquely, don't actually have any parameters or + # buffers from what I see. In addition, we want the input to + # the preprocessor to be float32, but need to create the + # output in appropriate precision. We have this empty tensor + # here just to detect which dtype tensor this module should + # output at the end of execution. + self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False) + @typecheck() @torch.no_grad() def forward(self, input_signal, length): - processed_signal, processed_length = self.get_features(input_signal, length) - + if input_signal.dtype != torch.float32: + logging.warn( + f"AudioPreprocessor received an input signal of dtype {input_signal.dtype}, rather than torch.float32. In sweeps across multiple datasets, we have found that the preprocessor is not robust to low precision mathematics. As such, it runs in float32. Your input will be cast to float32, but this is not necessarily enough to recovery full accuracy. For example, simply casting input_signal from torch.float32 to torch.bfloat16, then back to torch.float32 before running AudioPreprocessor causes drops in absolute WER of up to 0.1%. torch.bfloat16 simply does not have enough mantissa bits to represent enough values in the range [-1.0,+1.0] correctly.", + mode=logging_mode.ONCE, + ) + processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length) + processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype) return processed_signal, processed_length @abstractmethod diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index d8f0e58833f71..d723ce85d2ce7 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -679,7 +679,8 @@ def set_max_audio_length(self, max_audio_length): """ self.max_audio_length = max_audio_length device = next(self.parameters()).device - self.pos_enc.extend_pe(max_audio_length, device) + dtype = next(self.parameters()).dtype + self.pos_enc.extend_pe(max_audio_length, device, dtype) def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device): if self.self_attention_model != "rel_pos_local_attn": diff --git a/nemo/collections/asr/modules/squeezeformer_encoder.py b/nemo/collections/asr/modules/squeezeformer_encoder.py index ce0d49843d4f8..ae779380edf65 100644 --- a/nemo/collections/asr/modules/squeezeformer_encoder.py +++ b/nemo/collections/asr/modules/squeezeformer_encoder.py @@ -99,8 +99,7 @@ def input_example(self, max_batch=1, max_dim=256): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), @@ -110,8 +109,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), @@ -253,7 +251,11 @@ def __init__( # Chose same type of positional encoding as the originally determined above if self_attention_model == "rel_pos": self.time_reduce_pos_enc = RelPositionalEncoding( - d_model=d_model, dropout_rate=0.0, max_len=pos_emb_max_len, xscale=None, dropout_rate_emb=0.0, + d_model=d_model, + dropout_rate=0.0, + max_len=pos_emb_max_len, + xscale=None, + dropout_rate_emb=0.0, ) else: self.time_reduce_pos_enc = PositionalEncoding( @@ -275,20 +277,21 @@ def __init__( self.interctc_capture_at_layers = None def set_max_audio_length(self, max_audio_length): - """ Sets maximum input length. - Pre-calculates internal seq_range mask. + """Sets maximum input length. + Pre-calculates internal seq_range mask. """ self.max_audio_length = max_audio_length device = next(self.parameters()).device + dtype = next(self.parameters()).dtype seq_range = torch.arange(0, self.max_audio_length, device=device) if hasattr(self, 'seq_range'): self.seq_range = seq_range else: self.register_buffer('seq_range', seq_range, persistent=False) - self.pos_enc.extend_pe(max_audio_length, device) + self.pos_enc.extend_pe(max_audio_length, device, dtype) if self.time_reduce_pos_enc is not None: - self.time_reduce_pos_enc.extend_pe(max_audio_length, device) + self.time_reduce_pos_enc.extend_pe(max_audio_length, device, dtype) @typecheck() def forward(self, audio_signal, length=None): @@ -434,7 +437,9 @@ def _update_adapter_cfg_input_dim(self, cfg: DictConfig): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) return cfg - def get_accepted_adapter_types(self,) -> Set[type]: + def get_accepted_adapter_types( + self, + ) -> Set[type]: types = super().get_accepted_adapter_types() if len(types) == 0: diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 261e97a225dd5..5b9461d0a3896 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -181,7 +181,7 @@ class TranscriptionMixin(ABC): """ - @torch.no_grad() + @torch.inference_mode() def transcribe( self, audio: Union[str, List[str], np.ndarray, DataLoader], @@ -381,7 +381,6 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig for test_batch in tqdm(dataloader, desc="Transcribing", disable=not verbose): # Move batch to device test_batch = move_to_device(test_batch, transcribe_cfg._internal.device) - # Run forward pass model_outputs = self._transcribe_forward(test_batch, transcribe_cfg) processed_outputs = self._transcribe_output_processing(model_outputs, transcribe_cfg) diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index efd23ef446288..093cde63c4393 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -377,7 +377,7 @@ def forward(self, x, pad_mask=None, cache=None): x = self.pointwise_activation(x) if pad_mask is not None: - x = x.float().masked_fill(pad_mask.unsqueeze(1), 0.0) + x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) x = self.depthwise_conv(x, cache=cache) if cache is not None: diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 19d7134059538..de86132a721bc 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -669,7 +669,10 @@ def _compute_out_global_to_all( global_attn_scores = global_attn_scores.view(batch_size * self.h, max_num_global_attn_indices, seq_len) # compute global attn probs - global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1, dtype=torch.float32) + if self.training: + global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1, dtype=torch.float32) + else: + global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1) global_attn_probs = self.dropout(global_attn_probs_float) @@ -906,7 +909,7 @@ def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_rat else: self.dropout_emb = None - def create_pe(self, positions): + def create_pe(self, positions, dtype): pos_length = positions.size(0) pe = torch.zeros(pos_length, self.d_model, device=positions.device) div_term = torch.exp( @@ -915,18 +918,18 @@ def create_pe(self, positions): ) pe[:, 0::2] = torch.sin(positions * div_term) pe[:, 1::2] = torch.cos(positions * div_term) - pe = pe.unsqueeze(0) + pe = pe.unsqueeze(0).to(dtype) if hasattr(self, 'pe'): self.pe = pe else: self.register_buffer('pe', pe, persistent=False) - def extend_pe(self, length, device): + def extend_pe(self, length, device, dtype): """Reset and extend the positional encodings if needed.""" if hasattr(self, 'pe') and self.pe.size(1) >= length: return positions = torch.arange(0, length, dtype=torch.float32, device=device).unsqueeze(1) - self.create_pe(positions=positions) + self.create_pe(positions=positions, dtype=dtype) def forward(self, x: torch.Tensor, cache_len=0): """Adds positional encoding. @@ -958,7 +961,7 @@ class RelPositionalEncoding(PositionalEncoding): dropout_rate_emb (float): dropout rate for the positional embeddings """ - def extend_pe(self, length, device): + def extend_pe(self, length, device, dtype): """Reset and extend the positional encodings if needed.""" needed_size = 2 * length - 1 if hasattr(self, 'pe') and self.pe.size(1) >= needed_size: @@ -966,7 +969,7 @@ def extend_pe(self, length, device): # positions would be from negative numbers to positive # positive positions would be used for left positions and negative for right positions positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1) - self.create_pe(positions=positions) + self.create_pe(positions=positions, dtype=dtype) def forward(self, x, cache_len=0): """Compute positional encoding. @@ -1012,7 +1015,7 @@ def __init__(self, att_context_size, **kwargs): self.left_context = att_context_size[0] self.right_context = att_context_size[1] - def extend_pe(self, length, device): + def extend_pe(self, length, device, dtype): """Reset and extend the positional encodings only at the beginning""" if hasattr(self, 'pe'): return @@ -1020,7 +1023,7 @@ def extend_pe(self, length, device): positions = torch.arange( self.left_context, -self.right_context - 1, -1, dtype=torch.float32, device=device ).unsqueeze(1) - self.create_pe(positions=positions) + self.create_pe(positions=positions, dtype=dtype) def forward(self, x, cache_len=0): """Compute positional encoding. diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 8465406224e7d..c270e5c3a0f7b 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -232,7 +232,7 @@ def get_buffered_pred_feat_multitaskAED( def wrap_transcription(hyps: List[str]) -> List[rnnt_utils.Hypothesis]: - """ Wrap transcription to the expected format in func write_transcription """ + """Wrap transcription to the expected format in func write_transcription""" wrapped_hyps = [] for hyp in hyps: hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], text=hyp) @@ -241,7 +241,7 @@ def wrap_transcription(hyps: List[str]) -> List[rnnt_utils.Hypothesis]: def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel, str]: - """ Setup model from cfg and return model and model name for next step """ + """Setup model from cfg and return model and model name for next step""" if cfg.model_path is not None and cfg.model_path != "None": # restore model from .nemo file path model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True) @@ -249,13 +249,15 @@ def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel, imported_class = model_utils.import_class_by_path(classpath) # type: ASRModel logging.info(f"Restoring model : {imported_class.__name__}") asr_model = imported_class.restore_from( - restore_path=cfg.model_path, map_location=map_location, + restore_path=cfg.model_path, + map_location=map_location, ) # type: ASRModel model_name = os.path.splitext(os.path.basename(cfg.model_path))[0] else: # restore model by name asr_model = ASRModel.from_pretrained( - model_name=cfg.pretrained_name, map_location=map_location, + model_name=cfg.pretrained_name, + map_location=map_location, ) # type: ASRModel model_name = cfg.pretrained_name @@ -269,7 +271,7 @@ def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel, def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: - """ Prepare audio data and decide whether it's partial_audio condition. """ + """Prepare audio data and decide whether it's partial_audio condition.""" # this part may need refactor alongsides with refactor of transcribe partial_audio = False @@ -282,11 +284,20 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: logging.error(f"The input dataset_manifest {cfg.dataset_manifest} is empty. Exiting!") return None + audio_key = cfg.get('audio_key', 'audio_filepath') + + with open(cfg.dataset_manifest, "rt") as fh: + for line in fh: + item = json.loads(line) + item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest) + if item.get("duration") is None and cfg.presort_manifest: + raise ValueError( + f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." + ) all_entries_have_offset_and_duration = True for item in read_and_maybe_sort_manifest(cfg.dataset_manifest, try_sort=cfg.presort_manifest): if not ("offset" in item and "duration" in item): all_entries_have_offset_and_duration = False - audio_key = cfg.get('audio_key', 'audio_filepath') audio_file = get_full_path(audio_file=item[audio_key], manifest_file=cfg.dataset_manifest) filepaths.append(audio_file) partial_audio = all_entries_have_offset_and_duration @@ -322,7 +333,7 @@ def restore_transcription_order(manifest_path: str, transcriptions: list) -> lis def compute_output_filename(cfg: DictConfig, model_name: str) -> DictConfig: - """ Compute filename of output manifest and update cfg""" + """Compute filename of output manifest and update cfg""" if cfg.output_filename is None: # create default output filename if cfg.audio_dir is not None: @@ -363,7 +374,7 @@ def write_transcription( compute_langs: bool = False, compute_timestamps: bool = False, ) -> Tuple[str, str]: - """ Write generated transcription to output file. """ + """Write generated transcription to output file.""" if cfg.append_pred: logging.info(f'Transcripts will be written in "{cfg.output_filename}" file') if cfg.pred_name_postfix is not None: @@ -533,7 +544,11 @@ def transcribe_partial_audio( lg = logits[idx][: logits_len[idx]] hypotheses.append(lg) else: - current_hypotheses, _ = decode_function(logits, logits_len, return_hypotheses=return_hypotheses,) + current_hypotheses, _ = decode_function( + logits, + logits_len, + return_hypotheses=return_hypotheses, + ) if return_hypotheses: # dump log probs per file @@ -567,10 +582,9 @@ def compute_metrics_per_sample( punctuation_marks: List[str] = [".", ",", "?"], output_manifest_path: str = None, ) -> dict: - ''' Computes metrics per sample for given manifest - + Args: manifest_path: str, Required - path to dataset JSON manifest file (in NeMo format) reference_field: str, Optional - name of field in .json manifest with the reference text ("text" by default). @@ -578,7 +592,7 @@ def compute_metrics_per_sample( metrics: list[str], Optional - list of metrics to be computed (currently supported "wer", "cer", "punct_er") punctuation_marks: list[str], Optional - list of punctuation marks for computing punctuation error rate ([".", ",", "?"] by default). output_manifest_path: str, Optional - path where .json manifest with calculated metrics will be saved. - + Returns: samples: dict - Dict of samples with calculated metrics ''' diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index 2637e33ebd2aa..c4ee4b97a2a6a 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -150,7 +150,7 @@ def test_relmha_adapter_init(self, n_head, proj_dim): relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50) pad_mask, att_mask = get_mask(lengths) - relpos_enc.extend_pe(lengths.max(), device='cpu') + relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32) with torch.no_grad(): assert adapter.linear_out.weight.sum() == 0 @@ -171,7 +171,7 @@ def test_abspos_encoding_init(self): relpos_enc = adapter_modules.PositionalEncodingAdapter(d_model=50) - relpos_enc.extend_pe(lengths.max(), device='cpu') + relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32) with torch.no_grad(): out, pos_emb = relpos_enc(x) @@ -187,7 +187,7 @@ def test_relpos_encoding_init(self): relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50) - relpos_enc.extend_pe(lengths.max(), device='cpu') + relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32) with torch.no_grad(): out, pos_emb = relpos_enc(x)