diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 480b834077b86..f6c40d36f8026 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -106,6 +106,7 @@ struct Embedding { const std::vector &getData() const { return Data; } /// Arithmetic operators + Embedding operator+(const Embedding &RHS) const; Embedding &operator+=(const Embedding &RHS); Embedding &operator-=(const Embedding &RHS); Embedding &operator*=(double Factor); diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 27cc2a4109879..d5d27db8bd2bf 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -71,6 +71,14 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out, // Embedding //===----------------------------------------------------------------------===// +Embedding Embedding::operator+(const Embedding &RHS) const { + assert(this->size() == RHS.size() && "Vectors must have the same dimension"); + Embedding Result(*this); + std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(), + std::plus()); + return Result; +} + Embedding &Embedding::operator+=(const Embedding &RHS) { assert(this->size() == RHS.size() && "Vectors must have the same dimension"); std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 33ac16828eb6c..50eb7f73c6f50 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) { } } +TEST(EmbeddingTest, AddVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {0.5, 1.5, -1.0}; + + Embedding E3 = E1 + E2; + EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0)); + + // Check that E1 and E2 are unchanged + EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0)); + EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); +} + TEST(EmbeddingTest, AddVectors) { Embedding E1 = {1.0, 2.0, 3.0}; Embedding E2 = {0.5, 1.5, -1.0}; @@ -180,6 +192,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) { EXPECT_DEATH(E[4] = 4.0, "Index out of bounds"); } +TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0}; + Embedding E2 = {1.0}; + EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension"); +} + TEST(EmbeddingTest, MismatchedDimensionsAddVectors) { Embedding E1 = {1.0, 2.0}; Embedding E2 = {1.0};