Skip to content

Commit 065de25

Browse files
committed
Add mypy checks along with pre-commit hook
1 parent e7ec855 commit 065de25

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,9 @@ repos:
2121
hooks:
2222
- id: isort
2323
args: ["--profile", "black", "--filter-files"]
24+
25+
- repo: https://github.com/pre-commit/mirrors-mypy
26+
rev: v0.910
27+
hooks:
28+
- id: mypy
29+
exclude: ^tests/

message_db/client.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Dict, List
3+
from typing import Any, Dict, List
44
from uuid import uuid4
55

66
from psycopg2.extensions import connection
@@ -13,7 +13,7 @@ class MessageDB:
1313
"""This class provides a Python interface to all MessageDB commands."""
1414

1515
@classmethod
16-
def from_url(cls, url: str, **kwargs: str) -> MessageDB:
16+
def from_url(cls, url: str, **kwargs: Any) -> MessageDB:
1717
"""Returns a MessageDB client object configured from the given URL.
1818
1919
The general form of a connection string is:
@@ -51,10 +51,10 @@ def _write(
5151
connection: connection,
5252
stream_name: str,
5353
message_type: str,
54-
data: Dict,
55-
metadata: Dict = None,
54+
data: Dict[str, Any],
55+
metadata: Dict[str, Any] = None,
5656
expected_version: int = None,
57-
):
57+
) -> int:
5858
try:
5959
with connection.cursor(cursor_factory=RealDictCursor) as cursor:
6060
cursor.execute(
@@ -85,7 +85,7 @@ def write(
8585
data: Dict,
8686
metadata: Dict = None,
8787
expected_version: int = None,
88-
):
88+
) -> int:
8989
conn = self.connection_pool.get_connection()
9090

9191
try:
@@ -98,32 +98,34 @@ def write(
9898

9999
return position
100100

101-
def write_batch(self, stream_name, data, expected_version: int = None) -> None:
101+
def write_batch(self, stream_name, data, expected_version: int = None) -> int:
102102
conn = self.connection_pool.get_connection()
103103

104104
try:
105105
with conn:
106106
for record in data:
107-
expected_version = self._write(
107+
position = self._write(
108108
conn,
109109
stream_name,
110110
record[0],
111111
record[1],
112112
metadata=record[2] if len(record) > 2 else None,
113113
expected_version=expected_version,
114114
)
115+
116+
expected_version = position
115117
finally:
116118
self.connection_pool.release(conn)
117119

118-
return expected_version
120+
return position
119121

120122
def read(
121123
self,
122124
stream_name: str,
123125
sql: str = None,
124126
position: int = 0,
125127
no_of_messages: int = 1000,
126-
) -> List[Dict]:
128+
) -> List[Dict[str, Any]]:
127129
conn = self.connection_pool.get_connection()
128130
cursor = conn.cursor(cursor_factory=RealDictCursor)
129131

@@ -149,7 +151,9 @@ def read(
149151

150152
return messages
151153

152-
def read_stream(self, stream_name, position=0, no_of_messages=1000) -> List[Dict]:
154+
def read_stream(
155+
self, stream_name: str, position: int = 0, no_of_messages: int = 1000
156+
) -> List[Dict[str, Any]]:
153157
if "-" not in stream_name:
154158
raise ValueError(f"{stream_name} is not a stream")
155159

@@ -160,8 +164,8 @@ def read_stream(self, stream_name, position=0, no_of_messages=1000) -> List[Dict
160164
)
161165

162166
def read_category(
163-
self, category_name, position=0, no_of_messages=1000
164-
) -> List[Dict]:
167+
self, category_name: str, position: int = 0, no_of_messages: int = 1000
168+
) -> List[Dict[str, Any]]:
165169
if "-" in category_name:
166170
raise ValueError(f"{category_name} is not a category")
167171

@@ -171,7 +175,7 @@ def read_category(
171175
category_name, sql=sql, position=position, no_of_messages=no_of_messages
172176
)
173177

174-
def read_last_message(self, stream_name):
178+
def read_last_message(self, stream_name) -> Dict[str, Any]:
175179
conn = self.connection_pool.get_connection()
176180
cursor = conn.cursor(cursor_factory=RealDictCursor)
177181

message_db/connection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
from psycopg2.extensions import connection
46
from psycopg2.pool import SimpleConnectionPool
57

68

79
class ConnectionPool:
810
@classmethod
911
def from_url(
10-
cls, *args: str, max_connections: int = 100, **kwargs: str
12+
cls, *args: str, max_connections: int = 100, **kwargs: Any
1113
) -> ConnectionPool:
1214
"""Return a Connection Pool configured from the given URL.
1315
@@ -21,7 +23,7 @@ def from_url(
2123
"""
2224
return cls(*args, max_connections=max_connections, **kwargs)
2325

24-
def __init__(self, *args: str, max_connections: int = 100, **kwargs: str) -> None:
26+
def __init__(self, *args: str, max_connections: int = 100, **kwargs: Any) -> None:
2527
if not isinstance(max_connections, int) or max_connections < 0:
2628
raise ValueError('"max_connections" must be a positive integer')
2729

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pre-commit = "^2.15.0"
1919
isort = "^5.10.1"
2020
autoflake = "^1.4"
2121
pytest-cov = "^3.0.0"
22+
mypy = "^0.910"
23+
types-psycopg2 = "^2.9.1"
2224

2325
[build-system]
2426
requires = ["poetry-core>=1.0.0"]

0 commit comments

Comments
 (0)