Skip to content

Commit 96e4a8b

Browse files
committed
Vocab changes1
1 parent dfe59f2 commit 96e4a8b

File tree

3 files changed

+164
-67
lines changed

3 files changed

+164
-67
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
#include "llvm/ADT/DenseMap.h"
3333
#include "llvm/IR/PassManager.h"
34+
#include "llvm/Support/CommandLine.h"
3435
#include "llvm/Support/ErrorOr.h"
36+
#include "llvm/Support/JSON.h"
3537
#include <map>
3638

3739
namespace llvm {
@@ -43,6 +45,7 @@ class Function;
4345
class Type;
4446
class Value;
4547
class raw_ostream;
48+
class LLVMContext;
4649

4750
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
4851
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,7 +56,12 @@ class raw_ostream;
5356
enum class IR2VecKind { Symbolic };
5457

5558
namespace ir2vec {
56-
/// Embedding is a datavtype that wraps std::vector<double>. It provides
59+
60+
LLVM_ABI extern cl::opt<float> OpcWeight;
61+
LLVM_ABI extern cl::opt<float> TypeWeight;
62+
LLVM_ABI extern cl::opt<float> ArgWeight;
63+
64+
/// Embedding is a datatype that wraps std::vector<double>. It provides
5765
/// additional functionality for arithmetic and comparison operations.
5866
/// It is meant to be used *like* std::vector<double> but is more restrictive
5967
/// in the sense that it does not allow the user to change the size of the
@@ -226,10 +234,12 @@ class IR2VecVocabResult {
226234
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
227235
ir2vec::Vocab Vocabulary;
228236
Error readVocabulary();
237+
void emitError(Error Err, LLVMContext &Ctx);
229238

230239
public:
231240
static AnalysisKey Key;
232241
IR2VecVocabAnalysis() = default;
242+
explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
233243
using Result = IR2VecVocabResult;
234244
Result run(Module &M, ModuleAnalysisManager &MAM);
235245
};

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
#include "llvm/ADT/Statistic.h"
1717
#include "llvm/IR/Module.h"
1818
#include "llvm/IR/PassManager.h"
19-
#include "llvm/Support/CommandLine.h"
2019
#include "llvm/Support/Debug.h"
2120
#include "llvm/Support/Errc.h"
2221
#include "llvm/Support/Error.h"
2322
#include "llvm/Support/ErrorHandling.h"
2423
#include "llvm/Support/Format.h"
25-
#include "llvm/Support/JSON.h"
2624
#include "llvm/Support/MemoryBuffer.h"
2725

2826
using namespace llvm;
@@ -33,25 +31,29 @@ using namespace ir2vec;
3331
STATISTIC(VocabMissCounter,
3432
"Number of lookups to entites not present in the vocabulary");
3533

34+
namespace llvm {
35+
namespace ir2vec {
3636
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
3737

3838
// FIXME: Use a default vocab when not specified
3939
static cl::opt<std::string>
4040
VocabFile("ir2vec-vocab-path", cl::Optional,
4141
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
4242
cl::cat(IR2VecCategory));
43-
static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
44-
cl::init(1.0),
45-
cl::desc("Weight for opcode embeddings"),
46-
cl::cat(IR2VecCategory));
47-
static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
48-
cl::init(0.5),
49-
cl::desc("Weight for type embeddings"),
50-
cl::cat(IR2VecCategory));
51-
static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
52-
cl::init(0.2),
53-
cl::desc("Weight for argument embeddings"),
54-
cl::cat(IR2VecCategory));
43+
LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
44+
cl::init(1.0),
45+
cl::desc("Weight for opcode embeddings"),
46+
cl::cat(IR2VecCategory));
47+
LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
48+
cl::init(0.5),
49+
cl::desc("Weight for type embeddings"),
50+
cl::cat(IR2VecCategory));
51+
LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
52+
cl::init(0.2),
53+
cl::desc("Weight for argument embeddings"),
54+
cl::cat(IR2VecCategory));
55+
} // namespace ir2vec
56+
} // namespace llvm
5557

5658
AnalysisKey IR2VecVocabAnalysis::Key;
5759

@@ -251,49 +253,67 @@ bool IR2VecVocabResult::invalidate(
251253
// by auto-generating a default vocabulary during the build time.
252254
Error IR2VecVocabAnalysis::readVocabulary() {
253255
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
254-
if (!BufOrError) {
256+
if (!BufOrError)
255257
return createFileError(VocabFile, BufOrError.getError());
256-
}
258+
257259
auto Content = BufOrError.get()->getBuffer();
258260
json::Path::Root Path("");
259261
Expected<json::Value> ParsedVocabValue = json::parse(Content);
260262
if (!ParsedVocabValue)
261263
return ParsedVocabValue.takeError();
262264

263265
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
264-
if (!Res) {
266+
if (!Res)
265267
return createStringError(errc::illegal_byte_sequence,
266268
"Unable to parse the vocabulary");
267-
}
268-
assert(Vocabulary.size() > 0 && "Vocabulary is empty");
269+
270+
if (Vocabulary.empty())
271+
return createStringError(errc::illegal_byte_sequence,
272+
"Vocabulary is empty");
269273

270274
unsigned Dim = Vocabulary.begin()->second.size();
271-
assert(Dim > 0 && "Dimension of vocabulary is zero");
272-
(void)Dim;
273-
assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
274-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
275-
return Entry.second.size() == Dim;
276-
}) &&
277-
"All vectors in the vocabulary are not of the same dimension");
275+
if (Dim == 0)
276+
return createStringError(errc::illegal_byte_sequence,
277+
"Dimension of vocabulary is zero");
278+
279+
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
280+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
281+
return Entry.second.size() == Dim;
282+
}))
283+
return createStringError(
284+
errc::illegal_byte_sequence,
285+
"All vectors in the vocabulary are not of the same dimension");
286+
278287
return Error::success();
279288
}
280289

290+
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
291+
: Vocabulary(std::move(Vocabulary)) {}
292+
293+
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
294+
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
295+
Ctx.emitError("Error reading vocabulary: " + EI.message());
296+
});
297+
}
298+
281299
IR2VecVocabAnalysis::Result
282300
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
283301
auto Ctx = &M.getContext();
302+
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
303+
// If vocabulary is already populated by the constructor, use it.
304+
if (!Vocabulary.empty())
305+
return IR2VecVocabResult(std::move(Vocabulary));
306+
307+
// Otherwise, try to read from the vocabulary file.
284308
if (VocabFile.empty()) {
285309
// FIXME: Use default vocabulary
286310
Ctx->emitError("IR2Vec vocabulary file path not specified");
287311
return IR2VecVocabResult(); // Return invalid result
288312
}
289313
if (auto Err = readVocabulary()) {
290-
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
291-
Ctx->emitError("Error reading vocabulary: " + EI.message());
292-
});
314+
emitError(std::move(Err), *Ctx);
293315
return IR2VecVocabResult();
294316
}
295-
// FIXME: Scale the vocabulary here once. This would avoid scaling per use
296-
// later.
297317
return IR2VecVocabResult(std::move(Vocabulary));
298318
}
299319

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -266,25 +266,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
266266
EXPECT_EQ(validResult.getDimension(), 2u);
267267
}
268268

269-
// Helper to create a minimal function and embedder for getter tests
270-
struct GetterTestEnv {
271-
Vocab V = {};
269+
// Fixture for IR2Vec tests requiring IR setup and weight management.
270+
class IR2VecTestFixture : public ::testing::Test {
271+
protected:
272+
Vocab V;
272273
LLVMContext Ctx;
273-
std::unique_ptr<Module> M = nullptr;
274+
std::unique_ptr<Module> M;
274275
Function *F = nullptr;
275276
BasicBlock *BB = nullptr;
276-
Instruction *Add = nullptr;
277-
Instruction *Ret = nullptr;
278-
std::unique_ptr<Embedder> Emb = nullptr;
277+
Instruction *AddInst = nullptr;
278+
Instruction *RetInst = nullptr;
279279

280-
GetterTestEnv() {
280+
float OriginalOpcWeight = ::OpcWeight;
281+
float OriginalTypeWeight = ::TypeWeight;
282+
float OriginalArgWeight = ::ArgWeight;
283+
284+
void SetUp() override {
281285
V = {{"add", {1.0, 2.0}},
282286
{"integerTy", {0.5, 0.5}},
283287
{"constant", {0.2, 0.3}},
284288
{"variable", {0.0, 0.0}},
285289
{"unknownTy", {0.0, 0.0}}};
286290

287-
M = std::make_unique<Module>("M", Ctx);
291+
// Setup IR
292+
M = std::make_unique<Module>("TestM", Ctx);
288293
FunctionType *FTy = FunctionType::get(
289294
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
290295
false);
@@ -293,61 +298,82 @@ struct GetterTestEnv {
293298
Argument *Arg = F->getArg(0);
294299
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
295300

296-
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
297-
Ret = ReturnInst::Create(Ctx, Add, BB);
301+
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
302+
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
303+
}
304+
305+
void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
306+
::OpcWeight = OpcWeight;
307+
::TypeWeight = TypeWeight;
308+
::ArgWeight = ArgWeight;
309+
}
298310

299-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
300-
EXPECT_TRUE(static_cast<bool>(Result));
301-
Emb = std::move(*Result);
311+
void TearDown() override {
312+
// Restore original global weights
313+
::OpcWeight = OriginalOpcWeight;
314+
::TypeWeight = OriginalTypeWeight;
315+
::ArgWeight = OriginalArgWeight;
302316
}
303317
};
304318

305-
TEST(IR2VecTest, GetInstVecMap) {
306-
GetterTestEnv Env;
307-
const auto &InstMap = Env.Emb->getInstVecMap();
319+
TEST_F(IR2VecTestFixture, GetInstVecMap) {
320+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
321+
ASSERT_TRUE(static_cast<bool>(Result));
322+
auto Emb = std::move(*Result);
323+
324+
const auto &InstMap = Emb->getInstVecMap();
308325

309326
EXPECT_EQ(InstMap.size(), 2u);
310-
EXPECT_TRUE(InstMap.count(Env.Add));
311-
EXPECT_TRUE(InstMap.count(Env.Ret));
327+
EXPECT_TRUE(InstMap.count(AddInst));
328+
EXPECT_TRUE(InstMap.count(RetInst));
312329

313-
EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
314-
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
330+
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
331+
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
315332

316333
// Check values for add: {1.29, 2.31}
317-
EXPECT_THAT(InstMap.at(Env.Add),
334+
EXPECT_THAT(InstMap.at(AddInst),
318335
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
319336

320337
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
321338
// vocab
322-
EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
339+
EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
323340
}
324341

325-
TEST(IR2VecTest, GetBBVecMap) {
326-
GetterTestEnv Env;
327-
const auto &BBMap = Env.Emb->getBBVecMap();
342+
TEST_F(IR2VecTestFixture, GetBBVecMap) {
343+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
344+
ASSERT_TRUE(static_cast<bool>(Result));
345+
auto Emb = std::move(*Result);
346+
347+
const auto &BBMap = Emb->getBBVecMap();
328348

329349
EXPECT_EQ(BBMap.size(), 1u);
330-
EXPECT_TRUE(BBMap.count(Env.BB));
331-
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
350+
EXPECT_TRUE(BBMap.count(BB));
351+
EXPECT_EQ(BBMap.at(BB).size(), 2u);
332352

333353
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
334354
// {1.29, 2.31}
335-
EXPECT_THAT(BBMap.at(Env.BB),
355+
EXPECT_THAT(BBMap.at(BB),
336356
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
337357
}
338358

339-
TEST(IR2VecTest, GetBBVector) {
340-
GetterTestEnv Env;
341-
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
359+
TEST_F(IR2VecTestFixture, GetBBVector) {
360+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
361+
ASSERT_TRUE(static_cast<bool>(Result));
362+
auto Emb = std::move(*Result);
363+
364+
const auto &BBVec = Emb->getBBVector(*BB);
342365

343366
EXPECT_EQ(BBVec.size(), 2u);
344367
EXPECT_THAT(BBVec,
345368
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
346369
}
347370

348-
TEST(IR2VecTest, GetFunctionVector) {
349-
GetterTestEnv Env;
350-
const auto &FuncVec = Env.Emb->getFunctionVector();
371+
TEST_F(IR2VecTestFixture, GetFunctionVector) {
372+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
373+
ASSERT_TRUE(static_cast<bool>(Result));
374+
auto Emb = std::move(*Result);
375+
376+
const auto &FuncVec = Emb->getFunctionVector();
351377

352378
EXPECT_EQ(FuncVec.size(), 2u);
353379

@@ -356,4 +382,45 @@ TEST(IR2VecTest, GetFunctionVector) {
356382
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
357383
}
358384

385+
TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
386+
setWeights(1.0, 1.0, 1.0);
387+
388+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
389+
ASSERT_TRUE(static_cast<bool>(Result));
390+
auto Emb = std::move(*Result);
391+
392+
const auto &FuncVec = Emb->getFunctionVector();
393+
394+
EXPECT_EQ(FuncVec.size(), 2u);
395+
396+
// Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
397+
// 0.3] + [0.0 0.0])
398+
EXPECT_THAT(FuncVec,
399+
ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
400+
}
401+
402+
TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
403+
Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
404+
Vocab ExpectedVocab = InitialVocab;
405+
unsigned ExpectedDim = InitialVocab.begin()->second.size();
406+
407+
IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
408+
409+
LLVMContext TestCtx;
410+
Module TestMod("TestModuleForVocabAnalysis", TestCtx);
411+
ModuleAnalysisManager MAM;
412+
IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
413+
414+
EXPECT_TRUE(Result.isValid());
415+
ASSERT_FALSE(Result.getVocabulary().empty());
416+
EXPECT_EQ(Result.getDimension(), ExpectedDim);
417+
418+
const auto &ResultVocab = Result.getVocabulary();
419+
EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
420+
for (const auto &pair : ExpectedVocab) {
421+
EXPECT_TRUE(ResultVocab.count(pair.first));
422+
EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
423+
}
424+
}
425+
359426
} // end anonymous namespace

0 commit comments

Comments
 (0)