Skip to content

Commit 8e397f7

Browse files
committed
[IR2Vec] Scale vocab
1 parent b7ec652 commit 8e397f7

File tree

16 files changed

+395
-135
lines changed

16 files changed

+395
-135
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct Embedding {
108108
/// Arithmetic operators
109109
Embedding &operator+=(const Embedding &RHS);
110110
Embedding &operator-=(const Embedding &RHS);
111+
Embedding &operator*=(double Factor);
111112

112113
/// Adds Src Embedding scaled by Factor with the called Embedding.
113114
/// Called_Embedding += Src * Factor
@@ -116,6 +117,8 @@ struct Embedding {
116117
/// Returns true if the embedding is approximately equal to the RHS embedding
117118
/// within the specified tolerance.
118119
bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
120+
121+
void print(raw_ostream &OS) const;
119122
};
120123

121124
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
@@ -234,6 +237,8 @@ class IR2VecVocabResult {
234237
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
235238
ir2vec::Vocab Vocabulary;
236239
Error readVocabulary();
240+
Error parseVocabSection(const char *Key, const json::Value ParsedVocabValue,
241+
ir2vec::Vocab &TargetVocab, unsigned &Dim);
237242
void emitError(Error Err, LLVMContext &Ctx);
238243

239244
public:
@@ -249,14 +254,23 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
249254
/// functions.
250255
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
251256
raw_ostream &OS;
252-
void printVector(const ir2vec::Embedding &Vec) const;
253257

254258
public:
255259
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
256260
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
257261
static bool isRequired() { return true; }
258262
};
259263

264+
/// This pass prints the embeddings in the vocabulary
265+
class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> {
266+
raw_ostream &OS;
267+
268+
public:
269+
explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {}
270+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
271+
static bool isRequired() { return true; }
272+
};
273+
260274
} // namespace llvm
261275

262276
#endif // LLVM_ANALYSIS_IR2VEC_H

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
8585
return *this;
8686
}
8787

88+
Embedding &Embedding::operator*=(double Factor) {
89+
std::transform(this->begin(), this->end(), this->begin(),
90+
[Factor](double Elem) { return Elem * Factor; });
91+
return *this;
92+
}
93+
8894
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
8995
assert(this->size() == Src.size() && "Vectors must have the same dimension");
9096
for (size_t Itr = 0; Itr < this->size(); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101107
return true;
102108
}
103109

110+
void Embedding::print(raw_ostream &OS) const {
111+
OS << " [";
112+
for (const auto &Elem : Data)
113+
OS << " " << format("%.2f", Elem) << " ";
114+
OS << "]\n";
115+
}
116+
104117
// ==----------------------------------------------------------------------===//
105118
// Embedder and its subclasses
106119
//===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196209
for (const auto &I : BB.instructionsWithoutDebug()) {
197210
Embedding InstVector(Dimension, 0);
198211

199-
const auto OpcVec = lookupVocab(I.getOpcodeName());
200-
InstVector.scaleAndAdd(OpcVec, OpcWeight);
201-
202212
// FIXME: Currently lookups are string based. Use numeric Keys
203213
// for efficiency.
204-
const auto Type = I.getType();
205-
const auto TypeVec = getTypeEmbedding(Type);
206-
InstVector.scaleAndAdd(TypeVec, TypeWeight);
207-
214+
InstVector += lookupVocab(I.getOpcodeName());
215+
InstVector += getTypeEmbedding(I.getType());
208216
for (const auto &Op : I.operands()) {
209-
const auto OperandVec = getOperandEmbedding(Op.get());
210-
InstVector.scaleAndAdd(OperandVec, ArgWeight);
217+
InstVector += getOperandEmbedding(Op.get());
211218
}
212219
InstVecMap[&I] = InstVector;
213220
BBVector += InstVector;
@@ -251,6 +258,47 @@ bool IR2VecVocabResult::invalidate(
251258
return !(PAC.preservedWhenStateless());
252259
}
253260

261+
Error IR2VecVocabAnalysis::parseVocabSection(const char *Key,
262+
const json::Value ParsedVocabValue,
263+
ir2vec::Vocab &TargetVocab,
264+
unsigned &Dim) {
265+
assert(Key && "Key cannot be null");
266+
267+
json::Path::Root Path("");
268+
const json::Object *RootObj = ParsedVocabValue.getAsObject();
269+
if (!RootObj) {
270+
return createStringError(errc::invalid_argument,
271+
"JSON root is not an object");
272+
}
273+
274+
const json::Value *SectionValue = RootObj->get(Key);
275+
if (!SectionValue)
276+
return createStringError(errc::invalid_argument,
277+
"Missing '" + std::string(Key) +
278+
"' section in vocabulary file");
279+
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
280+
return createStringError(errc::illegal_byte_sequence,
281+
"Unable to parse '" + std::string(Key) +
282+
"' section from vocabulary");
283+
284+
Dim = TargetVocab.begin()->second.size();
285+
if (Dim == 0)
286+
return createStringError(errc::illegal_byte_sequence,
287+
"Dimension of '" + std::string(Key) +
288+
"' section of the vocabulary is zero");
289+
290+
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
291+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
292+
return Entry.second.size() == Dim;
293+
}))
294+
return createStringError(
295+
errc::illegal_byte_sequence,
296+
"All vectors in the '" + std::string(Key) +
297+
"' section of the vocabulary are not of the same dimension");
298+
299+
return Error::success();
300+
};
301+
254302
// FIXME: Make this optional. We can avoid file reads
255303
// by auto-generating a default vocabulary during the build time.
256304
Error IR2VecVocabAnalysis::readVocabulary() {
@@ -259,32 +307,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259307
return createFileError(VocabFile, BufOrError.getError());
260308

261309
auto Content = BufOrError.get()->getBuffer();
262-
json::Path::Root Path("");
310+
263311
Expected<json::Value> ParsedVocabValue = json::parse(Content);
264312
if (!ParsedVocabValue)
265313
return ParsedVocabValue.takeError();
266314

267-
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
268-
if (!Res)
269-
return createStringError(errc::illegal_byte_sequence,
270-
"Unable to parse the vocabulary");
315+
ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
316+
unsigned OpcodeDim, TypeDim, ArgDim;
317+
if (auto Err = parseVocabSection("Opcodes", *ParsedVocabValue, OpcodeVocab,
318+
OpcodeDim))
319+
return Err;
271320

272-
if (Vocabulary.empty())
273-
return createStringError(errc::illegal_byte_sequence,
274-
"Vocabulary is empty");
321+
if (auto Err =
322+
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
323+
return Err;
275324

276-
unsigned Dim = Vocabulary.begin()->second.size();
277-
if (Dim == 0)
325+
if (auto Err =
326+
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
327+
return Err;
328+
329+
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278330
return createStringError(errc::illegal_byte_sequence,
279-
"Dimension of vocabulary is zero");
331+
"Vocabulary sections have different dimensions");
280332

281-
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
282-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
283-
return Entry.second.size() == Dim;
284-
}))
285-
return createStringError(
286-
errc::illegal_byte_sequence,
287-
"All vectors in the vocabulary are not of the same dimension");
333+
auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
334+
for (auto &Entry : Vocab)
335+
Entry.second *= Weight;
336+
};
337+
scaleVocabSection(OpcodeVocab, OpcWeight);
338+
scaleVocabSection(TypeVocab, TypeWeight);
339+
scaleVocabSection(ArgVocab, ArgWeight);
340+
341+
Vocabulary.insert(OpcodeVocab.begin(), OpcodeVocab.end());
342+
Vocabulary.insert(TypeVocab.begin(), TypeVocab.end());
343+
Vocabulary.insert(ArgVocab.begin(), ArgVocab.end());
288344

289345
return Error::success();
290346
}
@@ -304,7 +360,7 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304360
IR2VecVocabAnalysis::Result
305361
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
306362
auto Ctx = &M.getContext();
307-
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
363+
308364
// If vocabulary is already populated by the constructor, use it.
309365
if (!Vocabulary.empty())
310366
return IR2VecVocabResult(std::move(Vocabulary));
@@ -323,16 +379,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323379
}
324380

325381
// ==----------------------------------------------------------------------===//
326-
// IR2VecPrinterPass
382+
// Printer Passes
327383
//===----------------------------------------------------------------------===//
328384

329-
void IR2VecPrinterPass::printVector(const Embedding &Vec) const {
330-
OS << " [";
331-
for (const auto &Elem : Vec)
332-
OS << " " << format("%.2f", Elem) << " ";
333-
OS << "]\n";
334-
}
335-
336385
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
337386
ModuleAnalysisManager &MAM) {
338387
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
@@ -353,15 +402,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353402

354403
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
355404
OS << "Function vector: ";
356-
printVector(Emb->getFunctionVector());
405+
Emb->getFunctionVector().print(OS);
357406

358407
OS << "Basic block vectors:\n";
359408
const auto &BBMap = Emb->getBBVecMap();
360409
for (const BasicBlock &BB : F) {
361410
auto It = BBMap.find(&BB);
362411
if (It != BBMap.end()) {
363412
OS << "Basic block: " << BB.getName() << ":\n";
364-
printVector(It->second);
413+
It->second.print(OS);
365414
}
366415
}
367416

@@ -373,10 +422,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373422
if (It != InstMap.end()) {
374423
OS << "Instruction: ";
375424
I.print(OS);
376-
printVector(It->second);
425+
It->second.print(OS);
377426
}
378427
}
379428
}
380429
}
381430
return PreservedAnalyses::all();
382431
}
432+
433+
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
434+
ModuleAnalysisManager &MAM) {
435+
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
436+
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
437+
438+
auto Vocab = IR2VecVocabResult.getVocabulary();
439+
for (const auto &Entry : Vocab) {
440+
OS << "Key: " << Entry.first << ": ";
441+
Entry.second.print(OS);
442+
}
443+
444+
return PreservedAnalyses::all();
445+
}

0 commit comments

Comments
 (0)