Skip to content

Commit ed114b6

Browse files
definelichtgysit
authored andcommitted
[MLIR][LLVM] Copy byval attributes during inlining.
Support inlining of function calls with the byval attribute on function arguments by copying the pointee into a newly alloca'ed pointer at the callsite before inlining. The alignment attribute is not yet taken into account. Reviewed By: ftynse, gysit Differential Revision: https://reviews.llvm.org/D146616
1 parent a2033ff commit ed114b6

File tree

2 files changed

+99
-16
lines changed

2 files changed

+99
-16
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1414
#include "TypeDetail.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
1617
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1718
#include "mlir/IR/Builders.h"
@@ -2854,6 +2855,39 @@ static void moveConstantAllocasToEntryBlock(
28542855
}
28552856
}
28562857

2858+
static Value handleByValArgument(OpBuilder &builder, Operation *callable,
2859+
Value argument,
2860+
NamedAttribute byValAttribute) {
2861+
auto func = cast<LLVM::LLVMFuncOp>(callable);
2862+
LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryAttr();
2863+
// If there is no memory effects attribute, assume that the function is
2864+
// not read-only.
2865+
bool isReadOnly = memoryEffects &&
2866+
memoryEffects.getArgMem() != ModRefInfo::ModRef &&
2867+
memoryEffects.getArgMem() != ModRefInfo::Mod;
2868+
if (isReadOnly)
2869+
return argument;
2870+
// Resolve the pointee type and its size.
2871+
auto ptrType = cast<LLVM::LLVMPointerType>(argument.getType());
2872+
Type elementType = cast<TypeAttr>(byValAttribute.getValue()).getValue();
2873+
unsigned int typeSize =
2874+
DataLayout(callable->getParentOfType<DataLayoutOpInterface>())
2875+
.getTypeSize(elementType);
2876+
// Allocate the new value on the stack.
2877+
Value one = builder.create<LLVM::ConstantOp>(
2878+
func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(1));
2879+
Value allocaOp =
2880+
builder.create<LLVM::AllocaOp>(func.getLoc(), ptrType, elementType, one);
2881+
// Copy the pointee to the newly allocated value.
2882+
Value copySize = builder.create<LLVM::ConstantOp>(
2883+
func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(typeSize));
2884+
Value isVolatile = builder.create<LLVM::ConstantOp>(
2885+
func.getLoc(), builder.getI1Type(), builder.getBoolAttr(false));
2886+
builder.create<LLVM::MemcpyOp>(func.getLoc(), allocaOp, argument, copySize,
2887+
isVolatile);
2888+
return allocaOp;
2889+
}
2890+
28572891
namespace {
28582892
struct LLVMInlinerInterface : public DialectInlinerInterface {
28592893
using DialectInlinerInterface::DialectInlinerInterface;
@@ -2866,8 +2900,19 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
28662900
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
28672901
if (!callOp || !funcOp)
28682902
return false;
2869-
// TODO: Handle argument and result attributes;
2870-
if (funcOp.getArgAttrs() || funcOp.getResAttrs())
2903+
if (auto attrs = funcOp.getArgAttrs()) {
2904+
for (Attribute attr : *attrs) {
2905+
auto attrDict = cast<DictionaryAttr>(attr);
2906+
for (NamedAttribute attr : attrDict) {
2907+
if (attr.getName() == LLVMDialect::getByValAttrName())
2908+
continue;
2909+
// TODO: Handle all argument attributes;
2910+
return false;
2911+
}
2912+
}
2913+
}
2914+
// TODO: Handle result attributes;
2915+
if (funcOp.getResAttrs())
28712916
return false;
28722917
// TODO: Handle exceptions.
28732918
if (funcOp.getPersonality())
@@ -2942,6 +2987,14 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
29422987
dst.replaceAllUsesWith(src);
29432988
}
29442989

2990+
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
2991+
Value argument, Type targetType,
2992+
DictionaryAttr argumentAttrs) const final {
2993+
if (auto attr = argumentAttrs.getNamed(LLVMDialect::getByValAttrName()))
2994+
return handleByValArgument(builder, callable, argument, *attr);
2995+
return argument;
2996+
}
2997+
29452998
void processInlinedCallBlocks(
29462999
Operation *call,
29473000
iterator_range<Region::iterator> inlinedBlocks) const override {

mlir/test/Dialect/LLVMIR/inlining.mlir

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,20 +187,6 @@ llvm.func @caller() {
187187

188188
// -----
189189

190-
llvm.func @callee(%ptr : !llvm.ptr {llvm.byval = !llvm.ptr}) -> (!llvm.ptr) {
191-
llvm.return %ptr : !llvm.ptr
192-
}
193-
194-
// CHECK-LABEL: llvm.func @caller
195-
// CHECK-NEXT: llvm.call @callee
196-
// CHECK-NEXT: return
197-
llvm.func @caller(%ptr : !llvm.ptr) -> (!llvm.ptr) {
198-
%0 = llvm.call @callee(%ptr) : (!llvm.ptr) -> (!llvm.ptr)
199-
llvm.return %0 : !llvm.ptr
200-
}
201-
202-
// -----
203-
204190
llvm.func @static_alloca() -> f32 {
205191
%0 = llvm.mlir.constant(4 : i32) : i32
206192
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
@@ -349,3 +335,47 @@ llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 {
349335
^bb3(%blockArg: f32):
350336
llvm.return %blockArg : f32
351337
}
338+
339+
// -----
340+
341+
llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) {
342+
llvm.return
343+
}
344+
345+
// CHECK-LABEL: llvm.func @test_byval
346+
// CHECK-SAME: %[[PTR:[a-zA-Z0-9_]+]]: !llvm.ptr
347+
// CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x f64
348+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[PTR]]
349+
llvm.func @test_byval(%ptr : !llvm.ptr) {
350+
llvm.call @with_byval_arg(%ptr) : (!llvm.ptr) -> ()
351+
llvm.return
352+
}
353+
354+
// -----
355+
356+
llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} {
357+
llvm.return
358+
}
359+
360+
// CHECK-LABEL: llvm.func @test_byval_read_only
361+
// CHECK-NOT: llvm.call
362+
// CHECK-NEXT: llvm.return
363+
llvm.func @test_byval_read_only(%ptr : !llvm.ptr) {
364+
llvm.call @with_byval_arg(%ptr) : (!llvm.ptr) -> ()
365+
llvm.return
366+
}
367+
368+
// -----
369+
370+
llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory = #llvm.memory_effects<other = readwrite, argMem = write, inaccessibleMem = readwrite>} {
371+
llvm.return
372+
}
373+
374+
// CHECK-LABEL: llvm.func @test_byval_write_only
375+
// CHECK-SAME: %[[PTR:[a-zA-Z0-9_]+]]: !llvm.ptr
376+
// CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x f64
377+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[PTR]]
378+
llvm.func @test_byval_write_only(%ptr : !llvm.ptr) {
379+
llvm.call @with_byval_arg(%ptr) : (!llvm.ptr) -> ()
380+
llvm.return
381+
}

0 commit comments

Comments
 (0)