@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
85
85
return *this ;
86
86
}
87
87
88
+ Embedding &Embedding::operator *=(double Factor) {
89
+ std::transform (this ->begin (), this ->end (), this ->begin (),
90
+ [Factor](double Elem) { return Elem * Factor; });
91
+ return *this ;
92
+ }
93
+
88
94
Embedding &Embedding::scaleAndAdd (const Embedding &Src, float Factor) {
89
95
assert (this ->size () == Src.size () && " Vectors must have the same dimension" );
90
96
for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101
107
return true ;
102
108
}
103
109
110
+ void Embedding::print (raw_ostream &OS) const {
111
+ OS << " [" ;
112
+ for (const auto &Elem : Data)
113
+ OS << " " << format (" %.2f" , Elem) << " " ;
114
+ OS << " ]\n " ;
115
+ }
116
+
104
117
// ==----------------------------------------------------------------------===//
105
118
// Embedder and its subclasses
106
119
// ===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196
209
for (const auto &I : BB.instructionsWithoutDebug ()) {
197
210
Embedding InstVector (Dimension, 0 );
198
211
199
- const auto OpcVec = lookupVocab (I.getOpcodeName ());
200
- InstVector.scaleAndAdd (OpcVec, OpcWeight);
201
-
202
212
// FIXME: Currently lookups are string based. Use numeric Keys
203
213
// for efficiency.
204
- const auto Type = I.getType ();
205
- const auto TypeVec = getTypeEmbedding (Type);
206
- InstVector.scaleAndAdd (TypeVec, TypeWeight);
207
-
214
+ InstVector += lookupVocab (I.getOpcodeName ());
215
+ InstVector += getTypeEmbedding (I.getType ());
208
216
for (const auto &Op : I.operands ()) {
209
- const auto OperandVec = getOperandEmbedding (Op.get ());
210
- InstVector.scaleAndAdd (OperandVec, ArgWeight);
217
+ InstVector += getOperandEmbedding (Op.get ());
211
218
}
212
219
InstVecMap[&I] = InstVector;
213
220
BBVector += InstVector;
@@ -251,6 +258,46 @@ bool IR2VecVocabResult::invalidate(
251
258
return !(PAC.preservedWhenStateless ());
252
259
}
253
260
261
+ Error IR2VecVocabAnalysis::parseVocabSection (const char *Key,
262
+ const json::Value ParsedVocabValue,
263
+ ir2vec::Vocab &TargetVocab,
264
+ unsigned &Dim) {
265
+ assert (Key && " Key cannot be null" );
266
+
267
+ json::Path::Root Path (" " );
268
+ const json::Object *RootObj = ParsedVocabValue.getAsObject ();
269
+ if (!RootObj)
270
+ return createStringError (errc::invalid_argument,
271
+ " JSON root is not an object" );
272
+
273
+ const json::Value *SectionValue = RootObj->get (Key);
274
+ if (!SectionValue)
275
+ return createStringError (errc::invalid_argument,
276
+ " Missing '" + std::string (Key) +
277
+ " ' section in vocabulary file" );
278
+ if (!json::fromJSON (*SectionValue, TargetVocab, Path))
279
+ return createStringError (errc::illegal_byte_sequence,
280
+ " Unable to parse '" + std::string (Key) +
281
+ " ' section from vocabulary" );
282
+
283
+ Dim = TargetVocab.begin ()->second .size ();
284
+ if (Dim == 0 )
285
+ return createStringError (errc::illegal_byte_sequence,
286
+ " Dimension of '" + std::string (Key) +
287
+ " ' section of the vocabulary is zero" );
288
+
289
+ if (!std::all_of (TargetVocab.begin (), TargetVocab.end (),
290
+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
291
+ return Entry.second .size () == Dim;
292
+ }))
293
+ return createStringError (
294
+ errc::illegal_byte_sequence,
295
+ " All vectors in the '" + std::string (Key) +
296
+ " ' section of the vocabulary are not of the same dimension" );
297
+
298
+ return Error::success ();
299
+ };
300
+
254
301
// FIXME: Make this optional. We can avoid file reads
255
302
// by auto-generating a default vocabulary during the build time.
256
303
Error IR2VecVocabAnalysis::readVocabulary () {
@@ -259,32 +306,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259
306
return createFileError (VocabFile, BufOrError.getError ());
260
307
261
308
auto Content = BufOrError.get ()->getBuffer ();
262
- json::Path::Root Path ( " " );
309
+
263
310
Expected<json::Value> ParsedVocabValue = json::parse (Content);
264
311
if (!ParsedVocabValue)
265
312
return ParsedVocabValue.takeError ();
266
313
267
- bool Res = json::fromJSON (*ParsedVocabValue, Vocabulary, Path);
268
- if (!Res)
269
- return createStringError (errc::illegal_byte_sequence,
270
- " Unable to parse the vocabulary" );
314
+ ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
315
+ unsigned OpcodeDim, TypeDim, ArgDim;
316
+ if (auto Err = parseVocabSection (" Opcodes" , *ParsedVocabValue, OpcodeVocab,
317
+ OpcodeDim))
318
+ return Err;
271
319
272
- if (Vocabulary. empty ())
273
- return createStringError (errc::illegal_byte_sequence,
274
- " Vocabulary is empty " ) ;
320
+ if (auto Err =
321
+ parseVocabSection ( " Types " , *ParsedVocabValue, TypeVocab, TypeDim))
322
+ return Err ;
275
323
276
- unsigned Dim = Vocabulary.begin ()->second .size ();
277
- if (Dim == 0 )
324
+ if (auto Err =
325
+ parseVocabSection (" Arguments" , *ParsedVocabValue, ArgVocab, ArgDim))
326
+ return Err;
327
+
328
+ if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278
329
return createStringError (errc::illegal_byte_sequence,
279
- " Dimension of vocabulary is zero " );
330
+ " Vocabulary sections have different dimensions " );
280
331
281
- if (!std::all_of (Vocabulary.begin (), Vocabulary.end (),
282
- [Dim](const std::pair<StringRef, Embedding> &Entry) {
283
- return Entry.second .size () == Dim;
284
- }))
285
- return createStringError (
286
- errc::illegal_byte_sequence,
287
- " All vectors in the vocabulary are not of the same dimension" );
332
+ auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
333
+ for (auto &Entry : Vocab)
334
+ Entry.second *= Weight;
335
+ };
336
+ scaleVocabSection (OpcodeVocab, OpcWeight);
337
+ scaleVocabSection (TypeVocab, TypeWeight);
338
+ scaleVocabSection (ArgVocab, ArgWeight);
339
+
340
+ Vocabulary.insert (OpcodeVocab.begin (), OpcodeVocab.end ());
341
+ Vocabulary.insert (TypeVocab.begin (), TypeVocab.end ());
342
+ Vocabulary.insert (ArgVocab.begin (), ArgVocab.end ());
288
343
289
344
return Error::success ();
290
345
}
@@ -304,7 +359,7 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304
359
IR2VecVocabAnalysis::Result
305
360
IR2VecVocabAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
306
361
auto Ctx = &M.getContext ();
307
- // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
362
+
308
363
// If vocabulary is already populated by the constructor, use it.
309
364
if (!Vocabulary.empty ())
310
365
return IR2VecVocabResult (std::move (Vocabulary));
@@ -323,16 +378,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323
378
}
324
379
325
380
// ==----------------------------------------------------------------------===//
326
- // IR2VecPrinterPass
381
+ // Printer Passes
327
382
// ===----------------------------------------------------------------------===//
328
383
329
- void IR2VecPrinterPass::printVector (const Embedding &Vec) const {
330
- OS << " [" ;
331
- for (const auto &Elem : Vec)
332
- OS << " " << format (" %.2f" , Elem) << " " ;
333
- OS << " ]\n " ;
334
- }
335
-
336
384
PreservedAnalyses IR2VecPrinterPass::run (Module &M,
337
385
ModuleAnalysisManager &MAM) {
338
386
auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
@@ -353,15 +401,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353
401
354
402
OS << " IR2Vec embeddings for function " << F.getName () << " :\n " ;
355
403
OS << " Function vector: " ;
356
- printVector ( Emb->getFunctionVector ());
404
+ Emb->getFunctionVector (). print (OS );
357
405
358
406
OS << " Basic block vectors:\n " ;
359
407
const auto &BBMap = Emb->getBBVecMap ();
360
408
for (const BasicBlock &BB : F) {
361
409
auto It = BBMap.find (&BB);
362
410
if (It != BBMap.end ()) {
363
411
OS << " Basic block: " << BB.getName () << " :\n " ;
364
- printVector ( It->second );
412
+ It->second . print (OS );
365
413
}
366
414
}
367
415
@@ -373,10 +421,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373
421
if (It != InstMap.end ()) {
374
422
OS << " Instruction: " ;
375
423
I.print (OS);
376
- printVector ( It->second );
424
+ It->second . print (OS );
377
425
}
378
426
}
379
427
}
380
428
}
381
429
return PreservedAnalyses::all ();
382
430
}
431
+
432
+ PreservedAnalyses IR2VecVocabPrinterPass::run (Module &M,
433
+ ModuleAnalysisManager &MAM) {
434
+ auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
435
+ assert (IR2VecVocabResult.isValid () && " IR2Vec Vocabulary is invalid" );
436
+
437
+ auto Vocab = IR2VecVocabResult.getVocabulary ();
438
+ for (const auto &Entry : Vocab) {
439
+ OS << " Key: " << Entry.first << " : " ;
440
+ Entry.second .print (OS);
441
+ }
442
+
443
+ return PreservedAnalyses::all ();
444
+ }
0 commit comments