Skip to content

Commit 7811330

Browse files
authored
[mlir][ArmSVE] Add arm_sve.psel operation (llvm#95764)
This adds a new operation for the SME/SVE2.1 psel instruction. This allows selecting a predicate based on a bit within another predicate, essentially allowing for 2-D predication. Informally, the semantics are: ```mlir %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> ``` => ``` if p2[index % num_elements(p2)] == 1: pd = p1 : type(p1) else: pd = all-false : type(p1) ```
1 parent 6244d87 commit 7811330

File tree

6 files changed

+166
-3
lines changed

6 files changed

+166
-3
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,16 @@ def ArmSVE_Dialect : Dialect {
3737
//===----------------------------------------------------------------------===//
3838

3939
def SVBool : ScalableVectorOfRankAndLengthAndType<
40-
[1], [16], [I1]>;
40+
[1], [16], [I1]>
41+
{
42+
let summary = "vector<[16]xi1>";
43+
}
4144

4245
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
43-
[1], [16, 8, 4, 2, 1], [I1]>;
46+
[1], [16, 8, 4, 2, 1], [I1]>
47+
{
48+
let summary = "vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>";
49+
}
4450

4551
// Generalizations of SVBool and SVEPredicate to ranks >= 1.
4652
// These are masks with a single trailing scalable dimension.
@@ -442,6 +448,43 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [
442448
}];
443449
}
444450

451+
def PselOp : ArmSVE_Op<"psel", [
452+
Pure,
453+
AllTypesMatch<["p1", "result"]>,
454+
]> {
455+
let summary = "Predicate select";
456+
457+
let description = [{
458+
This operation returns the input predicate `p1` or an all-false predicate
459+
based on the bit at `p2[index]`. Informally, the semantics are:
460+
```
461+
if p2[index % num_elements(p2)] == 1:
462+
return p1 : type(p1)
463+
return all-false : type(p1)
464+
```
465+
466+
Example:
467+
```mlir
468+
// Note: p1 and p2 can have different sizes.
469+
%pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
470+
```
471+
472+
Note: This requires SME or SVE2.1 (`+sme` or `+sve2p1` in LLVM target features).
473+
}];
474+
475+
let arguments = (ins SVEPredicate:$p1, SVEPredicate:$p2, Index:$index);
476+
let results = (outs SVEPredicate:$result);
477+
478+
let builders = [
479+
OpBuilder<(ins "Value":$p1, "Value":$p2, "Value":$index), [{
480+
build($_builder, $_state, p1.getType(), p1, p2, index);
481+
}]>];
482+
483+
let assemblyFormat = [{
484+
$p1 `,` $p2 `[` $index `]` attr-dict `:` type($p1) `,` type($p2)
485+
}];
486+
}
487+
445488
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
446489
[Commutative]>;
447490

@@ -552,6 +595,14 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
552595
Arg<AnyScalableVector, "v3">:$v3,
553596
Arg<AnyScalableVector, "v3">:$v4)>;
554597

598+
// Note: This intrinsic requires SME or SVE2.1.
599+
def PselIntrOp : ArmSVE_IntrOp<"psel",
600+
/*traits=*/[Pure, TypeIs<"res", SVBool>],
601+
/*overloadedOperands=*/[1]>,
602+
Arguments<(ins Arg<SVBool, "p1">:$p1,
603+
Arg<SVEPredicate, "p2">:$p2,
604+
Arg<I32, "index">:$index)>;
605+
555606
def WhileLTIntrOp :
556607
ArmSVE_IntrOp<"whilelt",
557608
[TypeIs<"res", SVEPredicate>, Pure],

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,28 @@ using ConvertFromSvboolOpLowering =
140140
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
141141
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
142142

143+
/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
144+
/// but first input (P1) and result predicates need conversion to/from svbool.
145+
struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
146+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
147+
148+
LogicalResult
149+
matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
150+
ConversionPatternRewriter &rewriter) const override {
151+
auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
152+
auto loc = pselOp.getLoc();
153+
auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
154+
adaptor.getP1());
155+
auto indexI32 = rewriter.create<arith::IndexCastOp>(
156+
loc, rewriter.getI32Type(), pselOp.getIndex());
157+
auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
158+
pselOp.getP2(), indexI32);
159+
rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
160+
pselOp, adaptor.getP1().getType(), pselIntr);
161+
return success();
162+
}
163+
};
164+
143165
/// Converts `vector.create_mask` ops that match the size of an SVE predicate
144166
/// to the `whilelt` intrinsic. This produces more canonical codegen than the
145167
/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
@@ -202,7 +224,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
202224
ConvertToSvboolOpLowering,
203225
ConvertFromSvboolOpLowering,
204226
ZipX2OpLowering,
205-
ZipX4OpLowering>(converter);
227+
ZipX4OpLowering,
228+
PselOpLowering>(converter);
206229
// Add vector.create_mask conversion with a high benefit as it produces much
207230
// nicer code than the generic lowering.
208231
patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
@@ -229,6 +252,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
229252
ConvertFromSvboolIntrOp,
230253
ZipX2IntrOp,
231254
ZipX4IntrOp,
255+
PselIntrOp,
232256
WhileLTIntrOp>();
233257
target.addIllegalOp<SdotOp,
234258
SmmlaOp,

mlir/test/Dialect/ArmSVE/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,11 @@ func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
6464
arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
6565
return
6666
}
67+
68+
// -----
69+
70+
func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
71+
// expected-error@+1 {{op operand #0 must be vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>, but got 'vector<[7]xi1>'}}
72+
arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
73+
return
74+
}

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,35 @@ func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, v
239239
%2 = vector.create_mask %index : vector<[32]xi1>
240240
return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
241241
}
242+
243+
// -----
244+
245+
// CHECK-LABEL: @arm_sve_psel_matching_predicate_types(
246+
// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[4]xi1>,
247+
// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[4]xi1>,
248+
// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64
249+
func.func @arm_sve_psel_matching_predicate_types(%p0: vector<[4]xi1>, %p1: vector<[4]xi1>, %index: index) -> vector<[4]xi1>
250+
{
251+
// CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
252+
// CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[4]xi1>) -> vector<[16]xi1>
253+
// CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
254+
// CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[4]xi1>
255+
%0 = arm_sve.psel %p0, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
256+
return %0 : vector<[4]xi1>
257+
}
258+
259+
// -----
260+
261+
// CHECK-LABEL: @arm_sve_psel_mixed_predicate_types(
262+
// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[8]xi1>,
263+
// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[16]xi1>,
264+
// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64
265+
func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[16]xi1>, %index: index) -> vector<[8]xi1>
266+
{
267+
// CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
268+
// CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
269+
// CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
270+
// CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[8]xi1>
271+
%0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
272+
return %0 : vector<[8]xi1>
273+
}

mlir/test/Dialect/ArmSVE/roundtrip.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,32 @@ func.func @arm_sve_zip_x4(
225225
%a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
226226
return
227227
}
228+
229+
// -----
230+
231+
func.func @arm_sve_psel(
232+
%p0: vector<[2]xi1>,
233+
%p1: vector<[4]xi1>,
234+
%p2: vector<[8]xi1>,
235+
%p3: vector<[16]xi1>,
236+
%index: index
237+
) {
238+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[2]xi1>
239+
%0 = arm_sve.psel %p0, %p0[%index] : vector<[2]xi1>, vector<[2]xi1>
240+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[4]xi1>
241+
%1 = arm_sve.psel %p1, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
242+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[8]xi1>
243+
%2 = arm_sve.psel %p2, %p2[%index] : vector<[8]xi1>, vector<[8]xi1>
244+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[16]xi1>
245+
%3 = arm_sve.psel %p3, %p3[%index] : vector<[16]xi1>, vector<[16]xi1>
246+
/// Some mixed predicate type examples:
247+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[4]xi1>
248+
%4 = arm_sve.psel %p0, %p1[%index] : vector<[2]xi1>, vector<[4]xi1>
249+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[8]xi1>
250+
%5 = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
251+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[16]xi1>
252+
%6 = arm_sve.psel %p2, %p3[%index] : vector<[8]xi1>, vector<[16]xi1>
253+
// CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[2]xi1>
254+
%7 = arm_sve.psel %p3, %p0[%index] : vector<[16]xi1>, vector<[2]xi1>
255+
return
256+
}

mlir/test/Target/LLVMIR/arm-sve.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,22 @@ llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
371371
%4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
372372
llvm.return
373373
}
374+
375+
// CHECK-LABEL: arm_sve_psel(
376+
// CHECK-SAME: <vscale x 16 x i1> %[[PN:[0-9]+]],
377+
// CHECK-SAME: <vscale x 2 x i1> %[[P1:[0-9]+]],
378+
// CHECK-SAME: <vscale x 4 x i1> %[[P2:[0-9]+]],
379+
// CHECK-SAME: <vscale x 8 x i1> %[[P3:[0-9]+]],
380+
// CHECK-SAME: <vscale x 16 x i1> %[[P4:[0-9]+]],
381+
// CHECK-SAME: i32 %[[INDEX:[0-9]+]])
382+
llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[4]xi1>, %p3: vector<[8]xi1>, %p4: vector<[16]xi1>, %index: i32) {
383+
// CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv2i1(<vscale x 16 x i1> %[[PN]], <vscale x 2 x i1> %[[P1]], i32 %[[INDEX]])
384+
"arm_sve.intr.psel"(%pn, %p1, %index) : (vector<[16]xi1>, vector<[2]xi1>, i32) -> vector<[16]xi1>
385+
// CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv4i1(<vscale x 16 x i1> %[[PN]], <vscale x 4 x i1> %[[P2]], i32 %[[INDEX]])
386+
"arm_sve.intr.psel"(%pn, %p2, %index) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
387+
// CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv8i1(<vscale x 16 x i1> %[[PN]], <vscale x 8 x i1> %[[P3]], i32 %[[INDEX]])
388+
"arm_sve.intr.psel"(%pn, %p3, %index) : (vector<[16]xi1>, vector<[8]xi1>, i32) -> vector<[16]xi1>
389+
// CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv16i1(<vscale x 16 x i1> %[[PN]], <vscale x 16 x i1> %[[P4]], i32 %[[INDEX]])
390+
"arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
391+
llvm.return
392+
}

0 commit comments

Comments
 (0)