4
4
import torch
5
5
import torch .distributed as dist
6
6
7
+ from vllm import _custom_ops as custom_ops
7
8
from flash_attn import flash_attn_varlen_func
8
9
from vllm .attention .ops .prefix_prefill import context_attention_fwd
9
10
@@ -59,7 +60,7 @@ def add_rms_norm(
59
60
weight : Tensor ,
60
61
epsilon : float ,
61
62
) -> Tuple [Tensor , Tensor ]:
62
- vllm . _custom_ops .fused_add_rms_norm (hidden_states , residual , weight , epsilon )
63
+ custom_ops .fused_add_rms_norm (hidden_states , residual , weight , epsilon )
63
64
return hidden_states , residual
64
65
65
66
@@ -188,7 +189,7 @@ def fill_kv_cache(
188
189
quant_bits : int ,
189
190
) -> Tuple [Tensor , Tensor ]:
190
191
kv_indices = kv_indices .squeeze (- 1 )
191
- vllm . _custom_ops .reshape_and_cache_new (
192
+ custom_ops .reshape_and_cache_new (
192
193
key , value , key_cache , value_cache , kv_indices , "auto" , 1.0 , 1.0
193
194
)
194
195
return key_cache , value_cache
@@ -220,7 +221,7 @@ def paged_decode_attention(
220
221
num_kv_heads = value_cache .size (1 )
221
222
block_size = value_cache .size (2 )
222
223
output = torch .empty_like (query )
223
- vllm . _custom_ops .paged_attention_v1 (
224
+ custom_ops .paged_attention_v1 (
224
225
output ,
225
226
query ,
226
227
key_cache ,
@@ -304,7 +305,7 @@ def rms_norm(
304
305
epsilon : float ,
305
306
) -> Tensor :
306
307
output = torch .empty_like (hidden_states )
307
- vllm . _custom_ops .rms_norm (output , hidden_states , weight , epsilon )
308
+ custom_ops .rms_norm (output , hidden_states , weight , epsilon )
308
309
return output
309
310
310
311
@@ -322,7 +323,7 @@ def moe_gating_topk_softmax(
322
323
323
324
token_expert_indicies = torch .empty_like (topk_ids )
324
325
325
- vllm . _custom_ops .topk_softmax (
326
+ custom_ops .topk_softmax (
326
327
topk_weights ,
327
328
topk_ids ,
328
329
token_expert_indicies ,
@@ -344,7 +345,7 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
344
345
d = x .shape [- 1 ] // 2
345
346
output_shape = x .shape [:- 1 ] + (d ,)
346
347
out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
347
- vllm . _custom_ops .silu_and_mul (out , x )
348
+ custom_ops .silu_and_mul (out , x )
348
349
return out
349
350
350
351
@@ -398,7 +399,7 @@ def weight_quant_matmul(
398
399
group_size : Optional [int ] = 0 ,
399
400
):
400
401
offset = None if (offset is None or offset .numel () == 0 ) else offset
401
- output = vllm . _custom_ops .awq_gemm (x , qweight , scale , offset , group_size )
402
+ output = custom_ops .awq_gemm (x , qweight , scale , offset , group_size )
402
403
if bias is not None :
403
404
output += bias
404
405
return output
0 commit comments