Skip to content

Commit 280b326

Browse files
authored
Unrolled build for rust-lang#137953
Rollup merge of rust-lang#137953 - RalfJung:simd-intrinsic-masks, r=WaffleLapkin simd intrinsics with mask: accept unsigned integer masks, and fix some of the errors It's not clear at all why the mask would have to be signed, it is anyway interpreted bitwise. The backend should just make sure that works no matter the surface-level type; our LLVM backend already does this correctly. The note of "the mask may be widened, which only has the correct behavior for signed integers" explains... nothing? Why can't the code do the widening correctly? If necessary, just cast to the signed type first... Also while we are at it, fix the errors. For simd_masked_load/store, the errors talked about the "third argument" but they meant the first argument (the mask is the first argument there). They also used the wrong type for `expected_element`. I have extremely low confidence in the GCC part of this PR. See [discussion on Zulip](https://rust-lang.zulipchat.com/#narrow/channel/257879-project-portable-simd/topic/On.20the.20sign.20of.20masks)
2 parents b8c54d6 + 566dfd1 commit 280b326

18 files changed

+119
-148
lines changed

compiler/rustc_codegen_gcc/src/intrinsic/simd.rs

+11-14
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,14 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
447447
m_len == v_len,
448448
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
449449
);
450+
// TODO: also support unsigned integers.
450451
match *m_elem_ty.kind() {
451452
ty::Int(_) => {}
452-
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
453+
_ => return_error!(InvalidMonomorphization::MaskWrongElementType {
454+
span,
455+
name,
456+
ty: m_elem_ty
457+
}),
453458
}
454459
return Ok(bx.vector_select(args[0].immediate(), args[1].immediate(), args[2].immediate()));
455460
}
@@ -991,19 +996,15 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
991996
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
992997
assert_eq!(underlying_ty, non_ptr(element_ty0));
993998

994-
// The element type of the third argument must be a signed integer type of any width:
999+
// The element type of the third argument must be an integer type of any width:
1000+
// TODO: also support unsigned integers.
9951001
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
9961002
match *element_ty2.kind() {
9971003
ty::Int(_) => (),
9981004
_ => {
9991005
require!(
10001006
false,
1001-
InvalidMonomorphization::ThirdArgElementType {
1002-
span,
1003-
name,
1004-
expected_element: element_ty2,
1005-
third_arg: arg_tys[2]
1006-
}
1007+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
10071008
);
10081009
}
10091010
}
@@ -1109,17 +1110,13 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
11091110
assert_eq!(underlying_ty, non_ptr(element_ty0));
11101111

11111112
// The element type of the third argument must be a signed integer type of any width:
1113+
// TODO: also support unsigned integers.
11121114
match *element_ty2.kind() {
11131115
ty::Int(_) => (),
11141116
_ => {
11151117
require!(
11161118
false,
1117-
InvalidMonomorphization::ThirdArgElementType {
1118-
span,
1119-
name,
1120-
expected_element: element_ty2,
1121-
third_arg: arg_tys[2]
1122-
}
1119+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
11231120
);
11241121
}
11251122
}

compiler/rustc_codegen_llvm/src/intrinsic.rs

+12-44
Original file line numberDiff line numberDiff line change
@@ -1184,18 +1184,6 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
11841184
}};
11851185
}
11861186

1187-
/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
1188-
macro_rules! require_int_ty {
1189-
($ty: expr, $diag: expr) => {
1190-
match $ty {
1191-
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1192-
_ => {
1193-
return_error!($diag);
1194-
}
1195-
}
1196-
};
1197-
}
1198-
11991187
/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
12001188
macro_rules! require_int_or_uint_ty {
12011189
($ty: expr, $diag: expr) => {
@@ -1485,9 +1473,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14851473
m_len == v_len,
14861474
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
14871475
);
1488-
let in_elem_bitwidth = require_int_ty!(
1476+
let in_elem_bitwidth = require_int_or_uint_ty!(
14891477
m_elem_ty.kind(),
1490-
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
1478+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: m_elem_ty }
14911479
);
14921480
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
14931481
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
@@ -1508,7 +1496,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15081496
// Integer vector <i{in_bitwidth} x in_len>:
15091497
let in_elem_bitwidth = require_int_or_uint_ty!(
15101498
in_elem.kind(),
1511-
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
1499+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: in_elem }
15121500
);
15131501

15141502
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
@@ -1732,14 +1720,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17321720
}
17331721
);
17341722

1735-
let mask_elem_bitwidth = require_int_ty!(
1723+
let mask_elem_bitwidth = require_int_or_uint_ty!(
17361724
element_ty2.kind(),
1737-
InvalidMonomorphization::ThirdArgElementType {
1738-
span,
1739-
name,
1740-
expected_element: element_ty2,
1741-
third_arg: arg_tys[2]
1742-
}
1725+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
17431726
);
17441727

17451728
// Alignment of T, must be a constant integer value:
@@ -1834,14 +1817,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18341817
}
18351818
);
18361819

1837-
let m_elem_bitwidth = require_int_ty!(
1820+
let m_elem_bitwidth = require_int_or_uint_ty!(
18381821
mask_elem.kind(),
1839-
InvalidMonomorphization::ThirdArgElementType {
1840-
span,
1841-
name,
1842-
expected_element: values_elem,
1843-
third_arg: mask_ty,
1844-
}
1822+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
18451823
);
18461824

18471825
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
@@ -1924,14 +1902,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19241902
}
19251903
);
19261904

1927-
let m_elem_bitwidth = require_int_ty!(
1905+
let m_elem_bitwidth = require_int_or_uint_ty!(
19281906
mask_elem.kind(),
1929-
InvalidMonomorphization::ThirdArgElementType {
1930-
span,
1931-
name,
1932-
expected_element: values_elem,
1933-
third_arg: mask_ty,
1934-
}
1907+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
19351908
);
19361909

19371910
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
@@ -2019,15 +1992,10 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20191992
}
20201993
);
20211994

2022-
// The element type of the third argument must be a signed integer type of any width:
2023-
let mask_elem_bitwidth = require_int_ty!(
1995+
// The element type of the third argument must be an integer type of any width:
1996+
let mask_elem_bitwidth = require_int_or_uint_ty!(
20241997
element_ty2.kind(),
2025-
InvalidMonomorphization::ThirdArgElementType {
2026-
span,
2027-
name,
2028-
expected_element: element_ty2,
2029-
third_arg: arg_tys[2]
2030-
}
1998+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
20311999
);
20322000

20332001
// Alignment of T, must be a constant integer value:

compiler/rustc_codegen_ssa/messages.ftl

+1-6
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ codegen_ssa_invalid_monomorphization_inserted_type = invalid monomorphization of
125125
126126
codegen_ssa_invalid_monomorphization_invalid_bitmask = invalid monomorphization of `{$name}` intrinsic: invalid bitmask `{$mask_ty}`, expected `u{$expected_int_bits}` or `[u8; {$expected_bytes}]`
127127
128-
codegen_ssa_invalid_monomorphization_mask_type = invalid monomorphization of `{$name}` intrinsic: found mask element type is `{$ty}`, expected a signed integer type
129-
.note = the mask may be widened, which only has the correct behavior for signed integers
128+
codegen_ssa_invalid_monomorphization_mask_wrong_element_type = invalid monomorphization of `{$name}` intrinsic: expected mask element type to be an integer, found `{$ty}`
130129
131130
codegen_ssa_invalid_monomorphization_mismatched_lengths = invalid monomorphization of `{$name}` intrinsic: mismatched lengths: mask length `{$m_len}` != other vector length `{$v_len}`
132131
@@ -158,8 +157,6 @@ codegen_ssa_invalid_monomorphization_simd_shuffle = invalid monomorphization of
158157
159158
codegen_ssa_invalid_monomorphization_simd_third = invalid monomorphization of `{$name}` intrinsic: expected SIMD third type, found non-SIMD `{$ty}`
160159
161-
codegen_ssa_invalid_monomorphization_third_arg_element_type = invalid monomorphization of `{$name}` intrinsic: expected element type `{$expected_element}` of third argument `{$third_arg}` to be a signed integer type
162-
163160
codegen_ssa_invalid_monomorphization_third_argument_length = invalid monomorphization of `{$name}` intrinsic: expected third argument with length {$in_len} (same as input type `{$in_ty}`), found `{$arg_ty}` with length {$out_len}
164161
165162
codegen_ssa_invalid_monomorphization_unrecognized_intrinsic = invalid monomorphization of `{$name}` intrinsic: unrecognized intrinsic `{$name}`
@@ -172,8 +169,6 @@ codegen_ssa_invalid_monomorphization_unsupported_symbol = invalid monomorphizati
172169
173170
codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size = invalid monomorphization of `{$name}` intrinsic: unsupported {$symbol} from `{$in_ty}` with element `{$in_elem}` of size `{$size}` to `{$ret_ty}`
174171
175-
codegen_ssa_invalid_monomorphization_vector_argument = invalid monomorphization of `{$name}` intrinsic: vector argument `{$in_ty}`'s element type `{$in_elem}`, expected integer element type
176-
177172
codegen_ssa_invalid_no_sanitize = invalid argument for `no_sanitize`
178173
.note = expected one of: `address`, `cfi`, `hwaddress`, `kcfi`, `memory`, `memtag`, `shadow-call-stack`, or `thread`
179174

compiler/rustc_codegen_ssa/src/errors.rs

+2-21
Original file line numberDiff line numberDiff line change
@@ -1037,24 +1037,14 @@ pub enum InvalidMonomorphization<'tcx> {
10371037
v_len: u64,
10381038
},
10391039

1040-
#[diag(codegen_ssa_invalid_monomorphization_mask_type, code = E0511)]
1041-
#[note]
1042-
MaskType {
1040+
#[diag(codegen_ssa_invalid_monomorphization_mask_wrong_element_type, code = E0511)]
1041+
MaskWrongElementType {
10431042
#[primary_span]
10441043
span: Span,
10451044
name: Symbol,
10461045
ty: Ty<'tcx>,
10471046
},
10481047

1049-
#[diag(codegen_ssa_invalid_monomorphization_vector_argument, code = E0511)]
1050-
VectorArgument {
1051-
#[primary_span]
1052-
span: Span,
1053-
name: Symbol,
1054-
in_ty: Ty<'tcx>,
1055-
in_elem: Ty<'tcx>,
1056-
},
1057-
10581048
#[diag(codegen_ssa_invalid_monomorphization_cannot_return, code = E0511)]
10591049
CannotReturn {
10601050
#[primary_span]
@@ -1077,15 +1067,6 @@ pub enum InvalidMonomorphization<'tcx> {
10771067
mutability: ExpectedPointerMutability,
10781068
},
10791069

1080-
#[diag(codegen_ssa_invalid_monomorphization_third_arg_element_type, code = E0511)]
1081-
ThirdArgElementType {
1082-
#[primary_span]
1083-
span: Span,
1084-
name: Symbol,
1085-
expected_element: Ty<'tcx>,
1086-
third_arg: Ty<'tcx>,
1087-
},
1088-
10891070
#[diag(codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size, code = E0511)]
10901071
UnsupportedSymbolOfSize {
10911072
#[primary_span]

library/core/src/intrinsics/simd.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ pub unsafe fn simd_shuffle<T, U, V>(x: T, y: T, idx: U) -> V;
304304
///
305305
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
306306
///
307-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
307+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
308308
///
309309
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, read the pointer.
310310
/// Otherwise if the corresponding value in `mask` is `0`, return the corresponding value from
@@ -325,7 +325,7 @@ pub unsafe fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
325325
///
326326
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
327327
///
328-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
328+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
329329
///
330330
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, write the
331331
/// corresponding value in `val` to the pointer.
@@ -349,7 +349,7 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
349349
///
350350
/// `U` must be a pointer to the element type of `T`
351351
///
352-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
352+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
353353
///
354354
/// For each element, if the corresponding value in `mask` is `!0`, read the corresponding
355355
/// pointer offset from `ptr`.
@@ -372,7 +372,7 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
372372
///
373373
/// `U` must be a pointer to the element type of `T`
374374
///
375-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
375+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
376376
///
377377
/// For each element, if the corresponding value in `mask` is `!0`, write the corresponding
378378
/// value in `val` to the pointer offset from `ptr`.
@@ -556,7 +556,7 @@ pub unsafe fn simd_bitmask<T, U>(x: T) -> U;
556556
///
557557
/// `T` must be a vector.
558558
///
559-
/// `M` must be a signed integer vector with the same length as `T` (but any element size).
559+
/// `M` must be an integer vector with the same length as `T` (but any element size).
560560
///
561561
/// For each element, if the corresponding value in `mask` is `!0`, select the element from
562562
/// `if_true`. If the corresponding value in `mask` is `0`, select the element from

src/tools/miri/src/helpers.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,11 @@ pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar {
13821382
}
13831383

13841384
pub(crate) fn simd_element_to_bool(elem: ImmTy<'_>) -> InterpResult<'_, bool> {
1385+
assert!(
1386+
matches!(elem.layout.ty.kind(), ty::Int(_) | ty::Uint(_)),
1387+
"SIMD mask element type must be an integer, but this is `{}`",
1388+
elem.layout.ty
1389+
);
13851390
let val = elem.to_scalar().to_int(elem.layout.size)?;
13861391
interp_ok(match val {
13871392
0 => false,

tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ pub unsafe fn gather_f32x2(
2929
simd_gather(values, pointers, mask)
3030
}
3131

32+
// CHECK-LABEL: @gather_f32x2_unsigned
33+
#[no_mangle]
34+
pub unsafe fn gather_f32x2_unsigned(
35+
pointers: Vec2<*const f32>,
36+
mask: Vec2<u32>,
37+
values: Vec2<f32>,
38+
) -> Vec2<f32> {
39+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
40+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
41+
// CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
42+
simd_gather(values, pointers, mask)
43+
}
44+
3245
// CHECK-LABEL: @gather_pf32x2
3346
#[no_mangle]
3447
pub unsafe fn gather_pf32x2(

tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32, values: Vec2<f32>
2323
simd_masked_load(mask, pointer, values)
2424
}
2525

26+
// CHECK-LABEL: @load_f32x2_unsigned
27+
#[no_mangle]
28+
pub unsafe fn load_f32x2_unsigned(
29+
mask: Vec2<u32>,
30+
pointer: *const f32,
31+
values: Vec2<f32>,
32+
) -> Vec2<f32> {
33+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
34+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
35+
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
36+
simd_masked_load(mask, pointer, values)
37+
}
38+
2639
// CHECK-LABEL: @load_pf32x4
2740
#[no_mangle]
2841
pub unsafe fn load_pf32x4(

tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs

+9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>)
2323
simd_masked_store(mask, pointer, values)
2424
}
2525

26+
// CHECK-LABEL: @store_f32x2_unsigned
27+
#[no_mangle]
28+
pub unsafe fn store_f32x2_unsigned(mask: Vec2<u32>, pointer: *mut f32, values: Vec2<f32>) {
29+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
30+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
31+
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
32+
simd_masked_store(mask, pointer, values)
33+
}
34+
2635
// CHECK-LABEL: @store_pf32x4
2736
#[no_mangle]
2837
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {

tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>, values: V
2525
simd_scatter(values, pointers, mask)
2626
}
2727

28+
// CHECK-LABEL: @scatter_f32x2_unsigned
29+
#[no_mangle]
30+
pub unsafe fn scatter_f32x2_unsigned(pointers: Vec2<*mut f32>, mask: Vec2<u32>, values: Vec2<f32>) {
31+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
32+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
33+
// CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
34+
simd_scatter(values, pointers, mask)
35+
}
36+
2837
// CHECK-LABEL: @scatter_pf32x2
2938
#[no_mangle]
3039
pub unsafe fn scatter_pf32x2(

tests/codegen/simd-intrinsic/simd-intrinsic-generic-select.rs

+13
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ pub struct b8x4(pub [i8; 4]);
2222
#[derive(Copy, Clone, PartialEq, Debug)]
2323
pub struct i32x4([i32; 4]);
2424

25+
#[repr(simd)]
26+
#[derive(Copy, Clone, PartialEq, Debug)]
27+
pub struct u32x4([u32; 4]);
28+
2529
// CHECK-LABEL: @select_m8
2630
#[no_mangle]
2731
pub unsafe fn select_m8(m: b8x4, a: f32x4, b: f32x4) -> f32x4 {
@@ -40,6 +44,15 @@ pub unsafe fn select_m32(m: i32x4, a: f32x4, b: f32x4) -> f32x4 {
4044
simd_select(m, a, b)
4145
}
4246

47+
// CHECK-LABEL: @select_m32_unsigned
48+
#[no_mangle]
49+
pub unsafe fn select_m32_unsigned(m: u32x4, a: f32x4, b: f32x4) -> f32x4 {
50+
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> %{{.*}}, {{<i32 31, i32 31, i32 31, i32 31>|splat \(i32 31\)}}
51+
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
52+
// CHECK: select <4 x i1> [[B]]
53+
simd_select(m, a, b)
54+
}
55+
4356
// CHECK-LABEL: @select_bitmask
4457
#[no_mangle]
4558
pub unsafe fn select_bitmask(m: i8, a: f32x8, b: f32x8) -> f32x8 {

0 commit comments

Comments
 (0)