From 9d311a7e3b1dd4c7c2f2bac6ecb6d098f8030836 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Wed, 15 Nov 2023 22:58:19 +0100 Subject: [PATCH] Fixes #112 : add support for intel arc gpus --- whisper_timestamped/transcribe.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index c86af76..02a86ec 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -15,6 +15,13 @@ import torch import torch.nn.functional as F +from importlib.util import find_spec +if find_spec("intel_extension_for_pytorch") is not None: + try: + import intel_extension_for_pytorch + except ImportError: + pass + # For alignment import numpy as np import dtw @@ -2046,9 +2053,18 @@ def write_csv(transcript, file, sep = ",", text_first=True, format_timestamps=No # CUDA initialization may fail on old GPU card def force_cudnn_initialization(device=None, s=32): if device is None: - device = torch.device('cuda') + device = get_default_device() torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=device), torch.zeros(s, s, s, s, device=device)) +def get_default_device(): + if torch.cuda.is_available(): + device = "cuda" + elif find_spec('torch.xpu') is not None and torch.xpu.is_available(): + device = "xpu" + else: + device = "cpu" + return device + # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. _ALIGNMENT_HEADS = { @@ -2242,7 +2258,7 @@ def get_do_write(output_format): parser.add_argument('audio', help="audio file(s) to transcribe", nargs='+') parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small") parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str) - parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str) valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"] def str2output_formats(string):