Skip to content

Combine parallel dense Optimization pass in ONNX Dialect #3123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

Arkar-Hema
Copy link
Contributor

Combine Parallel Dense

CombineParallelDense is an optimization pass designed to merge multiple parallel ONNXGemmOp (Dense/Fully Connected) operations into a single, more efficient Dense layer. This optimization reduces redundant computations, improves memory efficiency, and enhances hardware utilization.

The pass identifies Dense (Gemm) operations that:

  • Share the same input tensor.
  • Have identical attributes such as alpha, beta, transA and transB (ensuring compatibility).
  • May have different output dimensions (number of neurons) but maintain compatible weight shapes for concatenation.

Lets assume a input case:

  • Input Shape: (1, 512)
  • Dense A: out_features = 256
  • Dense B: out_features = 128
  • Dense C: out_features = 64
  • Attributes: transB = 0, alpha = 1.0, beta = 1.0

Before Optimization (Three Parallel Gemms)

  • Each GEMM does one full matrix multiplication (1×512 × 512×N)
  • Three separate weight and bias tensors and produces three outputs
    -Memory Reads: 3 times full input (one for each gemm)
  • Post-processing: A Concat(axis=1) merges them into one output: Y (1×448)

After Optimization (Combined Dense)

  • Total Output Features: 256 + 128 + 64 = 448
  • All weights are concatenated along output channel axis → New weight shape: (512, 448)
  • Biases are also concatenated
  • A single ONNXGemmOp computes Y (1×448) directly

Improvement in performance metrics

Latency Improvement: 7-15%
Throughput Improvement: 8-14%
Memory Usage Improvement: 10-12%

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@tungld
Copy link
Collaborator

tungld commented Apr 17, 2025

@Arkar-Hema A general question: in what kind of models have you seen this kind of pattern: multiple Gemm ops followed by a Concat op? and also similar patterns you have recently created PRs for? Just curious on how practical it is. Thanks!

@Arkar-Hema
Copy link
Contributor Author

@Arkar-Hema A general question: in what kind of models have you seen this kind of pattern: multiple Gemm ops followed by a Concat op? and also similar patterns you have recently created PRs for? Just curious on how practical it is. Thanks!

  • Models with the CombineParallelDense pattern (Combine parallel dense Optimization pass in ONNX Dialect #3123):
    These contain multiple Gemm ops, though not always followed by a Concat. I added the Concat condition to the pass so it would still handle those cases gracefully if present. Some models with this pattern include:
  1. Bertsquad-8
  2. Bertsquad-10
  3. Bertsquad-12
  4. FasterRCNN-10
  1. ResNet101-DUC-12
  2. ResNet101-DUC-7
  3. emotion-ferplus models
  4. caffenet models
  5. Densenet models
  6. googlenet models
  7. inception models
  8. rcnn-ilsvrc13 models
  9. resnet models
  10. vgg models
  1. retinanet models
  2. version-RFB-320
  3. version-RFB-640
  4. googlenet models
  5. inception models
  6. resnet models
  7. squeezenet models

@Arkar-Hema
Copy link
Contributor Author

@tungld could you please verify this patch?

@tungld
Copy link
Collaborator

tungld commented Apr 22, 2025

@Arkar-Hema thank you for the information!

I have some general comments:

  • I think that when multiple GEMM ops are followed by a concat, the performance in theory would be better. But, could you run with multiple input sizes to see how the performance benefit in practice?
  • When multiple GEMM ops are NOT followed by a concat (this is the case for the models you listed), you need a split and I think the split axis is the innermost dimension. I am not sure how slow the split is and if we can get speedup or not. Could you do a performance comparison to see if you can achieve speedup in this case?
  • Are you targeting optimization for CPU or it is beneficial for AI accelerators as well given that AI accelerators may use special data layout which may be not convenient for concat or split.

Thanks.

@Arkar-Hema
Copy link
Contributor Author

I ran performance benchmarks across a range of input sizes for both the GEMM → Concat and the Combined GEMM → Split cases. Results show that:

  • In the Concat case, the optimization provides consistent Latency improvement of 2-7%, and throughput improvement of 1-5%
    image
    image
    image
    image

  • In the cases where it splits, the optimization provides consistent Latency improvement of 1-7%, and throughput improvement of 1-8%
    image
    image
    image
    image

  • I’ve currently targeted this pass primarily for CPU backends only.

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Arkar-Hema for the experiments! Did you compile your programs with -O3?

Since this parallel fusion may not work for accelerators, could you create a compile option to enable this if needed, for example -fuse-parallel-onnx-gemm?

I don't think you need to handle the case where there is a concat after multiple gemms. Just emit a split op, then later you can write a simple canonicalization rule for concat to fuse Split -> Concat.

Below are my first-round comments, most of them are for simplifying the code, making it easy to follow. However, the important thing is you need to check the input C carefully because it's broadcastable.

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@Arkar-Hema
Copy link
Contributor Author

Thanks @Arkar-Hema for the experiments! Did you compile your programs with -O3?

Since this parallel fusion may not work for accelerators, could you create a compile option to enable this if needed, for example -fuse-parallel-onnx-gemm?

I don't think you need to handle the case where there is a concat after multiple gemms. Just emit a split op, then later you can write a simple canonicalization rule for concat to fuse Split -> Concat.

Below are my first-round comments, most of them are for simplifying the code, making it easy to follow. However, the important thing is you need to check the input C carefully because it's broadcastable.

I have added it, Thanks

Arkar-Hema added 3 commits May 2, 2025 05:00
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

auto aCShape = mlir::cast<ShapedType>(aC.getType()).getShape();
auto bCShape = mlir::cast<ShapedType>(bC.getType()).getShape();
if (aCShape.size() != 1 || bCShape.size() != 1)
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you allow the case where aCShape is tensor<1xf32> and bCShape is tensor<5xf32> (5 is just an example to say it it not 1) and vice versa, but I don't see in the following code how you handle it. In this case, we need to broadcast aC to tensor<5xf32> before concatenating it with bC to make ConcatOp valid.

It's up to you to support this case or not, but if you do, please add a lit test. Otherwise, check it and return false here. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Arkar-Hema please explain how did you solve this comment so that you marked it "solved"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation, I decided not to support the case where the bias shapes are different (e.g., tensor<1xf32> and tensor<5xf32>) since our concat operation would require explicit broadcasting to align the shapes before concatenation, and handling this properly would add additional complexity.

To handle this, I’ve added a check in areCompatible() to return false if the bias shapes differ in size when both biases are present - ensuring that we only merge Gemms where both bias tensors have the same shape. This preserves the correctness of the concat operation without requiring extra broadcasting logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

I see you checked aCShape[0] != bCShape[0]. This is only valid if the shapes are static, e.g. tensor<5xf32>, but it does not work if the shapes are dynamic, e.g. both aC and bC have shape of tensor<?xf32>.

In the dynamic case, aCShape[0] == bCShape[0] during compile time, but at runtime, aC can be tensor<1xf32 and bC can be tensor<5xf32 for example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the static case, I wonder how you handle the following case where both aC and bC have shape of tensor<1xf32>. For example:

  • gemm1: A: tensor<5x8x16xf32>, B: tensor<16x32xf32>, C: tensor<1xf32>
  • gemm2: A: tensor<5x8x16xf32>, B: tensor<16x32xf32>, C: tensor<1xf32>

They satisfy your conditions here, so how do you combine them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough review and the great questions!

I’ve updated the areCompatible() function to properly handle the edge cases you pointed out:

Dynamic Bias Shapes:
If either of the bias tensors has a dynamic shape at dimension 0 (i.e., tensor<?xf32>), I now conservatively return false since we can’t guarantee at compile time whether they’ll match or require broadcasting at runtime.

Both Biases as tensor<1xf32>:
If both biases are of shape tensor<1xf32>, I now check their corresponding Gemm output shapes and ensure their output channels (last dimension) match before considering them compatible. If they differ, the function returns false, as merging them without this check would be invalid.

This ensures that both static and dynamic cases are handled correctly and conservatively avoids undefined behavior at runtime.

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@AlexandreEichenberger
Copy link
Collaborator

@jenkins-droid test this please

@tungld
Copy link
Collaborator

tungld commented May 12, 2025

Hi @Arkar-Hema When addressing a comment, could you please provide a brief explanation of how you did so? This will make the review process easier. Thanks!

Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants