Skip to content

Commit d851fda

Browse files
committed
[PROGRESS]
1 parent ab3627c commit d851fda

File tree

7 files changed

+152
-58
lines changed

7 files changed

+152
-58
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bidict==0.13.1

tests/unit/test_subject.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import unittest
2+
from unittest import mock
3+
4+
from yamq import subject, observer
5+
6+
7+
class TestSubjectSTOMP(unittest.TestCase):
8+
9+
def setUp(self):
10+
self.loop = unittest.mock.Mock()
11+
self.transport = unittest.mock.Mock()
12+
13+
def test_object_creation(self):
14+
obj_1 = subject.SubjectSTOMP(name="test_queue_1", loop=self.loop)
15+
16+
self.assertDictEqual(obj_1.observers, {})
17+
obj_1.delete()
18+
19+
def test_subjects_are_unique(self):
20+
self.assertDictEqual(subject.SubjectSTOMP._objects, {})
21+
22+
obj_1 = subject.SubjectSTOMP(name="test_queue_1", loop=self.loop)
23+
24+
self.assertDictEqual(subject.SubjectSTOMP._objects, {obj_1.name: obj_1})
25+
26+
obj_2 = subject.SubjectSTOMP(name="test_queue_2", loop=self.loop)
27+
28+
self.assertDictEqual(subject.SubjectSTOMP._objects, {
29+
obj_1.name: obj_1,
30+
obj_2.name: obj_2
31+
})
32+
33+
obj_3 = subject.SubjectSTOMP(name=obj_1.name, loop=self.loop)
34+
35+
self.assertIs(obj_3, obj_1)
36+
self.assertDictEqual(subject.SubjectSTOMP._objects, {
37+
obj_1.name: obj_1,
38+
obj_2.name: obj_2
39+
})
40+
41+
obj_1.delete()
42+
obj_2.delete()
43+
obj_3.delete()
44+
45+
def test_subscribers_are_preserved(self):
46+
observer_1 = observer.ObserverSTOMP(self.loop, self.transport)
47+
observer_2 = observer.ObserverSTOMP(self.loop, self.transport)
48+
49+
obj_1 = subject.SubjectSTOMP(name="test_queue_1", loop=self.loop)
50+
obj_1.subscribe(observer_1)
51+
obj_1.subscribe(observer_2)
52+
53+
self.assertDictEqual(obj_1.observers, {
54+
observer_1: "auto",
55+
observer_2: "auto"
56+
})
57+
58+
obj_2 = subject.SubjectSTOMP(name="test_queue_2", loop=self.loop)
59+
obj_2.subscribe(observer_1)
60+
61+
self.assertDictEqual(obj_2.observers, {observer_1: "auto"})
62+
63+
import pdb; pdb.set_trace() # XXX BREAKPOINT
64+
obj_3 = subject.SubjectSTOMP(name="test_queue_1", loop=self.loop)
65+
self.assertDictEqual(obj_3.observers, {
66+
observer_1: "auto",
67+
observer_2: "auto"
68+
})

yamq/message.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from itertools import count
22

3-
from yamq import utils
4-
53

64
class Message():
75
"""Don't forget to call obj.delete() to delete object from the memory!"""
86

97
_messages = {}
108
_last_id = count()
119

12-
def __init__(self, message, content_type):
10+
def __init__(self, message, content_type='text/plain'):
1311
self._id = next(Message._last_id)
1412
Message._messages[self._id] = self
1513
self.message = message

yamq/observer.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import stomp
1+
from bidict import bidict
2+
3+
from yamq import stomp
24

35

46
class Observer:
@@ -16,16 +18,43 @@ class ObserverSTOMP:
1618
def __init__(self, loop, transport):
1719
self.loop = loop
1820
self.transport = transport
21+
self.subscriptions = bidict() # subscription_id: subject_obj
22+
23+
def subscribe(self, subject, ack_type, subscription_id):
24+
import pdb; pdb.set_trace() # XXX BREAKPOINT
25+
subject.subscribe(self, ack_type)
26+
self.subscriptions[subscription_id] = subject
27+
28+
def unsubscribe(self, subscription_id):
29+
subject = self.subscriptions.get(subscription_id)
30+
subject.unsubscribe(self)
31+
32+
def update_auto(self, message_frame):
33+
import pdb; pdb.set_trace() # XXX BREAKPOINT
34+
self.transport(stomp.dumps(message_frame))
1935

20-
async def ack_auto(self, frame):
36+
def update_client(self, message_frame):
2137
pass
2238

23-
async def update(self, subject, message, message_id, subscription_id, ack):
39+
def update_client_individual(self, mesasge_frame):
40+
pass
41+
42+
def delete(self):
43+
for _, subject in self.subscriptions:
44+
subject.unsubscribe(self)
45+
46+
def update(self, subject, message, ack):
47+
import pdb; pdb.set_trace() # XXX BREAKPOINT
48+
subscrption_id = self.subscriptions.inv[subject]
2449
frame = stomp.MessageFrame(
25-
subject,
26-
message_id,
27-
subscription_id,
28-
message,
29-
ack
50+
message=message,
51+
destination=subject.name,
52+
subscription_id=subscription_id,
53+
ack=ack
3054
)
31-
pass
55+
if ack == 'client-individual':
56+
self.update_client_inidividual(frame)
57+
elif ack == 'client':
58+
self.update_client(frame)
59+
else:
60+
self.update_auto(frame)

yamq/server.py

+13-36
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,12 @@
77
import observer
88
import message
99

10-
SUBJECTS = {}
11-
#MESSAGES = {}
12-
13-
#MESSAGE_ID = -1
14-
15-
16-
#def create_message(message):
17-
# MESSAGE_ID += 1
18-
# MESSAGES[MESSAGE_ID] = message
19-
# return MESSAGE_ID
20-
2110

2211
class STOMP_Server(asyncio.Protocol):
2312
"""Minimum supported version 1.2"""
2413

2514
def connection_made(self, transport):
2615
self.transport = transport
27-
self.subscription = {}
2816
self.observer = observer.ObserverSTOMP(event_loop, self.transport)
2917

3018
def data_received(self, data):
@@ -64,31 +52,18 @@ def connect(self):
6452
pass
6553

6654
def send(self, destination, raw_message, **headers):
67-
user_subject = SUBJECTS.get(destination)
68-
if not user_subject:
69-
user_subject = subject.SubjectSTOMP(
70-
name=destination, loop=event_loop
71-
)
72-
SUBJECTS[destination] = subject
55+
import pdb; pdb.set_trace() # XXX BREAKPOINT
56+
user_subject = subject.SubjectSTOMP(name=destination, loop=event_loop)
7357
message_obj = message.Message(raw_message)
74-
self.loop.call_soon(user_subject.notify(message_obj))
58+
user_subject.notify(message_obj)
7559

76-
def subscribe(self, subject_id, destination, ack="auto", **headers):
77-
user_subject = SUBJECTS.get(destination)
78-
if not user_subject:
79-
user_subject = subject.SubjectSTOMP(
80-
name=destination, loop=event_loop
81-
)
82-
SUBJECTS[destination] = subject
83-
self.subscription[subject_id] = subject
84-
user_subject.subscribe(self.observer, subscrption_id, ack)
60+
def subscribe(self, subscription_id, destination, ack="auto", **headers):
61+
import pdb; pdb.set_trace() # XXX BREAKPOINT
62+
user_subject = subject.SubjectSTOMP(destination, loop=event_loop)
63+
self.observer.subscribe(user_subject, ack, subscription_id)
8564

86-
def unsubscribe(self, subject_id, **headers):
87-
user_subject = self.subscription.get(subject_id)
88-
if not user_subject:
89-
# TODO: Return Error frame here
90-
pass
91-
user_subject.unsubscribe(self.observer)
65+
def unsubscribe(self, subscription_id, **headers):
66+
self.objserver.unsubscribe(subscription_id)
9267

9368
def ack(self, id, **headers):
9469
pass
@@ -97,8 +72,10 @@ def nack(self, id, **headers):
9772
pass
9873

9974
def disconnect(self, receipt, **headers):
100-
for user_subject in self.subscription.values():
101-
user_subject.unsubscribe(self.observer)
75+
self.observer.delete()
76+
77+
def connection_lost(self, exc):
78+
self.observer.delete()
10279

10380

10481
#class YampServer(asyncio.Protocol):

yamq/stomp/frame.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ def __init__(self, version="1.2", heart_beat=None, session=None,
2929

3030
class MessageFrame(Frame):
3131

32-
def __init__(self, destination, message_id, subscription_id, body="",
33-
ack=None):
32+
def __init__(self, message, destination, subscription_id, ack=None):
3433
command = "MESSAGE"
3534
headers = {
3635
'destination': destination,
37-
'message-id': message_id,
36+
'message-id': message._id,
37+
'content-type': message.content_type,
3838
'subscription': subscription_id
3939
}
4040

4141
if ack:
4242
headers['ack'] = ack
43-
super().__init__(command, headers, body)
43+
super().__init__(command, headers, message.message)

yamq/subject.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,36 @@ async def notify(self, message):
2626

2727
class SubjectSTOMP:
2828

29+
_objects = {}
30+
31+
def __new__(cls, name, loop):
32+
try:
33+
obj = cls._objects[name]
34+
except KeyError as e:
35+
obj = super().__new__(cls)
36+
cls._objects[name] = obj
37+
return obj
38+
2939
def __init__(self, name, loop):
40+
print("Init called for : {}".format(name))
3041
self.observers = {} # observer_object: ack_type
3142
self.name = name
3243
self.loop = loop
3344

3445
def __repr__(self):
3546
return "<Subject object: %s>" % (self.name,)
3647

37-
def subscribe(self, observer, subscription_id, ack="auto"):
48+
@classmethod
49+
def get(cls, name):
50+
return cls._objects.get(name)
51+
52+
def delete(self):
53+
try:
54+
del self.__class__._objects[self.name]
55+
except KeyError:
56+
pass
57+
58+
def subscribe(self, observer, ack="auto"):
3859
self.observers[observer] = ack
3960

4061
def unsubscribe(self, observer):
@@ -44,9 +65,9 @@ def unsubscribe(self, observer):
4465
# TODO: I am not sure you should silently pass
4566
pass
4667

68+
if not self.objservers:
69+
self.delete()
70+
4771
def notify(self, message):
48-
"""Follows PUSH stratergy."""
49-
for observer, ack in self.observers:
50-
await observer.update(
51-
self.name, message.message, message._id, ack
52-
)
72+
for observer, ack_type in self.observers:
73+
observer.update(message, ack_type, subject=self)

0 commit comments

Comments
 (0)