diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 125c9fe99f73..aeff5ea1c47a 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -524,9 +524,9 @@ uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout, assert(idx < getNumElements()); auto members = getMembers(); - unsigned offset = 0; + unsigned offset = 0, recordSize = 0; - for (unsigned i = 0, e = idx; i != e; ++i) { + for (unsigned i = 0, e = idx; i != e + 1; ++i) { auto ty = members[i]; // This matches LLVM since it uses the ABI instead of preferred alignment. @@ -534,10 +534,13 @@ uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout, llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty)); // Add padding if necessary to align the data element properly. - offset = llvm::alignTo(offset, tyAlign); + recordSize = llvm::alignTo(recordSize, tyAlign); + + if (i == idx) + offset = recordSize; // Consume space for this data item - offset += dataLayout.getTypeSize(ty); + recordSize += dataLayout.getTypeSize(ty); } return offset; diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 326152783980..f1a100ab85f6 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "LowerToMLIRHelpers.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -32,8 +33,10 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -163,17 +166,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern { matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type mlirType = - convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType()); + mlir::Type allocaType = adaptor.getAllocaType(); + mlir::Type mlirType = convertTypeForMemory(*getTypeConverter(), allocaType); // FIXME: Some types can not be converted yet (e.g. struct) if (!mlirType) return mlir::LogicalResult::failure(); auto memreftype = mlir::dyn_cast(mlirType); - if (memreftype && mlir::isa(adaptor.getAllocaType())) { - // if the type is an array, - // we don't need to wrap with memref. + if (memreftype && (mlir::isa(allocaType) || + mlir::isa(allocaType))) { + // Arrays and structs are already memref. No need to wrap another one. } else { memreftype = mlir::MemRefType::get({}, mlirType); } @@ -1240,6 +1243,36 @@ class CIRPtrStrideOpLowering } }; +class CIRGetMemberOpLowering + : public mlir::OpConversionPattern { +public: + CIRGetMemberOpLowering(mlir::TypeConverter &converter, mlir::MLIRContext *ctx, + const mlir::DataLayout &layout) + : OpConversionPattern(converter, ctx), layout(layout) {} + + mlir::LogicalResult + matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto baseAddr = op.getAddr(); + auto structType = + mlir::cast(baseAddr.getType().getPointee()); + uint64_t byteOffset = structType.getElementOffset(layout, op.getIndex()); + + auto fieldType = op.getResult().getType(); + auto resultType = mlir::cast( + getTypeConverter()->convertType(fieldType)); + + mlir::Value offsetValue = + rewriter.create(op.getLoc(), byteOffset); + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getAddr(), offsetValue, mlir::ValueRange{}); + return mlir::success(); + } + +private: + const mlir::DataLayout &layout; +}; + class CIRUnreachableOpLowering : public mlir::OpConversionPattern { public: @@ -1271,7 +1304,8 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern { }; void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, - mlir::TypeConverter &converter) { + mlir::TypeConverter &converter, + mlir::DataLayout layout) { patterns.add(patterns.getContext()); patterns @@ -1292,16 +1326,20 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext()); + + patterns.add(converter, patterns.getContext(), + layout); } -static mlir::TypeConverter prepareTypeConverter() { +static mlir::TypeConverter prepareTypeConverter(mlir::DataLayout layout) { mlir::TypeConverter converter; converter.addConversion([&](cir::PointerType type) -> mlir::Type { - auto ty = convertTypeForMemory(converter, type.getPointee()); + auto pointee = type.getPointee(); + auto ty = convertTypeForMemory(converter, pointee); // FIXME: The pointee type might not be converted (e.g. struct) if (!ty) return nullptr; - if (isa(type.getPointee())) + if (isa(pointee) || isa(pointee)) return ty; return mlir::MemRefType::get({}, ty); }); @@ -1353,6 +1391,13 @@ static mlir::TypeConverter prepareTypeConverter() { return nullptr; return mlir::MemRefType::get(shape, elementType); }); + converter.addConversion([&](cir::RecordType type) -> mlir::Type { + // Reinterpret structs as raw bytes. Don't use tuples as they can't be put + // in memref. + auto size = type.getTypeSize(layout, {}); + auto i8 = mlir::IntegerType::get(type.getContext(), /*width=*/8); + return mlir::MemRefType::get(size.getFixedValue(), i8); + }); converter.addConversion([&](cir::VectorType type) -> mlir::Type { auto ty = converter.convertType(type.getEltType()); return mlir::VectorType::get(type.getSize(), ty); @@ -1363,13 +1408,15 @@ static mlir::TypeConverter prepareTypeConverter() { void ConvertCIRToMLIRPass::runOnOperation() { auto module = getOperation(); + mlir::DataLayoutAnalysis layoutAnalysis(module); + const mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove(module); - auto converter = prepareTypeConverter(); + auto converter = prepareTypeConverter(layout); mlir::RewritePatternSet patterns(&getContext()); populateCIRLoopToSCFConversionPatterns(patterns, converter); - populateCIRToMLIRConversionPatterns(patterns, converter); + populateCIRToMLIRConversionPatterns(patterns, converter, layout); mlir::ConversionTarget target(getContext()); target.addLegalOp(); diff --git a/clang/test/CIR/Lowering/ThroughMLIR/struct.cir b/clang/test/CIR/Lowering/ThroughMLIR/struct.cir new file mode 100644 index 000000000000..dd599753cfc6 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/struct.cir @@ -0,0 +1,25 @@ +// RUN: cir-opt %s -cir-to-mlir -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +!s32i = !cir.int +!u8i = !cir.int +!u32i = !cir.int +!ty_S = !cir.record + +module { + cir.func @test() { + %1 = cir.alloca !ty_S, !cir.ptr, ["x"] {alignment = 4 : i64} + %3 = cir.get_member %1[0] {name = "c"} : !cir.ptr -> !cir.ptr + %5 = cir.get_member %1[1] {name = "i"} : !cir.ptr -> !cir.ptr + cir.return + } + + // CHECK: func.func @test() { + // CHECK: %[[alloca:[a-z0-9]+]] = memref.alloca() {alignment = 4 : i64} : memref<8xi8> + // CHECK: %[[zero:[a-z0-9]+]] = arith.constant 0 : index + // CHECK: memref.view %[[alloca]][%[[zero]]][] : memref<8xi8> to memref + // CHECK: %[[four:[a-z0-9]+]] = arith.constant 4 : index + // CHECK: %view_0 = memref.view %[[alloca]][%[[four]]][] : memref<8xi8> to memref + // CHECK: return + // CHECK: } +}