Skip to content

Commit cbd2c6e

Browse files
committed
Overloading operator+ for Embeddngs
1 parent d05856c commit cbd2c6e

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ struct Embedding {
106106
const std::vector<double> &getData() const { return Data; }
107107

108108
/// Arithmetic operators
109+
Embedding operator+(const Embedding &RHS) const;
109110
Embedding &operator+=(const Embedding &RHS);
110111
Embedding &operator-=(const Embedding &RHS);
111112
Embedding &operator*=(double Factor);

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
7171
// Embedding
7272
//===----------------------------------------------------------------------===//
7373

74+
Embedding Embedding::operator+(const Embedding &RHS) const {
75+
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
76+
Embedding Result(*this);
77+
std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
78+
std::plus<double>());
79+
return Result;
80+
}
81+
7482
Embedding &Embedding::operator+=(const Embedding &RHS) {
7583
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
7684
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
109109
}
110110
}
111111

112+
TEST(EmbeddingTest, AddVectorsOutOfPlace) {
113+
Embedding E1 = {1.0, 2.0, 3.0};
114+
Embedding E2 = {0.5, 1.5, -1.0};
115+
116+
Embedding E3 = E1 + E2;
117+
EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
118+
119+
// Check that E1 and E2 are unchanged
120+
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
121+
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
122+
}
123+
112124
TEST(EmbeddingTest, AddVectors) {
113125
Embedding E1 = {1.0, 2.0, 3.0};
114126
Embedding E2 = {0.5, 1.5, -1.0};
@@ -180,6 +192,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) {
180192
EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
181193
}
182194

195+
TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
196+
Embedding E1 = {1.0, 2.0};
197+
Embedding E2 = {1.0};
198+
EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
199+
}
200+
183201
TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
184202
Embedding E1 = {1.0, 2.0};
185203
Embedding E2 = {1.0};

0 commit comments

Comments
 (0)