Skip to content

Commit 78af498

Browse files
authored
[LLVM][IR] Support target extension types in vectors (#140630)
This change is to support target extension types in vectors. The change allows sized target extension types to opt-in to being a valid vector element. Allowing target extension types as vector elements will allow backends to use vector operations such as `insertelement` and `extractelement` on their target types with minimal changes. RFC: https://discourse.llvm.org/t/rfc-supporting-sized-target-extension-types-in-vector/86431
1 parent 7433d8c commit 78af498

File tree

5 files changed

+53
-6
lines changed

5 files changed

+53
-6
lines changed

llvm/docs/LangRef.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4438,7 +4438,8 @@ the type size is smaller than the type's store size.
44384438
< vscale x <# elements> x <elementtype> > ; Scalable vector
44394439

44404440
The number of elements is a constant integer value larger than 0;
4441-
elementtype may be any integer, floating-point or pointer type. Vectors
4441+
elementtype may be any integer, floating-point, pointer type, or a sized
4442+
target extension type that has the ``CanBeVectorElement`` property. Vectors
44424443
of size zero are not allowed. For scalable vectors, the total number of
44434444
elements is a constant multiple (called vscale) of the specified number
44444445
of elements; vscale is a positive integer that is unknown at compile time

llvm/include/llvm/IR/DerivedTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,8 @@ class TargetExtType : public Type {
845845
/// This type may be allocated on the stack, either as the allocated type
846846
/// of an alloca instruction or as a byval function parameter.
847847
CanBeLocal = 1U << 2,
848+
// This type may be used as an element in a vector.
849+
CanBeVectorElement = 1U << 3,
848850
};
849851

850852
/// Returns true if the target extension type contains the given property.

llvm/lib/IR/Type.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,12 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
790790
}
791791

792792
bool VectorType::isValidElementType(Type *ElemTy) {
793-
return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
794-
ElemTy->isPointerTy() || ElemTy->getTypeID() == TypedPointerTyID;
793+
if (ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
794+
ElemTy->isPointerTy() || ElemTy->getTypeID() == TypedPointerTyID)
795+
return true;
796+
if (auto *TTy = dyn_cast<TargetExtType>(ElemTy))
797+
return TTy->hasProperty(TargetExtType::CanBeVectorElement);
798+
return false;
795799
}
796800

797801
//===----------------------------------------------------------------------===//
@@ -801,8 +805,9 @@ bool VectorType::isValidElementType(Type *ElemTy) {
801805
FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
802806
assert(NumElts > 0 && "#Elements of a VectorType must be greater than 0");
803807
assert(isValidElementType(ElementType) && "Element type of a VectorType must "
804-
"be an integer, floating point, or "
805-
"pointer type.");
808+
"be an integer, floating point, "
809+
"pointer type, or a valid target "
810+
"extension type.");
806811

807812
auto EC = ElementCount::getFixed(NumElts);
808813

@@ -968,7 +973,11 @@ struct TargetTypeInfo {
968973

969974
template <typename... ArgTys>
970975
TargetTypeInfo(Type *LayoutType, ArgTys... Properties)
971-
: LayoutType(LayoutType), Properties((0 | ... | Properties)) {}
976+
: LayoutType(LayoutType), Properties((0 | ... | Properties)) {
977+
assert((!(this->Properties & TargetExtType::CanBeVectorElement) ||
978+
LayoutType->isSized()) &&
979+
"Vector element type must be sized");
980+
}
972981
};
973982
} // anonymous namespace
974983

@@ -1037,6 +1046,13 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
10371046
TargetExtType::CanBeGlobal);
10381047
}
10391048

1049+
// Type used to test vector element target extension property.
1050+
// Can be removed once a public target extension type uses CanBeVectorElement.
1051+
if (Name == "llvm.test.vectorelement") {
1052+
return TargetTypeInfo(Type::getInt32Ty(C), TargetExtType::CanBeLocal,
1053+
TargetExtType::CanBeVectorElement);
1054+
}
1055+
10401056
return TargetTypeInfo(Type::getVoidTy(C));
10411057
}
10421058

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s
2+
3+
; CHECK: invalid vector element type
4+
5+
define void @bad() {
6+
%v = alloca <2 x target("spirv.Image")>
7+
ret void
8+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -passes=verify -S %s | FileCheck %s
3+
4+
define <2 x target("llvm.test.vectorelement")> @vec_ops(<2 x target("llvm.test.vectorelement")> %x) {
5+
; CHECK-LABEL: define <2 x target("llvm.test.vectorelement")> @vec_ops(
6+
; CHECK-SAME: <2 x target("llvm.test.vectorelement")> [[X:%.*]]) {
7+
; CHECK-NEXT: [[A:%.*]] = alloca <2 x target("llvm.test.vectorelement")>{{.*}}
8+
; CHECK-NEXT: store <2 x target("llvm.test.vectorelement")> [[X]], ptr [[A]], {{.*}}
9+
; CHECK-NEXT: [[LOAD:%.*]] = load <2 x target("llvm.test.vectorelement")>, ptr [[A]], {{.*}}
10+
; CHECK-NEXT: [[ELT:%.*]] = extractelement <2 x target("llvm.test.vectorelement")> [[LOAD]], i64 0
11+
; CHECK-NEXT: [[RES:%.*]] = insertelement <2 x target("llvm.test.vectorelement")> poison, target("llvm.test.vectorelement") [[ELT]], i64 1
12+
; CHECK-NEXT: ret <2 x target("llvm.test.vectorelement")> [[RES]]
13+
;
14+
%a = alloca <2 x target("llvm.test.vectorelement")>
15+
store <2 x target("llvm.test.vectorelement")> %x, ptr %a
16+
%load = load <2 x target("llvm.test.vectorelement")>, ptr %a
17+
%elt = extractelement <2 x target("llvm.test.vectorelement")> %load, i64 0
18+
%res = insertelement <2 x target("llvm.test.vectorelement")> poison, target("llvm.test.vectorelement") %elt, i64 1
19+
ret <2 x target("llvm.test.vectorelement")> %res
20+
}

0 commit comments

Comments
 (0)