1
1
from __future__ import annotations
2
2
3
- from typing import Dict , List
3
+ from typing import Any , Dict , List
4
4
from uuid import uuid4
5
5
6
6
from psycopg2 .extensions import connection
@@ -13,7 +13,7 @@ class MessageDB:
13
13
"""This class provides a Python interface to all MessageDB commands."""
14
14
15
15
@classmethod
16
- def from_url (cls , url : str , ** kwargs : str ) -> MessageDB :
16
+ def from_url (cls , url : str , ** kwargs : Any ) -> MessageDB :
17
17
"""Returns a MessageDB client object configured from the given URL.
18
18
19
19
The general form of a connection string is:
@@ -51,10 +51,10 @@ def _write(
51
51
connection : connection ,
52
52
stream_name : str ,
53
53
message_type : str ,
54
- data : Dict ,
55
- metadata : Dict = None ,
54
+ data : Dict [ str , Any ] ,
55
+ metadata : Dict [ str , Any ] = None ,
56
56
expected_version : int = None ,
57
- ):
57
+ ) -> int :
58
58
try :
59
59
with connection .cursor (cursor_factory = RealDictCursor ) as cursor :
60
60
cursor .execute (
@@ -85,7 +85,7 @@ def write(
85
85
data : Dict ,
86
86
metadata : Dict = None ,
87
87
expected_version : int = None ,
88
- ):
88
+ ) -> int :
89
89
conn = self .connection_pool .get_connection ()
90
90
91
91
try :
@@ -98,32 +98,34 @@ def write(
98
98
99
99
return position
100
100
101
- def write_batch (self , stream_name , data , expected_version : int = None ) -> None :
101
+ def write_batch (self , stream_name , data , expected_version : int = None ) -> int :
102
102
conn = self .connection_pool .get_connection ()
103
103
104
104
try :
105
105
with conn :
106
106
for record in data :
107
- expected_version = self ._write (
107
+ position = self ._write (
108
108
conn ,
109
109
stream_name ,
110
110
record [0 ],
111
111
record [1 ],
112
112
metadata = record [2 ] if len (record ) > 2 else None ,
113
113
expected_version = expected_version ,
114
114
)
115
+
116
+ expected_version = position
115
117
finally :
116
118
self .connection_pool .release (conn )
117
119
118
- return expected_version
120
+ return position
119
121
120
122
def read (
121
123
self ,
122
124
stream_name : str ,
123
125
sql : str = None ,
124
126
position : int = 0 ,
125
127
no_of_messages : int = 1000 ,
126
- ) -> List [Dict ]:
128
+ ) -> List [Dict [ str , Any ] ]:
127
129
conn = self .connection_pool .get_connection ()
128
130
cursor = conn .cursor (cursor_factory = RealDictCursor )
129
131
@@ -149,7 +151,9 @@ def read(
149
151
150
152
return messages
151
153
152
- def read_stream (self , stream_name , position = 0 , no_of_messages = 1000 ) -> List [Dict ]:
154
+ def read_stream (
155
+ self , stream_name : str , position : int = 0 , no_of_messages : int = 1000
156
+ ) -> List [Dict [str , Any ]]:
153
157
if "-" not in stream_name :
154
158
raise ValueError (f"{ stream_name } is not a stream" )
155
159
@@ -160,8 +164,8 @@ def read_stream(self, stream_name, position=0, no_of_messages=1000) -> List[Dict
160
164
)
161
165
162
166
def read_category (
163
- self , category_name , position = 0 , no_of_messages = 1000
164
- ) -> List [Dict ]:
167
+ self , category_name : str , position : int = 0 , no_of_messages : int = 1000
168
+ ) -> List [Dict [ str , Any ] ]:
165
169
if "-" in category_name :
166
170
raise ValueError (f"{ category_name } is not a category" )
167
171
@@ -171,7 +175,7 @@ def read_category(
171
175
category_name , sql = sql , position = position , no_of_messages = no_of_messages
172
176
)
173
177
174
- def read_last_message (self , stream_name ):
178
+ def read_last_message (self , stream_name ) -> Dict [ str , Any ] :
175
179
conn = self .connection_pool .get_connection ()
176
180
cursor = conn .cursor (cursor_factory = RealDictCursor )
177
181
0 commit comments