Skip to content

Handle align argument attribute #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 79 additions & 27 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,44 +95,86 @@ static void moveConstantAllocasToEntryBlock(
}
}

/// Tries to find and return the alignment of the pointer `value` by looking for
/// an alignment attribute on the defining allocation op or function argument.
/// If no such attribute is found, returns 1 (i.e., assume that no alignment is
/// guaranteed).
static unsigned getAlignmentOf(Value value) {
if (Operation *definingOp = value.getDefiningOp()) {
if (auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
return alloca.getAlignment().value_or(1);
if (auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
if (auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
definingOp, addressOf.getGlobalNameAttr()))
return global.getAlignment().value_or(1);
// We don't currently handle this operation; assume no alignment.
return 1;
}
// Since there is no defining op, this is a block argument. Probably this
// comes directly from a function argument, so check that this is the case.
Operation *parentOp = value.getParentBlock()->getParentOp();
if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
// Use the alignment attribute set for this argument in the parent
// function if it has been set.
auto blockArg = value.cast<BlockArgument>();
if (Attribute alignAttr = func.getArgAttr(
blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
}
// We didn't find anything useful; assume no alignment.
return 1;
}

/// Copies the data from a byval pointer argument into newly alloca'ed memory
/// and returns the value of the alloca.
static Value handleByValArgumentInit(OpBuilder &builder, Location loc,
Value argument, Type elementType,
unsigned elementTypeSize,
unsigned targetAlignment) {
// Allocate the new value on the stack.
Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getI64IntegerAttr(1));
Value allocaOp = builder.create<LLVM::AllocaOp>(
loc, argument.getType(), elementType, one, targetAlignment);
// Copy the pointee to the newly allocated value.
Value copySize = builder.create<LLVM::ConstantOp>(
loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize));
Value isVolatile = builder.create<LLVM::ConstantOp>(
loc, builder.getI1Type(), builder.getBoolAttr(false));
builder.create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize, isVolatile);
return allocaOp;
}

/// Handles a function argument marked with the byval attribute by introducing a
/// memcpy if necessary, either due to the pointee being writeable in the
/// callee, and/or due to an alignment mismatch. `requestedAlignment` specifies
/// the alignment set in the "align" argument attribute (or 1 if no align
/// attribute was set).
static Value handleByValArgument(OpBuilder &builder, Operation *callable,
Value argument,
NamedAttribute byValAttribute) {
Value argument, Type elementType,
unsigned requestedAlignment) {
auto func = cast<LLVM::LLVMFuncOp>(callable);
LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryAttr();
// If there is no memory effects attribute, assume that the function is
// not read-only.
bool isReadOnly = memoryEffects &&
memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
if (isReadOnly)
// Check if there's an alignment mismatch requiring us to copy.
DataLayout dataLayout(callable->getParentOfType<DataLayoutOpInterface>());
unsigned minimumAlignment = dataLayout.getTypeABIAlignment(elementType);
if (isReadOnly && (requestedAlignment <= minimumAlignment ||
getAlignmentOf(argument) >= requestedAlignment))
return argument;
// Resolve the pointee type and its size.
auto ptrType = cast<LLVM::LLVMPointerType>(argument.getType());
Type elementType = cast<TypeAttr>(byValAttribute.getValue()).getValue();
unsigned int typeSize =
DataLayout(callable->getParentOfType<DataLayoutOpInterface>())
.getTypeSize(elementType);
// Allocate the new value on the stack.
Value one = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(1));
Value allocaOp =
builder.create<LLVM::AllocaOp>(func.getLoc(), ptrType, elementType, one);
// Copy the pointee to the newly allocated value.
Value copySize = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(typeSize));
Value isVolatile = builder.create<LLVM::ConstantOp>(
func.getLoc(), builder.getI1Type(), builder.getBoolAttr(false));
builder.create<LLVM::MemcpyOp>(func.getLoc(), allocaOp, argument, copySize,
isVolatile);
return allocaOp;
unsigned targetAlignment = std::max(requestedAlignment, minimumAlignment);
return handleByValArgumentInit(builder, func.getLoc(), argument, elementType,
dataLayout.getTypeSize(elementType),
targetAlignment);
}

/// Returns true if the given argument or result attribute is supported by the
/// inliner, false otherwise.
static bool isArgOrResAttrSupported(NamedAttribute attr) {
if (attr.getName() == LLVM::LLVMDialect::getAlignAttrName())
return false;
if (attr.getName() == LLVM::LLVMDialect::getInAllocaAttrName())
return false;
if (attr.getName() == LLVM::LLVMDialect::getNoAliasAttrName())
Expand Down Expand Up @@ -289,9 +331,19 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
Value argument, Type targetType,
DictionaryAttr argumentAttrs) const final {
if (auto attr =
argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName()))
return handleByValArgument(builder, callable, argument, *attr);
if (std::optional<NamedAttribute> attr =
argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
unsigned requestedAlignment = 1;
if (std::optional<NamedAttribute> alignAttr =
argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
.getValue()
.getLimitedValue();
}
return handleByValArgument(builder, callable, argument, elementType,
requestedAlignment);
}
return argument;
}

Expand Down
64 changes: 63 additions & 1 deletion mlir/test/Dialect/LLVMIR/inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,68 @@ llvm.func @test_byval_write_only(%ptr : !llvm.ptr) {

// -----

llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
llvm.return
}

// CHECK-LABEL: llvm.func @test_byval_input_aligned
// CHECK-SAME: %[[UNALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr
// CHECK-SAME: %[[ALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr
llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr { llvm.align = 16 }) {
// Make sure only the unaligned input triggers a memcpy.
// CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x i16 {alignment = 16
// CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]]
llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
// CHECK-NOT: memcpy
llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> ()
llvm.return
}

// -----

llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
llvm.return
}

// CHECK-LABEL: llvm.func @test_byval_alloca
llvm.func @test_byval_alloca() {
// Make sure only the unaligned alloca triggers a memcpy.
%size = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ALLOCA:.+]] = llvm.alloca {{.+}}alignment = 1
// CHECK: "llvm.intr.memcpy"(%{{.+}}, %[[ALLOCA]]
%unaligned = llvm.alloca %size x i16 { alignment = 1 } : (i64) -> !llvm.ptr
llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
// CHECK-NOT: memcpy
%aligned = llvm.alloca %size x i16 { alignment = 16 } : (i64) -> !llvm.ptr
llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> ()
llvm.return
}

// -----

llvm.mlir.global private @unaligned_global(42 : i64) : i64
llvm.mlir.global private @aligned_global(42 : i64) { alignment = 64 } : i64

llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
llvm.return
}

// CHECK-LABEL: llvm.func @test_byval_global
llvm.func @test_byval_global() {
// Make sure only the unaligned global triggers a memcpy.
// CHECK: %[[UNALIGNED:.+]] = llvm.mlir.addressof @unaligned_global
// CHECK: %[[ALLOCA:.+]] = llvm.alloca
// CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]]
// CHECK-NOT: llvm.alloca
%unaligned = llvm.mlir.addressof @unaligned_global : !llvm.ptr
llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
%aligned = llvm.mlir.addressof @aligned_global : !llvm.ptr
llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> ()
llvm.return
}

// -----

llvm.func @ignored_attrs(%ptr : !llvm.ptr { llvm.inreg, llvm.nocapture, llvm.nofree, llvm.preallocated = i32, llvm.returned, llvm.alignstack = 32 : i64, llvm.writeonly, llvm.noundef, llvm.nonnull }, %x : i32 { llvm.zeroext }) -> (!llvm.ptr { llvm.noundef, llvm.inreg, llvm.nonnull }) {
llvm.return %ptr : !llvm.ptr
}
Expand All @@ -413,7 +475,7 @@ llvm.func @test_ignored_attrs(%ptr : !llvm.ptr, %x : i32) {

// -----

llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.align = 16 : i32 }) {
llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.noalias }) {
llvm.return
}

Expand Down