Skip to content

Commit c960f18

Browse files
authored
Merge pull request #26668 from JuliaLang/kf/addrspacephi
Handle PHI nodes in address space propagation
2 parents 59bffa5 + 4577c0f commit c960f18

12 files changed

+149
-57
lines changed

src/llvm-propagate-addrspaces.cpp

Lines changed: 120 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ struct PropagateJuliaAddrspaces : public FunctionPass, public InstVisitor<Propag
5151
void visitLoadInst(LoadInst &LI);
5252
void visitMemSetInst(MemSetInst &MI);
5353
void visitMemTransferInst(MemTransferInst &MTI);
54+
55+
private:
56+
void PoisonValues(std::vector<Value *> &Worklist);
5457
};
5558

5659
bool PropagateJuliaAddrspaces::runOnFunction(Function &F) {
@@ -74,67 +77,137 @@ static bool isSpecialAS(unsigned AS) {
7477
return AddressSpace::FirstSpecial <= AS && AS <= AddressSpace::LastSpecial;
7578
}
7679

80+
void PropagateJuliaAddrspaces::PoisonValues(std::vector<Value *> &Worklist) {
81+
while (!Worklist.empty()) {
82+
Value *CurrentV = Worklist.back();
83+
Worklist.pop_back();
84+
for (Value *User : CurrentV->users()) {
85+
if (Visited.count(User))
86+
continue;
87+
Visited.insert(CurrentV);
88+
Worklist.push_back(User);
89+
}
90+
}
91+
}
92+
7793
Value *PropagateJuliaAddrspaces::LiftPointer(Value *V, Type *LocTy, Instruction *InsertPt) {
7894
SmallVector<Value *, 4> Stack;
79-
Value *CurrentV = V;
95+
std::vector<Value *> Worklist;
96+
std::set<Value *> LocalVisited;
97+
Worklist.push_back(V);
8098
// Follow pointer casts back, see if we're based on a pointer in
8199
// an untracked address space, in which case we're allowed to drop
82100
// intermediate addrspace casts.
83-
while (true) {
84-
Stack.push_back(CurrentV);
85-
if (isa<BitCastInst>(CurrentV))
86-
CurrentV = cast<BitCastInst>(CurrentV)->getOperand(0);
87-
else if (isa<AddrSpaceCastInst>(CurrentV)) {
88-
CurrentV = cast<AddrSpaceCastInst>(CurrentV)->getOperand(0);
89-
if (!isSpecialAS(getValueAddrSpace(CurrentV)))
90-
break;
101+
while (!Worklist.empty()) {
102+
Value *CurrentV = Worklist.back();
103+
Worklist.pop_back();
104+
if (LocalVisited.count(CurrentV)) {
105+
continue;
91106
}
92-
else if (isa<GetElementPtrInst>(CurrentV)) {
93-
if (LiftingMap.count(CurrentV)) {
94-
CurrentV = LiftingMap[CurrentV];
107+
while (true) {
108+
if (auto *BCI = dyn_cast<BitCastInst>(CurrentV))
109+
CurrentV = BCI->getOperand(0);
110+
else if (auto *ACI = dyn_cast<AddrSpaceCastInst>(CurrentV)) {
111+
CurrentV = ACI->getOperand(0);
112+
if (!isSpecialAS(getValueAddrSpace(ACI)))
113+
break;
114+
}
115+
else if (auto *GEP = dyn_cast<GetElementPtrInst>(CurrentV)) {
116+
if (LiftingMap.count(GEP)) {
117+
CurrentV = LiftingMap[GEP];
118+
break;
119+
} else if (Visited.count(GEP)) {
120+
return nullptr;
121+
}
122+
Stack.push_back(GEP);
123+
LocalVisited.insert(GEP);
124+
CurrentV = GEP->getOperand(0);
125+
} else if (auto *Phi = dyn_cast<PHINode>(CurrentV)) {
126+
if (LiftingMap.count(Phi)) {
127+
break;
128+
}
129+
for (Value *Incoming : Phi->incoming_values()) {
130+
Worklist.push_back(Incoming);
131+
}
132+
Stack.push_back(Phi);
133+
LocalVisited.insert(Phi);
134+
break;
135+
} else {
136+
// Ok, we've reached a leaf - check if it is eligible for lifting
137+
if (!CurrentV->getType()->isPointerTy() ||
138+
isSpecialAS(getValueAddrSpace(CurrentV))) {
139+
// If not, poison all (recursive) users of this value, to prevent
140+
// looking at them again in future iterations.
141+
Worklist.clear();
142+
Worklist.push_back(CurrentV);
143+
Visited.insert(CurrentV);
144+
PoisonValues(Worklist);
145+
return nullptr;
146+
}
95147
break;
96-
} else if (Visited.count(CurrentV)) {
97-
return nullptr;
98148
}
99-
Visited.insert(CurrentV);
100-
CurrentV = cast<GetElementPtrInst>(CurrentV)->getOperand(0);
101-
} else
102-
break;
149+
}
103150
}
104-
if (!CurrentV->getType()->isPointerTy())
105-
return nullptr;
106-
if (isSpecialAS(getValueAddrSpace(CurrentV)))
107-
return nullptr;
108-
// Ok, we're allowed to change the address space of this load, go back and
109-
// reconstitute any GEPs in the new address space.
110-
for (Value *V : llvm::reverse(Stack)) {
111-
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V);
112-
if (!GEP)
113-
continue;
114-
if (LiftingMap.count(GEP)) {
115-
CurrentV = LiftingMap[GEP];
151+
152+
// Go through and insert lifted versions of all instructions on the list.
153+
std::vector<Value *> ToRevisit;
154+
for (Value *V : Stack) {
155+
if (LiftingMap.count(V))
116156
continue;
157+
if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
158+
auto *NewGEP = cast<GetElementPtrInst>(GEP->clone());
159+
ToInsert.push_back(std::make_pair(NewGEP, GEP));
160+
Type *NewRetTy = cast<PointerType>(GEP->getType())->getElementType()->getPointerTo(0);
161+
NewGEP->mutateType(NewRetTy);
162+
LiftingMap[GEP] = NewGEP;
163+
ToRevisit.push_back(NewGEP);
164+
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
165+
auto *NewPhi = cast<PHINode>(Phi->clone());
166+
ToInsert.push_back(std::make_pair(NewPhi, Phi));
167+
Type *NewRetTy = cast<PointerType>(Phi->getType())->getElementType()->getPointerTo(0);
168+
NewPhi->mutateType(NewRetTy);
169+
LiftingMap[Phi] = NewPhi;
170+
ToRevisit.push_back(NewPhi);
171+
}
172+
}
173+
174+
auto CollapseCastsAndLift = [&](Value *CurrentV, Instruction *InsertPt) {
175+
Type *TargetType = cast<PointerType>(CurrentV->getType())->getElementType()->getPointerTo(0);
176+
while (!LiftingMap.count(CurrentV)) {
177+
if (isa<BitCastInst>(CurrentV))
178+
CurrentV = cast<BitCastInst>(CurrentV)->getOperand(0);
179+
else if (isa<AddrSpaceCastInst>(CurrentV))
180+
CurrentV = cast<AddrSpaceCastInst>(CurrentV)->getOperand(0);
181+
else
182+
break;
117183
}
118-
GetElementPtrInst *NewGEP = cast<GetElementPtrInst>(GEP->clone());
119-
ToInsert.push_back(std::make_pair(NewGEP, GEP));
120-
Type *GEPTy = GEP->getSourceElementType();
121-
Type *NewRetTy = cast<PointerType>(GEP->getType())->getElementType()->getPointerTo(getValueAddrSpace(CurrentV));
122-
NewGEP->mutateType(NewRetTy);
123-
if (cast<PointerType>(CurrentV->getType())->getElementType() != GEPTy) {
124-
auto *BCI = new BitCastInst(CurrentV, GEPTy->getPointerTo());
125-
ToInsert.push_back(std::make_pair(BCI, NewGEP));
184+
if (LiftingMap.count(CurrentV))
185+
CurrentV = LiftingMap[CurrentV];
186+
if (CurrentV->getType() != TargetType) {
187+
auto *BCI = new BitCastInst(CurrentV, TargetType);
188+
ToInsert.push_back(std::make_pair(BCI, InsertPt));
126189
CurrentV = BCI;
127190
}
128-
NewGEP->setOperand(GetElementPtrInst::getPointerOperandIndex(), CurrentV);
129-
LiftingMap[GEP] = NewGEP;
130-
CurrentV = NewGEP;
131-
}
132-
if (LocTy && cast<PointerType>(CurrentV->getType())->getElementType() != LocTy) {
133-
auto *BCI = new BitCastInst(CurrentV, LocTy->getPointerTo());
134-
ToInsert.push_back(std::make_pair(BCI, InsertPt));
135-
CurrentV = BCI;
191+
return CurrentV;
192+
};
193+
194+
// Now go through and update the operands
195+
for (Value *V : ToRevisit) {
196+
if (GetElementPtrInst *NewGEP = dyn_cast<GetElementPtrInst>(V)) {
197+
NewGEP->setOperand(GetElementPtrInst::getPointerOperandIndex(),
198+
CollapseCastsAndLift(NewGEP->getOperand(GetElementPtrInst::getPointerOperandIndex()),
199+
NewGEP));
200+
} else if (PHINode *NewPhi = dyn_cast<PHINode>(V)) {
201+
for (size_t i = 0; i < NewPhi->getNumIncomingValues(); ++i) {
202+
NewPhi->setIncomingValue(i, CollapseCastsAndLift(NewPhi->getIncomingValue(i),
203+
NewPhi->getIncomingBlock(i)->getTerminator()));
204+
}
205+
} else {
206+
assert(false && "Shouldn't have reached here");
207+
}
136208
}
137-
return CurrentV;
209+
210+
return CollapseCastsAndLift(V, cast<Instruction>(V));
138211
}
139212

140213
void PropagateJuliaAddrspaces::visitLoadInst(LoadInst &LI) {

test/llvmpasses/alloc-opt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
# RUN: julia --startup-file=no %s | opt -load libjulia.so -AllocOpt -LateLowerGCFrame -S - | FileCheck %s
3+
# RUN: julia --startup-file=no %s | opt -load libjulia%shlibext -AllocOpt -LateLowerGCFrame -S - | FileCheck %s
44

55
isz = sizeof(UInt) == 8 ? "i64" : "i32"
66

test/llvmpasses/alloc-opt2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
# RUN: julia --startup-file=no %s | opt -load libjulia.so -AllocOpt -S - | FileCheck %s
3+
# RUN: julia --startup-file=no %s | opt -load libjulia%shlibext -AllocOpt -S - | FileCheck %s
44

55
isz = sizeof(UInt) == 8 ? "i64" : "i32"
66

test/llvmpasses/gcroots.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt -load libjulia.so -LateLowerGCFrame -S %s | FileCheck %s
1+
; RUN: opt -load libjulia%shlibext -LateLowerGCFrame -S %s | FileCheck %s
22

33
%jl_value_t = type opaque
44

test/llvmpasses/lit.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ config.name = 'Julia'
1010
config.suffixes = ['.ll','.jl']
1111
config.test_source_root = os.path.dirname(__file__)
1212
config.test_format = lit.formats.ShTest(True)
13+
config.substitutions.append(('%shlibext', '.dylib' if platform.system() == 'Darwin' else '.dll' if
14+
platform.system() == 'Windows' else '.so'))
1315

1416
path = os.path.pathsep.join((os.path.join(os.path.dirname(__file__),"../../usr/tools"), os.path.join(os.path.dirname(__file__),"../../usr/bin"), config.environment['PATH']))
1517
config.environment['PATH'] = path

test/llvmpasses/lower-handlers.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt -load libjulia.so -LowerExcHandlers -S %s | FileCheck %s
1+
; RUN: opt -load libjulia%shlibext -LowerExcHandlers -S %s | FileCheck %s
22

33
attributes #1 = { returns_twice }
44
declare i32 @julia.except_enter() #1

test/llvmpasses/muladd.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt -load libjulia.so -CombineMulAdd -S %s | FileCheck %s
1+
; RUN: opt -load libjulia%shlibext -CombineMulAdd -S %s | FileCheck %s
22

33
define double @fast_muladd1(double %a, double %b, double %c) {
44
top:

test/llvmpasses/propagate-addrspace.ll

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt -load libjulia.so -PropagateJuliaAddrspaces -dce -S %s | FileCheck %s
1+
; RUN: opt -load libjulia%shlibext -PropagateJuliaAddrspaces -dce -S %s | FileCheck %s
22

33
define i64 @simple() {
44
; CHECK-LABEL: @simple
@@ -19,3 +19,20 @@ define i64 @twogeps() {
1919
%loaded = load i64, i64 addrspace(11)* %gep2
2020
ret i64 %loaded
2121
}
22+
23+
define i64 @phi(i1 %cond) {
24+
; CHECK-LABEL: @phi
25+
; CHECK-NOT: addrspace(11)
26+
top:
27+
%stack1 = alloca i64
28+
%stack2 = alloca i64
29+
%stack1_casted = addrspacecast i64 *%stack1 to i64 addrspace(11)*
30+
%stack2_casted = addrspacecast i64 *%stack2 to i64 addrspace(11)*
31+
br i1 %cond, label %A, label %B
32+
A:
33+
br label %B
34+
B:
35+
%phi = phi i64 addrspace(11)* [ %stack1_casted, %top ], [ %stack2_casted, %A ]
36+
%load = load i64, i64 addrspace(11)* %phi
37+
ret i64 %load
38+
}

test/llvmpasses/refinements.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt -load libjulia.so -LateLowerGCFrame -S %s | FileCheck %s
1+
; RUN: opt -load libjulia%shlibext -LateLowerGCFrame -S %s | FileCheck %s
22

33
%jl_value_t = type opaque
44

test/llvmpasses/returnstwicegc.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt -load libjulia.so -LateLowerGCFrame -S %s | FileCheck %s
1+
; RUN: opt -load libjulia%shlibext -LateLowerGCFrame -S %s | FileCheck %s
22

33
%jl_value_t = type opaque
44

test/llvmpasses/safepoint_stress.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
# RUN: julia --startup-file=no %s | opt -load libjulia.so -LateLowerGCFrame -S - | FileCheck %s
3+
# RUN: julia --startup-file=no %s | opt -load libjulia%shlibext -LateLowerGCFrame -S - | FileCheck %s
44

55
println("""
66
%jl_value_t = type opaque

test/llvmpasses/simdloop.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt -load libjulia.so -LowerSIMDLoop -S %s | FileCheck %s
1+
; RUN: opt -load libjulia%shlibext -LowerSIMDLoop -S %s | FileCheck %s
22

33
define void @simd_test(double *%a, double *%b) {
44
top:

0 commit comments

Comments
 (0)