Skip to content

SM-constraint-GEMM by triton persistent kernel #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 1, 2025

Conversation

yyihuang
Copy link
Contributor

@yyihuang yyihuang commented Mar 29, 2025

Add SM-constraint GEMM operation by triton persistent kernel to support Nanoflow infra-device parallelism.

Checklist:

  • functional test passed
  • benchmark
  • SM usage by nsys profile
  • (optional for this PR) tune: get best config for gemm

Benchmark results:
https://docs.google.com/document/d/189f1VdZ36B-iJTYlC2LDgWGDSWKCiZI6PTfg-ltjXv4/edit?usp=sharing

Nsys Results

        When num_sm = 1:
          gemm_kernel_persistent
          Begins: 3.89591s
          Ends: 3.89592s (+5.248 μs)
          grid:  <<<1, 1, 1>>>
          block: <<<128, 1, 1>>>
        
        When num_sm = 32:
          gemm_kernel_persistent
          Begins: 3.91269s
          Ends: 3.92016s (+7.466 ms)
          grid:  <<<32, 1, 1>>>
          block: <<<128, 1, 1>>>
        
        When num_sm = 64:
          gemm_kernel_persistent
          Begins: 3.59851s
          Ends: 3.60234s (+3.829 ms)
          grid:  <<<64, 1, 1>>>
          block: <<<128, 1, 1>>>
          Launch Type: Regular
        
        When num_sm = 128:
          gemm_kernel_persistent
          Begins: 3.17387s
          Ends: 3.17586s (+1.992 ms)
          grid:  <<<128, 1, 1>>>
          block: <<<128, 1, 1>>>
        
        When num_sm = 133:
          gemm_kernel_persistent
          Begins: 3.51542s
          Ends: 3.5173s (+1.879 ms)
          grid:  <<<132, 1, 1>>>
          block: <<<128, 1, 1>>>

Related issues:
#591
#675

@yyihuang yyihuang marked this pull request as draft March 29, 2025 08:32
@yzh119 yzh119 marked this pull request as ready for review March 29, 2025 17:58
@yyihuang yyihuang requested a review from yzh119 March 30, 2025 01:16
@yzh119 yzh119 requested a review from Copilot March 30, 2025 20:33
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces an SM-constrained GEMM operation using a Triton persistent kernel to support Nanoflow infra-device parallelism. Key changes include:

  • Adding new GEMM implementations (gemm_persistent and gemm) in flashinfer/triton/sm_constraint_gemm.py.
  • Implementing Triton kernel functions in flashinfer/triton/kernels/sm_constraint_gemm.py.
  • Adding comprehensive tests in tests/test_sm_constraint_gemm.py to validate the new functionality.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.

File Description
tests/test_sm_constraint_gemm.py Adds tests for the new SM-constrained GEMM operations
flashinfer/triton/sm_constraint_gemm.py Implements persistent and non-constrained GEMM functions
flashinfer/triton/kernels/sm_constraint_gemm.py Defines Triton kernel implementations for GEMM operations
flashinfer/triton/init.py Includes the new SM-constrained GEMM module in the package API
Comments suppressed due to low confidence (2)

tests/test_sm_constraint_gemm.py:48

  • Consider replacing print statements with assert messages that include detailed failure diagnostics to improve test clarity and traceability.
print(f"c_torch: {c_torch}")

flashinfer/triton/kernels/sm_constraint_gemm.py:84

  • [nitpick] The variable name 'tile_id_c' is not very descriptive; consider renaming it to something like 'adjusted_tile_id' and add an inline comment explaining its role as a workaround for pipelining limitations.
tile_id_c = start_pid - NUM_SMS

@yyihuang yyihuang marked this pull request as draft March 31, 2025 00:17
@yyihuang yyihuang requested review from Copilot and yzh119 April 1, 2025 00:08
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds an SM-constrained GEMM operation implemented via Triton persistent kernels to support Nanoflow infra-device parallelism.

  • Introduces new tests for SM-constrained GEMM variants with various parameter configurations.
  • Implements three GEMM variants (persistent, naive, and descriptor persistent) along with their corresponding Triton kernels.
  • Updates the module initialization to include the new sm_constraint_gemm functionality.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
tests/test_sm_constraint_gemm.py Adds comprehensive tests for the new GEMM operations with various parameters.
flashinfer/triton/sm_constraint_gemm.py Provides implementations of GEMM variants including persistent and descriptor persistent modes.
flashinfer/triton/kernels/sm_constraint_gemm.py Introduces Triton kernel implementations for SM-constrained GEMM.
flashinfer/triton/init.py Updates module imports to register the sm_constraint_gemm module.

@yyihuang yyihuang marked this pull request as ready for review April 1, 2025 00:17
@yyihuang yyihuang requested a review from Copilot April 1, 2025 01:41
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces SM-constraint GEMM operations implemented via Triton persistent kernels to support Nanoflow infra‐device parallelism. Key changes include new GEMM functions with varying SM constraint strategies, comprehensive test coverage for these operations, and Triton kernel implementations with persistent and descriptor variants.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.

File Description
tests/test_sm_constraint_gemm.py Adds tests for SM constraint GEMM operations.
flashinfer/triton/sm_constraint_gemm.py Implements GEMM functions (persistent, naive, descriptor persistent).
flashinfer/triton/kernels/sm_constraint_gemm.py Provides Triton kernel implementations for SM-constraint GEMM.
flashinfer/triton/init.py Exposes the sm_constraint_gemm module.
Comments suppressed due to low confidence (3)

flashinfer/triton/sm_constraint_gemm.py:32

  • If 'c' is allowed to be None (to enable automatic allocation), consider checking for None before calling check_input(c), otherwise the branch for allocating c will never be reached.
    check_input(c)

flashinfer/triton/sm_constraint_gemm.py:212

  • [nitpick] The assertion logic is ambiguous due to operator precedence; adding parentheses (e.g., (K >= 16 and dtype == torch.float8_e4m3fn) or (K >= 8)) will clarify the intended condition.
    assert (K >= 16 and dtype == torch.float8_e4m3fn or K >= 8), "Least chunk size must be 16B"

flashinfer/triton/sm_constraint_gemm.py:215

  • [nitpick] The assertion here also could benefit from explicit parentheses to clarify the intended condition (e.g., (N >= 16 and dtype == torch.float8_e4m3fn) or (N >= 8)).
    assert (N >= 16 and dtype == torch.float8_e4m3fn or N >= 8), "Least chunk size must be 16B"

@yyihuang yyihuang requested a review from Copilot April 1, 2025 02:03
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new SM-constraint GEMM operation implemented using Triton persistent kernels to support Nanoflow infra-device parallelism. Key changes include the addition of tests for the new GEMM variants, new SM-constraint GEMM implementations (persistent, naive, and descriptor persistent) and accompanying Triton kernel definitions, and an update to the package initialization to expose the new functionality.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
tests/test_sm_constraint_gemm.py Adds functional tests covering various parameter configurations
flashinfer/triton/sm_constraint_gemm.py Introduces three GEMM functions (persistent, naive, and descriptor persistent) with SM constraint logic
flashinfer/triton/kernels/sm_constraint_gemm.py Implements Triton kernels for the SM-constraint GEMM operation
flashinfer/triton/init.py Updates package initialization by importing the new module
Comments suppressed due to low confidence (1)

flashinfer/triton/sm_constraint_gemm.py:199

  • In gemm_descriptor_persistent, calling check_dim(2, c) without a preceding None check may cause issues when c is not provided. Consider guarding this call with 'if c is not None:' to ensure proper handling of optional output tensors.
check_dim(2, c)

@yyihuang yyihuang requested a review from Copilot April 1, 2025 02:07
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new SM-constraint GEMM operation powered by a Triton persistent kernel to support Nanoflow infra-device parallelism. The key changes include:

  • Adding comprehensive tests in tests/test_sm_constraint_gemm.py with varied input parameters.
  • Implementing three GEMM functions (persistent, naive, descriptor persistent) in flashinfer/triton/sm_constraint_gemm.py.
  • Introducing new Triton kernel implementations in flashinfer/triton/kernels/sm_constraint_gemm.py and updating the module’s init.py.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
tests/test_sm_constraint_gemm.py New tests for SM-constraint GEMM covering multiple parameter combinations.
flashinfer/triton/sm_constraint_gemm.py Implementation of GEMM functions using Triton persistent kernels and related input validations.
flashinfer/triton/kernels/sm_constraint_gemm.py New Triton kernel implementations for persistent and descriptor-based GEMM.
flashinfer/triton/init.py Updated to import the new sm_constraint_gemm module.
Comments suppressed due to low confidence (1)

flashinfer/triton/sm_constraint_gemm.py:236

  • Setting the global Triton allocator within the gemm_descriptor_persistent function may introduce side effects affecting other operations. Consider documenting the potential impacts or encapsulating allocator changes to avoid unintended interference.
    triton.set_allocator(alloc_fn)

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind writing some simple benchmark like:

  • Given different problem shapes (M, N, K) = [(4096, 4096, 4096), (8192, 8192, 8192)], varying the number of SMs and measuring the performance using triton's do_bench function, you can also compute TFLOPs/s.

Some reference: https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_deepseek_mla.py

@yyihuang yyihuang requested review from Copilot and yzh119 April 1, 2025 05:34
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new SM-constraint GEMM operation using Triton’s persistent kernels to support Nanoflow infra-device parallelism. The key changes include:

  • New GEMM implementations (persistent, naive, and descriptor persistent) in the flashinfer/triton module and their corresponding kernels.
  • A comprehensive test suite (tests/test_sm_constraint_gemm.py) to validate the implementations.
  • Benchmark scripts (benchmarks/bench_persistent_gemm.py) to evaluate performance under various SM configurations.

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/test_sm_constraint_gemm.py Adds multi-parameter tests comparing torch and Triton GEMM implementations.
flashinfer/triton/sm_constraint_gemm.py Provides GEMM operations with SM constraint using persistent kernels.
flashinfer/triton/kernels/sm_constraint_gemm.py Contains Triton kernel implementations for GEMM (persistent and descriptor).
flashinfer/triton/init.py Registers the new sm_constraint_gemm module.
benchmarks/bench_persistent_gemm.py Implements benchmarks to assess GEMM performance on various SM configurations.

@yyihuang yyihuang requested a review from yzh119 April 1, 2025 06:58
@yyihuang yyihuang requested a review from Copilot April 1, 2025 17:56
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds an SM-constraint GEMM operation via Triton persistent kernels to support Nanoflow infra-device parallelism.

  • Implements three GEMM variants (persistent, naive, and descriptor persistent) in the flashinfer.triton module.
  • Introduces corresponding Triton kernel implementations and integrates them into the existing FlashInfer framework.
  • Includes unit tests and benchmark scripts to validate functionality and performance under various SM configurations.

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/test_sm_constraint_gemm.py New tests covering various GEMM variants and tolerance settings.
flashinfer/triton/sm_constraint_gemm.py New GEMM functions with SM constraints including persistent kernels.
flashinfer/triton/kernels/sm_constraint_gemm.py Triton kernel implementations for persistent and descriptor GEMM operations.
flashinfer/triton/init.py Updated module imports.
benchmarks/bench_persistent_gemm.py Added benchmark scripts for performance evaluation.
Comments suppressed due to low confidence (1)

flashinfer/triton/sm_constraint_gemm.py:191

  • The docstring for 'gemm_descriptor_persistent' specifies 'b' with shape (N, K) while other GEMM functions expect 'b' with shape (K, N). Consider clarifying the expected shape for consistency or explicitly handling the necessary transpose conversion.
b: The second input matrix. Shape: (N, K)

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @yyihuang thanks for the contribution and let's merge this first and move on to the next step.

@yzh119 yzh119 merged commit 5751fc6 into flashinfer-ai:main Apr 1, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants