Skip to content

Commit df7fcb1

Browse files
committed
Use same code for masked conversion for masked stores and loads
1 parent 2abd0f0 commit df7fcb1

File tree

4 files changed

+122
-63
lines changed

4 files changed

+122
-63
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,33 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
965965
}};
966966
}
967967

968+
/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
969+
macro_rules! require_int_ty {
970+
($ty: expr, $diag: expr) => {
971+
match $ty {
972+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
973+
_ => {
974+
return_error!($diag);
975+
}
976+
}
977+
};
978+
}
979+
980+
/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
981+
macro_rules! require_int_or_uint_ty {
982+
($ty: expr, $diag: expr) => {
983+
match $ty {
984+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
985+
ty::Uint(i) => {
986+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
987+
}
988+
_ => {
989+
return_error!($diag);
990+
}
991+
}
992+
};
993+
}
994+
968995
/// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
969996
/// down to an i1 based mask that can be used by llvm intrinsics.
970997
///
@@ -1252,10 +1279,10 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
12521279
m_len == v_len,
12531280
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
12541281
);
1255-
let in_elem_bitwidth = match m_elem_ty.kind() {
1256-
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1257-
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
1258-
};
1282+
let in_elem_bitwidth = require_int_ty!(
1283+
m_elem_ty.kind(),
1284+
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
1285+
);
12591286
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
12601287
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
12611288
}
@@ -1274,24 +1301,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
12741301
let expected_bytes = expected_int_bits / 8 + ((expected_int_bits % 8 > 0) as u64);
12751302

12761303
// Integer vector <i{in_bitwidth} x in_len>:
1277-
let (i_xn, in_elem_bitwidth) = match in_elem.kind() {
1278-
ty::Int(i) => (
1279-
args[0].immediate(),
1280-
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1281-
),
1282-
ty::Uint(i) => (
1283-
args[0].immediate(),
1284-
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1285-
),
1286-
_ => return_error!(InvalidMonomorphization::VectorArgument {
1287-
span,
1288-
name,
1289-
in_ty,
1290-
in_elem
1291-
}),
1292-
};
1304+
let in_elem_bitwidth = require_int_or_uint_ty!(
1305+
in_elem.kind(),
1306+
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
1307+
);
12931308

1294-
let i1xn = vector_mask_to_bitmask(bx, i_xn, in_elem_bitwidth, in_len);
1309+
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
12951310
// Bitcast <i1 x N> to iN:
12961311
let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));
12971312

@@ -1509,17 +1524,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15091524
}
15101525
);
15111526

1512-
let mask_elem_bitwidth = match element_ty2.kind() {
1513-
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1514-
_ => {
1515-
return_error!(InvalidMonomorphization::ThirdArgElementType {
1516-
span,
1517-
name,
1518-
expected_element: element_ty2,
1519-
third_arg: arg_tys[2]
1520-
})
1527+
let mask_elem_bitwidth = require_int_ty!(
1528+
element_ty2.kind(),
1529+
InvalidMonomorphization::ThirdArgElementType {
1530+
span,
1531+
name,
1532+
expected_element: element_ty2,
1533+
third_arg: arg_tys[2]
15211534
}
1522-
};
1535+
);
15231536

15241537
// Alignment of T, must be a constant integer value:
15251538
let alignment_ty = bx.type_i32();
@@ -1612,8 +1625,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
16121625
}
16131626
);
16141627

1615-
require!(
1616-
matches!(mask_elem.kind(), ty::Int(_)),
1628+
let m_elem_bitwidth = require_int_ty!(
1629+
mask_elem.kind(),
16171630
InvalidMonomorphization::ThirdArgElementType {
16181631
span,
16191632
name,
@@ -1622,17 +1635,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
16221635
}
16231636
);
16241637

1638+
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
1639+
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
1640+
16251641
// Alignment of T, must be a constant integer value:
16261642
let alignment_ty = bx.type_i32();
16271643
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
16281644

1629-
// Truncate the mask vector to a vector of i1s:
1630-
let (mask, mask_ty) = {
1631-
let i1 = bx.type_i1();
1632-
let i1xn = bx.type_vector(i1, mask_len);
1633-
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1634-
};
1635-
16361645
let llvm_pointer = bx.type_ptr();
16371646

16381647
// Type of the vector of elements:
@@ -1704,8 +1713,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17041713
}
17051714
);
17061715

1707-
require!(
1708-
matches!(mask_elem.kind(), ty::Int(_)),
1716+
let m_elem_bitwidth = require_int_ty!(
1717+
mask_elem.kind(),
17091718
InvalidMonomorphization::ThirdArgElementType {
17101719
span,
17111720
name,
@@ -1714,17 +1723,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17141723
}
17151724
);
17161725

1726+
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
1727+
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
1728+
17171729
// Alignment of T, must be a constant integer value:
17181730
let alignment_ty = bx.type_i32();
17191731
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
17201732

1721-
// Truncate the mask vector to a vector of i1s:
1722-
let (mask, mask_ty) = {
1723-
let i1 = bx.type_i1();
1724-
let i1xn = bx.type_vector(i1, in_len);
1725-
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1726-
};
1727-
17281733
let ret_t = bx.type_void();
17291734

17301735
let llvm_pointer = bx.type_ptr();
@@ -1803,17 +1808,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18031808
);
18041809

18051810
// The element type of the third argument must be a signed integer type of any width:
1806-
let mask_elem_bitwidth = match element_ty2.kind() {
1807-
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1808-
_ => {
1809-
return_error!(InvalidMonomorphization::ThirdArgElementType {
1810-
span,
1811-
name,
1812-
expected_element: element_ty2,
1813-
third_arg: arg_tys[2]
1814-
});
1811+
let mask_elem_bitwidth = require_int_ty!(
1812+
element_ty2.kind(),
1813+
InvalidMonomorphization::ThirdArgElementType {
1814+
span,
1815+
name,
1816+
expected_element: element_ty2,
1817+
third_arg: arg_tys[2]
18151818
}
1816-
};
1819+
);
18171820

18181821
// Alignment of T, must be a constant integer value:
18191822
let alignment_ty = bx.type_i32();
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// verify that simd mask reductions do not introduce additional bit shift operations
2+
//@ revisions: x86 aarch64
3+
//@ [x86] compile-flags: --target=x86_64-unknown-linux-gnu -C llvm-args=-x86-asm-syntax=intel
4+
//@ [x86] needs-llvm-components: x86
5+
//@ [aarch64] compile-flags: --target=aarch64-unknown-linux-gnu
6+
//@ [aarch64] needs-llvm-components: aarch64
7+
//@ [aarch64] min-llvm-version: 15.0
8+
//@ assembly-output: emit-asm
9+
//@ compile-flags: --crate-type=lib -O
10+
11+
#![feature(no_core, lang_items, repr_simd, intrinsics)]
12+
#![no_core]
13+
#![allow(non_camel_case_types)]
14+
15+
// Because we don't have core yet.
16+
#[lang = "sized"]
17+
pub trait Sized {}
18+
19+
#[lang = "copy"]
20+
trait Copy {}
21+
22+
#[repr(simd)]
23+
pub struct mask8x16([i8; 16]);
24+
25+
extern "rust-intrinsic" {
26+
fn simd_reduce_all<T>(x: T) -> bool;
27+
fn simd_reduce_any<T>(x: T) -> bool;
28+
}
29+
30+
// CHECK-LABEL: mask_reduce_all:
31+
#[no_mangle]
32+
pub unsafe fn mask_reduce_all(m: mask8x16) -> bool {
33+
// x86: movdqa
34+
// x86-NEXT: pmovmskb
35+
// aarch64: cmge
36+
// aarch64-NEXT: umaxv
37+
simd_reduce_all(m)
38+
}
39+
40+
// CHECK-LABEL: mask_reduce_any:
41+
#[no_mangle]
42+
pub unsafe fn mask_reduce_any(m: mask8x16) -> bool {
43+
// x86: movdqa
44+
// x86-NEXT: pmovmskb
45+
// aarch64: cmlt
46+
// aarch64-NEXT: umaxv
47+
simd_reduce_any(m)
48+
}

tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,18 @@ extern "rust-intrinsic" {
2121
#[no_mangle]
2222
pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32,
2323
values: Vec2<f32>) -> Vec2<f32> {
24-
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> {{.*}}, <2 x float> {{.*}})
24+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, <i32 31, i32 31>
25+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
26+
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
2527
simd_masked_load(mask, pointer, values)
2628
}
2729

2830
// CHECK-LABEL: @load_pf32x4
2931
#[no_mangle]
3032
pub unsafe fn load_pf32x4(mask: Vec4<i32>, pointer: *const *const f32,
3133
values: Vec4<*const f32>) -> Vec4<*const f32> {
32-
// CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}}, <4 x ptr> {{.*}})
34+
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, <i32 31, i32 31, i32 31, i32 31>
35+
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
36+
// CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]], <4 x ptr> {{.*}})
3337
simd_masked_load(mask, pointer, values)
3438
}

tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ extern "rust-intrinsic" {
2020
// CHECK-LABEL: @store_f32x2
2121
#[no_mangle]
2222
pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>) {
23-
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> {{.*}})
23+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, <i32 31, i32 31>
24+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
25+
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
2426
simd_masked_store(mask, pointer, values)
2527
}
2628

2729
// CHECK-LABEL: @store_pf32x4
2830
#[no_mangle]
2931
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {
30-
// CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}})
32+
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, <i32 31, i32 31, i32 31, i32 31>
33+
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
34+
// CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]])
3135
simd_masked_store(mask, pointer, values)
3236
}

0 commit comments

Comments
 (0)