Skip to content

[CIR] Add side effect attribute to call operations #144201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,22 +227,26 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
//===--------------------------------------------------------------------===//

cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee,
mlir::Type returnType, mlir::ValueRange operands) {
return create<cir::CallOp>(loc, callee, returnType, operands);
mlir::Type returnType, mlir::ValueRange operands,
cir::SideEffect sideEffect = cir::SideEffect::All) {
return create<cir::CallOp>(loc, callee, returnType, operands, sideEffect);
}

cir::CallOp createCallOp(mlir::Location loc, cir::FuncOp callee,
mlir::ValueRange operands) {
mlir::ValueRange operands,
cir::SideEffect sideEffect = cir::SideEffect::All) {
return createCallOp(loc, mlir::SymbolRefAttr::get(callee),
callee.getFunctionType().getReturnType(), operands);
callee.getFunctionType().getReturnType(), operands,
sideEffect);
}

cir::CallOp createIndirectCallOp(mlir::Location loc,
mlir::Value indirectTarget,
cir::FuncType funcType,
mlir::ValueRange operands) {
mlir::ValueRange operands,
cir::SideEffect sideEffect) {
return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(),
operands);
operands, sideEffect);
}

//===--------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class CIR_TypedAttr<string name, string attrMnemonic, list<Trait> traits = []>
let assemblyFormat = [{}];
}

class CIR_I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, summary, cases> {
let cppNamespace = "::cir";
}

class CIRUnitAttr<string name, string attrMnemonic, list<Trait> traits = []>
: CIR_Attr<name, attrMnemonic, traits> {
let returnType = "bool";
Expand Down
43 changes: 40 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,36 @@ def FuncOp : CIR_Op<"func", [
// CallOp
//===----------------------------------------------------------------------===//

def CIR_SideEffect : CIR_I32EnumAttr<
"SideEffect", "allowed side effects of a function", [
I32EnumAttrCase<"All", 1, "all">,
I32EnumAttrCase<"Pure", 2, "pure">,
I32EnumAttrCase<"Const", 3, "const">
]> {
let description = [{
The side effect attribute specifies the possible side effects of the callee
of a call operation. This is an enumeration attribute and all possible
enumerators are:

- all: The callee can have any side effects. This is the default if no side
effects are explicitly listed.
- pure: The callee may read data from memory, but it cannot write data to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to model these using the GCC attribute names? I guess this is OK for now, but it seems like we might be better off with more general modeling that is consistent with the LLVM IR side effect handling (memory(read) = pure and memory(none) = const).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is that we should be closer to source than LLVM here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't necessarily mean to say I wanted to change this in order to be consistent with the LLVM IR representation. I just think that's a better, more general representation. To me pure and const are unclear, particularly given that both terms have conflicting meanings in C++.

Do we represent side-effects on LLVM intrinsic calls? Those would require a richer representation. It would also be nice to mark things like whether a function modifies errno or whether it depends on or modifies the floating-point environment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, clarity is indeed key, didn't think from that perspective.

Do we represent side-effects on LLVM intrinsic calls?

If you mean whether LLVMIntrinsicCallOp does it, the answer is no but we should (likewise for the whole family of CIR operations that mostly cover intrinsics). We'll definitely need them for things like llvm/test/Verifier/fp-intrinsics-pass.ll, etc.

This is a good reminder so we pay close attention for when PRs for all those arrive, so we can make sure those operations will include a way to handle such attributes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea to extend this attribute beyond memory effects to model more stuff that could be called "side effects". errno could be a good extension, maybe we could even encode whether the callee throws into this attribute.

memory. This has the same effect as the GNU C/C++ attribute
`__attribute__((pure))`.
- const: The callee may not read or write data from memory. This has the
same effect as the GNU C/C++ attribute `__attribute__((const))`.

Examples:

```mlir
%2 = cir.call @add(%0, %1) : (!s32i, !s32i) -> !s32i
%2 = cir.call @add(%0, %1) : (!s32i, !s32i) -> !s32i side_effect(pure)
%2 = cir.call @add(%0, %1) : (!s32i, !s32i) -> !s32i side_effect(const)
```
}];
let cppNamespace = "::cir";
}

class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
: Op<CIR_Dialect, mnemonic,
!listconcat(extra_traits,
Expand Down Expand Up @@ -1911,7 +1941,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
// will add in the future.

dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<CIR_AnyType>:$args);
Variadic<CIR_AnyType>:$args,
DefaultValuedAttr<CIR_SideEffect, "SideEffect::All">:$side_effect);
}

def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
Expand Down Expand Up @@ -1942,20 +1973,26 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
let builders = [
// Build a call op for a direct call
OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
"mlir::ValueRange":$operands), [{
"mlir::ValueRange":$operands,
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
assert(callee && "callee attribute is required for direct call");
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
$_state.addAttribute("side_effect",
SideEffectAttr::get($_builder.getContext(), sideEffect));
if (resType && !isa<VoidType>(resType))
$_state.addTypes(resType);
}]>,
// Build a call op for an indirect call
OpBuilder<(ins "mlir::Value":$calleePtr, "mlir::Type":$resType,
"mlir::ValueRange":$operands), [{
"mlir::ValueRange":$operands,
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
$_state.addOperands(calleePtr);
$_state.addOperands(operands);
if (resType && !isa<VoidType>(resType))
$_state.addTypes(resType);
$_state.addAttribute("side_effect",
SideEffectAttr::get($_builder.getContext(), sideEffect));
}]>,
];
}
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Interfaces/CIROpInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ let cppNamespace = "::cir" in {
"Return the number of operands, accounts for indirect call or "
"exception info",
"unsigned", "getNumArgOperands", (ins)>,
InterfaceMethod<"Return the side effects of the call operation",
"cir::SideEffect", "getSideEffect", (ins)>,
];
}

Expand Down
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ struct MissingFeatures {
static bool opCallReturn() { return false; }
static bool opCallArgEvaluationOrder() { return false; }
static bool opCallCallConv() { return false; }
static bool opCallSideEffect() { return false; }
static bool opCallNoPrototypeFunc() { return false; }
static bool opCallMustTail() { return false; }
static bool opCallVirtual() { return false; }
Expand Down
46 changes: 39 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,35 @@ void CIRGenFunction::emitAggregateStore(mlir::Value value, Address dest) {
builder.createStore(*currSrcLoc, value, dest);
}

/// Construct the CIR attribute list of a function or call.
void CIRGenModule::constructAttributeList(CIRGenCalleeInfo calleeInfo,
cir::SideEffect &sideEffect) {
assert(!cir::MissingFeatures::opCallCallConv());
sideEffect = cir::SideEffect::All;

assert(!cir::MissingFeatures::opCallAttrs());

const Decl *targetDecl = calleeInfo.getCalleeDecl().getDecl();

if (targetDecl) {
assert(!cir::MissingFeatures::opCallAttrs());

// 'const', 'pure' and 'noalias' attributed functions are also nounwind.
if (targetDecl->hasAttr<ConstAttr>()) {
// gcc specifies that 'const' functions have greater restrictions than
// 'pure' functions, so they also cannot have infinite loops.
sideEffect = cir::SideEffect::Const;
} else if (targetDecl->hasAttr<PureAttr>()) {
// gcc specifies that 'pure' functions cannot have infinite loops.
sideEffect = cir::SideEffect::Pure;
}

assert(!cir::MissingFeatures::opCallAttrs());
}

assert(!cir::MissingFeatures::opCallAttrs());
}

/// Returns the canonical formal type of the given C++ method.
static CanQual<FunctionProtoType> getFormalType(const CXXMethodDecl *md) {
return md->getType()
Expand Down Expand Up @@ -386,7 +415,8 @@ static cir::CIRCallOpInterface
emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
cir::FuncOp directFuncOp,
const SmallVectorImpl<mlir::Value> &cirCallArgs) {
const SmallVectorImpl<mlir::Value> &cirCallArgs,
cir::SideEffect sideEffect) {
CIRGenBuilderTy &builder = cgf.getBuilder();

assert(!cir::MissingFeatures::opCallSurroundingTry());
Expand All @@ -397,11 +427,11 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
if (indirectFuncTy) {
// TODO(cir): Set calling convention for indirect calls.
assert(!cir::MissingFeatures::opCallCallConv());
return builder.createIndirectCallOp(callLoc, indirectFuncVal,
indirectFuncTy, cirCallArgs);
return builder.createIndirectCallOp(
callLoc, indirectFuncVal, indirectFuncTy, cirCallArgs, sideEffect);
}

return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
return builder.createCallOp(callLoc, directFuncOp, cirCallArgs, sideEffect);
}

const CIRGenFunctionInfo &
Expand Down Expand Up @@ -513,8 +543,9 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
funcName = calleeFuncOp.getName();

assert(!cir::MissingFeatures::opCallCallConv());
assert(!cir::MissingFeatures::opCallSideEffect());
assert(!cir::MissingFeatures::opCallAttrs());
cir::SideEffect sideEffect;
cgm.constructAttributeList(callee.getAbstractInfo(), sideEffect);

assert(!cir::MissingFeatures::invokeOp());

Expand All @@ -538,8 +569,9 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
assert(!cir::MissingFeatures::opCallAttrs());

mlir::Location callLoc = loc;
cir::CIRCallOpInterface theCall = emitCallLikeOp(
*this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs);
cir::CIRCallOpInterface theCall =
emitCallLikeOp(*this, loc, indirectFuncTy, indirectFuncVal, directFuncOp,
cirCallArgs, sideEffect);

if (callOp)
*callOp = theCall;
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ class CIRGenCallee {
/// callee
CIRGenCallee prepareConcreteCallee(CIRGenFunction &cgf) const;

CIRGenCalleeInfo getAbstractInfo() const {
assert(!cir::MissingFeatures::opCallVirtual());
assert(isOrdinary());
return abstractInfo;
}

mlir::Operation *getFunctionPointer() const {
assert(isOrdinary());
return reinterpret_cast<mlir::Operation *>(kindOrFunctionPtr);
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENMODULE_H

#include "CIRGenBuilder.h"
#include "CIRGenCall.h"
#include "CIRGenTypeCache.h"
#include "CIRGenTypes.h"
#include "CIRGenValue.h"
Expand Down Expand Up @@ -158,6 +159,15 @@ class CIRGenModule : public CIRGenTypeCache {
const CXXRecordDecl *derivedClass,
llvm::iterator_range<CastExpr::path_const_iterator> path);

/// Get the CIR attributes and calling convention to use for a particular
/// function type.
///
/// \param calleeInfo - The callee information these attributes are being
/// constructed for. If valid, the attributes applied to this decl may
/// contribute to the function attributes and calling convention.
void constructAttributeList(CIRGenCalleeInfo calleeInfo,
cir::SideEffect &sideEffect);

/// Return a constant array for the given string.
mlir::Attribute getConstantArrayFromStringLiteral(const StringLiteral *e);

Expand Down
66 changes: 63 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,46 @@ Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
// Helpers
//===----------------------------------------------------------------------===//

// Parses one of the keywords provided in the list `keywords` and returns the
// position of the parsed keyword in the list. If none of the keywords from the
// list is parsed, returns -1.
static int parseOptionalKeywordAlternative(AsmParser &parser,
ArrayRef<llvm::StringRef> keywords) {
for (auto en : llvm::enumerate(keywords)) {
if (succeeded(parser.parseOptionalKeyword(en.value())))
return en.index();
}
return -1;
}

namespace {
template <typename Ty> struct EnumTraits {};

#define REGISTER_ENUM_TYPE(Ty) \
template <> struct EnumTraits<cir::Ty> { \
static llvm::StringRef stringify(cir::Ty value) { \
return stringify##Ty(value); \
} \
static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
}

REGISTER_ENUM_TYPE(SideEffect);
} // namespace

/// Parse an enum from the keyword, return failure if the keyword is not found.
template <typename EnumTy, typename RetTy = EnumTy>
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
llvm::SmallVector<llvm::StringRef, 10> names;
for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));

int index = parseOptionalKeywordAlternative(parser, names);
if (index == -1)
return failure();
result = static_cast<RetTy>(index);
return success();
}

// Check if a region's termination omission is valid and, if so, creates and
// inserts the omitted terminator into the region.
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
Expand Down Expand Up @@ -534,6 +574,18 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.parseRParen())
return mlir::failure();

if (parser.parseOptionalKeyword("side_effect").succeeded()) {
if (parser.parseLParen().failed())
return failure();
cir::SideEffect sideEffect;
if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
return failure();
if (parser.parseRParen().failed())
return failure();
auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
result.addAttribute("side_effect", attr);
}

if (parser.parseOptionalAttrDict(result.attributes))
return ::mlir::failure();

Expand All @@ -556,7 +608,8 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
static void printCallCommon(mlir::Operation *op,
mlir::FlatSymbolRefAttr calleeSym,
mlir::Value indirectCallee,
mlir::OpAsmPrinter &printer) {
mlir::OpAsmPrinter &printer,
cir::SideEffect sideEffect) {
printer << ' ';

auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
Expand All @@ -572,7 +625,13 @@ static void printCallCommon(mlir::Operation *op,
}
printer << "(" << ops << ")";

printer.printOptionalAttrDict(op->getAttrs(), {"callee"});
if (sideEffect != cir::SideEffect::All) {
printer << " side_effect(";
printer << stringifySideEffect(sideEffect);
printer << ")";
}

printer.printOptionalAttrDict(op->getAttrs(), {"callee", "side_effect"});

printer << " : ";
printer.printFunctionalType(op->getOperands().getTypes(),
Expand All @@ -586,7 +645,8 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,

void cir::CallOp::print(mlir::OpAsmPrinter &p) {
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
cir::SideEffect sideEffect = getSideEffect();
printCallCommon(*this, getCalleeAttr(), indirectCallee, p, sideEffect);
}

static LogicalResult
Expand Down
Loading