Skip to content

Commit 9dbcce8

Browse files
xendoJerzy Zagorski
andauthored
[Neuron] [Bugfix] Fix neuron startup (vllm-project#9374)
Co-authored-by: Jerzy Zagorski <[email protected]>
1 parent a48e3ec commit 9dbcce8

File tree

7 files changed

+37
-18
lines changed

7 files changed

+37
-18
lines changed

vllm/_custom_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import vllm._moe_C # noqa: F401
2727
supports_moe_ops = True
2828

29-
if TYPE_CHECKING:
29+
# neuron has torch version that doesn't even have impl_abstract
30+
if TYPE_CHECKING or current_platform.is_neuron():
3031

3132
def register_fake(fn):
3233
return lambda name: fn

vllm/config.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
get_hf_image_processor_config,
1818
get_hf_text_config)
1919
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
20-
is_hip, is_neuron, is_openvino, is_xpu,
21-
print_warning_once)
20+
is_hip, is_openvino, is_xpu, print_warning_once)
2221

2322
if TYPE_CHECKING:
2423
from ray.util.placement_group import PlacementGroup
@@ -215,8 +214,10 @@ def __init__(self,
215214
self.is_attention_free = self._init_attention_free()
216215
self.has_inner_state = self._init_has_inner_state()
217216

218-
self.override_neuron_config = override_neuron_config if is_neuron(
219-
) else None
217+
if current_platform.is_neuron():
218+
self.override_neuron_config = override_neuron_config
219+
else:
220+
self.override_neuron_config = None
220221

221222
supported_tasks, task = self._resolve_task(task, self.hf_config)
222223
self.supported_tasks = supported_tasks
@@ -368,7 +369,7 @@ def _verify_quantization(self) -> None:
368369
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
369370
" is not set, enabling VLLM_USE_TRITON_AWQ.")
370371
envs.VLLM_USE_TRITON_AWQ = True
371-
if is_neuron(
372+
if current_platform.is_neuron(
372373
) and self.quantization not in neuron_supported_quantization:
373374
raise ValueError(
374375
f"{self.quantization} quantization is currently not "
@@ -1112,7 +1113,7 @@ def __init__(self, device: str = "auto") -> None:
11121113
# Automated device type detection
11131114
if current_platform.is_cuda_alike():
11141115
self.device_type = "cuda"
1115-
elif is_neuron():
1116+
elif current_platform.is_neuron():
11161117
self.device_type = "neuron"
11171118
elif is_openvino():
11181119
self.device_type = "openvino"

vllm/platforms/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@
5858
except Exception:
5959
pass
6060

61+
is_neuron = False
62+
try:
63+
import transformers_neuronx # noqa: F401
64+
is_neuron = True
65+
except ImportError:
66+
pass
67+
6168
if is_tpu:
6269
# people might install pytorch built with cuda but run on tpu
6370
# so we need to check tpu first
@@ -75,6 +82,9 @@
7582
elif is_cpu:
7683
from .cpu import CpuPlatform
7784
current_platform = CpuPlatform()
85+
elif is_neuron:
86+
from .neuron import NeuronPlatform
87+
current_platform = NeuronPlatform()
7888
else:
7989
current_platform = UnspecifiedPlatform()
8090

vllm/platforms/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
1010
TPU = enum.auto()
1111
XPU = enum.auto()
1212
CPU = enum.auto()
13+
NEURON = enum.auto()
1314
UNSPECIFIED = enum.auto()
1415

1516

@@ -48,6 +49,9 @@ def is_xpu(self) -> bool:
4849
def is_cpu(self) -> bool:
4950
return self._enum == PlatformEnum.CPU
5051

52+
def is_neuron(self) -> bool:
53+
return self._enum == PlatformEnum.NEURON
54+
5155
def is_cuda_alike(self) -> bool:
5256
"""Stateless version of :func:`torch.cuda.is_available`."""
5357
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

vllm/platforms/neuron.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .interface import Platform, PlatformEnum
2+
3+
4+
class NeuronPlatform(Platform):
5+
_enum = PlatformEnum.NEURON
6+
7+
@classmethod
8+
def get_device_name(cls, device_id: int = 0) -> str:
9+
return "neuron"

vllm/triton_utils/importing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from importlib.util import find_spec
22

33
from vllm.logger import init_logger
4+
from vllm.platforms import current_platform
45

56
logger = init_logger(__name__)
67

7-
HAS_TRITON = find_spec("triton") is not None
8+
# neuron has too old torch
9+
HAS_TRITON = find_spec(
10+
"triton") is not None and not current_platform.is_neuron()
811

912
if not HAS_TRITON:
1013
logger.info("Triton not installed; certain GPU-related functions"

vllm/utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -327,15 +327,6 @@ def is_openvino() -> bool:
327327
return False
328328

329329

330-
@lru_cache(maxsize=None)
331-
def is_neuron() -> bool:
332-
try:
333-
import transformers_neuronx
334-
except ImportError:
335-
transformers_neuronx = None
336-
return transformers_neuronx is not None
337-
338-
339330
@lru_cache(maxsize=None)
340331
def is_xpu() -> bool:
341332
from importlib.metadata import PackageNotFoundError, version
@@ -786,7 +777,7 @@ def is_pin_memory_available() -> bool:
786777
elif is_xpu():
787778
print_warning_once("Pin memory is not supported on XPU.")
788779
return False
789-
elif is_neuron():
780+
elif current_platform.is_neuron():
790781
print_warning_once("Pin memory is not supported on Neuron.")
791782
return False
792783
elif current_platform.is_cpu() or is_openvino():

0 commit comments

Comments
 (0)