Skip to content

Commit 4888f55

Browse files
committed
[CIR][CUDA] Emit address space casts on lowering
1 parent 4d2fb6c commit 4888f55

File tree

7 files changed

+96
-26
lines changed

7 files changed

+96
-26
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+17-1
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,25 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
396396

397397
mlir::Value createGetGlobal(mlir::Location loc, cir::GlobalOp global,
398398
bool threadLocal = false) {
399-
return create<cir::GetGlobalOp>(
399+
auto getGlobal = create<cir::GetGlobalOp>(
400400
loc, getPointerTo(global.getSymType(), global.getAddrSpaceAttr()),
401401
global.getName(), threadLocal);
402+
403+
// When AST attribute is not present, the global is a temporary,
404+
// and actual & expected address spaces won't mismatch for temporaries.
405+
if (global.getAst()) {
406+
cir::ASTVarDeclInterface varDecl = global.getAstAttr();
407+
auto gpuAS = getAddrSpaceAttr(varDecl.getExpectedAS());
408+
if (gpuAS != global.getAddrSpaceAttr()) {
409+
auto oldTy = mlir::cast<cir::PointerType>(getGlobal.getType());
410+
auto newTy =
411+
cir::PointerType::get(oldTy.getPointee(), /*addrspace=*/gpuAS);
412+
auto cast = createAddrSpaceCast(loc, getGlobal, newTy);
413+
return cast;
414+
}
415+
}
416+
417+
return getGlobal;
402418
}
403419

404420
mlir::Value createGetGlobal(cir::GlobalOp global, bool threadLocal = false) {

clang/include/clang/CIR/Interfaces/ASTAttrInterfaces.td

+5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ let cppNamespace = "::cir" in {
6868
/*defaultImplementation=*/ [{
6969
return $_attr.getAst()->getTLSKind();
7070
}]
71+
>,
72+
InterfaceMethod<"", "clang::LangAS", "getExpectedAS", (ins), [{}],
73+
/*defaultImplementation=*/[{
74+
return $_attr.getAst()->getType().getAddressSpace();
75+
}]
7176
>
7277
];
7378
}

clang/lib/CIR/CodeGen/CIRGenDecl.cpp

+37-14
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
470470
Name = getStaticDeclName(*this, D);
471471

472472
mlir::Type LTy = getTypes().convertTypeForMem(Ty);
473-
cir::AddressSpaceAttr AS =
473+
cir::AddressSpaceAttr actualAS =
474474
builder.getAddrSpaceAttr(getGlobalVarAddressSpace(&D));
475475

476476
// OpenCL variables in local address space and CUDA shared
@@ -482,8 +482,9 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
482482
!D.hasAttr<CUDASharedAttr>())
483483
Init = builder.getZeroInitAttr(convertType(Ty));
484484

485-
cir::GlobalOp GV = builder.createVersionedGlobal(
486-
getModule(), getLoc(D.getLocation()), Name, LTy, false, Linkage, AS);
485+
cir::GlobalOp GV =
486+
builder.createVersionedGlobal(getModule(), getLoc(D.getLocation()), Name,
487+
LTy, false, Linkage, actualAS);
487488
// TODO(cir): infer visibility from linkage in global op builder.
488489
GV.setVisibility(getMLIRVisibilityFromCIRLinkage(Linkage));
489490
GV.setInitialValueAttr(Init);
@@ -497,14 +498,15 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
497498

498499
setGVProperties(GV, &D);
499500

500-
// OG checks if the expected address space, denoted by the type, is the
501-
// same as the actual address space indicated by attributes. If they aren't
502-
// the same, an addrspacecast is emitted when this variable is accessed.
503-
// In CIR however, cir.get_global alreadys carries that information in
504-
// !cir.ptr type - if this global is in OpenCL local address space, then its
505-
// type would be !cir.ptr<..., addrspace(offload_local)>. Therefore we don't
506-
// need an explicit address space cast in CIR: they will get emitted when
507-
// lowering to LLVM IR.
501+
// OG checks whether the expected address space (AS), denoted by
502+
// __attributes__((addrspace(n))), is the same as the actual AS indicated by
503+
// other attributes (such as __device__ in CUDA). If they aren't the same, an
504+
// addrspacecast is emitted when this variable is accessed, which means we
505+
// need it in this function. In CIR however, since we access globals by
506+
// `cir.get_global`, we won't emit a cast for GlobalOp here. Instead, we
507+
// record the AST, and create a CastOp in
508+
// `CIRGenBaseBuilder::createGetGlobal`.
509+
GV.setAstAttr(cir::ASTVarDeclAttr::get(&getMLIRContext(), &D));
508510

509511
// Ensure that the static local gets initialized by making sure the parent
510512
// function gets emitted eventually.
@@ -617,7 +619,10 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
617619
// TODO(cir): we should have a way to represent global ops as values without
618620
// having to emit a get global op. Sometimes these emissions are not used.
619621
auto addr = getBuilder().createGetGlobal(globalOp);
620-
auto getAddrOp = mlir::cast<cir::GetGlobalOp>(addr.getDefiningOp());
622+
auto definingOp = addr.getDefiningOp();
623+
bool hasCast = isa<cir::CastOp>(definingOp);
624+
auto getAddrOp = mlir::cast<cir::GetGlobalOp>(
625+
hasCast ? definingOp->getOperand(0).getDefiningOp() : definingOp);
621626

622627
CharUnits alignment = getContext().getDeclAlign(&D);
623628

@@ -633,7 +638,7 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
633638
llvm_unreachable("VLAs are NYI");
634639

635640
// Save the type in case adding the initializer forces a type change.
636-
auto expectedType = addr.getType();
641+
auto expectedType = cast<cir::PointerType>(addr.getType());
637642

638643
auto var = globalOp;
639644

@@ -678,7 +683,25 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
678683
//
679684
// FIXME: It is really dangerous to store this in the map; if anyone
680685
// RAUW's the GV uses of this constant will be invalid.
681-
auto castedAddr = builder.createBitcast(getAddrOp.getAddr(), expectedType);
686+
mlir::Value castedAddr;
687+
if (!hasCast)
688+
castedAddr = builder.createBitcast(getAddrOp.getAddr(), expectedType);
689+
else {
690+
// If there is an extra CastOp from createGetGlobal, we need to remove the
691+
// existing addrspacecast, then supply a bitcast and a new addrspacecast:
692+
// %1 = cir.get_global @addr
693+
// %2 = cir.cast(addrspacecast, %1) <--- remove
694+
// %2 = cir.cast(bitcast, %1) <--- insert
695+
// %3 = cir.cast(addrspacecast, %2) <--- insert
696+
definingOp->erase();
697+
698+
auto expectedTypeWithAS = cir::PointerType::get(
699+
expectedType.getPointee(), getAddrOp.getType().getAddrSpace());
700+
auto converted =
701+
builder.createBitcast(getAddrOp.getAddr(), expectedTypeWithAS);
702+
castedAddr = builder.createAddrSpaceCast(converted, expectedType);
703+
}
704+
682705
LocalDeclMap.find(&D)->second = Address(castedAddr, elemTy, alignment);
683706
CGM.setStaticLocalDeclAddress(&D, var);
684707

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -1267,9 +1267,15 @@ mlir::Value CIRGenModule::getAddrOfGlobalVar(const VarDecl *d, mlir::Type ty,
12671267

12681268
bool tlsAccess = d->getTLSKind() != VarDecl::TLS_None;
12691269
auto g = getOrCreateCIRGlobal(d, ty, isForDefinition);
1270-
auto ptrTy = builder.getPointerTo(g.getSymType(), g.getAddrSpaceAttr());
1271-
return builder.create<cir::GetGlobalOp>(getLoc(d->getSourceRange()), ptrTy,
1272-
g.getSymName(), tlsAccess);
1270+
mlir::Value globalAddr =
1271+
builder.createGetGlobal(getLoc(d->getSourceRange()), g);
1272+
auto definingOp = globalAddr.getDefiningOp();
1273+
auto getGlobalOp = mlir::cast<cir::GetGlobalOp>(
1274+
isa<cir::GetGlobalOp>(definingOp)
1275+
? definingOp
1276+
: definingOp->getOperand(0).getDefiningOp());
1277+
getGlobalOp.setTls(tlsAccess);
1278+
return globalAddr;
12731279
}
12741280

12751281
cir::GlobalViewAttr

clang/test/CIR/CodeGen/CUDA/address-spaces.cu

+25-5
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,35 @@
55
// RUN: %s -o %t.cir
66
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
77

8+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
9+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.ll
11+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
12+
13+
__device__ int k;
14+
815
__global__ void fn() {
916
int i = 0;
1017
__shared__ int j;
1118
j = i;
19+
k = i;
1220
}
1321

1422
// CIR: cir.global "private" internal dsolocal addrspace(offload_local) @_ZZ2fnvE1j : !s32i
15-
// CIR: cir.func @_Z2fnv
16-
// CIR: [[Local:%[0-9]+]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["i", init]
17-
// CIR: [[Shared:%[0-9]+]] = cir.get_global @_ZZ2fnvE1j : !cir.ptr<!s32i, addrspace(offload_local)>
18-
// CIR: [[Tmp:%[0-9]+]] = cir.load [[Local]] : !cir.ptr<!s32i>, !s32i
19-
// CIR: cir.store [[Tmp]], [[Shared]] : !s32i, !cir.ptr<!s32i, addrspace(offload_local)>
23+
// CIR: cir.func @_Z2fnv() {{.*}} {
24+
// CIR: %[[#Local:]] = cir.alloca !s32i, !cir.ptr<!s32i>
25+
// CIR: %[[#Shared:]] = cir.get_global @_ZZ2fnvE1j : !cir.ptr<!s32i, addrspace(offload_local)>
26+
// CIR: %[[#Converted:]] = cir.cast(address_space, %[[#Shared]] :
27+
// CIR-SAME: !cir.ptr<!s32i, addrspace(offload_local)>), !cir.ptr<!s32i>
28+
// CIR: %[[#Tmp:]] = cir.load %[[#Local]] : !cir.ptr<!s32i>, !s32i
29+
// CIR: cir.store %[[#Tmp]], %[[#Converted]] : !s32i, !cir.ptr<!s32i>
30+
// CIR: }
31+
32+
// LLVM: @_ZZ2fnvE1j = internal addrspace(3) global i32 undef
33+
// LLVM: define dso_local ptx_kernel void @_Z2fnv() #{{.*}} {
34+
// LLVM: %[[#T1:]] = alloca i32, i64 1
35+
// LLVM: store i32 0, ptr %[[#T1]]
36+
// LLVM: %[[#T2:]] = load i32, ptr %[[#T1]]
37+
// LLVM: store i32 %[[#T2]], ptr addrspacecast (ptr addrspace(3) @_ZZ2fnvE1j to ptr)
38+
// LLVM: ret void
39+
// LLVM: }

clang/test/CIR/CodeGen/OpenCL/static-vardecl.cl

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
kernel void test_static(int i) {
77
static global int b = 15;
8-
// CIR-DAG: cir.global "private" internal dsolocal addrspace(offload_global) @test_static.b = #cir.int<15> : !s32i {alignment = 4 : i64}
8+
// CIR-DAG: cir.global "private" internal dsolocal addrspace(offload_global) @test_static.b = #cir.int<15> : !s32i {alignment = 4 : i64, ast = {{.*}}}
99
// LLVM-DAG: @test_static.b = internal addrspace(1) global i32 15
1010

1111
local int c;
12-
// CIR-DAG: cir.global "private" internal dsolocal addrspace(offload_local) @test_static.c : !s32i {alignment = 4 : i64}
12+
// CIR-DAG: cir.global "private" internal dsolocal addrspace(offload_local) @test_static.c : !s32i {alignment = 4 : i64, ast = {{.*}}}
1313
// LLVM-DAG: @test_static.c = internal addrspace(3) global i32 undef
1414

1515
// CIR-DAG: %[[#ADDRB:]] = cir.get_global @test_static.b : !cir.ptr<!s32i, addrspace(offload_global)>

clang/test/CIR/CodeGen/const-array.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ void bar() {
44
const int arr[1] = {1};
55
}
66

7-
// CHECK: cir.global "private" constant internal dsolocal @bar.arr = #cir.const_array<[#cir.int<1> : !s32i]> : !cir.array<!s32i x 1> {alignment = 4 : i64}
7+
// CHECK: cir.global "private" constant internal dsolocal @bar.arr = #cir.const_array<[#cir.int<1> : !s32i]> : !cir.array<!s32i x 1> {alignment = 4 : i64, ast = {{.*}}}
88
// CHECK: cir.func no_proto @bar()
99
// CHECK: {{.*}} = cir.get_global @bar.arr : !cir.ptr<!cir.array<!s32i x 1>>
1010

0 commit comments

Comments
 (0)