Skip to content

Commit 4e76aa9

Browse files
authored
fix maca_vllm import error. (#160)
1 parent 90d8a9e commit 4e76aa9

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

dlinfer/vendor/maca/maca_ops.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.distributed as dist
66

7+
from vllm import _custom_ops as custom_ops
78
from flash_attn import flash_attn_varlen_func
89
from vllm.attention.ops.prefix_prefill import context_attention_fwd
910

@@ -59,7 +60,7 @@ def add_rms_norm(
5960
weight: Tensor,
6061
epsilon: float,
6162
) -> Tuple[Tensor, Tensor]:
62-
vllm._custom_ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
63+
custom_ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
6364
return hidden_states, residual
6465

6566

@@ -188,7 +189,7 @@ def fill_kv_cache(
188189
quant_bits: int,
189190
) -> Tuple[Tensor, Tensor]:
190191
kv_indices = kv_indices.squeeze(-1)
191-
vllm._custom_ops.reshape_and_cache_new(
192+
custom_ops.reshape_and_cache_new(
192193
key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0
193194
)
194195
return key_cache, value_cache
@@ -220,7 +221,7 @@ def paged_decode_attention(
220221
num_kv_heads = value_cache.size(1)
221222
block_size = value_cache.size(2)
222223
output = torch.empty_like(query)
223-
vllm._custom_ops.paged_attention_v1(
224+
custom_ops.paged_attention_v1(
224225
output,
225226
query,
226227
key_cache,
@@ -304,7 +305,7 @@ def rms_norm(
304305
epsilon: float,
305306
) -> Tensor:
306307
output = torch.empty_like(hidden_states)
307-
vllm._custom_ops.rms_norm(output, hidden_states, weight, epsilon)
308+
custom_ops.rms_norm(output, hidden_states, weight, epsilon)
308309
return output
309310

310311

@@ -322,7 +323,7 @@ def moe_gating_topk_softmax(
322323

323324
token_expert_indicies = torch.empty_like(topk_ids)
324325

325-
vllm._custom_ops.topk_softmax(
326+
custom_ops.topk_softmax(
326327
topk_weights,
327328
topk_ids,
328329
token_expert_indicies,
@@ -344,7 +345,7 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
344345
d = x.shape[-1] // 2
345346
output_shape = x.shape[:-1] + (d,)
346347
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
347-
vllm._custom_ops.silu_and_mul(out, x)
348+
custom_ops.silu_and_mul(out, x)
348349
return out
349350

350351

@@ -398,7 +399,7 @@ def weight_quant_matmul(
398399
group_size: Optional[int] = 0,
399400
):
400401
offset = None if (offset is None or offset.numel() == 0) else offset
401-
output = vllm._custom_ops.awq_gemm(x, qweight, scale, offset, group_size)
402+
output = custom_ops.awq_gemm(x, qweight, scale, offset, group_size)
402403
if bias is not None:
403404
output += bias
404405
return output

0 commit comments

Comments
 (0)