Skip to content

Commit cae9c25

Browse files
committed
Added runtime detection
Cannot do a `cupid` test because they don't support `amx`.
1 parent 1cbf7c7 commit cae9c25

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

crates/std_detect/src/detect/arch/x86.rs

+15
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ features! {
8181
/// * `"avxneconvert"`
8282
/// * `"avxvnniint8"`
8383
/// * `"avxvnniint16"`
84+
/// * `"amx-tile"`
85+
/// * `"amx-int8"`
86+
/// * `"amx-bf16"`
87+
/// * `"amx-fp16"`
88+
/// * `"amx-complex"`
8489
/// * `"f16c"`
8590
/// * `"fma"`
8691
/// * `"bmi1"`
@@ -187,6 +192,16 @@ features! {
187192
/// AVX-VNNI_INT8 (VNNI with 16-bit Integers)
188193
@FEATURE: #[unstable(feature = "avx512_target_feature", issue = "44839")] avxvnniint8: "avxvnniint8";
189194
/// AVX-VNNI_INT16 (VNNI with 8-bit integers)
195+
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_tile: "amx-tile";
196+
/// AMX (Advanced Matrix Extensions) - Tile load/store
197+
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_int8: "amx-int8";
198+
/// AMX-INT8 (Operations on 8-bit integers)
199+
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_bf16: "amx-bf16";
200+
/// AMX-BF16 (BFloat16 Operations)
201+
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_fp16: "amx-fp16";
202+
/// AMX-FP16 (Float16 Operations)
203+
@FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_complex: "amx-complex";
204+
/// AMX-COMPLEX (Complex number Operations)
190205
@FEATURE: #[stable(feature = "simd_x86", since = "1.27.0")] f16c: "f16c";
191206
/// F16C (Conversions between IEEE-754 `binary16` and `binary32` formats)
192207
@FEATURE: #[stable(feature = "simd_x86", since = "1.27.0")] fma: "fma";

crates/std_detect/src/detect/os/x86.rs

+11
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ pub(crate) fn detect_features() -> cache::Initializer {
165165
// * SSE -> `XCR0.SSE[1]`
166166
// * AVX -> `XCR0.AVX[2]`
167167
// * AVX-512 -> `XCR0.AVX-512[7:5]`.
168+
// * AMX -> `XCR0.AMX[18:17]`
168169
//
169170
// by setting the corresponding bits of `XCR0` to `1`.
170171
//
@@ -175,6 +176,8 @@ pub(crate) fn detect_features() -> cache::Initializer {
175176
let os_avx_support = xcr0 & 6 == 6;
176177
// Test `XCR0.AVX-512[7:5]` with the mask `0b1110_0000 == 0xe0`:
177178
let os_avx512_support = xcr0 & 0xe0 == 0xe0;
179+
// Test `XCR0.AMX[18:17]` with the mask `0b110_0000_0000_0000_0000 == 0x60000`
180+
let os_amx_support = xcr0 & 0x60000 == 0x60000;
178181

179182
// Only if the OS and the CPU support saving/restoring the AVX
180183
// registers we enable `xsave` support:
@@ -237,6 +240,14 @@ pub(crate) fn detect_features() -> cache::Initializer {
237240
enable(extended_features_edx, 8, Feature::avx512vp2intersect);
238241
enable(extended_features_edx, 23, Feature::avx512fp16);
239242
enable(extended_features_eax_leaf_1, 5, Feature::avx512bf16);
243+
244+
if os_amx_support {
245+
enable(extended_features_edx, 24, Feature::amx_tile);
246+
enable(extended_features_edx, 25, Feature::amx_int8);
247+
enable(extended_features_edx, 22, Feature::amx_bf16);
248+
enable(extended_features_eax_leaf_1, 21, Feature::amx_fp16);
249+
enable(extended_features_edx_leaf_1, 8, Feature::amx_complex);
250+
}
240251
}
241252
}
242253
}

crates/std_detect/tests/x86-specific.rs

+24-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
22
#![allow(internal_features)]
3-
#![feature(stdarch_internal, avx512_target_feature)]
3+
#![feature(stdarch_internal, avx512_target_feature, x86_amx_intrinsics)]
44

55
extern crate cupid;
66
#[macro_use]
@@ -24,34 +24,34 @@ fn dump() {
2424
println!("f16c: {:?}", is_x86_feature_detected!("f16c"));
2525
println!("avx: {:?}", is_x86_feature_detected!("avx"));
2626
println!("avx2: {:?}", is_x86_feature_detected!("avx2"));
27-
println!("avx512f {:?}", is_x86_feature_detected!("avx512f"));
28-
println!("avx512cd {:?}", is_x86_feature_detected!("avx512cd"));
29-
println!("avx512er {:?}", is_x86_feature_detected!("avx512er"));
30-
println!("avx512pf {:?}", is_x86_feature_detected!("avx512pf"));
31-
println!("avx512bw {:?}", is_x86_feature_detected!("avx512bw"));
32-
println!("avx512dq {:?}", is_x86_feature_detected!("avx512dq"));
33-
println!("avx512vl {:?}", is_x86_feature_detected!("avx512vl"));
34-
println!("avx512_ifma {:?}", is_x86_feature_detected!("avx512ifma"));
27+
println!("avx512f: {:?}", is_x86_feature_detected!("avx512f"));
28+
println!("avx512cd: {:?}", is_x86_feature_detected!("avx512cd"));
29+
println!("avx512er: {:?}", is_x86_feature_detected!("avx512er"));
30+
println!("avx512pf: {:?}", is_x86_feature_detected!("avx512pf"));
31+
println!("avx512bw: {:?}", is_x86_feature_detected!("avx512bw"));
32+
println!("avx512dq: {:?}", is_x86_feature_detected!("avx512dq"));
33+
println!("avx512vl: {:?}", is_x86_feature_detected!("avx512vl"));
34+
println!("avx512_ifma: {:?}", is_x86_feature_detected!("avx512ifma"));
3535
println!("avx512vbmi {:?}", is_x86_feature_detected!("avx512vbmi"));
3636
println!(
37-
"avx512_vpopcntdq {:?}",
37+
"avx512_vpopcntdq: {:?}",
3838
is_x86_feature_detected!("avx512vpopcntdq")
3939
);
40-
println!("avx512vbmi2 {:?}", is_x86_feature_detected!("avx512vbmi2"));
41-
println!("gfni {:?}", is_x86_feature_detected!("gfni"));
42-
println!("vaes {:?}", is_x86_feature_detected!("vaes"));
43-
println!("vpclmulqdq {:?}", is_x86_feature_detected!("vpclmulqdq"));
44-
println!("avx512vnni {:?}", is_x86_feature_detected!("avx512vnni"));
40+
println!("avx512vbmi2: {:?}", is_x86_feature_detected!("avx512vbmi2"));
41+
println!("gfni: {:?}", is_x86_feature_detected!("gfni"));
42+
println!("vaes: {:?}", is_x86_feature_detected!("vaes"));
43+
println!("vpclmulqdq: {:?}", is_x86_feature_detected!("vpclmulqdq"));
44+
println!("avx512vnni: {:?}", is_x86_feature_detected!("avx512vnni"));
4545
println!(
46-
"avx512bitalg {:?}",
46+
"avx512bitalg: {:?}",
4747
is_x86_feature_detected!("avx512bitalg")
4848
);
49-
println!("avx512bf16 {:?}", is_x86_feature_detected!("avx512bf16"));
49+
println!("avx512bf16: {:?}", is_x86_feature_detected!("avx512bf16"));
5050
println!(
51-
"avx512vp2intersect {:?}",
51+
"avx512vp2intersect: {:?}",
5252
is_x86_feature_detected!("avx512vp2intersect")
5353
);
54-
println!("avx512fp16 {:?}", is_x86_feature_detected!("avx512fp16"));
54+
println!("avx512fp16: {:?}", is_x86_feature_detected!("avx512fp16"));
5555
println!("fma: {:?}", is_x86_feature_detected!("fma"));
5656
println!("abm: {:?}", is_x86_feature_detected!("abm"));
5757
println!("bmi: {:?}", is_x86_feature_detected!("bmi1"));
@@ -79,6 +79,11 @@ fn dump() {
7979
"avxvnniint16: {:?}",
8080
is_x86_feature_detected!("avxvnniint16")
8181
);
82+
println!("amx-bf16: {:?}", is_x86_feature_detected!("amx-bf16"));
83+
println!("amx-tile: {:?}", is_x86_feature_detected!("amx-tile"));
84+
println!("amx-int8: {:?}", is_x86_feature_detected!("amx-int8"));
85+
println!("amx-fp16: {:?}", is_x86_feature_detected!("amx-fp16"));
86+
println!("amx-complex: {:?}", is_x86_feature_detected!("amx-complex"));
8287
}
8388

8489
#[cfg(feature = "std_detect_env_override")]

0 commit comments

Comments
 (0)