|
6 | 6 | from dlinfer.utils.registry import register_ops
|
7 | 7 | from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple
|
8 | 8 |
|
| 9 | + |
9 | 10 | __all__ = [
|
10 | 11 | "add_rms_norm",
|
11 | 12 | "apply_rotary_pos_emb",
|
|
17 | 18 | "moe_gating_topk_softmax",
|
18 | 19 | "fused_moe",
|
19 | 20 | "linear",
|
| 21 | + "dynamic_quant", |
| 22 | + "linear_w8a8", |
| 23 | + "rms_norm_w8a8", |
| 24 | + "add_rms_norm_w8a8", |
20 | 25 | ]
|
21 | 26 |
|
22 | 27 |
|
@@ -368,35 +373,119 @@ def fused_moe(
|
368 | 373 | return out
|
369 | 374 |
|
370 | 375 |
|
| 376 | +def _process_input_dim(x: Tensor, scale: Optional[Tensor] = None): |
| 377 | + """ |
| 378 | + NOTE: Since torch_mlu_ops matmul kernels requires input to be tow dimension. |
| 379 | + we need to reshape the input tensor to 2D tensor if it is 3D tensor. |
| 380 | + """ |
| 381 | + original_shape = x.shape |
| 382 | + x = x.view(-1, x.shape[-1]) |
| 383 | + if scale is not None: |
| 384 | + scale = scale.view(-1) |
| 385 | + return x, scale, original_shape |
| 386 | + |
| 387 | + |
| 388 | +def _process_output_dim(out: Tensor, original_shape: Tuple[int, ...]): |
| 389 | + if original_shape is not None: |
| 390 | + return out.view(*original_shape[:-1], -1) |
| 391 | + return out |
| 392 | + |
| 393 | + |
371 | 394 | @register_ops(vendor_ops_registry)
|
372 | 395 | def linear(
|
373 | 396 | x: Tensor,
|
374 | 397 | weight: Tensor,
|
375 | 398 | bias: Optional[Tensor],
|
376 | 399 | all_reduce: Optional[bool],
|
377 | 400 | ) -> Tensor:
|
378 |
| - if x.dim() == 2: |
379 |
| - if all_reduce: |
380 |
| - cncl_comm = ( |
381 |
| - torch.distributed.distributed_c10d._world.default_pg._get_backend( |
382 |
| - x.device |
383 |
| - ).get_cncl_comm(x.device.index) |
384 |
| - ) |
385 |
| - out = tmo.matmul_allreduce(cncl_comm, x, weight, bias) |
386 |
| - else: |
387 |
| - out = tmo.matmul(x, weight, bias) |
388 |
| - elif x.dim() == 3: |
389 |
| - bsz, seq_len, _ = x.size() |
390 |
| - x_reshaped = x.view(bsz * seq_len, -1) |
391 |
| - if all_reduce: |
392 |
| - cncl_comm = ( |
393 |
| - torch.distributed.distributed_c10d._world.default_pg._get_backend( |
394 |
| - x.device |
395 |
| - ).get_cncl_comm(x.device.index) |
396 |
| - ) |
397 |
| - out = tmo.matmul_allreduce(cncl_comm, x_reshaped, weight, bias).view( |
398 |
| - bsz, seq_len, -1 |
399 |
| - ) |
400 |
| - else: |
401 |
| - out = tmo.matmul(x_reshaped, weight, bias).view(bsz, seq_len, -1) |
402 |
| - return out |
| 401 | + x, _, original_shape = _process_input_dim(x, None) |
| 402 | + if all_reduce: |
| 403 | + cncl_comm = torch.distributed.distributed_c10d._world.default_pg._get_backend( |
| 404 | + x.device |
| 405 | + ).get_cncl_comm(x.device.index) |
| 406 | + out = tmo.matmul_allreduce(cncl_comm, x, weight, bias) |
| 407 | + else: |
| 408 | + out = tmo.matmul(x, weight, bias) |
| 409 | + return _process_output_dim(out, original_shape) |
| 410 | + |
| 411 | + |
| 412 | +@register_ops(vendor_ops_registry) |
| 413 | +def dynamic_quant( |
| 414 | + x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN" |
| 415 | +): |
| 416 | + assert quant_dtype == torch.int8 |
| 417 | + assert quant_granularity == "PER_TOKEN" |
| 418 | + from . import SMOOTH_VEC |
| 419 | + |
| 420 | + if x.shape[-1] > SMOOTH_VEC.shape[0]: |
| 421 | + from . import update_smooth |
| 422 | + |
| 423 | + SMOOTH_VEC = update_smooth(x.shape[-1]) |
| 424 | + smooth = SMOOTH_VEC[: x.shape[-1]] |
| 425 | + return tmo.per_token_smooth_quantize(x, smooth=smooth) |
| 426 | + |
| 427 | + |
| 428 | +@register_ops(vendor_ops_registry) |
| 429 | +def linear_w8a8( |
| 430 | + a: Tensor, |
| 431 | + b: Tensor, |
| 432 | + rms_scale: float, |
| 433 | + linear_scale: float, |
| 434 | + out_dtype: torch.dtype, |
| 435 | + quant_dtype: torch.dtype = torch.int8, |
| 436 | + bias: Tensor = None, |
| 437 | +): |
| 438 | + assert quant_dtype == torch.int8 |
| 439 | + input_quant, input_scale, original_shape = _process_input_dim(a, rms_scale) |
| 440 | + out = tmo.smooth_quant_matmul( |
| 441 | + input_quant, input_scale, b, linear_scale, out_dtype, bias |
| 442 | + ) |
| 443 | + return _process_output_dim(out, original_shape) |
| 444 | + |
| 445 | + |
| 446 | +@register_ops(vendor_ops_registry) |
| 447 | +def rms_norm_w8a8( |
| 448 | + hidden_states: Tensor, |
| 449 | + weight: Tensor, |
| 450 | + epsilon: float, |
| 451 | + quant_dtype: torch.dtype = torch.int8, |
| 452 | +): |
| 453 | + assert quant_dtype == torch.int8 |
| 454 | + store_output_before_norm = False |
| 455 | + normed_hidden_states = tmo.fused_rms_norm( |
| 456 | + hidden_states, |
| 457 | + None, |
| 458 | + weight, |
| 459 | + None, |
| 460 | + None, |
| 461 | + epsilon, |
| 462 | + store_output_before_norm, |
| 463 | + None, |
| 464 | + None, |
| 465 | + ) |
| 466 | + x, rms_scale = dynamic_quant(normed_hidden_states, quant_dtype, "PER_TOKEN") |
| 467 | + return x, rms_scale |
| 468 | + |
| 469 | + |
| 470 | +@register_ops(vendor_ops_registry) |
| 471 | +def add_rms_norm_w8a8( |
| 472 | + hidden_states: Tensor, |
| 473 | + residual: Tensor, |
| 474 | + weight: Tensor, |
| 475 | + epsilon: float, |
| 476 | + quant_dtype: torch.dtype = torch.int8, |
| 477 | +): |
| 478 | + assert quant_dtype == torch.int8 |
| 479 | + store_output_before_norm = True |
| 480 | + normed_hidden_states, added_hidden_states = tmo.fused_rms_norm( |
| 481 | + hidden_states, |
| 482 | + residual, |
| 483 | + weight, |
| 484 | + None, |
| 485 | + None, |
| 486 | + epsilon, |
| 487 | + store_output_before_norm, |
| 488 | + None, |
| 489 | + ) |
| 490 | + x, rms_scale = dynamic_quant(normed_hidden_states, quant_dtype, "PER_TOKEN") |
| 491 | + return x, rms_scale, added_hidden_states |
0 commit comments