Skip to content

Commit e01ad15

Browse files
authored
[camb]add w8a8 support (#176)
1 parent 06f8580 commit e01ad15

File tree

2 files changed

+139
-25
lines changed

2 files changed

+139
-25
lines changed

dlinfer/vendor/camb/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,28 @@
11
import torch
22

33
from . import pytorch_patch, camb_ops
4+
5+
# TODO. weitao: camb torch-mlu-ops-v1.2.0 per_token_smooth_quantize need smooth_vec
6+
SMOOTH_VEC = torch.ones(8192, dtype=torch.float32, device="mlu")
7+
8+
9+
def next_power_of_2(n: int):
10+
"""Return the smallest power of 2 greater than or equal to n."""
11+
n -= 1
12+
n |= n >> 1
13+
n |= n >> 2
14+
n |= n >> 4
15+
n |= n >> 8
16+
n |= n >> 16
17+
n |= n >> 32
18+
n += 1
19+
return n
20+
21+
22+
def update_smooth(length):
23+
global SMOOTH_VEC
24+
if length > SMOOTH_VEC.shape[0]:
25+
SMOOTH_VEC = torch.ones(
26+
next_power_of_2(length), dtype=torch.float32, device="mlu"
27+
)
28+
return SMOOTH_VEC

dlinfer/vendor/camb/camb_ops.py

Lines changed: 114 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dlinfer.utils.registry import register_ops
77
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple
88

9+
910
__all__ = [
1011
"add_rms_norm",
1112
"apply_rotary_pos_emb",
@@ -17,6 +18,10 @@
1718
"moe_gating_topk_softmax",
1819
"fused_moe",
1920
"linear",
21+
"dynamic_quant",
22+
"linear_w8a8",
23+
"rms_norm_w8a8",
24+
"add_rms_norm_w8a8",
2025
]
2126

2227

@@ -368,35 +373,119 @@ def fused_moe(
368373
return out
369374

370375

376+
def _process_input_dim(x: Tensor, scale: Optional[Tensor] = None):
377+
"""
378+
NOTE: Since torch_mlu_ops matmul kernels requires input to be tow dimension.
379+
we need to reshape the input tensor to 2D tensor if it is 3D tensor.
380+
"""
381+
original_shape = x.shape
382+
x = x.view(-1, x.shape[-1])
383+
if scale is not None:
384+
scale = scale.view(-1)
385+
return x, scale, original_shape
386+
387+
388+
def _process_output_dim(out: Tensor, original_shape: Tuple[int, ...]):
389+
if original_shape is not None:
390+
return out.view(*original_shape[:-1], -1)
391+
return out
392+
393+
371394
@register_ops(vendor_ops_registry)
372395
def linear(
373396
x: Tensor,
374397
weight: Tensor,
375398
bias: Optional[Tensor],
376399
all_reduce: Optional[bool],
377400
) -> Tensor:
378-
if x.dim() == 2:
379-
if all_reduce:
380-
cncl_comm = (
381-
torch.distributed.distributed_c10d._world.default_pg._get_backend(
382-
x.device
383-
).get_cncl_comm(x.device.index)
384-
)
385-
out = tmo.matmul_allreduce(cncl_comm, x, weight, bias)
386-
else:
387-
out = tmo.matmul(x, weight, bias)
388-
elif x.dim() == 3:
389-
bsz, seq_len, _ = x.size()
390-
x_reshaped = x.view(bsz * seq_len, -1)
391-
if all_reduce:
392-
cncl_comm = (
393-
torch.distributed.distributed_c10d._world.default_pg._get_backend(
394-
x.device
395-
).get_cncl_comm(x.device.index)
396-
)
397-
out = tmo.matmul_allreduce(cncl_comm, x_reshaped, weight, bias).view(
398-
bsz, seq_len, -1
399-
)
400-
else:
401-
out = tmo.matmul(x_reshaped, weight, bias).view(bsz, seq_len, -1)
402-
return out
401+
x, _, original_shape = _process_input_dim(x, None)
402+
if all_reduce:
403+
cncl_comm = torch.distributed.distributed_c10d._world.default_pg._get_backend(
404+
x.device
405+
).get_cncl_comm(x.device.index)
406+
out = tmo.matmul_allreduce(cncl_comm, x, weight, bias)
407+
else:
408+
out = tmo.matmul(x, weight, bias)
409+
return _process_output_dim(out, original_shape)
410+
411+
412+
@register_ops(vendor_ops_registry)
413+
def dynamic_quant(
414+
x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN"
415+
):
416+
assert quant_dtype == torch.int8
417+
assert quant_granularity == "PER_TOKEN"
418+
from . import SMOOTH_VEC
419+
420+
if x.shape[-1] > SMOOTH_VEC.shape[0]:
421+
from . import update_smooth
422+
423+
SMOOTH_VEC = update_smooth(x.shape[-1])
424+
smooth = SMOOTH_VEC[: x.shape[-1]]
425+
return tmo.per_token_smooth_quantize(x, smooth=smooth)
426+
427+
428+
@register_ops(vendor_ops_registry)
429+
def linear_w8a8(
430+
a: Tensor,
431+
b: Tensor,
432+
rms_scale: float,
433+
linear_scale: float,
434+
out_dtype: torch.dtype,
435+
quant_dtype: torch.dtype = torch.int8,
436+
bias: Tensor = None,
437+
):
438+
assert quant_dtype == torch.int8
439+
input_quant, input_scale, original_shape = _process_input_dim(a, rms_scale)
440+
out = tmo.smooth_quant_matmul(
441+
input_quant, input_scale, b, linear_scale, out_dtype, bias
442+
)
443+
return _process_output_dim(out, original_shape)
444+
445+
446+
@register_ops(vendor_ops_registry)
447+
def rms_norm_w8a8(
448+
hidden_states: Tensor,
449+
weight: Tensor,
450+
epsilon: float,
451+
quant_dtype: torch.dtype = torch.int8,
452+
):
453+
assert quant_dtype == torch.int8
454+
store_output_before_norm = False
455+
normed_hidden_states = tmo.fused_rms_norm(
456+
hidden_states,
457+
None,
458+
weight,
459+
None,
460+
None,
461+
epsilon,
462+
store_output_before_norm,
463+
None,
464+
None,
465+
)
466+
x, rms_scale = dynamic_quant(normed_hidden_states, quant_dtype, "PER_TOKEN")
467+
return x, rms_scale
468+
469+
470+
@register_ops(vendor_ops_registry)
471+
def add_rms_norm_w8a8(
472+
hidden_states: Tensor,
473+
residual: Tensor,
474+
weight: Tensor,
475+
epsilon: float,
476+
quant_dtype: torch.dtype = torch.int8,
477+
):
478+
assert quant_dtype == torch.int8
479+
store_output_before_norm = True
480+
normed_hidden_states, added_hidden_states = tmo.fused_rms_norm(
481+
hidden_states,
482+
residual,
483+
weight,
484+
None,
485+
None,
486+
epsilon,
487+
store_output_before_norm,
488+
None,
489+
)
490+
x, rms_scale = dynamic_quant(normed_hidden_states, quant_dtype, "PER_TOKEN")
491+
return x, rms_scale, added_hidden_states

0 commit comments

Comments
 (0)