-
Notifications
You must be signed in to change notification settings - Fork 346
Combine parallel dense Optimization pass in ONNX Dialect #3123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
@Arkar-Hema A general question: in what kind of models have you seen this kind of pattern: multiple Gemm ops followed by a Concat op? and also similar patterns you have recently created PRs for? Just curious on how practical it is. Thanks! |
|
@tungld could you please verify this patch? |
@Arkar-Hema thank you for the information! I have some general comments:
Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Arkar-Hema for the experiments! Did you compile your programs with -O3?
Since this parallel fusion may not work for accelerators, could you create a compile option to enable this if needed, for example -fuse-parallel-onnx-gemm
?
I don't think you need to handle the case where there is a concat after multiple gemms. Just emit a split op, then later you can write a simple canonicalization rule for concat to fuse Split -> Concat
.
Below are my first-round comments, most of them are for simplifying the code, making it easy to follow. However, the important thing is you need to check the input C carefully because it's broadcastable.
Can one of the admins verify this patch? |
I have added it, Thanks |
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
@jenkins-droid test this please |
auto aCShape = mlir::cast<ShapedType>(aC.getType()).getShape(); | ||
auto bCShape = mlir::cast<ShapedType>(bC.getType()).getShape(); | ||
if (aCShape.size() != 1 || bCShape.size() != 1) | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you allow the case where aCShape is tensor<1xf32>
and bCShape is tensor<5xf32>
(5 is just an example to say it it not 1) and vice versa, but I don't see in the following code how you handle it. In this case, we need to broadcast aC
to tensor<5xf32>
before concatenating it with bC
to make ConcatOp valid.
It's up to you to support this case or not, but if you do, please add a lit test. Otherwise, check it and return false here. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Arkar-Hema please explain how did you solve this comment so that you marked it "solved"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current implementation, I decided not to support the case where the bias shapes are different (e.g., tensor<1xf32> and tensor<5xf32>) since our concat operation would require explicit broadcasting to align the shapes before concatenation, and handling this properly would add additional complexity.
To handle this, I’ve added a check in areCompatible() to return false if the bias shapes differ in size when both biases are present - ensuring that we only merge Gemms where both bias tensors have the same shape. This preserves the correctness of the concat operation without requiring extra broadcasting logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
I see you checked aCShape[0] != bCShape[0]
. This is only valid if the shapes are static, e.g. tensor<5xf32>
, but it does not work if the shapes are dynamic, e.g. both aC and bC have shape of tensor<?xf32>
.
In the dynamic case, aCShape[0] == bCShape[0]
during compile time, but at runtime, aC can be tensor<1xf32
and bC can be tensor<5xf32
for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the static case, I wonder how you handle the following case where both aC and bC have shape of tensor<1xf32>
. For example:
- gemm1: A: tensor<5x8x16xf32>, B: tensor<16x32xf32>, C: tensor<1xf32>
- gemm2: A: tensor<5x8x16xf32>, B: tensor<16x32xf32>, C: tensor<1xf32>
They satisfy your conditions here, so how do you combine them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the thorough review and the great questions!
I’ve updated the areCompatible() function to properly handle the edge cases you pointed out:
Dynamic Bias Shapes:
If either of the bias tensors has a dynamic shape at dimension 0 (i.e., tensor<?xf32>), I now conservatively return false since we can’t guarantee at compile time whether they’ll match or require broadcasting at runtime.
Both Biases as tensor<1xf32>:
If both biases are of shape tensor<1xf32>, I now check their corresponding Gemm output shapes and ensure their output channels (last dimension) match before considering them compatible. If they differ, the function returns false, as merging them without this check would be invalid.
This ensures that both static and dynamic cases are handled correctly and conservatively avoids undefined behavior at runtime.
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
@jenkins-droid test this please |
Hi @Arkar-Hema When addressing a comment, could you please provide a brief explanation of how you did so? This will make the review process easier. Thanks! |
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Combine Parallel Dense
CombineParallelDense is an optimization pass designed to merge multiple parallel ONNXGemmOp (Dense/Fully Connected) operations into a single, more efficient Dense layer. This optimization reduces redundant computations, improves memory efficiency, and enhances hardware utilization.
The pass identifies Dense (Gemm) operations that:
Lets assume a input case:
Before Optimization (Three Parallel Gemms)
-Memory Reads: 3 times full input (one for each gemm)
After Optimization (Combined Dense)
Improvement in performance metrics
Latency Improvement: 7-15%
Throughput Improvement: 8-14%
Memory Usage Improvement: 10-12%