Skip to content

A pattern to simplify WhereOp #2818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2024
Merged

Conversation

tungld
Copy link
Collaborator

@tungld tungld commented May 8, 2024

This patch adds a canonicalization rule to WhereOp to simplify the following pattern (found in the model xlm-roberta-base-language-detection):

    %0 = onnx.Constant dense<-1> : tensor<2xi64>
    %1 = onnx.Constant dense<1> : tensor<2xi64>
    %2 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
    %3 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
    %4 = "onnx.Concat"(%2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
    %5 = "onnx.Equal"(%4, %0) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1>
    %6 = "onnx.Where"(%5, %1, %4) : (tensor<2xi1>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>

Since the condition of where, i.e. %5, is always false (it compares dimension sizes with -1, and dimension sizes are always positive so the result is false), onnx.Where can be replaced by its "false" value that is %4.
into

    %0 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
    %1 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
    %2 = "onnx.Concat"(%0, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>

This simplification of WhereOp helps DimAnalysis work well with the model xlm-roberta-base-language-detection so that all MatMuls could run on NNPA.

Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, outstanding catch.

@tungld
Copy link
Collaborator Author

tungld commented May 9, 2024

@jenkins-droid test this please

@tungld tungld merged commit 869f152 into onnx:main May 9, 2024
6 of 7 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #14786 [push] A pttern to simplify Whe... started at 02:38

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #13811 [push] A pttern to simplify Whe... started at 02:50

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #14781 [push] A pttern to simplify Whe... started at 01:38

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #14781 [push] A pttern to simplify Whe... passed after 1 hr 19 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #14786 [push] A pttern to simplify Whe... passed after 1 hr 54 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #13811 [push] A pttern to simplify Whe... passed after 2 hr 8 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants