-
Notifications
You must be signed in to change notification settings - Fork 290
SM-constraint-GEMM by triton persistent kernel #982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces an SM-constrained GEMM operation using a Triton persistent kernel to support Nanoflow infra-device parallelism. Key changes include:
- Adding new GEMM implementations (gemm_persistent and gemm) in flashinfer/triton/sm_constraint_gemm.py.
- Implementing Triton kernel functions in flashinfer/triton/kernels/sm_constraint_gemm.py.
- Adding comprehensive tests in tests/test_sm_constraint_gemm.py to validate the new functionality.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
File | Description |
---|---|
tests/test_sm_constraint_gemm.py | Adds tests for the new SM-constrained GEMM operations |
flashinfer/triton/sm_constraint_gemm.py | Implements persistent and non-constrained GEMM functions |
flashinfer/triton/kernels/sm_constraint_gemm.py | Defines Triton kernel implementations for GEMM operations |
flashinfer/triton/init.py | Includes the new SM-constrained GEMM module in the package API |
Comments suppressed due to low confidence (2)
tests/test_sm_constraint_gemm.py:48
- Consider replacing print statements with assert messages that include detailed failure diagnostics to improve test clarity and traceability.
print(f"c_torch: {c_torch}")
flashinfer/triton/kernels/sm_constraint_gemm.py:84
- [nitpick] The variable name 'tile_id_c' is not very descriptive; consider renaming it to something like 'adjusted_tile_id' and add an inline comment explaining its role as a workaround for pipelining limitations.
tile_id_c = start_pid - NUM_SMS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds an SM-constrained GEMM operation implemented via Triton persistent kernels to support Nanoflow infra-device parallelism.
- Introduces new tests for SM-constrained GEMM variants with various parameter configurations.
- Implements three GEMM variants (persistent, naive, and descriptor persistent) along with their corresponding Triton kernels.
- Updates the module initialization to include the new sm_constraint_gemm functionality.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
tests/test_sm_constraint_gemm.py | Adds comprehensive tests for the new GEMM operations with various parameters. |
flashinfer/triton/sm_constraint_gemm.py | Provides implementations of GEMM variants including persistent and descriptor persistent modes. |
flashinfer/triton/kernels/sm_constraint_gemm.py | Introduces Triton kernel implementations for SM-constrained GEMM. |
flashinfer/triton/init.py | Updates module imports to register the sm_constraint_gemm module. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces SM-constraint GEMM operations implemented via Triton persistent kernels to support Nanoflow infra‐device parallelism. Key changes include new GEMM functions with varying SM constraint strategies, comprehensive test coverage for these operations, and Triton kernel implementations with persistent and descriptor variants.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
File | Description |
---|---|
tests/test_sm_constraint_gemm.py | Adds tests for SM constraint GEMM operations. |
flashinfer/triton/sm_constraint_gemm.py | Implements GEMM functions (persistent, naive, descriptor persistent). |
flashinfer/triton/kernels/sm_constraint_gemm.py | Provides Triton kernel implementations for SM-constraint GEMM. |
flashinfer/triton/init.py | Exposes the sm_constraint_gemm module. |
Comments suppressed due to low confidence (3)
flashinfer/triton/sm_constraint_gemm.py:32
- If 'c' is allowed to be None (to enable automatic allocation), consider checking for None before calling check_input(c), otherwise the branch for allocating c will never be reached.
check_input(c)
flashinfer/triton/sm_constraint_gemm.py:212
- [nitpick] The assertion logic is ambiguous due to operator precedence; adding parentheses (e.g., (K >= 16 and dtype == torch.float8_e4m3fn) or (K >= 8)) will clarify the intended condition.
assert (K >= 16 and dtype == torch.float8_e4m3fn or K >= 8), "Least chunk size must be 16B"
flashinfer/triton/sm_constraint_gemm.py:215
- [nitpick] The assertion here also could benefit from explicit parentheses to clarify the intended condition (e.g., (N >= 16 and dtype == torch.float8_e4m3fn) or (N >= 8)).
assert (N >= 16 and dtype == torch.float8_e4m3fn or N >= 8), "Least chunk size must be 16B"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a new SM-constraint GEMM operation implemented using Triton persistent kernels to support Nanoflow infra-device parallelism. Key changes include the addition of tests for the new GEMM variants, new SM-constraint GEMM implementations (persistent, naive, and descriptor persistent) and accompanying Triton kernel definitions, and an update to the package initialization to expose the new functionality.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
tests/test_sm_constraint_gemm.py | Adds functional tests covering various parameter configurations |
flashinfer/triton/sm_constraint_gemm.py | Introduces three GEMM functions (persistent, naive, and descriptor persistent) with SM constraint logic |
flashinfer/triton/kernels/sm_constraint_gemm.py | Implements Triton kernels for the SM-constraint GEMM operation |
flashinfer/triton/init.py | Updates package initialization by importing the new module |
Comments suppressed due to low confidence (1)
flashinfer/triton/sm_constraint_gemm.py:199
- In gemm_descriptor_persistent, calling check_dim(2, c) without a preceding None check may cause issues when c is not provided. Consider guarding this call with 'if c is not None:' to ensure proper handling of optional output tensors.
check_dim(2, c)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new SM-constraint GEMM operation powered by a Triton persistent kernel to support Nanoflow infra-device parallelism. The key changes include:
- Adding comprehensive tests in tests/test_sm_constraint_gemm.py with varied input parameters.
- Implementing three GEMM functions (persistent, naive, descriptor persistent) in flashinfer/triton/sm_constraint_gemm.py.
- Introducing new Triton kernel implementations in flashinfer/triton/kernels/sm_constraint_gemm.py and updating the module’s init.py.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
tests/test_sm_constraint_gemm.py | New tests for SM-constraint GEMM covering multiple parameter combinations. |
flashinfer/triton/sm_constraint_gemm.py | Implementation of GEMM functions using Triton persistent kernels and related input validations. |
flashinfer/triton/kernels/sm_constraint_gemm.py | New Triton kernel implementations for persistent and descriptor-based GEMM. |
flashinfer/triton/init.py | Updated to import the new sm_constraint_gemm module. |
Comments suppressed due to low confidence (1)
flashinfer/triton/sm_constraint_gemm.py:236
- Setting the global Triton allocator within the gemm_descriptor_persistent function may introduce side effects affecting other operations. Consider documenting the potential impacts or encapsulating allocator changes to avoid unintended interference.
triton.set_allocator(alloc_fn)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind writing some simple benchmark like:
- Given different problem shapes (M, N, K) = [(4096, 4096, 4096), (8192, 8192, 8192)], varying the number of SMs and measuring the performance using triton's do_bench function, you can also compute TFLOPs/s.
Some reference: https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_deepseek_mla.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a new SM-constraint GEMM operation using Triton’s persistent kernels to support Nanoflow infra-device parallelism. The key changes include:
- New GEMM implementations (persistent, naive, and descriptor persistent) in the flashinfer/triton module and their corresponding kernels.
- A comprehensive test suite (tests/test_sm_constraint_gemm.py) to validate the implementations.
- Benchmark scripts (benchmarks/bench_persistent_gemm.py) to evaluate performance under various SM configurations.
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/test_sm_constraint_gemm.py | Adds multi-parameter tests comparing torch and Triton GEMM implementations. |
flashinfer/triton/sm_constraint_gemm.py | Provides GEMM operations with SM constraint using persistent kernels. |
flashinfer/triton/kernels/sm_constraint_gemm.py | Contains Triton kernel implementations for GEMM (persistent and descriptor). |
flashinfer/triton/init.py | Registers the new sm_constraint_gemm module. |
benchmarks/bench_persistent_gemm.py | Implements benchmarks to assess GEMM performance on various SM configurations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds an SM-constraint GEMM operation via Triton persistent kernels to support Nanoflow infra-device parallelism.
- Implements three GEMM variants (persistent, naive, and descriptor persistent) in the flashinfer.triton module.
- Introduces corresponding Triton kernel implementations and integrates them into the existing FlashInfer framework.
- Includes unit tests and benchmark scripts to validate functionality and performance under various SM configurations.
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
tests/test_sm_constraint_gemm.py | New tests covering various GEMM variants and tolerance settings. |
flashinfer/triton/sm_constraint_gemm.py | New GEMM functions with SM constraints including persistent kernels. |
flashinfer/triton/kernels/sm_constraint_gemm.py | Triton kernel implementations for persistent and descriptor GEMM operations. |
flashinfer/triton/init.py | Updated module imports. |
benchmarks/bench_persistent_gemm.py | Added benchmark scripts for performance evaluation. |
Comments suppressed due to low confidence (1)
flashinfer/triton/sm_constraint_gemm.py:191
- The docstring for 'gemm_descriptor_persistent' specifies 'b' with shape (N, K) while other GEMM functions expect 'b' with shape (K, N). Consider clarifying the expected shape for consistency or explicitly handling the necessary transpose conversion.
b: The second input matrix. Shape: (N, K)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, @yyihuang thanks for the contribution and let's merge this first and move on to the next step.
Add SM-constraint GEMM operation by triton persistent kernel to support Nanoflow infra-device parallelism.
Checklist:
Benchmark results:
https://docs.google.com/document/d/189f1VdZ36B-iJTYlC2LDgWGDSWKCiZI6PTfg-ltjXv4/edit?usp=sharing
Nsys Results
Related issues:
#591
#675