@@ -470,7 +470,7 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
470
470
Name = getStaticDeclName (*this , D);
471
471
472
472
mlir::Type LTy = getTypes ().convertTypeForMem (Ty);
473
- cir::AddressSpaceAttr AS =
473
+ cir::AddressSpaceAttr actualAS =
474
474
builder.getAddrSpaceAttr (getGlobalVarAddressSpace (&D));
475
475
476
476
// OpenCL variables in local address space and CUDA shared
@@ -482,8 +482,9 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
482
482
!D.hasAttr <CUDASharedAttr>())
483
483
Init = builder.getZeroInitAttr (convertType (Ty));
484
484
485
- cir::GlobalOp GV = builder.createVersionedGlobal (
486
- getModule (), getLoc (D.getLocation ()), Name, LTy, false , Linkage, AS);
485
+ cir::GlobalOp GV =
486
+ builder.createVersionedGlobal (getModule (), getLoc (D.getLocation ()), Name,
487
+ LTy, false , Linkage, actualAS);
487
488
// TODO(cir): infer visibility from linkage in global op builder.
488
489
GV.setVisibility (getMLIRVisibilityFromCIRLinkage (Linkage));
489
490
GV.setInitialValueAttr (Init);
@@ -497,14 +498,15 @@ CIRGenModule::getOrCreateStaticVarDecl(const VarDecl &D,
497
498
498
499
setGVProperties (GV, &D);
499
500
500
- // OG checks if the expected address space, denoted by the type, is the
501
- // same as the actual address space indicated by attributes. If they aren't
502
- // the same, an addrspacecast is emitted when this variable is accessed.
503
- // In CIR however, cir.get_global alreadys carries that information in
504
- // !cir.ptr type - if this global is in OpenCL local address space, then its
505
- // type would be !cir.ptr<..., addrspace(offload_local)>. Therefore we don't
506
- // need an explicit address space cast in CIR: they will get emitted when
507
- // lowering to LLVM IR.
501
+ // OG checks whether the expected address space (AS), denoted by
502
+ // __attributes__((addrspace(n))), is the same as the actual AS indicated by
503
+ // other attributes (such as __device__ in CUDA). If they aren't the same, an
504
+ // addrspacecast is emitted when this variable is accessed, which means we
505
+ // need it in this function. In CIR however, since we access globals by
506
+ // `cir.get_global`, we won't emit a cast for GlobalOp here. Instead, we
507
+ // record the AST, and create a CastOp in
508
+ // `CIRGenBaseBuilder::createGetGlobal`.
509
+ GV.setAstAttr (cir::ASTVarDeclAttr::get (&getMLIRContext (), &D));
508
510
509
511
// Ensure that the static local gets initialized by making sure the parent
510
512
// function gets emitted eventually.
@@ -617,7 +619,10 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
617
619
// TODO(cir): we should have a way to represent global ops as values without
618
620
// having to emit a get global op. Sometimes these emissions are not used.
619
621
auto addr = getBuilder ().createGetGlobal (globalOp);
620
- auto getAddrOp = mlir::cast<cir::GetGlobalOp>(addr.getDefiningOp ());
622
+ auto definingOp = addr.getDefiningOp ();
623
+ bool hasCast = isa<cir::CastOp>(definingOp);
624
+ auto getAddrOp = mlir::cast<cir::GetGlobalOp>(
625
+ hasCast ? definingOp->getOperand (0 ).getDefiningOp () : definingOp);
621
626
622
627
CharUnits alignment = getContext ().getDeclAlign (&D);
623
628
@@ -633,7 +638,7 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
633
638
llvm_unreachable (" VLAs are NYI" );
634
639
635
640
// Save the type in case adding the initializer forces a type change.
636
- auto expectedType = addr.getType ();
641
+ auto expectedType = cast<cir::PointerType>( addr.getType () );
637
642
638
643
auto var = globalOp;
639
644
@@ -678,7 +683,25 @@ void CIRGenFunction::emitStaticVarDecl(const VarDecl &D,
678
683
//
679
684
// FIXME: It is really dangerous to store this in the map; if anyone
680
685
// RAUW's the GV uses of this constant will be invalid.
681
- auto castedAddr = builder.createBitcast (getAddrOp.getAddr (), expectedType);
686
+ mlir::Value castedAddr;
687
+ if (!hasCast)
688
+ castedAddr = builder.createBitcast (getAddrOp.getAddr (), expectedType);
689
+ else {
690
+ // If there is an extra CastOp from createGetGlobal, we need to remove the
691
+ // existing addrspacecast, then supply a bitcast and a new addrspacecast:
692
+ // %1 = cir.get_global @addr
693
+ // %2 = cir.cast(addrspacecast, %1) <--- remove
694
+ // %2 = cir.cast(bitcast, %1) <--- insert
695
+ // %3 = cir.cast(addrspacecast, %2) <--- insert
696
+ definingOp->erase ();
697
+
698
+ auto expectedTypeWithAS = cir::PointerType::get (
699
+ expectedType.getPointee (), getAddrOp.getType ().getAddrSpace ());
700
+ auto converted =
701
+ builder.createBitcast (getAddrOp.getAddr (), expectedTypeWithAS);
702
+ castedAddr = builder.createAddrSpaceCast (converted, expectedType);
703
+ }
704
+
682
705
LocalDeclMap.find (&D)->second = Address (castedAddr, elemTy, alignment);
683
706
CGM.setStaticLocalDeclAddress (&D, var);
684
707
0 commit comments