@@ -266,25 +266,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
266
266
EXPECT_EQ (validResult.getDimension (), 2u );
267
267
}
268
268
269
- // Helper to create a minimal function and embedder for getter tests
270
- struct GetterTestEnv {
271
- Vocab V = {};
269
+ // Fixture for IR2Vec tests requiring IR setup and weight management.
270
+ class IR2VecTestFixture : public ::testing::Test {
271
+ protected:
272
+ Vocab V;
272
273
LLVMContext Ctx;
273
- std::unique_ptr<Module> M = nullptr ;
274
+ std::unique_ptr<Module> M;
274
275
Function *F = nullptr ;
275
276
BasicBlock *BB = nullptr ;
276
- Instruction *Add = nullptr ;
277
- Instruction *Ret = nullptr ;
278
- std::unique_ptr<Embedder> Emb = nullptr ;
277
+ Instruction *AddInst = nullptr ;
278
+ Instruction *RetInst = nullptr ;
279
279
280
- GetterTestEnv () {
280
+ float OriginalOpcWeight = ::OpcWeight;
281
+ float OriginalTypeWeight = ::TypeWeight;
282
+ float OriginalArgWeight = ::ArgWeight;
283
+
284
+ void SetUp () override {
281
285
V = {{" add" , {1.0 , 2.0 }},
282
286
{" integerTy" , {0.5 , 0.5 }},
283
287
{" constant" , {0.2 , 0.3 }},
284
288
{" variable" , {0.0 , 0.0 }},
285
289
{" unknownTy" , {0.0 , 0.0 }}};
286
290
287
- M = std::make_unique<Module>(" M" , Ctx);
291
+ // Setup IR
292
+ M = std::make_unique<Module>(" TestM" , Ctx);
288
293
FunctionType *FTy = FunctionType::get (
289
294
Type::getInt32Ty (Ctx), {Type::getInt32Ty (Ctx), Type::getInt32Ty (Ctx)},
290
295
false );
@@ -293,61 +298,82 @@ struct GetterTestEnv {
293
298
Argument *Arg = F->getArg (0 );
294
299
llvm::Value *Const = ConstantInt::get (Type::getInt32Ty (Ctx), 42 );
295
300
296
- Add = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
297
- Ret = ReturnInst::Create (Ctx, Add, BB);
301
+ AddInst = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
302
+ RetInst = ReturnInst::Create (Ctx, AddInst, BB);
303
+ }
304
+
305
+ void setWeights (float OpcWeight, float TypeWeight, float ArgWeight) {
306
+ ::OpcWeight = OpcWeight;
307
+ ::TypeWeight = TypeWeight;
308
+ ::ArgWeight = ArgWeight;
309
+ }
298
310
299
- auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
300
- EXPECT_TRUE (static_cast <bool >(Result));
301
- Emb = std::move (*Result);
311
+ void TearDown () override {
312
+ // Restore original global weights
313
+ ::OpcWeight = OriginalOpcWeight;
314
+ ::TypeWeight = OriginalTypeWeight;
315
+ ::ArgWeight = OriginalArgWeight;
302
316
}
303
317
};
304
318
305
- TEST (IR2VecTest, GetInstVecMap) {
306
- GetterTestEnv Env;
307
- const auto &InstMap = Env.Emb ->getInstVecMap ();
319
+ TEST_F (IR2VecTestFixture, GetInstVecMap) {
320
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
321
+ ASSERT_TRUE (static_cast <bool >(Result));
322
+ auto Emb = std::move (*Result);
323
+
324
+ const auto &InstMap = Emb->getInstVecMap ();
308
325
309
326
EXPECT_EQ (InstMap.size (), 2u );
310
- EXPECT_TRUE (InstMap.count (Env. Add ));
311
- EXPECT_TRUE (InstMap.count (Env. Ret ));
327
+ EXPECT_TRUE (InstMap.count (AddInst ));
328
+ EXPECT_TRUE (InstMap.count (RetInst ));
312
329
313
- EXPECT_EQ (InstMap.at (Env. Add ).size (), 2u );
314
- EXPECT_EQ (InstMap.at (Env. Ret ).size (), 2u );
330
+ EXPECT_EQ (InstMap.at (AddInst ).size (), 2u );
331
+ EXPECT_EQ (InstMap.at (RetInst ).size (), 2u );
315
332
316
333
// Check values for add: {1.29, 2.31}
317
- EXPECT_THAT (InstMap.at (Env. Add ),
334
+ EXPECT_THAT (InstMap.at (AddInst ),
318
335
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
319
336
320
337
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
321
338
// vocab
322
- EXPECT_THAT (InstMap.at (Env. Ret ), ElementsAre (0.0 , 0.0 ));
339
+ EXPECT_THAT (InstMap.at (RetInst ), ElementsAre (0.0 , 0.0 ));
323
340
}
324
341
325
- TEST (IR2VecTest, GetBBVecMap) {
326
- GetterTestEnv Env;
327
- const auto &BBMap = Env.Emb ->getBBVecMap ();
342
+ TEST_F (IR2VecTestFixture, GetBBVecMap) {
343
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
344
+ ASSERT_TRUE (static_cast <bool >(Result));
345
+ auto Emb = std::move (*Result);
346
+
347
+ const auto &BBMap = Emb->getBBVecMap ();
328
348
329
349
EXPECT_EQ (BBMap.size (), 1u );
330
- EXPECT_TRUE (BBMap.count (Env. BB ));
331
- EXPECT_EQ (BBMap.at (Env. BB ).size (), 2u );
350
+ EXPECT_TRUE (BBMap.count (BB));
351
+ EXPECT_EQ (BBMap.at (BB).size (), 2u );
332
352
333
353
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
334
354
// {1.29, 2.31}
335
- EXPECT_THAT (BBMap.at (Env. BB ),
355
+ EXPECT_THAT (BBMap.at (BB),
336
356
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
337
357
}
338
358
339
- TEST (IR2VecTest, GetBBVector) {
340
- GetterTestEnv Env;
341
- const auto &BBVec = Env.Emb ->getBBVector (*Env.BB );
359
+ TEST_F (IR2VecTestFixture, GetBBVector) {
360
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
361
+ ASSERT_TRUE (static_cast <bool >(Result));
362
+ auto Emb = std::move (*Result);
363
+
364
+ const auto &BBVec = Emb->getBBVector (*BB);
342
365
343
366
EXPECT_EQ (BBVec.size (), 2u );
344
367
EXPECT_THAT (BBVec,
345
368
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
346
369
}
347
370
348
- TEST (IR2VecTest, GetFunctionVector) {
349
- GetterTestEnv Env;
350
- const auto &FuncVec = Env.Emb ->getFunctionVector ();
371
+ TEST_F (IR2VecTestFixture, GetFunctionVector) {
372
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
373
+ ASSERT_TRUE (static_cast <bool >(Result));
374
+ auto Emb = std::move (*Result);
375
+
376
+ const auto &FuncVec = Emb->getFunctionVector ();
351
377
352
378
EXPECT_EQ (FuncVec.size (), 2u );
353
379
@@ -356,4 +382,45 @@ TEST(IR2VecTest, GetFunctionVector) {
356
382
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
357
383
}
358
384
385
+ TEST_F (IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
386
+ setWeights (1.0 , 1.0 , 1.0 );
387
+
388
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
389
+ ASSERT_TRUE (static_cast <bool >(Result));
390
+ auto Emb = std::move (*Result);
391
+
392
+ const auto &FuncVec = Emb->getFunctionVector ();
393
+
394
+ EXPECT_EQ (FuncVec.size (), 2u );
395
+
396
+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
397
+ // 0.3] + [0.0 0.0])
398
+ EXPECT_THAT (FuncVec,
399
+ ElementsAre (DoubleNear (1.7 , 1e-6 ), DoubleNear (2.8 , 1e-6 )));
400
+ }
401
+
402
+ TEST (IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
403
+ Vocab InitialVocab = {{" key1" , {1.1 , 2.2 }}, {" key2" , {3.3 , 4.4 }}};
404
+ Vocab ExpectedVocab = InitialVocab;
405
+ unsigned ExpectedDim = InitialVocab.begin ()->second .size ();
406
+
407
+ IR2VecVocabAnalysis VocabAnalysis (std::move (InitialVocab));
408
+
409
+ LLVMContext TestCtx;
410
+ Module TestMod (" TestModuleForVocabAnalysis" , TestCtx);
411
+ ModuleAnalysisManager MAM;
412
+ IR2VecVocabResult Result = VocabAnalysis.run (TestMod, MAM);
413
+
414
+ EXPECT_TRUE (Result.isValid ());
415
+ ASSERT_FALSE (Result.getVocabulary ().empty ());
416
+ EXPECT_EQ (Result.getDimension (), ExpectedDim);
417
+
418
+ const auto &ResultVocab = Result.getVocabulary ();
419
+ EXPECT_EQ (ResultVocab.size (), ExpectedVocab.size ());
420
+ for (const auto &pair : ExpectedVocab) {
421
+ EXPECT_TRUE (ResultVocab.count (pair.first ));
422
+ EXPECT_THAT (ResultVocab.at (pair.first ), ElementsAreArray (pair.second ));
423
+ }
424
+ }
425
+
359
426
} // end anonymous namespace
0 commit comments