Skip to content

Commit

Permalink
Fixes #112 : add support for intel arc gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeronymous committed Nov 15, 2023
1 parent eda426b commit 9d311a7
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
import torch
import torch.nn.functional as F

from importlib.util import find_spec
if find_spec("intel_extension_for_pytorch") is not None:
try:
import intel_extension_for_pytorch
except ImportError:
pass

# For alignment
import numpy as np
import dtw
Expand Down Expand Up @@ -2046,9 +2053,18 @@ def write_csv(transcript, file, sep = ",", text_first=True, format_timestamps=No
# CUDA initialization may fail on old GPU card
def force_cudnn_initialization(device=None, s=32):
if device is None:
device = torch.device('cuda')
device = get_default_device()
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=device), torch.zeros(s, s, s, s, device=device))

def get_default_device():
if torch.cuda.is_available():
device = "cuda"
elif find_spec('torch.xpu') is not None and torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"
return device

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
Expand Down Expand Up @@ -2242,7 +2258,7 @@ def get_do_write(output_format):
parser.add_argument('audio', help="audio file(s) to transcribe", nargs='+')
parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small")
parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str)
parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str)
valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"]
def str2output_formats(string):
Expand Down

0 comments on commit 9d311a7

Please sign in to comment.