Skip to content

Commit 90e95e7

Browse files
committed
Reorder the arguments to simd_masked_load
1 parent 54ab131 commit 90e95e7

File tree

8 files changed

+91
-80
lines changed

8 files changed

+91
-80
lines changed

compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
10101010
}
10111011

10121012
sym::simd_masked_load => {
1013-
intrinsic_args!(fx, args => (val, ptr, mask); intrinsic);
1013+
intrinsic_args!(fx, args => (mask, ptr, val); intrinsic);
10141014

10151015
let (val_lane_count, val_lane_ty) = val.layout().ty.simd_size_and_type(fx.tcx);
10161016
let (mask_lane_count, _mask_lane_ty) = mask.layout().ty.simd_size_and_type(fx.tcx);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,85 +1493,92 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14931493
}
14941494

14951495
if name == sym::simd_masked_load {
1496-
// simd_masked_load(values: <N x T>, pointer: *_ T, mask: <N x i{M}>) -> <N x T>
1496+
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
14971497
// * N: number of elements in the input vectors
14981498
// * T: type of the element to load
14991499
// * M: any integer width is supported, will be truncated to i1
15001500
// Loads contiguous elements from memory behind `pointer`, but only for
15011501
// those lanes whose `mask` bit is enabled.
15021502
// The memory addresses corresponding to the “off” lanes are not accessed.
15031503

1504-
// The first argument is a passthrough vector providing values for disabled lanes
1505-
15061504
// The element type of the "mask" argument must be a signed integer type of any width
1507-
let (mask_len, mask_elem) = require_simd!(arg_tys[2], SimdThird);
1505+
let mask_ty = in_ty;
1506+
let (mask_len, mask_elem) = (in_len, in_elem);
1507+
1508+
// The second argument must be a pointer matching the element type
1509+
let pointer_ty = arg_tys[1];
1510+
1511+
// The last argument is a passthrough vector providing values for disabled lanes
1512+
let values_ty = arg_tys[2];
1513+
let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
1514+
15081515
require_simd!(ret_ty, SimdReturn);
15091516

15101517
// Of the same length:
15111518
require!(
1512-
in_len == mask_len,
1519+
values_len == mask_len,
15131520
InvalidMonomorphization::ThirdArgumentLength {
15141521
span,
15151522
name,
1516-
in_len,
1517-
in_ty,
1518-
arg_ty: arg_tys[2],
1519-
out_len: mask_len
1523+
in_len: mask_len,
1524+
in_ty: mask_ty,
1525+
arg_ty: values_ty,
1526+
out_len: values_len
15201527
}
15211528
);
15221529

1523-
// The return type must match the first argument type
1530+
// The return type must match the last argument type
15241531
require!(
1525-
ret_ty == in_ty,
1526-
InvalidMonomorphization::ExpectedReturnType { span, name, in_ty, ret_ty }
1532+
ret_ty == values_ty,
1533+
InvalidMonomorphization::ExpectedReturnType { span, name, in_ty: values_ty, ret_ty }
15271534
);
15281535

1529-
// The second argument must be a pointer matching the element type
15301536
require!(
15311537
matches!(
1532-
arg_tys[1].kind(),
1533-
ty::RawPtr(p) if p.ty == in_elem && p.ty.kind() == in_elem.kind()
1538+
pointer_ty.kind(),
1539+
ty::RawPtr(p) if p.ty == values_elem && p.ty.kind() == values_elem.kind()
15341540
),
15351541
InvalidMonomorphization::ExpectedElementType {
15361542
span,
15371543
name,
1538-
expected_element: in_elem,
1539-
second_arg: arg_tys[1],
1540-
in_elem,
1541-
in_ty,
1544+
expected_element: values_elem,
1545+
second_arg: pointer_ty,
1546+
in_elem: values_elem,
1547+
in_ty: values_ty,
15421548
mutability: ExpectedPointerMutability::Not,
15431549
}
15441550
);
15451551

1546-
// Mask needs to be an integer type
1547-
match mask_elem.kind() {
1548-
ty::Int(_) => (),
1549-
_ => {
1550-
return_error!(InvalidMonomorphization::ThirdArgElementType {
1551-
span,
1552-
name,
1553-
expected_element: mask_elem,
1554-
third_arg: arg_tys[2]
1555-
});
1552+
let expected_int_bits = (mask_len.max(8) - 1).next_power_of_two();
1553+
let expected_bytes = mask_len / 8 + ((mask_len % 8 > 0) as u64);
1554+
1555+
require!(
1556+
matches!(mask_elem.kind(), ty::Int(_)),
1557+
InvalidMonomorphization::InvalidBitmask {
1558+
span,
1559+
name,
1560+
mask_ty,
1561+
expected_int_bits,
1562+
expected_bytes
15561563
}
1557-
}
1564+
);
15581565

15591566
// Alignment of T, must be a constant integer value:
15601567
let alignment_ty = bx.type_i32();
1561-
let alignment = bx.const_i32(bx.align_of(in_ty).bytes() as i32);
1568+
let alignment = bx.const_i32(bx.align_of(values_ty).bytes() as i32);
15621569

15631570
// Truncate the mask vector to a vector of i1s:
15641571
let (mask, mask_ty) = {
15651572
let i1 = bx.type_i1();
15661573
let i1xn = bx.type_vector(i1, mask_len);
1567-
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1574+
(bx.trunc(args[0].immediate(), i1xn), i1xn)
15681575
};
15691576

15701577
let llvm_pointer = bx.type_ptr();
15711578

15721579
// Type of the vector of elements:
1573-
let llvm_elem_vec_ty = llvm_vector_ty(bx, in_elem, mask_len);
1574-
let llvm_elem_vec_str = llvm_vector_str(bx, in_elem, mask_len);
1580+
let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);
1581+
let llvm_elem_vec_str = llvm_vector_str(bx, values_elem, values_len);
15751582

15761583
let llvm_intrinsic = format!("llvm.masked.load.{llvm_elem_vec_str}.p0");
15771584
let fn_ty = bx
@@ -1582,7 +1589,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15821589
None,
15831590
None,
15841591
f,
1585-
&[args[1].immediate(), alignment, mask, args[0].immediate()],
1592+
&[args[1].immediate(), alignment, mask, args[2].immediate()],
15861593
None,
15871594
);
15881595
return Ok(v);

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
540540
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
541541
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
542542
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
543-
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(0)),
543+
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(2)),
544544
sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
545545
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
546546
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),

tests/ui/simd/masked-load-store-build-fail.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#![feature(repr_simd, platform_intrinsics)]
33

44
extern "platform-intrinsic" {
5-
fn simd_masked_load<T, P, M>(values: T, pointer: P, mask: M) -> T;
5+
fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
66
fn simd_masked_store<T, P, M>(values: T, pointer: P, mask: M) -> ();
77
}
88

@@ -16,32 +16,32 @@ fn main() {
1616
let default = Simd::<u8, 4>([9; 4]);
1717

1818
simd_masked_load(
19-
default,
19+
Simd::<i8, 8>([-1, 0, -1, -1, 0, 0, 0, 0]),
2020
arr.as_ptr(),
21-
Simd::<i8, 8>([-1, 0, -1, -1, 0, 0, 0, 0])
21+
default
2222
);
23-
//~^^^^^ ERROR expected third argument with length 4 (same as input type `Simd<u8, 4>`), found `Simd<i8, 8>` with length 8
23+
//~^^^^^ ERROR expected third argument with length 8 (same as input type `Simd<i8, 8>`), found `Simd<u8, 4>` with length 4
2424

2525
simd_masked_load(
26-
default,
26+
Simd::<i8, 4>([-1, 0, -1, -1]),
2727
arr.as_ptr() as *const i8,
28-
Simd::<i8, 4>([-1, 0, -1, -1])
28+
default
2929
);
3030
//~^^^^^ ERROR expected element type `u8` of second argument `*const i8` to be a pointer to the element type `u8` of the first argument `Simd<u8, 4>`, found `u8` != `*_ u8`
3131

3232
simd_masked_load(
33-
Simd::<u32, 4>([9; 4]),
33+
Simd::<i8, 4>([-1, 0, -1, -1]),
3434
arr.as_ptr(),
35-
Simd::<i8, 4>([-1, 0, -1, -1])
35+
Simd::<u32, 4>([9; 4])
3636
);
3737
//~^^^^^ ERROR expected element type `u32` of second argument `*const u8` to be a pointer to the element type `u32` of the first argument `Simd<u32, 4>`, found `u32` != `*_ u32`
3838

3939
simd_masked_load(
40-
default,
40+
Simd::<u8, 4>([1, 0, 1, 1]),
4141
arr.as_ptr(),
42-
Simd::<u8, 4>([1, 0, 1, 1])
42+
default
4343
);
44-
//~^^^^^ ERROR expected element type `u8` of third argument `Simd<u8, 4>` to be a signed integer type
44+
//~^^^^^ ERROR invalid bitmask `Simd<u8, 4>`, expected `u8` or `[u8; 1]`
4545

4646
simd_masked_store(
4747
Simd([5u32; 4]),

tests/ui/simd/masked-load-store-build-fail.stderr

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
1-
error[E0511]: invalid monomorphization of `simd_masked_load` intrinsic: expected third argument with length 4 (same as input type `Simd<u8, 4>`), found `Simd<i8, 8>` with length 8
1+
error[E0511]: invalid monomorphization of `simd_masked_load` intrinsic: expected third argument with length 8 (same as input type `Simd<i8, 8>`), found `Simd<u8, 4>` with length 4
22
--> $DIR/masked-load-store-build-fail.rs:18:9
33
|
44
LL | / simd_masked_load(
5-
LL | | default,
5+
LL | | Simd::<i8, 8>([-1, 0, -1, -1, 0, 0, 0, 0]),
66
LL | | arr.as_ptr(),
7-
LL | | Simd::<i8, 8>([-1, 0, -1, -1, 0, 0, 0, 0])
7+
LL | | default
88
LL | | );
99
| |_________^
1010

1111
error[E0511]: invalid monomorphization of `simd_masked_load` intrinsic: expected element type `u8` of second argument `*const i8` to be a pointer to the element type `u8` of the first argument `Simd<u8, 4>`, found `u8` != `*_ u8`
1212
--> $DIR/masked-load-store-build-fail.rs:25:9
1313
|
1414
LL | / simd_masked_load(
15-
LL | | default,
15+
LL | | Simd::<i8, 4>([-1, 0, -1, -1]),
1616
LL | | arr.as_ptr() as *const i8,
17-
LL | | Simd::<i8, 4>([-1, 0, -1, -1])
17+
LL | | default
1818
LL | | );
1919
| |_________^
2020

2121
error[E0511]: invalid monomorphization of `simd_masked_load` intrinsic: expected element type `u32` of second argument `*const u8` to be a pointer to the element type `u32` of the first argument `Simd<u32, 4>`, found `u32` != `*_ u32`
2222
--> $DIR/masked-load-store-build-fail.rs:32:9
2323
|
2424
LL | / simd_masked_load(
25-
LL | | Simd::<u32, 4>([9; 4]),
25+
LL | | Simd::<i8, 4>([-1, 0, -1, -1]),
2626
LL | | arr.as_ptr(),
27-
LL | | Simd::<i8, 4>([-1, 0, -1, -1])
27+
LL | | Simd::<u32, 4>([9; 4])
2828
LL | | );
2929
| |_________^
3030

31-
error[E0511]: invalid monomorphization of `simd_masked_load` intrinsic: expected element type `u8` of third argument `Simd<u8, 4>` to be a signed integer type
31+
error[E0511]: invalid monomorphization of `simd_masked_load` intrinsic: invalid bitmask `Simd<u8, 4>`, expected `u8` or `[u8; 1]`
3232
--> $DIR/masked-load-store-build-fail.rs:39:9
3333
|
3434
LL | / simd_masked_load(
35-
LL | | default,
35+
LL | | Simd::<u8, 4>([1, 0, 1, 1]),
3636
LL | | arr.as_ptr(),
37-
LL | | Simd::<u8, 4>([1, 0, 1, 1])
37+
LL | | default
3838
LL | | );
3939
| |_________^
4040

tests/ui/simd/masked-load-store-check-fail.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#![feature(repr_simd, platform_intrinsics)]
33

44
extern "platform-intrinsic" {
5-
fn simd_masked_load<T, P, M>(values: T, pointer: P, mask: M) -> T;
5+
fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
66
fn simd_masked_store<T, P, M>(values: T, pointer: P, mask: M) -> ();
77
}
88

@@ -16,17 +16,17 @@ fn main() {
1616
let default = Simd::<u8, 4>([9; 4]);
1717

1818
let _x: Simd<u8, 2> = simd_masked_load(
19-
Simd::<u8, 4>([9; 4]),
19+
Simd::<i8, 4>([-1, 0, -1, -1]),
2020
arr.as_ptr(),
21-
Simd::<i8, 4>([-1, 0, -1, -1])
21+
Simd::<u8, 4>([9; 4])
2222
);
23-
//~^^^^ ERROR mismatched types
23+
//~^^ ERROR mismatched types
2424

2525
let _x: Simd<u32, 4> = simd_masked_load(
26-
default,
26+
Simd::<u8, 4>([1, 0, 1, 1]),
2727
arr.as_ptr(),
28-
Simd::<u8, 4>([1, 0, 1, 1])
28+
default
2929
);
30-
//~^^^^ ERROR mismatched types
30+
//~^^ ERROR mismatched types
3131
}
3232
}

tests/ui/simd/masked-load-store-check-fail.stderr

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
error[E0308]: mismatched types
2-
--> $DIR/masked-load-store-check-fail.rs:19:13
2+
--> $DIR/masked-load-store-check-fail.rs:21:13
33
|
44
LL | let _x: Simd<u8, 2> = simd_masked_load(
55
| ---------------- arguments to this function are incorrect
6-
LL | Simd::<u8, 4>([9; 4]),
6+
...
7+
LL | Simd::<u8, 4>([9; 4])
78
| ^^^^^^^^^^^^^^^^^^^^^ expected `2`, found `4`
89
|
910
= note: expected struct `Simd<_, 2>`
@@ -13,24 +14,25 @@ help: the return type of this call is `Simd<u8, 4>` due to the type of the argum
1314
|
1415
LL | let _x: Simd<u8, 2> = simd_masked_load(
1516
| _______________________________^
16-
LL | | Simd::<u8, 4>([9; 4]),
17-
| | --------------------- this argument influences the return type of `simd_masked_load`
17+
LL | | Simd::<i8, 4>([-1, 0, -1, -1]),
1818
LL | | arr.as_ptr(),
19-
LL | | Simd::<i8, 4>([-1, 0, -1, -1])
19+
LL | | Simd::<u8, 4>([9; 4])
20+
| | --------------------- this argument influences the return type of `simd_masked_load`
2021
LL | | );
2122
| |_________^
2223
note: function defined here
2324
--> $DIR/masked-load-store-check-fail.rs:5:8
2425
|
25-
LL | fn simd_masked_load<T, P, M>(values: T, pointer: P, mask: M) -> T;
26+
LL | fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
2627
| ^^^^^^^^^^^^^^^^
2728

2829
error[E0308]: mismatched types
29-
--> $DIR/masked-load-store-check-fail.rs:26:13
30+
--> $DIR/masked-load-store-check-fail.rs:28:13
3031
|
3132
LL | let _x: Simd<u32, 4> = simd_masked_load(
3233
| ---------------- arguments to this function are incorrect
33-
LL | default,
34+
...
35+
LL | default
3436
| ^^^^^^^ expected `Simd<u32, 4>`, found `Simd<u8, 4>`
3537
|
3638
= note: expected struct `Simd<u32, _>`
@@ -40,16 +42,16 @@ help: the return type of this call is `Simd<u8, 4>` due to the type of the argum
4042
|
4143
LL | let _x: Simd<u32, 4> = simd_masked_load(
4244
| ________________________________^
43-
LL | | default,
44-
| | ------- this argument influences the return type of `simd_masked_load`
45+
LL | | Simd::<u8, 4>([1, 0, 1, 1]),
4546
LL | | arr.as_ptr(),
46-
LL | | Simd::<u8, 4>([1, 0, 1, 1])
47+
LL | | default
48+
| | ------- this argument influences the return type of `simd_masked_load`
4749
LL | | );
4850
| |_________^
4951
note: function defined here
5052
--> $DIR/masked-load-store-check-fail.rs:5:8
5153
|
52-
LL | fn simd_masked_load<T, P, M>(values: T, pointer: P, mask: M) -> T;
54+
LL | fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
5355
| ^^^^^^^^^^^^^^^^
5456

5557
error: aborting due to 2 previous errors

tests/ui/simd/masked-load-store.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#![feature(repr_simd, platform_intrinsics)]
33

44
extern "platform-intrinsic" {
5-
fn simd_masked_load<T, P, M>(values: T, pointer: P, mask: M) -> T;
5+
fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
66
fn simd_masked_store<T, P, M>(values: T, pointer: P, mask: M) -> ();
77
}
88

@@ -16,11 +16,13 @@ fn main() {
1616
let b_src = [4u8, 5, 6, 7];
1717
let b_default = Simd::<u8, 4>([9; 4]);
1818
let b: Simd::<u8, 4> = simd_masked_load(
19-
b_default,
19+
Simd::<i8, 4>([-1, 0, -1, -1]),
2020
b_src.as_ptr(),
21-
Simd::<i8, 4>([-1, 0, -1, -1])
21+
b_default
2222
);
2323

24+
assert_eq!(&b.0, &[4, 9, 6, 7]);
25+
2426
let mut output = [u8::MAX; 5];
2527

2628
simd_masked_store(a, output.as_mut_ptr(), Simd::<i8, 4>([-1, -1, -1, 0]));

0 commit comments

Comments
 (0)