Skip to content

Commit 08fae46

Browse files
authored
[DAG] fold avgs(sext(x), sext(y)) -> sext(avgs(x, y)) (#95365)
Follow up of #95134. Context: #95134 (comment).
1 parent d5297b7 commit 08fae46

File tree

3 files changed

+95
-4
lines changed

3 files changed

+95
-4
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -5237,6 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
52375237
DAG.getShiftAmountConstant(1, VT, DL));
52385238

52395239
// fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5240+
// fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
52405241
if (sd_match(
52415242
N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
52425243
X.getValueType() == Y.getValueType() &&
@@ -5251,6 +5252,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
52515252
SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
52525253
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
52535254
}
5255+
if (sd_match(
5256+
N, m_BinOp(ISD::AVGFLOORS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
5257+
X.getValueType() == Y.getValueType() &&
5258+
hasOperation(ISD::AVGFLOORS, X.getValueType())) {
5259+
SDValue AvgFloorS = DAG.getNode(ISD::AVGFLOORS, DL, X.getValueType(), X, Y);
5260+
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorS);
5261+
}
5262+
if (sd_match(
5263+
N, m_BinOp(ISD::AVGCEILS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
5264+
X.getValueType() == Y.getValueType() &&
5265+
hasOperation(ISD::AVGCEILS, X.getValueType())) {
5266+
SDValue AvgCeilS = DAG.getNode(ISD::AVGCEILS, DL, X.getValueType(), X, Y);
5267+
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilS);
5268+
}
52545269

52555270
// Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
52565271
// Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0

llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll

+2-4
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ define <8 x i16> @urhadd_sext(<8 x i8> %a0, <8 x i8> %a1) {
9595
define <8 x i16> @hadds_sext(<8 x i8> %a0, <8 x i8> %a1) {
9696
; CHECK-LABEL: hadds_sext:
9797
; CHECK: // %bb.0:
98+
; CHECK-NEXT: shadd v0.8b, v0.8b, v1.8b
9899
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
99-
; CHECK-NEXT: sshll v1.8h, v1.8b, #0
100-
; CHECK-NEXT: shadd v0.8h, v0.8h, v1.8h
101100
; CHECK-NEXT: bic v0.8h, #254, lsl #8
102101
; CHECK-NEXT: ret
103102
%x0 = sext <8 x i8> %a0 to <8 x i16>
@@ -110,9 +109,8 @@ define <8 x i16> @hadds_sext(<8 x i8> %a0, <8 x i8> %a1) {
110109
define <8 x i16> @shaddu_sext(<8 x i8> %a0, <8 x i8> %a1) {
111110
; CHECK-LABEL: shaddu_sext:
112111
; CHECK: // %bb.0:
112+
; CHECK-NEXT: srhadd v0.8b, v0.8b, v1.8b
113113
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
114-
; CHECK-NEXT: sshll v1.8h, v1.8b, #0
115-
; CHECK-NEXT: srhadd v0.8h, v0.8h, v1.8h
116114
; CHECK-NEXT: bic v0.8h, #254, lsl #8
117115
; CHECK-NEXT: ret
118116
%x0 = sext <8 x i8> %a0 to <8 x i16>

llvm/test/CodeGen/AArch64/avg.ll

+78
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,81 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
6868
%avg = sub <16 x i16> %or, %shift
6969
ret <16 x i16> %avg
7070
}
71+
72+
define <16 x i16> @sext_avgfloors(<16 x i8> %a0, <16 x i8> %a1) {
73+
; CHECK-LABEL: sext_avgfloors:
74+
; CHECK: // %bb.0:
75+
; CHECK-NEXT: shadd v0.16b, v0.16b, v1.16b
76+
; CHECK-NEXT: sshll2 v1.8h, v0.16b, #0
77+
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
78+
; CHECK-NEXT: ret
79+
%x0 = sext <16 x i8> %a0 to <16 x i16>
80+
%x1 = sext <16 x i8> %a1 to <16 x i16>
81+
%and = and <16 x i16> %x0, %x1
82+
%xor = xor <16 x i16> %x0, %x1
83+
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
84+
%avg = add <16 x i16> %and, %shift
85+
ret <16 x i16> %avg
86+
}
87+
88+
define <16 x i16> @sext_avgfloors_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
89+
; CHECK-LABEL: sext_avgfloors_mismatch:
90+
; CHECK: // %bb.0:
91+
; CHECK-NEXT: ushll2 v2.8h, v1.16b, #0
92+
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
93+
; CHECK-NEXT: sshll v3.8h, v0.8b, #0
94+
; CHECK-NEXT: sshll2 v0.8h, v0.16b, #0
95+
; CHECK-NEXT: shl v1.8h, v1.8h, #12
96+
; CHECK-NEXT: shl v2.8h, v2.8h, #12
97+
; CHECK-NEXT: sshr v4.8h, v1.8h, #12
98+
; CHECK-NEXT: sshr v1.8h, v2.8h, #12
99+
; CHECK-NEXT: shadd v1.8h, v0.8h, v1.8h
100+
; CHECK-NEXT: shadd v0.8h, v3.8h, v4.8h
101+
; CHECK-NEXT: ret
102+
%x0 = sext <16 x i8> %a0 to <16 x i16>
103+
%x1 = sext <16 x i4> %a1 to <16 x i16>
104+
%and = and <16 x i16> %x0, %x1
105+
%xor = xor <16 x i16> %x0, %x1
106+
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
107+
%avg = add <16 x i16> %and, %shift
108+
ret <16 x i16> %avg
109+
}
110+
111+
define <16 x i16> @sext_avgceils(<16 x i8> %a0, <16 x i8> %a1) {
112+
; CHECK-LABEL: sext_avgceils:
113+
; CHECK: // %bb.0:
114+
; CHECK-NEXT: srhadd v0.16b, v0.16b, v1.16b
115+
; CHECK-NEXT: sshll2 v1.8h, v0.16b, #0
116+
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
117+
; CHECK-NEXT: ret
118+
%x0 = sext <16 x i8> %a0 to <16 x i16>
119+
%x1 = sext <16 x i8> %a1 to <16 x i16>
120+
%or = or <16 x i16> %x0, %x1
121+
%xor = xor <16 x i16> %x0, %x1
122+
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
123+
%avg = sub <16 x i16> %or, %shift
124+
ret <16 x i16> %avg
125+
}
126+
127+
define <16 x i16> @sext_avgceils_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
128+
; CHECK-LABEL: sext_avgceils_mismatch:
129+
; CHECK: // %bb.0:
130+
; CHECK-NEXT: ushll v2.8h, v0.8b, #0
131+
; CHECK-NEXT: ushll2 v0.8h, v0.16b, #0
132+
; CHECK-NEXT: sshll v3.8h, v1.8b, #0
133+
; CHECK-NEXT: sshll2 v1.8h, v1.16b, #0
134+
; CHECK-NEXT: shl v2.8h, v2.8h, #12
135+
; CHECK-NEXT: shl v0.8h, v0.8h, #12
136+
; CHECK-NEXT: sshr v2.8h, v2.8h, #12
137+
; CHECK-NEXT: sshr v0.8h, v0.8h, #12
138+
; CHECK-NEXT: srhadd v1.8h, v0.8h, v1.8h
139+
; CHECK-NEXT: srhadd v0.8h, v2.8h, v3.8h
140+
; CHECK-NEXT: ret
141+
%x0 = sext <16 x i4> %a0 to <16 x i16>
142+
%x1 = sext <16 x i8> %a1 to <16 x i16>
143+
%or = or <16 x i16> %x0, %x1
144+
%xor = xor <16 x i16> %x0, %x1
145+
%shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
146+
%avg = sub <16 x i16> %or, %shift
147+
ret <16 x i16> %avg
148+
}

0 commit comments

Comments
 (0)