Skip to content

Commit 857bfe5

Browse files
ThomasKulintkulinpuddly
authored
Add support for receiving basic fragmented messages (#669)
* add basic fragmentation support * add basic fragmentation support (remembering to cleanup leftover fragments) * code cleanup and improved partial fragment removal * minor change to frag ack callback handling, added test case for fragmentation * added more tests * Use generic types in favor of `typing` * Use an instance-specific fragment manager instance instead of a global * Keep track of fragmentation ACK tasks in an instance variable * Clean up formatting a little * Fix unit tests --------- Co-authored-by: tkulin <[email protected]> Co-authored-by: puddly <[email protected]>
1 parent ce8a126 commit 857bfe5

File tree

4 files changed

+563
-0
lines changed

4 files changed

+563
-0
lines changed

bellows/ezsp/fragmentation.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Implements APS fragmentation reassembly on the EZSP Host side,
2+
mirroring the logic from fragmentation.c in the EmberZNet stack.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import asyncio
8+
import logging
9+
10+
LOGGER = logging.getLogger(__name__)
11+
12+
# The maximum time (in seconds) we wait for all fragments of a given message.
13+
# If not all fragments arrive within this time, we discard the partial data.
14+
FRAGMENT_TIMEOUT = 10
15+
16+
# store partial data keyed by (sender, aps_sequence, profile_id, cluster_id)
17+
FragmentKey = tuple[int, int, int, int]
18+
19+
20+
class _FragmentEntry:
21+
def __init__(self, fragment_count: int):
22+
self.fragment_count = fragment_count
23+
self.fragments_received = 0
24+
self.fragment_data = {}
25+
26+
def add_fragment(self, index: int, data: bytes) -> None:
27+
if index not in self.fragment_data:
28+
self.fragment_data[index] = data
29+
self.fragments_received += 1
30+
31+
def is_complete(self) -> bool:
32+
return self.fragments_received == self.fragment_count
33+
34+
def assemble(self) -> bytes:
35+
return b"".join(
36+
self.fragment_data[i] for i in sorted(self.fragment_data.keys())
37+
)
38+
39+
40+
class FragmentManager:
41+
def __init__(self):
42+
self._partial: dict[FragmentKey, _FragmentEntry] = {}
43+
self._cleanup_timers: dict[FragmentKey, asyncio.TimerHandle] = {}
44+
45+
def handle_incoming_fragment(
46+
self,
47+
sender_nwk: int,
48+
aps_sequence: int,
49+
profile_id: int,
50+
cluster_id: int,
51+
fragment_count: int,
52+
fragment_index: int,
53+
payload: bytes,
54+
) -> tuple[bool, bytes | None, int, int]:
55+
"""Handle a newly received fragment.
56+
57+
:param sender_nwk: NWK address or the short ID of the sender.
58+
:param aps_sequence: The APS sequence from the incoming APS frame.
59+
:param profile_id: The APS frame's profileId.
60+
:param cluster_id: The APS frame's clusterId.
61+
:param fragment_count: The total number of expected message fragments.
62+
:param fragment_index: The index of the current fragment being processed.
63+
:param payload: The fragment of data for this message.
64+
:return: (complete, reassembled_data, fragment_count, fragment_index)
65+
complete = True if we have all fragments now, else False
66+
reassembled_data = the final complete payload (bytes) if complete is True
67+
fragment_coutn = the total number of fragments holding the complete packet
68+
fragment_index = the index of the current received fragment
69+
"""
70+
71+
key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id)
72+
73+
# If we have never seen this message, create a reassembly entry.
74+
if key not in self._partial:
75+
entry = _FragmentEntry(fragment_count)
76+
self._partial[key] = entry
77+
else:
78+
entry = self._partial[key]
79+
80+
LOGGER.debug(
81+
"Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)",
82+
fragment_index + 1,
83+
fragment_count,
84+
sender_nwk,
85+
aps_sequence,
86+
cluster_id,
87+
)
88+
89+
entry.add_fragment(fragment_index, payload)
90+
91+
loop = asyncio.get_running_loop()
92+
self._cleanup_timers[key] = loop.call_later(
93+
FRAGMENT_TIMEOUT, self.cleanup_partial, key
94+
)
95+
96+
if entry.is_complete():
97+
reassembled = entry.assemble()
98+
del self._partial[key]
99+
timer = self._cleanup_timers.pop(key, None)
100+
if timer:
101+
timer.cancel()
102+
LOGGER.debug(
103+
"Message reassembly complete. Total length=%d", len(reassembled)
104+
)
105+
return (True, reassembled, fragment_count, fragment_index)
106+
else:
107+
return (False, None, fragment_count, fragment_index)
108+
109+
def cleanup_partial(self, key: FragmentKey):
110+
# Called when FRAGMENT_TIMEOUT passes with no new fragments for that key.
111+
LOGGER.debug(
112+
"Timeout for partial reassembly of fragmented message, discarding key=%s",
113+
key,
114+
)
115+
self._partial.pop(key, None)
116+
self._cleanup_timers.pop(key, None)

bellows/ezsp/protocol.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from bellows.config import CONF_EZSP_POLICIES
2222
from bellows.exception import InvalidCommandError
23+
from bellows.ezsp.fragmentation import FragmentManager
2324
import bellows.types as t
2425

2526
if TYPE_CHECKING:
@@ -53,6 +54,8 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
5354

5455
# Cached by `set_extended_timeout` so subsequent calls are a little faster
5556
self._address_table_size: int | None = None
57+
self._fragment_manager = FragmentManager()
58+
self._fragment_ack_tasks: set[asyncio.Task] = set()
5659

5760
def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes:
5861
"""Serialize the named frame and data."""
@@ -181,6 +184,52 @@ def __call__(self, data: bytes) -> None:
181184
if data:
182185
LOGGER.debug("Frame contains trailing data: %s", data)
183186

187+
if (
188+
frame_name == "incomingMessageHandler"
189+
and result[1].options & t.EmberApsOption.APS_OPTION_FRAGMENT
190+
):
191+
# Extract received APS frame and sender
192+
aps_frame = result[1]
193+
sender = result[4]
194+
195+
# The fragment count and index are encoded in the groupId field
196+
fragment_count = (aps_frame.groupId >> 8) & 0xFF
197+
fragment_index = aps_frame.groupId & 0xFF
198+
199+
(
200+
complete,
201+
reassembled,
202+
frag_count,
203+
frag_index,
204+
) = self._fragment_manager.handle_incoming_fragment(
205+
sender_nwk=sender,
206+
aps_sequence=aps_frame.sequence,
207+
profile_id=aps_frame.profileId,
208+
cluster_id=aps_frame.clusterId,
209+
fragment_count=fragment_count,
210+
fragment_index=fragment_index,
211+
payload=result[7],
212+
)
213+
214+
ack_task = asyncio.create_task(
215+
self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)
216+
) # APS Ack
217+
218+
self._fragment_ack_tasks.add(ack_task)
219+
ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t))
220+
221+
if not complete:
222+
# Do not pass partial data up the stack
223+
LOGGER.debug("Fragment reassembly not complete. waiting for more data.")
224+
return
225+
226+
# Replace partial data with fully reassembled data
227+
result[7] = reassembled
228+
229+
LOGGER.debug(
230+
"Reassembled fragmented message. Proceeding with normal handling."
231+
)
232+
184233
if sequence in self._awaiting:
185234
expected_id, schema, future = self._awaiting.pop(sequence)
186235
try:
@@ -205,6 +254,32 @@ def __call__(self, data: bytes) -> None:
205254
else:
206255
self._handle_callback(frame_name, result)
207256

257+
async def _send_fragment_ack(
258+
self,
259+
sender: int,
260+
incoming_aps: t.EmberApsFrame,
261+
fragment_count: int,
262+
fragment_index: int,
263+
) -> t.EmberStatus:
264+
ackFrame = t.EmberApsFrame(
265+
profileId=incoming_aps.profileId,
266+
clusterId=incoming_aps.clusterId,
267+
sourceEndpoint=incoming_aps.destinationEndpoint,
268+
destinationEndpoint=incoming_aps.sourceEndpoint,
269+
options=incoming_aps.options,
270+
groupId=((0xFF00) | (fragment_index & 0xFF)),
271+
sequence=incoming_aps.sequence,
272+
)
273+
274+
LOGGER.debug(
275+
"Sending fragment ack to 0x%04X for fragment index=%d/%d",
276+
sender,
277+
fragment_index + 1,
278+
fragment_count,
279+
)
280+
status = await self.sendReply(sender, ackFrame, b"")
281+
return status[0]
282+
208283
def __getattr__(self, name: str) -> Callable:
209284
if name not in self.COMMANDS:
210285
raise AttributeError(f"{name} not found in COMMANDS")

tests/test_ezsp_protocol.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,160 @@ async def test_parsing_schema_response(prot_hndl_v9):
133133

134134
rsp = await coro
135135
assert rsp == GetTokenDataRsp(status=t.EmberStatus.LIBRARY_NOT_PRESENT)
136+
137+
138+
async def test_send_fragment_ack(prot_hndl, caplog):
139+
"""Test the _send_fragment_ack method."""
140+
sender = 0x1D6F
141+
incoming_aps = t.EmberApsFrame(
142+
profileId=260,
143+
clusterId=65281,
144+
sourceEndpoint=2,
145+
destinationEndpoint=2,
146+
options=33088,
147+
groupId=512,
148+
sequence=238,
149+
)
150+
fragment_count = 2
151+
fragment_index = 0
152+
153+
expected_ack_frame = t.EmberApsFrame(
154+
profileId=260,
155+
clusterId=65281,
156+
sourceEndpoint=2,
157+
destinationEndpoint=2,
158+
options=33088,
159+
groupId=((0xFF00) | (fragment_index & 0xFF)),
160+
sequence=238,
161+
)
162+
163+
with patch.object(prot_hndl, "sendReply", new=AsyncMock()) as mock_send_reply:
164+
mock_send_reply.return_value = (t.EmberStatus.SUCCESS,)
165+
166+
caplog.set_level(logging.DEBUG)
167+
status = await prot_hndl._send_fragment_ack(
168+
sender, incoming_aps, fragment_count, fragment_index
169+
)
170+
171+
# Assertions
172+
assert status == t.EmberStatus.SUCCESS
173+
assert (
174+
"Sending fragment ack to 0x1d6f for fragment index=1/2".lower()
175+
in caplog.text.lower()
176+
)
177+
mock_send_reply.assert_called_once_with(sender, expected_ack_frame, b"")
178+
179+
180+
async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog):
181+
"""Test handling of an incomplete fragmented message."""
182+
packet = b"\x90\x01\x45\x00\x05\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x01\xdd"
183+
184+
# Parse packet manually to extract parameters for assertions
185+
sender = 0x1D6F
186+
aps_frame = t.EmberApsFrame(
187+
profileId=261, # 0x0105
188+
clusterId=65281, # 0xFF01
189+
sourceEndpoint=2, # 0x02
190+
destinationEndpoint=2, # 0x02
191+
options=33088, # 0x8140 (APS_OPTION_FRAGMENT + others)
192+
groupId=512, # 0x0002 (fragment_count=2, fragment_index=0)
193+
sequence=238, # 0xEE
194+
)
195+
196+
with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack:
197+
mock_ack.return_value = None
198+
199+
caplog.set_level(logging.DEBUG)
200+
prot_hndl(packet)
201+
202+
assert len(prot_hndl._fragment_ack_tasks) == 1
203+
ack_task = next(iter(prot_hndl._fragment_ack_tasks))
204+
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
205+
assert (
206+
len(prot_hndl._fragment_ack_tasks) == 0
207+
), "Done callback should have removed task"
208+
209+
prot_hndl._handle_callback.assert_not_called()
210+
assert "Fragment reassembly not complete. waiting for more data." in caplog.text
211+
mock_ack.assert_called_once_with(sender, aps_frame, 2, 0)
212+
213+
214+
async def test_incoming_fragmented_message_complete(prot_hndl, caplog):
215+
"""Test handling of a complete fragmented message."""
216+
packet1 = (
217+
b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x09"
218+
+ b"complete "
219+
) # fragment index 0
220+
packet2 = (
221+
b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07"
222+
+ b"message"
223+
) # fragment index 1
224+
sender = 0x1D6F
225+
226+
aps_frame_1 = t.EmberApsFrame(
227+
profileId=260,
228+
clusterId=65281,
229+
sourceEndpoint=2,
230+
destinationEndpoint=2,
231+
options=33088, # Includes APS_OPTION_FRAGMENT
232+
groupId=512, # fragment_count=2, fragment_index=0
233+
sequence=238,
234+
)
235+
aps_frame_2 = t.EmberApsFrame(
236+
profileId=260,
237+
clusterId=65281,
238+
sourceEndpoint=2,
239+
destinationEndpoint=2,
240+
options=33088,
241+
groupId=513, # fragment_count=2, fragment_index=1
242+
sequence=238,
243+
)
244+
reassembled = b"complete message"
245+
246+
with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack:
247+
mock_ack.return_value = None
248+
caplog.set_level(logging.DEBUG)
249+
250+
# Packet 1
251+
prot_hndl(packet1)
252+
assert len(prot_hndl._fragment_ack_tasks) == 1
253+
ack_task = next(iter(prot_hndl._fragment_ack_tasks))
254+
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
255+
assert (
256+
len(prot_hndl._fragment_ack_tasks) == 0
257+
), "Done callback should have removed task"
258+
259+
prot_hndl._handle_callback.assert_not_called()
260+
assert (
261+
"Reassembled fragmented message. Proceeding with normal handling."
262+
not in caplog.text
263+
)
264+
mock_ack.assert_called_with(sender, aps_frame_1, 2, 0)
265+
266+
# Packet 2
267+
prot_hndl(packet2)
268+
assert len(prot_hndl._fragment_ack_tasks) == 1
269+
ack_task = next(iter(prot_hndl._fragment_ack_tasks))
270+
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
271+
assert (
272+
len(prot_hndl._fragment_ack_tasks) == 0
273+
), "Done callback should have removed task"
274+
275+
prot_hndl._handle_callback.assert_called_once_with(
276+
"incomingMessageHandler",
277+
[
278+
t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00
279+
aps_frame_2, # Parsed APS frame
280+
255, # lastHopLqi: 0xFF
281+
-8, # lastHopRssi: 0xF8
282+
sender, # 0x1D6F
283+
255, # bindingIndex: 0xFF
284+
255, # addressIndex: 0xFF
285+
reassembled, # Reassembled payload
286+
],
287+
)
288+
assert (
289+
"Reassembled fragmented message. Proceeding with normal handling."
290+
in caplog.text
291+
)
292+
mock_ack.assert_called_with(sender, aps_frame_2, 2, 1)

0 commit comments

Comments
 (0)