Skip to content

Commit 8add186

Browse files
committed
Limit to_device EDU size to 65536
1 parent ae877aa commit 8add186

File tree

4 files changed

+224
-50
lines changed

4 files changed

+224
-50
lines changed

synapse/api/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
# the max size of a (canonical-json-encoded) event
3030
MAX_PDU_SIZE = 65536
31+
MAX_EDU_SIZE = 65536
3132

3233
# Max/min size of ints in canonical JSON
3334
CANONICALJSON_MAX_INT = (2**53) - 1

synapse/handlers/devicemessage.py

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,19 @@
2020
#
2121

2222
import logging
23+
from copy import deepcopy
2324
from http import HTTPStatus
24-
from typing import TYPE_CHECKING, Any, Dict, Optional
25+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
2526

26-
from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
27-
from synapse.api.errors import Codes, SynapseError
27+
from canonicaljson import encode_canonical_json
28+
29+
from synapse.api.constants import (
30+
MAX_EDU_SIZE,
31+
EduTypes,
32+
EventContentFields,
33+
ToDeviceEventTypes,
34+
)
35+
from synapse.api.errors import Codes, EventSizeError, SynapseError
2836
from synapse.api.ratelimiting import Ratelimiter
2937
from synapse.logging.context import run_in_background
3038
from synapse.logging.opentracing import (
@@ -293,18 +301,18 @@ async def send_device_message(
293301

294302
remote_edu_contents = {}
295303
for destination, messages in remote_messages.items():
296-
# The EDU contains a "message_id" property which is used for
297-
# idempotence. Make up a random one.
298-
message_id = random_string(16)
299-
log_kv({"destination": destination, "message_id": message_id})
300-
301-
remote_edu_contents[destination] = {
302-
"messages": messages,
303-
"sender": sender_user_id,
304-
"type": message_type,
305-
"message_id": message_id,
306-
"org.matrix.opentracing_context": json_encoder.encode(context),
307-
}
304+
edu_contents = get_device_message_edu_contents(
305+
sender_user_id, message_type, messages, context
306+
)
307+
remote_edu_contents[destination] = edu_contents
308+
log_kv(
309+
{
310+
"destination": destination,
311+
"message_ids": [
312+
edu_content["message_id"] for edu_content in edu_contents
313+
],
314+
}
315+
)
308316

309317
# Add messages to the database.
310318
# Retrieve the stream id of the last-processed to-device message.
@@ -409,3 +417,63 @@ async def get_events_for_dehydrated_device(
409417
"events": messages,
410418
"next_batch": f"d{stream_id}",
411419
}
420+
421+
422+
def get_device_message_edu_contents(
423+
sender_user_id: str,
424+
message_type: str,
425+
messages: Dict[str, Dict[str, JsonDict]],
426+
context: Dict[str, Any],
427+
) -> List[JsonDict]:
428+
"""
429+
This function takes a dictionary of messages and splits them into several EDUs if needed.
430+
431+
It will raise an EventSizeError if a single message is too large to fit into an EDU.
432+
"""
433+
434+
base_edu_content = {
435+
"messages": {},
436+
"sender": sender_user_id,
437+
"type": message_type,
438+
"message_id": random_string(16),
439+
}
440+
# This is the size of the full EDU without any messages and without the opentracing context
441+
base_edu_size = len(
442+
encode_canonical_json(
443+
{
444+
"edu_type": "m.direct_to_device",
445+
"content": base_edu_content,
446+
}
447+
)
448+
)
449+
base_edu_content["org.matrix.opentracing_context"] = json_encoder.encode(context)
450+
451+
edu_contents = []
452+
453+
current_edu_content: JsonDict = deepcopy(base_edu_content)
454+
current_edu_size = base_edu_size
455+
456+
for recipient, message in messages.items():
457+
# We remove 2 for the curly braces and add 2 for the colon and comma
458+
# We may overshoot by 1 for single message EDUs because of the comma, but that's fine
459+
message_entry_size = len(encode_canonical_json({recipient: message}))
460+
461+
if current_edu_size + message_entry_size > MAX_EDU_SIZE:
462+
if len(current_edu_content["messages"]) == 0:
463+
raise EventSizeError("device message too large", unpersistable=True)
464+
465+
edu_contents.append(current_edu_content)
466+
467+
current_edu_content = deepcopy(base_edu_content)
468+
current_edu_content["message_id"] = random_string(16)
469+
470+
current_edu_size = base_edu_size
471+
else:
472+
current_edu_size += message_entry_size
473+
474+
current_edu_content["messages"][recipient] = message
475+
476+
if len(current_edu_content["messages"]) > 0:
477+
edu_contents.append(current_edu_content)
478+
479+
return edu_contents

synapse/storage/databases/main/deviceinbox.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -718,15 +718,15 @@ def get_all_new_device_messages_txn(
718718
async def add_messages_to_device_inbox(
719719
self,
720720
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
721-
remote_messages_by_destination: Dict[str, JsonDict],
721+
remote_edu_contents: Dict[str, List[JsonDict]],
722722
) -> int:
723723
"""Used to send messages from this server.
724724
725725
Args:
726726
local_messages_by_user_then_device:
727727
Dictionary of recipient user_id to recipient device_id to message.
728-
remote_messages_by_destination:
729-
Dictionary of destination server_name to the EDU JSON to send.
728+
remote_edu_contents:
729+
Dictionary of destination server_name to the list of EDU contents to send.
730730
731731
Returns:
732732
The new stream_id.
@@ -760,42 +760,53 @@ def add_messages_txn(
760760
destination,
761761
stream_id,
762762
now_ms,
763-
json_encoder.encode(edu),
763+
json_encoder.encode(edu_content),
764764
self._instance_name,
765765
)
766-
for destination, edu in remote_messages_by_destination.items()
766+
for destination, edu_contents in remote_edu_contents.items()
767+
for edu_content in edu_contents
767768
],
768769
)
769770

770-
for destination, edu in remote_messages_by_destination.items():
771-
if issue9533_logger.isEnabledFor(logging.DEBUG):
772-
issue9533_logger.debug(
773-
"Queued outgoing to-device messages with "
774-
"stream_id %i, EDU message_id %s, type %s for %s: %s",
775-
stream_id,
776-
edu["message_id"],
777-
edu["type"],
778-
destination,
779-
[
780-
f"{user_id}/{device_id} (msgid "
781-
f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})"
782-
for (user_id, messages_by_device) in edu["messages"].items()
783-
for (device_id, msg) in messages_by_device.items()
784-
],
785-
)
771+
for destination, edu_contents in remote_edu_contents.items():
772+
for edu_content in edu_contents:
773+
if issue9533_logger.isEnabledFor(logging.DEBUG):
774+
issue9533_logger.debug(
775+
"Queued outgoing to-device messages with "
776+
"stream_id %i, EDU message_id %s, type %s for %s: %s",
777+
stream_id,
778+
edu_content["message_id"],
779+
edu_content["type"],
780+
destination,
781+
[
782+
f"{user_id}/{device_id} (msgid "
783+
f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})"
784+
for (user_id, messages_by_device) in edu_content[
785+
"messages"
786+
].items()
787+
for (device_id, msg) in messages_by_device.items()
788+
],
789+
)
786790

787-
for user_id, messages_by_device in edu["messages"].items():
788-
for device_id, msg in messages_by_device.items():
789-
with start_active_span("store_outgoing_to_device_message"):
790-
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
791-
set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
792-
set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"])
793-
set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
794-
set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
795-
set_tag(
796-
SynapseTags.TO_DEVICE_MSGID,
797-
msg.get(EventContentFields.TO_DEVICE_MSGID),
798-
)
791+
for user_id, messages_by_device in edu_content["messages"].items():
792+
for device_id, msg in messages_by_device.items():
793+
with start_active_span("store_outgoing_to_device_message"):
794+
set_tag(
795+
SynapseTags.TO_DEVICE_EDU_ID, edu_content["sender"]
796+
)
797+
set_tag(
798+
SynapseTags.TO_DEVICE_EDU_ID,
799+
edu_content["message_id"],
800+
)
801+
set_tag(SynapseTags.TO_DEVICE_TYPE, edu_content["type"])
802+
set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
803+
set_tag(
804+
SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id
805+
)
806+
set_tag(
807+
SynapseTags.TO_DEVICE_MSGID,
808+
msg.get(EventContentFields.TO_DEVICE_MSGID),
809+
)
799810

800811
async with self._to_device_msg_id_gen.get_next() as stream_id:
801812
now_ms = self._clock.time_msec()
@@ -804,7 +815,7 @@ def add_messages_txn(
804815
)
805816
for user_id in local_messages_by_user_then_device.keys():
806817
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
807-
for destination in remote_messages_by_destination.keys():
818+
for destination in remote_edu_contents.keys():
808819
self._device_federation_outbox_stream_cache.entity_has_changed(
809820
destination, stream_id
810821
)

tests/rest/client/test_sendtodevice.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,20 @@
1818
# [This file includes modifications made by New Vector Limited]
1919
#
2020
#
21+
from unittest.mock import AsyncMock, Mock
22+
2123
from parameterized import parameterized_class
2224

23-
from synapse.api.constants import EduTypes
25+
from twisted.test.proto_helpers import MemoryReactor
26+
27+
from synapse.api.constants import MAX_EDU_SIZE, EduTypes
28+
from synapse.api.errors import Codes
2429
from synapse.rest import admin
2530
from synapse.rest.client import login, sendtodevice, sync
31+
from synapse.server import HomeServer
2632
from synapse.types import JsonDict
33+
from synapse.util import Clock
34+
from synapse.util.stringutils import random_string
2735

2836
from tests.unittest import HomeserverTestCase, override_config
2937

@@ -61,8 +69,18 @@ class SendToDeviceTestCase(HomeserverTestCase):
6169
def default_config(self) -> JsonDict:
6270
config = super().default_config()
6371
config["experimental_features"] = self.experimental_features
72+
config["federation_sender_instances"] = None
6473
return config
6574

75+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
76+
self.federation_transport_client = Mock(spec=["send_transaction"])
77+
self.federation_transport_client.send_transaction = AsyncMock()
78+
hs = self.setup_test_homeserver(
79+
federation_transport_client=self.federation_transport_client,
80+
)
81+
82+
return hs
83+
6684
def test_user_to_user(self) -> None:
6785
"""A to-device message from one user to another should get delivered"""
6886

@@ -113,6 +131,82 @@ def test_user_to_user(self) -> None:
113131
self.assertEqual(channel.code, 200, channel.result)
114132
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
115133

134+
def test_large_remote_todevice(self) -> None:
135+
"""A to-device message needs to fit in the EDU size limit"""
136+
_ = self.register_user("u1", "pass")
137+
user1_tok = self.login("u1", "pass", "d1")
138+
139+
# send the message
140+
test_msg = {"foo": random_string(MAX_EDU_SIZE)}
141+
channel = self.make_request(
142+
"PUT",
143+
"/_matrix/client/r0/sendToDevice/m.test/12345",
144+
content={"messages": {"@remote_user:secondserver": {"device": test_msg}}},
145+
access_token=user1_tok,
146+
)
147+
self.assertEqual(channel.code, 413, channel.result)
148+
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
149+
150+
def test_edu_splitting(self) -> None:
151+
"""Test that a bunch of to-device messages are split into multiple EDUs if they are too large"""
152+
mock_send_transaction: AsyncMock = (
153+
self.federation_transport_client.send_transaction
154+
)
155+
mock_send_transaction.return_value = {}
156+
157+
sender = self.hs.get_federation_sender()
158+
159+
_ = self.register_user("u1", "pass")
160+
user1_tok = self.login("u1", "pass", "d1")
161+
destination = "secondserver"
162+
messages = {}
163+
164+
# 2 small messages that should fit in a single EDU
165+
for i in range(2):
166+
messages[f"@remote_user{i}:" + destination] = {
167+
"device": {"foo": random_string(100)}
168+
}
169+
170+
channel = self.make_request(
171+
"PUT",
172+
"/_matrix/client/r0/sendToDevice/m.test/123456",
173+
content={"messages": messages},
174+
access_token=user1_tok,
175+
)
176+
self.assertEqual(channel.code, 200, channel.result)
177+
178+
self.get_success(sender.send_device_messages([destination]))
179+
180+
self.pump()
181+
182+
json_cb = mock_send_transaction.call_args[0][1]
183+
data = json_cb()
184+
self.assertEqual(len(data["edus"]), 1)
185+
186+
mock_send_transaction.reset_mock()
187+
188+
# 2 messages, each just big enough to fit in an EDU
189+
for i in range(2):
190+
messages[f"@remote_user{i}:" + destination] = {
191+
"device": {"foo": random_string(MAX_EDU_SIZE - 1000)}
192+
}
193+
194+
channel = self.make_request(
195+
"PUT",
196+
"/_matrix/client/r0/sendToDevice/m.test/1234567",
197+
content={"messages": messages},
198+
access_token=user1_tok,
199+
)
200+
self.assertEqual(channel.code, 200, channel.result)
201+
202+
self.get_success(sender.send_device_messages([destination]))
203+
204+
self.pump()
205+
206+
json_cb = mock_send_transaction.call_args[0][1]
207+
data = json_cb()
208+
self.assertEqual(len(data["edus"]), 2)
209+
116210
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
117211
def test_local_room_key_request(self) -> None:
118212
"""m.room_key_request has special-casing; test from local user"""

0 commit comments

Comments
 (0)