Skip to content

added partial test setup #1659 #1663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions integrations/mongodb_atlas/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg)](https://pypi.org/project/mongodb-atlas-haystack)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mongodb-atlas-haystack.svg)](https://pypi.org/project/mongodb-atlas-haystack)

-----
---

**Table of Contents**

Expand All @@ -20,24 +20,29 @@ pip install mongodb-atlas-haystack
## Contributing

`hatch` is the best way to interact with this project, to install it:

```sh
pip install hatch
```

To run the linters `ruff` and `mypy`:

```
hatch run lint:all
```

To run all the tests:

```
hatch run test
```

Note: you need your own MongoDB Atlas account to run the tests: you can make one here:
Note: you need your own MongoDB Atlas account to run the tests: you can make one here:
https://www.mongodb.com/cloud/atlas/register. Once you have it, export the connection string
to the env var `MONGO_CONNECTION_STRING`. If you forget to do so, all the tests will be skipped.

Note: before the tests run a script creates a mongo database, atest collection, vector search indexes, and some sample documents in MongoDB for integration tests, if they don't exist.

## License

`mongodb-atlas-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license.
7 changes: 3 additions & 4 deletions integrations/mongodb_atlas/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ dependencies = [
]

[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x"
test = ["python tests/mongo_atlas_setup.py", "pytest {args:tests}"]
test-cov = ["python tests/mongo_atlas_setup.py", "coverage run -m pytest {args:tests}"]
test-cov-retry = ["python tests/mongo_atlas_setup.py", "test-cov --reruns 3 --reruns-delay 30 -x"]
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]
cov-retry = ["test-cov-retry", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]

[tool.hatch.envs.lint]
installer = "uv"
detached = true
Expand Down
98 changes: 98 additions & 0 deletions integrations/mongodb_atlas/tests/mongo_atlas_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import logging
from pymongo import MongoClient, TEXT
from pymongo.operations import SearchIndexModel
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from requests.auth import HTTPDigestAuth

# Logging for visibility
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
embedding_dimension = 768


DEFAULT_DOCS = [
{"content": "Document A", "embedding": [-1] + [0.2] * (embedding_dimension - 1)},
{"content": "Document B", "embedding": [0] + [0.15] * (embedding_dimension - 1)},
{"content": "Document C", "embedding": [0.1] * embedding_dimension},
]

VECTOR_INDEXES = [
{
"name": "cosine_index",
"fields": [
{"type": "vector", "path": "embedding", "numDimensions": 768, "similarity": "cosine"},
{"type": "filter", "path": "content"},
],
},
{
"name": "dotProduct_index",
"fields": [
{"type": "vector", "path": "embedding", "numDimensions": 768, "similarity": "dotProduct"},
],
},
{
"name": "euclidean_index",
"fields": [
{"type": "vector", "path": "embedding", "numDimensions": 768, "similarity": "euclidean"},
],
},
]

FULL_TEXT_INDEX = {
"name": "full_text_index",
"definition": {
"mappings": {"dynamic": True},
},
}


def get_collection(client, db_name, coll_name):
db = client[db_name]
if coll_name not in db.list_collection_names():
logger.info(f"Creating collection '{coll_name}' in DB '{db_name}'")
db.create_collection(coll_name)
return db[coll_name]


def setup_test_embeddings_collection(client):
collection = get_collection(client, "haystack_integration_test", "test_embeddings_collection")

if collection.count_documents({}) == 0:
collection.insert_many(DEFAULT_DOCS)

existing_index_names = {idx["name"] for idx in collection.list_search_indexes()}

for index in VECTOR_INDEXES:
if index["name"] not in existing_index_names:
logger.info(f"Creating vector search index: {index['name']}")
model = SearchIndexModel(definition={"fields": index["fields"]}, name=index["name"], type="vectorSearch")
collection.create_search_index(model=model)


def setup_test_full_text_search_collection(client):
collection = get_collection(client, "haystack_integration_test", "test_full_text_search_collection")

existing_index_names = {idx["name"] for idx in collection.list_search_indexes()}

if FULL_TEXT_INDEX["name"] not in existing_index_names:
logger.info(f"Creating full text search index: {FULL_TEXT_INDEX['name']}")
model = SearchIndexModel(definition=FULL_TEXT_INDEX["definition"], name=FULL_TEXT_INDEX["name"], type="search")
collection.create_search_index(model=model)


def setup_mongodb_for_tests():
connection_str = os.environ.get("MONGO_CONNECTION_STRING")
if not connection_str:
logger.warning("Skipping MongoDB Atlas setup: no MONGO_CONNECTION_STRING")
return

client = MongoClient(connection_str)

setup_test_embeddings_collection(client)
setup_test_full_text_search_collection(client)


if __name__ == "__main__":
setup_mongodb_for_tests()
10 changes: 6 additions & 4 deletions integrations/mongodb_atlas/tests/test_embedding_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore

embedding_dimension = 768


@pytest.mark.skipif(
not os.environ.get("MONGO_CONNECTION_STRING"),
Expand All @@ -23,7 +25,7 @@ def test_embedding_retrieval_cosine_similarity(self):
vector_search_index="cosine_index",
full_text_search_index="full_text_index",
)
query_embedding = [0.1] * 768
query_embedding = [0.1] * embedding_dimension
results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
assert len(results) == 2
assert results[0].content == "Document C"
Expand All @@ -37,7 +39,7 @@ def test_embedding_retrieval_dot_product(self):
vector_search_index="dotProduct_index",
full_text_search_index="full_text_index",
)
query_embedding = [0.1] * 768
query_embedding = [0.1] * embedding_dimension
results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
assert len(results) == 2
assert results[0].content == "Document A"
Expand All @@ -51,7 +53,7 @@ def test_embedding_retrieval_euclidean(self):
vector_search_index="euclidean_index",
full_text_search_index="full_text_index",
)
query_embedding = [0.1] * 768
query_embedding = [0.1] * embedding_dimension
results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={})
assert len(results) == 2
assert results[0].content == "Document C"
Expand Down Expand Up @@ -105,7 +107,7 @@ def test_embedding_retrieval_with_filters(self):
vector_search_index="cosine_index",
full_text_search_index="full_text_index",
)
query_embedding = [0.1] * 768
query_embedding = [0.1] * embedding_dimension
filters = {"field": "content", "operator": "!=", "value": "Document A"}
results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters=filters)
assert len(results) == 2
Expand Down
8 changes: 4 additions & 4 deletions integrations/mongodb_atlas/tests/test_fulltext_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

def get_document_store():
return MongoDBAtlasDocumentStore(
mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING_2"),
database_name="haystack_test",
collection_name="test_collection",
mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING"),
database_name="haystack_integration_test",
collection_name="test_full_text_search_collection",
vector_search_index="cosine_index",
full_text_search_index="full_text_index",
)


@pytest.mark.skipif(
not os.environ.get("MONGO_CONNECTION_STRING_2"),
not os.environ.get("MONGO_CONNECTION_STRING"),
reason="No MongoDB Atlas connection string provided",
)
@pytest.mark.integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.store._connection_async.close()


@pytest.mark.skipif(
not os.environ.get("MONGO_CONNECTION_STRING_2"), reason="No MongoDBAtlas connection string provided"
)
@pytest.mark.skipif(not os.environ.get("MONGO_CONNECTION_STRING"), reason="No MongoDBAtlas connection string provided")
@pytest.mark.integration
class TestFullTextRetrieval:

@pytest.fixture
async def document_store(self) -> MongoDBAtlasDocumentStore:
async with AsyncDocumentStoreContext(
mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING_2"),
database_name="haystack_test",
collection_name="test_collection",
mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING"),
database_name="haystack_integration_test",
collection_name="test_full_text_search_collection",
vector_search_index="cosine_index",
full_text_search_index="full_text_index",
) as store:
Expand Down
Loading