Skip to content

Commit b6c26fa

Browse files
committed
Add enhanced + missing test cases
1 parent 593c717 commit b6c26fa

File tree

6 files changed

+109
-8
lines changed

6 files changed

+109
-8
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,8 @@ mdb.write("your_stream_name", "your_message_type", {"data": "value"})
3333
# Read a message
3434
message = mdb.read_last_message("your_stream_name")
3535
print(message)
36-
```
36+
```
37+
38+
## License
39+
40+
The Postgres Message Store is released under the [MIT License](https://github.com/subhashb/message-db-py/blob/main/LICENSE).

src/message_db/client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _write(
5555
data: Dict[str, Any],
5656
metadata: Dict[str, Any] | None = None,
5757
expected_version: int | None = None,
58-
) -> int:
58+
) -> int | None:
5959
try:
6060
with connection.cursor(cursor_factory=RealDictCursor) as cursor:
6161
cursor.execute(
@@ -74,8 +74,6 @@ def _write(
7474
)
7575

7676
result = cursor.fetchone()
77-
if result is None:
78-
raise ValueError("No result returned from the database operation.")
7977
except Exception as exc:
8078
raise ValueError(
8179
f"{getattr(exc, 'pgcode')}-{getattr(exc, 'pgerror').splitlines()[0]}"
@@ -90,7 +88,7 @@ def write(
9088
data: Dict,
9189
metadata: Dict | None = None,
9290
expected_version: int | None = None,
93-
) -> int:
91+
) -> int | None:
9492
conn = self.connection_pool.get_connection()
9593

9694
try:
@@ -105,7 +103,7 @@ def write(
105103

106104
def write_batch(
107105
self, stream_name, data, expected_version: int | None = None
108-
) -> int:
106+
) -> int | None:
109107
conn = self.connection_pool.get_connection()
110108

111109
try:

tests/test_client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from unittest.mock import patch
2+
3+
from psycopg2 import OperationalError
4+
15
from message_db.client import MessageDB
26
from message_db.connection import ConnectionPool
37

@@ -18,3 +22,14 @@ def test_client_construction_from_args(self):
1822

1923
assert isinstance(store, MessageDB)
2024
assert isinstance(store.connection_pool, ConnectionPool)
25+
26+
def test_reconnection_after_failure(self, client):
27+
with patch(
28+
"psycopg2.connect", side_effect=[OperationalError("Connection lost"), None]
29+
):
30+
try:
31+
client.read("testStream-123")
32+
except OperationalError:
33+
# Retry reading after catching the first failure
34+
messages = client.read("testStream-123")
35+
assert messages is not None

tests/test_connection.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
from unittest.mock import patch
2+
13
import pytest
24
from psycopg2 import OperationalError, ProgrammingError
35
from psycopg2.extensions import TRANSACTION_STATUS_ACTIVE
6+
from psycopg2.pool import PoolError
47

58
from message_db.connection import ConnectionPool
69

10+
CONNECT_URL = "postgresql://message_store@localhost:5432/message_store"
11+
712

813
def test_constructing_connection_pool_from_url(pool):
914
assert pool is not None
@@ -19,14 +24,14 @@ def test_error_on_invalid_url():
1924

2025
def test_error_on_invalid_role():
2126
with pytest.raises(OperationalError) as exc:
22-
ConnectionPool("postgresql://foo@localhost:5432/postgres")
27+
ConnectionPool("postgresql://foo@localhost:5432/message_store")
2328

2429
assert 'role "foo" does not exist' in exc.value.args[0]
2530

2631

2732
def test_error_on_invalid_max_connections():
2833
with pytest.raises(ValueError):
29-
ConnectionPool("postgresql://foo@localhost:5432/postgres", max_connections=-1)
34+
ConnectionPool(CONNECT_URL, max_connections=-1)
3035

3136

3237
def test_retrieving_connection_from_pool(pool):
@@ -46,3 +51,59 @@ def test_releasing_a_connection(pool):
4651

4752
pool.release(conn)
4853
assert len(pool._connection_pool._used) == used_count - 1
54+
55+
56+
def test_connection_pool_initialization():
57+
"""Test the connection pool initialization with different parameters."""
58+
pool = ConnectionPool(CONNECT_URL, max_connections=10)
59+
assert pool.max_connections == 10
60+
assert pool._connection_pool.maxconn == 10
61+
62+
63+
def test_multiple_simultaneous_connections():
64+
"""Test handling multiple simultaneous connections within max limit."""
65+
pool = ConnectionPool(CONNECT_URL, max_connections=5)
66+
connections = [pool.get_connection() for _ in range(5)]
67+
assert len(pool._connection_pool._used) == 5
68+
# Clean up
69+
for conn in connections:
70+
pool.release(conn)
71+
72+
73+
def test_exceeding_connection_limit():
74+
"""Test behavior when more connections than max are requested."""
75+
pool = ConnectionPool(CONNECT_URL, max_connections=2)
76+
conn1 = pool.get_connection()
77+
conn2 = pool.get_connection()
78+
with pytest.raises(Exception) as exc:
79+
pool.get_connection()
80+
assert "connection pool exhausted" in str(exc.value)
81+
# Clean up
82+
pool.release(conn1)
83+
pool.release(conn2)
84+
85+
86+
def test_release_invalid_connection():
87+
"""Test releasing a connection that was never retrieved."""
88+
pool = ConnectionPool(CONNECT_URL, max_connections=5)
89+
fake_conn = None # Simulating an invalid connection
90+
with pytest.raises(PoolError) as exc:
91+
pool.release(fake_conn)
92+
assert "trying to put unkeyed connection" in str(exc.value)
93+
94+
95+
def test_close_all_connections():
96+
"""Test closing all connections."""
97+
pool = ConnectionPool(CONNECT_URL, max_connections=5)
98+
connections = [pool.get_connection() for _ in range(5)]
99+
pool.closeall()
100+
assert all(conn.closed for conn in connections)
101+
102+
103+
def test_network_error_on_connection():
104+
"""Test handling network errors during connection retrieval."""
105+
with patch("psycopg2.connect", side_effect=OSError("Network Error")):
106+
with pytest.raises(OSError) as exc:
107+
pool = ConnectionPool(CONNECT_URL, max_connections=5)
108+
pool.get_connection()
109+
assert "Network Error" in str(exc.value)

tests/test_read.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def test_read_stream_last_message(self, client):
3434
assert message["position"] == 4
3535
assert message["data"] == {"foo": "bar4"}
3636

37+
def test_last_message_empty_stream(self, client):
38+
result = client.read_last_message("emptyStream-123")
39+
assert result is None
40+
41+
def test_read_from_non_existing_stream(self, client):
42+
messages = client.read("nonExistingStream-123")
43+
assert messages == []
44+
3745
def test_read_specific_stream_message(self, client):
3846
for i in range(5):
3947
client.write("testStream-123", "Event1", {"foo": f"bar{i}"})

tests/test_write.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,18 @@ def test_that_write_fails_on_expected_version_mismatch(self, client):
6868
exc.value.args[0]
6969
== "P0001-ERROR: Wrong expected version: 1 (Stream: testStream-123, Stream Version: 2)"
7070
)
71+
72+
def test_concurrent_writes(self, client):
73+
from threading import Thread
74+
75+
def write_msg():
76+
client.write("concurrentStream-123", "Event", {"thread": "value"})
77+
78+
threads = [Thread(target=write_msg) for _ in range(10)]
79+
for thread in threads:
80+
thread.start()
81+
for thread in threads:
82+
thread.join()
83+
84+
messages = client.read("concurrentStream-123")
85+
assert len(messages) == 10

0 commit comments

Comments
 (0)