diff --git a/whisper/__init__.py b/whisper/__init__.py index 379133b6a..8d33446aa 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -4,8 +4,11 @@ import urllib import warnings from typing import List, Optional, Union +from importlib.util import find_spec import torch +if find_spec("intel_extension_for_pytorch") is not None: + import intel_extension_for_pytorch from tqdm import tqdm from .audio import load_audio, log_mel_spectrogram, pad_or_trim @@ -122,7 +125,13 @@ def load_model( """ if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif find_spec('torch.xpu') is not None and torch.xpu.is_available(): + device = "xpu" + else: + device = "cpu" + if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6e43a22fa..74e8c5103 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,6 +2,7 @@ import os import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union +from importlib.util import find_spec import numpy as np import torch @@ -110,6 +111,8 @@ def transcribe( if model.device == torch.device("cpu"): if torch.cuda.is_available(): warnings.warn("Performing inference on CPU when CUDA is available") + if find_spec('torch.xpu') is not None and torch.xpu.is_available(): + warnings.warn("Performing inference on CPU when XPU is available") if dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32 @@ -383,7 +386,7 @@ def cli(): parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device", default=None, help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")