Skip to content

Commit 316593c

Browse files
ergawyantiagainst
authored andcommitted
[MLIR][SPIRV] Start module combiner
This commit adds a new library that merges/combines a number of spv modules into a combined one. The library has a single entry point: combine(...). To combine a number of MLIR spv modules, we move all the module-level ops from all the input modules into one big combined module. To that end, the combination process can proceed in 2 phases: (1) resolving conflicts between pairs of ops from different modules (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) This patch implements only the first phase. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D90022
1 parent cea69fa commit 316593c

File tree

10 files changed

+1043
-0
lines changed

10 files changed

+1043
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the entry point to the SPIR-V module combiner library.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
14+
#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
15+
16+
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
17+
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/SmallVector.h"
19+
20+
namespace mlir {
21+
class OpBuilder;
22+
23+
namespace spirv {
24+
class ModuleOp;
25+
26+
/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
27+
/// from all the input modules into one big combined module. To that end, the
28+
/// combination process proceeds in 2 phases:
29+
///
30+
/// (1) resolve conflicts between pairs of ops from different modules
31+
/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
32+
///
33+
/// For the conflict resolution phase, the following rules are employed to
34+
/// resolve such conflicts:
35+
///
36+
/// - If 2 spv.func's have the same symbol name, then rename one of the
37+
/// functions.
38+
/// - If an spv.func and another op have the same symbol name, then rename the
39+
/// other symbol.
40+
/// - If none of the 2 conflicting ops are spv.func, then rename either.
41+
///
42+
/// In all cases, the references to the updated symbol are also updated to
43+
/// reflect the change.
44+
///
45+
/// \param modules the list of modules to combine. Input modules are not
46+
/// modified.
47+
/// \param combinedMdouleBuilder an OpBuilder to be used for
48+
/// building up the combined module.
49+
/// \param symbRenameListener a listener that gets called everytime a symbol in
50+
/// one of the input modules is renamed. The arguments
51+
/// passed to the listener are: the input
52+
/// spirv::ModuleOp that contains the renamed symbol,
53+
/// a StringRef to the old symbol name, and a
54+
/// StringRef to the new symbol name. Note that it is
55+
/// the responsibility of the caller to properly
56+
/// retain the storage underlying the passed
57+
/// StringRefs if the listener callback outlives this
58+
/// function call.
59+
///
60+
/// \return the combined module.
61+
OwningSPIRVModuleRef
62+
combine(llvm::MutableArrayRef<ModuleOp> modules,
63+
OpBuilder &combinedModuleBuilder,
64+
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
65+
symbRenameListener);
66+
} // namespace spirv
67+
} // namespace mlir
68+
69+
#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_

mlir/lib/Dialect/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,6 @@ add_mlir_dialect_library(MLIRSPIRV
3434
MLIRTransforms
3535
)
3636

37+
add_subdirectory(Linking)
3738
add_subdirectory(Serialization)
3839
add_subdirectory(Transforms)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(ModuleCombiner)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_mlir_dialect_library(MLIRSPIRVModuleCombiner
2+
ModuleCombiner.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
6+
)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the the SPIR-V module combiner library.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
14+
15+
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
16+
#include "mlir/IR/Builders.h"
17+
#include "mlir/IR/SymbolTable.h"
18+
#include "llvm/ADT/ArrayRef.h"
19+
#include "llvm/ADT/StringExtras.h"
20+
21+
using namespace mlir;
22+
23+
static constexpr unsigned maxFreeID = 1 << 20;
24+
25+
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
26+
spirv::ModuleOp combinedModule) {
27+
SmallString<64> newSymName(oldSymName);
28+
newSymName.push_back('_');
29+
30+
while (lastUsedID < maxFreeID) {
31+
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
32+
33+
if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
34+
newSymName += llvm::utostr(lastUsedID);
35+
break;
36+
}
37+
}
38+
39+
return newSymName;
40+
}
41+
42+
/// Check if a symbol with the same name as op already exists in source. If so,
43+
/// rename op and update all its references in target.
44+
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
45+
spirv::ModuleOp target,
46+
spirv::ModuleOp source,
47+
unsigned &lastUsedID) {
48+
if (!SymbolTable::lookupSymbolIn(source, op.getName()))
49+
return success();
50+
51+
StringRef oldSymName = op.getName();
52+
SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
53+
54+
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
55+
return op.emitError("unable to update all symbol uses for ")
56+
<< oldSymName << " to " << newSymName;
57+
58+
SymbolTable::setSymbolName(op, newSymName);
59+
return success();
60+
}
61+
62+
namespace mlir {
63+
namespace spirv {
64+
65+
// TODO Properly test symbol rename listener mechanism.
66+
67+
OwningSPIRVModuleRef
68+
combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
69+
OpBuilder &combinedModuleBuilder,
70+
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
71+
symRenameListener) {
72+
unsigned lastUsedID = 0;
73+
74+
if (modules.empty())
75+
return nullptr;
76+
77+
auto addressingModel = modules[0].addressing_model();
78+
auto memoryModel = modules[0].memory_model();
79+
80+
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
81+
modules[0].getLoc(), addressingModel, memoryModel);
82+
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
83+
84+
// In some cases, a symbol in the (current state of the) combined module is
85+
// renamed in order to maintain the conflicting symbol in the input module
86+
// being merged. For example, if the conflict is between a global variable in
87+
// the current combined module and a function in the input module, the global
88+
// varaible is renamed. In order to notify listeners of the symbol updates in
89+
// such cases, we need to keep track of the module from which the renamed
90+
// symbol in the combined module originated. This map keeps such information.
91+
DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
92+
93+
for (auto module : modules) {
94+
if (module.addressing_model() != addressingModel ||
95+
module.memory_model() != memoryModel) {
96+
module.emitError(
97+
"input modules differ in addressing model and/or memory model");
98+
return nullptr;
99+
}
100+
101+
spirv::ModuleOp moduleClone = module.clone();
102+
103+
// In the combined module, rename all symbols that conflict with symbols
104+
// from the current input module. This renmaing applies to all ops except
105+
// for spv.funcs. This way, if the conflicting op in the input module is
106+
// non-spv.func, we rename that symbol instead and maintain the spv.func in
107+
// the combined module name as it is.
108+
for (auto &op : combinedModule.getBlock().without_terminator()) {
109+
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
110+
if (!symbolOp)
111+
continue;
112+
113+
StringRef oldSymName = symbolOp.getName();
114+
115+
if (!isa<FuncOp>(op) &&
116+
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
117+
lastUsedID)))
118+
return nullptr;
119+
120+
StringRef newSymName = symbolOp.getName();
121+
122+
if (symRenameListener && oldSymName != newSymName) {
123+
spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
124+
125+
if (!originalModule) {
126+
module.emitError("unable to find original ModuleOp for symbol ")
127+
<< oldSymName;
128+
return nullptr;
129+
}
130+
131+
symRenameListener(originalModule, oldSymName, newSymName);
132+
133+
// Since the symbol name is updated, there is no need to maintain the
134+
// entry that assocaites the old symbol name with the original module.
135+
symNameToModuleMap.erase(oldSymName);
136+
// Instead, add a new entry to map the new symbol name to the original
137+
// module in case it gets renamed again later.
138+
symNameToModuleMap[newSymName] = originalModule;
139+
}
140+
}
141+
142+
// In the current input module, rename all symbols that conflict with
143+
// symbols from the combined module. This includes renaming spv.funcs.
144+
for (auto &op : moduleClone.getBlock().without_terminator()) {
145+
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
146+
if (!symbolOp)
147+
continue;
148+
149+
StringRef oldSymName = symbolOp.getName();
150+
151+
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
152+
lastUsedID)))
153+
return nullptr;
154+
155+
StringRef newSymName = symbolOp.getName();
156+
157+
if (symRenameListener && oldSymName != newSymName) {
158+
symRenameListener(module, oldSymName, newSymName);
159+
160+
// Insert the module associated with the symbol name.
161+
auto emplaceResult = symNameToModuleMap.try_emplace(newSymName, module);
162+
163+
// If an entry with the same symbol name is already present, this must
164+
// be a problem with the implementation, specially clean-up of the map
165+
// while iterating over the combined module above.
166+
if (!emplaceResult.second) {
167+
module.emitError("did not expect to find an entry for symbol ")
168+
<< newSymName;
169+
return nullptr;
170+
}
171+
}
172+
}
173+
174+
// Clone all the module's ops to the combined module.
175+
for (auto &op : moduleClone.getBlock().without_terminator())
176+
combinedModuleBuilder.insert(op.clone());
177+
}
178+
179+
return combinedModule;
180+
}
181+
182+
} // namespace spirv
183+
} // namespace mlir
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK: module {
4+
// CHECK-NEXT: spv.module Logical GLSL450 {
5+
// CHECK-NEXT: spv.specConstant @m1_sc
6+
// CHECK-NEXT: spv.specConstant @m2_sc
7+
// CHECK-NEXT: spv.func @variable_init_spec_constant
8+
// CHECK-NEXT: spv._reference_of @m2_sc
9+
// CHECK-NEXT: spv.Variable init
10+
// CHECK-NEXT: spv.Return
11+
// CHECK-NEXT: }
12+
// CHECK-NEXT: }
13+
// CHECK-NEXT: }
14+
15+
module {
16+
spv.module Logical GLSL450 {
17+
spv.specConstant @m1_sc = 42.42 : f32
18+
}
19+
20+
spv.module Logical GLSL450 {
21+
spv.specConstant @m2_sc = 42 : i32
22+
spv.func @variable_init_spec_constant() -> () "None" {
23+
%0 = spv._reference_of @m2_sc : i32
24+
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
25+
spv.Return
26+
}
27+
}
28+
}
29+
30+
// -----
31+
32+
module {
33+
spv.module Physical64 GLSL450 {
34+
}
35+
36+
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
37+
spv.module Logical GLSL450 {
38+
}
39+
}
40+
41+
// -----
42+
43+
module {
44+
spv.module Logical Simple {
45+
}
46+
47+
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
48+
spv.module Logical GLSL450 {
49+
}
50+
}

0 commit comments

Comments
 (0)