Skip to content

Commit 9218afe

Browse files
committed
Add auto-bitcasts from/to x86amx and i32x256 for AMX intrinsics
1 parent 3ef8e64 commit 9218afe

File tree

16 files changed

+99
-32
lines changed

16 files changed

+99
-32
lines changed

compiler/rustc_codegen_gcc/src/type_of.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt::Write;
22

3-
use gccjit::{Struct, Type};
3+
use gccjit::{RValue, Struct, Type};
44
use rustc_abi as abi;
55
use rustc_abi::Primitive::*;
66
use rustc_abi::{
@@ -373,7 +373,11 @@ impl<'gcc, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
373373
unimplemented!();
374374
}
375375

376-
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Type<'gcc> {
376+
fn fn_decl_backend_type(
377+
&self,
378+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
379+
_fn_ptr: RValue<'gcc>,
380+
) -> Type<'gcc> {
377381
// FIXME(antoyo): Should we do something with `FnAbiGcc::fn_attributes`?
378382
let FnAbiGcc { return_type, arguments_type, is_c_variadic, .. } = fn_abi.gcc_type(self);
379383
self.context.new_function_pointer_type(None, return_type, &arguments_type, is_c_variadic)

compiler/rustc_codegen_llvm/src/abi.rs

+30-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::cmp;
44
use libc::c_uint;
55
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
66
use rustc_codegen_ssa::MemFlags;
7+
use rustc_codegen_ssa::common::TypeKind;
78
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
89
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
910
use rustc_codegen_ssa::traits::*;
@@ -308,7 +309,12 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
308309
}
309310

310311
pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
311-
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
312+
fn llvm_type(
313+
&self,
314+
cx: &CodegenCx<'ll, 'tcx>,
315+
name: &[u8],
316+
is_llvm_intrinsic: bool,
317+
) -> &'ll Type;
312318
fn ptr_to_llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
313319
fn llvm_cconv(&self, cx: &CodegenCx<'ll, 'tcx>) -> llvm::CallConv;
314320

@@ -325,26 +331,45 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
325331
}
326332

327333
impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
328-
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
334+
fn llvm_type(
335+
&self,
336+
cx: &CodegenCx<'ll, 'tcx>,
337+
name: &[u8],
338+
is_llvm_intrinsic: bool,
339+
) -> &'ll Type {
329340
// Ignore "extra" args from the call site for C variadic functions.
330341
// Only the "fixed" args are part of the LLVM function signature.
331342
let args =
332343
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };
333344

345+
let amx_intrinsic =
346+
is_llvm_intrinsic && name.starts_with(b"llvm.x86.") && name.ends_with(b".internal");
347+
let adjust_ty = |ty| {
348+
// Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
349+
if amx_intrinsic && cx.type_kind(ty) == TypeKind::Vector && cx.vector_length(ty) == 256
350+
{
351+
let element_ty = cx.element_type(ty);
352+
if cx.type_kind(element_ty) == TypeKind::Integer && cx.int_width(element_ty) == 32 {
353+
return cx.type_x86amx();
354+
}
355+
}
356+
ty
357+
};
358+
334359
// This capacity calculation is approximate.
335360
let mut llargument_tys = Vec::with_capacity(
336361
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
337362
);
338363

339-
let llreturn_ty = match &self.ret.mode {
364+
let llreturn_ty = adjust_ty(match &self.ret.mode {
340365
PassMode::Ignore => cx.type_void(),
341366
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
342367
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
343368
PassMode::Indirect { .. } => {
344369
llargument_tys.push(cx.type_ptr());
345370
cx.type_void()
346371
}
347-
};
372+
});
348373

349374
for arg in args {
350375
// Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +413,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388413
cast.llvm_type(cx)
389414
}
390415
};
391-
llargument_tys.push(llarg_ty);
416+
llargument_tys.push(adjust_ty(llarg_ty));
392417
}
393418

394419
if self.c_variadic {

compiler/rustc_codegen_llvm/src/builder.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14351435
if let Some(fn_abi) = fn_abi {
14361436
fn_abi.apply_attrs_callsite(self, call);
14371437
}
1438-
call
1438+
1439+
if self.cx.type_kind(self.cx.val_ty(call)) == TypeKind::X86_AMX {
1440+
self.bitcast(call, self.cx.type_vector(self.cx.type_i32(), 256))
1441+
} else {
1442+
call
1443+
}
14391444
}
14401445

14411446
fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {

compiler/rustc_codegen_llvm/src/callee.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
//! and methods are represented as just a fn ptr and not a full
55
//! closure.
66
7-
use rustc_codegen_ssa::common;
7+
use rustc_codegen_ssa::{base, common};
88
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv};
99
use rustc_middle::ty::{self, Instance, TypeVisitableExt};
1010
use tracing::debug;
@@ -36,6 +36,8 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t
3636
llfn
3737
} else {
3838
let instance_def_id = instance.def_id();
39+
let is_llvm_intrinsic = base::is_llvm_intrinsic(tcx, instance_def_id);
40+
3941
let llfn = if tcx.sess.target.arch == "x86"
4042
&& let Some(dllimport) = crate::common::get_dllimport(tcx, instance_def_id, sym)
4143
{
@@ -53,6 +55,7 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t
5355
),
5456
fn_abi,
5557
Some(instance),
58+
is_llvm_intrinsic,
5659
);
5760

5861
// Fix for https://github.com/rust-lang/rust/issues/104453
@@ -69,7 +72,7 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t
6972
llvm::set_dllimport_storage_class(llfn);
7073
llfn
7174
} else {
72-
cx.declare_fn(sym, fn_abi, Some(instance))
75+
cx.declare_fn(sym, fn_abi, Some(instance), is_llvm_intrinsic)
7376
};
7477
debug!("get_fn: not casting pointer!");
7578

compiler/rustc_codegen_llvm/src/consts.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ fn check_and_apply_linkage<'ll, 'tcx>(
191191
let fn_sig = sig.with(*header);
192192

193193
let fn_abi = cx.fn_abi_of_fn_ptr(fn_sig, ty::List::empty());
194-
cx.declare_fn(sym, &fn_abi, None)
194+
cx.declare_fn(
195+
sym,
196+
&fn_abi,
197+
None,
198+
rustc_codegen_ssa::base::is_llvm_intrinsic(cx.tcx, def_id),
199+
)
195200
} else {
196201
cx.declare_global(sym, cx.type_i8())
197202
}

compiler/rustc_codegen_llvm/src/declare.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
147147
name: &str,
148148
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
149149
instance: Option<Instance<'tcx>>,
150+
is_llvm_intrinsic: bool,
150151
) -> &'ll Value {
151152
debug!("declare_rust_fn(name={:?}, fn_abi={:?})", name, fn_abi);
152153

@@ -158,7 +159,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
158159
fn_abi.llvm_cconv(self),
159160
llvm::UnnamedAddr::Global,
160161
llvm::Visibility::Default,
161-
fn_abi.llvm_type(self),
162+
fn_abi.llvm_type(self, name.as_ref(), is_llvm_intrinsic),
162163
);
163164
fn_abi.apply_attrs_llfn(self, llfn, instance);
164165

compiler/rustc_codegen_llvm/src/intrinsic.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1088,10 +1088,11 @@ fn gen_fn<'a, 'll, 'tcx>(
10881088
name: &str,
10891089
rust_fn_sig: ty::PolyFnSig<'tcx>,
10901090
codegen: &mut dyn FnMut(Builder<'a, 'll, 'tcx>),
1091+
is_llvm_intrinsic: bool,
10911092
) -> (&'ll Type, &'ll Value) {
10921093
let fn_abi = cx.fn_abi_of_fn_ptr(rust_fn_sig, ty::List::empty());
1093-
let llty = fn_abi.llvm_type(cx);
1094-
let llfn = cx.declare_fn(name, fn_abi, None);
1094+
let llty = fn_abi.llvm_type(cx, name.as_ref(), is_llvm_intrinsic);
1095+
let llfn = cx.declare_fn(name, fn_abi, None, is_llvm_intrinsic);
10951096
cx.set_frame_pointer_type(llfn);
10961097
cx.apply_target_cpu_attr(llfn);
10971098
// FIXME(eddyb) find a nicer way to do this.
@@ -1147,7 +1148,7 @@ fn get_rust_try_fn<'a, 'll, 'tcx>(
11471148
hir::Safety::Unsafe,
11481149
ExternAbi::Rust,
11491150
));
1150-
let rust_try = gen_fn(cx, "__rust_try", rust_fn_sig, codegen);
1151+
let rust_try = gen_fn(cx, "__rust_try", rust_fn_sig, codegen, false);
11511152
cx.rust_try_fn.set(Some(rust_try));
11521153
rust_try
11531154
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,9 @@ unsafe extern "C" {
10551055
pub(crate) fn LLVMPointerTypeInContext(C: &Context, AddressSpace: c_uint) -> &Type;
10561056
pub(crate) fn LLVMVectorType(ElementType: &Type, ElementCount: c_uint) -> &Type;
10571057

1058+
// Special X86 Type for AMX
1059+
pub(crate) fn LLVMX86AMXTypeInContext(C: &Context) -> &Type;
1060+
10581061
pub(crate) fn LLVMGetElementType(Ty: &Type) -> &Type;
10591062
pub(crate) fn LLVMGetVectorSize(VectorTy: &Type) -> c_uint;
10601063

@@ -1177,6 +1180,7 @@ unsafe extern "C" {
11771180

11781181
// Operations on functions
11791182
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
1183+
pub(crate) fn LLVMGetIntrinsicID(Fn: &Value) -> c_uint;
11801184

11811185
// Operations on parameters
11821186
pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;

compiler/rustc_codegen_llvm/src/mono_item.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> {
5353
assert!(!instance.args.has_infer());
5454

5555
let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty());
56-
let lldecl = self.declare_fn(symbol_name, fn_abi, Some(instance));
56+
let lldecl = self.declare_fn(
57+
symbol_name,
58+
fn_abi,
59+
Some(instance),
60+
rustc_codegen_ssa::base::is_llvm_intrinsic(self.tcx, instance.def_id()),
61+
);
5762
llvm::set_linkage(lldecl, base::linkage_to_llvm(linkage));
5863
let attrs = self.tcx.codegen_fn_attrs(instance.def_id());
5964
base::set_link_section(lldecl, attrs);

compiler/rustc_codegen_llvm/src/type_.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
154154
)
155155
}
156156
}
157+
158+
pub(crate) fn type_x86amx(&self) -> &'ll Type {
159+
unsafe { llvm::LLVMX86AMXTypeInContext(self.llcx()) }
160+
}
157161
}
158162

159163
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -284,8 +288,14 @@ impl<'ll, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
284288
fn cast_backend_type(&self, ty: &CastTarget) -> &'ll Type {
285289
ty.llvm_type(self)
286290
}
287-
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> &'ll Type {
288-
fn_abi.llvm_type(self)
291+
fn fn_decl_backend_type(
292+
&self,
293+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
294+
fn_ptr: &'ll Value,
295+
) -> &'ll Type {
296+
let intrinsic_id = unsafe { llvm::LLVMGetIntrinsicID(fn_ptr) };
297+
// When the function is not an intrinsic, `Intrinsic::getIntrinsicID` returns `Intrinsic::not_intrinsic`, which is always defined to be 0
298+
fn_abi.llvm_type(self, llvm::get_value_name(fn_ptr), intrinsic_id != 0)
289299
}
290300
fn fn_ptr_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> &'ll Type {
291301
fn_abi.ptr_to_llvm_type(self)

compiler/rustc_codegen_ssa/src/base.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,14 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
914914
ongoing_codegen
915915
}
916916

917+
pub fn is_llvm_intrinsic(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
918+
if let Some(name) = tcx.codegen_fn_attrs(def_id).link_name {
919+
name.as_str().starts_with("llvm.")
920+
} else {
921+
false
922+
}
923+
}
924+
917925
/// Returns whether a call from the current crate to the [`Instance`] would produce a call
918926
/// from `compiler_builtins` to a symbol the linker must resolve.
919927
///
@@ -927,14 +935,6 @@ pub fn is_call_from_compiler_builtins_to_upstream_monomorphization<'tcx>(
927935
tcx: TyCtxt<'tcx>,
928936
instance: Instance<'tcx>,
929937
) -> bool {
930-
fn is_llvm_intrinsic(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
931-
if let Some(name) = tcx.codegen_fn_attrs(def_id).link_name {
932-
name.as_str().starts_with("llvm.")
933-
} else {
934-
false
935-
}
936-
}
937-
938938
let def_id = instance.def_id();
939939
!def_id.is_local()
940940
&& tcx.is_compiler_builtins(LOCAL_CRATE)

compiler/rustc_codegen_ssa/src/mir/block.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> {
187187

188188
// If there is a cleanup block and the function we're calling can unwind, then
189189
// do an invoke, otherwise do a call.
190-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
190+
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);
191191

192192
let fn_attrs = if bx.tcx().def_kind(fx.instance.def_id()).has_codegen_attrs() {
193193
Some(bx.tcx().codegen_fn_attrs(fx.instance.def_id()))
@@ -1806,7 +1806,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
18061806
if is_call_from_compiler_builtins_to_upstream_monomorphization(bx.tcx(), instance) {
18071807
bx.abort();
18081808
} else {
1809-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
1809+
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);
18101810

18111811
let llret = bx.call(fn_ty, None, Some(fn_abi), fn_ptr, &[], funclet.as_ref(), None);
18121812
bx.apply_attrs_to_cleanup_callsite(llret);

compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
779779
};
780780
let fn_ptr = bx.get_fn_addr(instance);
781781
let fn_abi = bx.fn_abi_of_instance(instance, ty::List::empty());
782-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
782+
let fn_ty = bx.fn_decl_backend_type(fn_abi, fn_ptr);
783783
let fn_attrs = if bx.tcx().def_kind(instance.def_id()).has_codegen_attrs() {
784784
Some(bx.tcx().codegen_fn_attrs(instance.def_id()))
785785
} else {

compiler/rustc_codegen_ssa/src/size_of_val.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pub fn size_and_align_of_dst<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
6767
// Generate the call. Cannot use `do_call` since we don't have a MIR terminator so we
6868
// can't create a `TerminationCodegenHelper`. (But we are in good company, this code is
6969
// duplicated plenty of times.)
70-
let fn_ty = bx.fn_decl_backend_type(fn_abi);
70+
let fn_ty = bx.fn_decl_backend_type(fn_abi, llfn);
7171

7272
bx.call(
7373
fn_ty,

compiler/rustc_codegen_ssa/src/traits/type_.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ pub trait LayoutTypeCodegenMethods<'tcx>: BackendTypes {
9696
/// such as when it's stack-allocated or when it's being loaded or stored.
9797
fn backend_type(&self, layout: TyAndLayout<'tcx>) -> Self::Type;
9898
fn cast_backend_type(&self, ty: &CastTarget) -> Self::Type;
99-
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Self::Type;
99+
fn fn_decl_backend_type(
100+
&self,
101+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
102+
fn_ptr: Self::Value,
103+
) -> Self::Type;
100104
fn fn_ptr_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Self::Type;
101105
fn reg_backend_type(&self, ty: &Reg) -> Self::Type;
102106
/// The backend type used for a rust type when it's in an SSA register.

compiler/rustc_target/src/target_features.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ pub fn all_rust_features() -> impl Iterator<Item = (&'static str, Stability)> {
786786
// certain size to have their "proper" ABI on each architecture.
787787
// Note that they must be kept sorted by vector size.
788788
const X86_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] =
789-
&[(128, "sse"), (256, "avx"), (512, "avx512f")]; // FIXME: might need changes for AVX10.
789+
&[(128, "sse"), (256, "avx"), (512, "avx512f"), (8192, "amx-tile")];
790790
const AARCH64_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] = &[(128, "neon")];
791791

792792
// We might want to add "helium" too.

0 commit comments

Comments
 (0)