-
Notifications
You must be signed in to change notification settings - Fork 13.7k
Add Dead Block Elimination to NVVMReflect #144171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Yonah Goldberg (YonahGoldberg) ChangesCurrently, NVVMReflect replaces calls to __nvvm_reflect with a constant, and then constant propagates/folds the result, but doesn't handle dead block elimination. The most common use case of reflect calls is to query the arch number and select valid code depending on the arch. Therefore, the blocks that become dead after reflect replacement need to be deleted as a matter of correctness. The way this gets cleaned up now in llc is with UnreachableBlockElim followed by CodegenPrepare, which I've observed work together to delete the dead blocks. It's better to just have this pass handle deleting the dead blocks right away. This PR introduces some additional code to handle the dead block deletion. I think what I've written is actually pretty general, it's kind've like a lightweight version of SCCP. I wonder if I missed somewhere where this is already implemented in LLVM so I don't duplicate code. If I didn't, would it ever be useful to put this somewhere more general where others can use it instead of in NVVMReflect? Note that I also removed running simplifycfg in two test cases, which shows that this pass is now able to handle the dead block elimination without simplifycfg. Full diff: https://github.com/llvm/llvm-project/pull/144171.diff 3 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 208bab52284a3..2585ff45bde4c 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -19,6 +19,7 @@
//===----------------------------------------------------------------------===//
#include "NVPTX.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ConstantFolding.h"
@@ -59,7 +60,10 @@ class NVVMReflect {
StringMap<unsigned> ReflectMap;
bool handleReflectFunction(Module &M, StringRef ReflectName);
void populateReflectMap(Module &M);
- void foldReflectCall(CallInst *Call, Constant *NewValue);
+ void replaceReflectCalls(
+ SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
+ const DataLayout &DL);
+ SetVector<BasicBlock *> findTransitivelyDeadBlocks(BasicBlock *DeadBB);
public:
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
@@ -138,6 +142,8 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
assert(F->getReturnType()->isIntegerTy() &&
"_reflect's return type should be integer");
+ SmallVector<std::pair<CallInst *, Constant *>, 8> ReflectReplacements;
+
const bool Changed = !F->use_empty();
for (User *U : make_early_inc_range(F->users())) {
// Reflect function calls look like:
@@ -178,38 +184,111 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
<< "(" << ReflectArg << ") with value " << ReflectVal
<< "\n");
auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
- foldReflectCall(Call, NewValue);
- Call->eraseFromParent();
+ ReflectReplacements.push_back({Call, NewValue});
}
- // Remove the __nvvm_reflect function from the module
+ replaceReflectCalls(ReflectReplacements, M.getDataLayout());
F->eraseFromParent();
return Changed;
}
-void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
+/// Find all blocks that become dead transitively from an initial dead block.
+/// Returns the complete set including the original dead block and any blocks
+/// that lose all their predecessors due to the deletion cascade.
+SetVector<BasicBlock *>
+NVVMReflect::findTransitivelyDeadBlocks(BasicBlock *DeadBB) {
+ SmallVector<BasicBlock *, 8> Worklist({DeadBB});
+ SetVector<BasicBlock *> DeadBlocks;
+ while (!Worklist.empty()) {
+ auto *BB = Worklist.pop_back_val();
+ DeadBlocks.insert(BB);
+
+ for (BasicBlock *Succ : successors(BB))
+ if (pred_size(Succ) == 1 && DeadBlocks.insert(Succ))
+ Worklist.push_back(Succ);
+ }
+ return DeadBlocks;
+}
+
+/// Replace calls to __nvvm_reflect with corresponding constant values. Then
+/// clean up through constant folding and propagation and dead block
+/// elimination.
+///
+/// The purpose of this cleanup is not optimization because that could be
+/// handled by later passes
+/// (i.e. SCCP, SimplifyCFG, etc.), but for correctness. Reflect calls are most
+/// commonly used to query the arch number and select a valid instruction for
+/// the arch. Therefore, you need to eliminate blocks that become dead because
+/// they may contain invalid instructions for the arch. The purpose of the
+/// cleanup is to do the minimal amount of work to leave the code in a valid
+/// state.
+void NVVMReflect::replaceReflectCalls(
+ SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
+ const DataLayout &DL) {
SmallVector<Instruction *, 8> Worklist;
- // Replace an instruction with a constant and add all users of the instruction
- // to the worklist
+ SetVector<BasicBlock *> DeadBlocks;
+
+ // Replace an instruction with a constant and add all users to the worklist,
+ // then delete the instruction
auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
for (auto *U : I->users())
if (auto *UI = dyn_cast<Instruction>(U))
Worklist.push_back(UI);
I->replaceAllUsesWith(C);
+ if (isInstructionTriviallyDead(I))
+ I->eraseFromParent();
};
- ReplaceInstructionWithConst(Call, NewValue);
+ for (auto &[Call, NewValue] : ReflectReplacements)
+ ReplaceInstructionWithConst(Call, NewValue);
- auto &DL = Call->getModule()->getDataLayout();
- while (!Worklist.empty()) {
- auto *I = Worklist.pop_back_val();
- if (auto *C = ConstantFoldInstruction(I, DL)) {
- ReplaceInstructionWithConst(I, C);
- if (isInstructionTriviallyDead(I))
- I->eraseFromParent();
- } else if (I->isTerminator()) {
- ConstantFoldTerminator(I->getParent());
+ // Alternate between constant folding/propagation and dead block elimination.
+ // Terminator folding may create new dead blocks. When those dead blocks are
+ // deleted, their live successors may have PHIs that can be simplified, which
+ // may yield more work for folding/propagation.
+ while (true) {
+ // Iterate folding and propagating constants until the worklist is empty.
+ while (!Worklist.empty()) {
+ auto *I = Worklist.pop_back_val();
+ if (auto *C = ConstantFoldInstruction(I, DL)) {
+ ReplaceInstructionWithConst(I, C);
+ } else if (I->isTerminator()) {
+ BasicBlock *BB = I->getParent();
+ SmallVector<BasicBlock *, 8> Succs(successors(BB));
+ // Some blocks may become dead if the terminator is folded because
+ // a conditional branch is turned into a direct branch.
+ if (ConstantFoldTerminator(BB)) {
+ for (BasicBlock *Succ : Succs) {
+ if (pred_empty(Succ) &&
+ Succ != &Succ->getParent()->getEntryBlock()) {
+ SetVector<BasicBlock *> TransitivelyDead =
+ findTransitivelyDeadBlocks(Succ);
+ DeadBlocks.insert(TransitivelyDead.begin(),
+ TransitivelyDead.end());
+ }
+ }
+ }
+ }
}
+ // No more constants to fold and no more dead blocks
+ // to create more work. We're done.
+ if (DeadBlocks.empty())
+ break;
+ // PHI nodes of live successors of dead blocks get eliminated when the dead
+ // blocks are eliminated. Their users can now be simplified further, so add
+ // them to the worklist.
+ for (BasicBlock *DeadBB : DeadBlocks)
+ for (BasicBlock *Succ : successors(DeadBB))
+ if (!DeadBlocks.contains(Succ))
+ for (PHINode &PHI : Succ->phis())
+ for (auto *U : PHI.users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ Worklist.push_back(UI);
+ // Delete all dead blocks
+ for (BasicBlock *DeadBB : DeadBlocks)
+ DeleteDeadBlock(DeadBB);
+
+ DeadBlocks.clear();
}
}
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
index 19c74df303702..7bb1af707001a 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
@@ -3,12 +3,12 @@
; RUN: cat %s > %t.noftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
-; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK
; RUN: cat %s > %t.ftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
-; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK
@str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
index 244b44fea9b83..581dbf353c1ff 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
@@ -3,12 +3,12 @@
; RUN: cat %s > %t.noftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
-; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK
; RUN: cat %s > %t.ftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
-; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK
@str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"
|
I'm not quite sure what is the problem the patch is intended to solve.
How is that better than the optimization passes we already have to do exactly that job? NVVMReflect is normally added very early in the pipeline, and subsequent passes do a pretty good job eliminating the dead code after the call is replaced with a constant. llvm-project/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp Lines 245 to 250 in d7e64d9
|
Currently, NVVMReflect replaces calls to __nvvm_reflect with a constant, and then constant propagates/folds the result, but doesn't handle dead block elimination.
The most common use case of reflect calls is to query the arch number and select valid code depending on the arch. Therefore, the blocks that become dead after reflect replacement need to be deleted as a matter of correctness.
The way this gets cleaned up now in llc is with UnreachableBlockElim followed by CodegenPrepare, which I've observed work together to delete the dead blocks. It's better to just have this pass handle deleting the dead blocks right away.
This PR introduces some additional code to handle the dead block deletion. I think what I've written is actually pretty general, it's kind've like a lightweight version of SCCP. I wonder if I missed somewhere where this is already implemented in LLVM so I don't duplicate code. If I didn't, would it ever be useful to put this somewhere more general where others can use it instead of in NVVMReflect?
Note that I also removed running simplifycfg in two test cases, which shows that this pass is now able to handle the dead block elimination without simplifycfg.