From 299a658f5ea9809753a0dc1d4deaf8fa579a9f81 Mon Sep 17 00:00:00 2001 From: leuc Date: Thu, 18 May 2023 17:10:33 +0200 Subject: [PATCH] Add support for Intel GPU's Requires Intel Extension for PyTorch v1.13.120+xpu https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html Tested on Intel ARC A770 16GB VRAM with large model --- whisper/__init__.py | 11 ++++++++++- whisper/transcribe.py | 5 ++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/whisper/__init__.py b/whisper/__init__.py index 379133b6a..8d33446aa 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -4,8 +4,11 @@ import urllib import warnings from typing import List, Optional, Union +from importlib.util import find_spec import torch +if find_spec("intel_extension_for_pytorch") is not None: + import intel_extension_for_pytorch from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim @@ -122,7 +125,13 @@ def load_model( """ if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif find_spec('torch.xpu') is not None and torch.xpu.is_available(): + device = "xpu" + else: + device = "cpu" + if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ff73a5530..3a096ae0c 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,6 +2,7 @@ import os import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union +from importlib.util import find_spec import numpy as np import torch @@ -110,6 +111,8 @@ def transcribe( if model.device == torch.device("cpu"): if torch.cuda.is_available(): warnings.warn("Performing inference on CPU when CUDA is available") + if find_spec('torch.xpu') is not None and torch.xpu.is_available(): + warnings.warn("Performing inference on CPU when XPU is available") if dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32 @@ -379,7 +382,7 @@ def cli(): parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device", default=None, help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")