Skip to content

Commit

Permalink
Add support for Intel GPU's
Browse files Browse the repository at this point in the history
Requires Intel Extension for PyTorch v1.13.120+xpu
https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html

Tested on Intel ARC A770 16GB VRAM with large model
  • Loading branch information
leuc committed May 18, 2023
1 parent 248b6cb commit 299a658
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
11 changes: 10 additions & 1 deletion whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import urllib
import warnings
from typing import List, Optional, Union
from importlib.util import find_spec

import torch
if find_spec("intel_extension_for_pytorch") is not None:
import intel_extension_for_pytorch
from tqdm import tqdm

from .audio import load_audio, log_mel_spectrogram, pad_or_trim
Expand Down Expand Up @@ -122,7 +125,13 @@ def load_model(
"""

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = "cuda"
elif find_spec('torch.xpu') is not None and torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"

if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
Expand Down
5 changes: 4 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from importlib.util import find_spec

import numpy as np
import torch
Expand Down Expand Up @@ -110,6 +111,8 @@ def transcribe(
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if find_spec('torch.xpu') is not None and torch.xpu.is_available():
warnings.warn("Performing inference on CPU when XPU is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
Expand Down Expand Up @@ -379,7 +382,7 @@ def cli():
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device", default=None, help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
Expand Down

0 comments on commit 299a658

Please sign in to comment.