diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index a80e12d..306b630 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -3,6 +3,7 @@ import logging import uuid +import aioredlock from redis import asyncio as aioredis from .serializers import registry @@ -117,6 +118,11 @@ def __init__( RedisSingleShardConnection(host, self) for host in decode_hosts(hosts) ] + # Create lock manager for all redis connections + redis_connections = [shard.host for shard in self._shards] + self._lock_manager = aioredlock.Aioredlock() + self._lock_manager.redis_connections = redis_connections + def _get_shard(self, channel_or_group_name): """ Return the shard that is used exclusively for this channel or group. @@ -135,10 +141,21 @@ def _get_group_channel_name(self, group): """ return f"{self.prefix}__group__{group}" + async def _acquire_lock(self, channel): + try: + await self._lock_manager.lock(channel, lock_timeout=60) + except aioredlock.LockError: + logger.debug("Failed to acquire lock on channel %s", channel) + return False + + return True + async def _subscribe_to_channel(self, channel): self.channels[channel] = asyncio.Queue() - shard = self._get_shard(channel) - await shard.subscribe(channel) + + if await self._acquire_lock(channel): + shard = self._get_shard(channel) + await shard.subscribe(channel) extensions = ["groups", "flush"] diff --git a/setup.py b/setup.py index a368709..96a3713 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ include_package_data=True, python_requires=">=3.8", install_requires=[ + "aioredlock>=0.7.3,<1", "redis>=4.6", "msgpack~=1.0", "asgiref>=3.2.10,<4", diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 3c00dd6..5ec33d9 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -5,8 +5,8 @@ import async_timeout import pytest - from asgiref.sync import async_to_sync + from channels_redis.pubsub import RedisPubSubChannelLayer from channels_redis.utils import _close_redis @@ -261,3 +261,45 @@ async def test_discard_before_add(channel_layer): channel_name = await channel_layer.new_channel(prefix="test-channel") # Make sure that we can remove a group before it was ever added without crashing. await channel_layer.group_discard("test-group", channel_name) + + +@pytest.mark.asyncio +async def test_guarantee_at_most_once_delivery() -> None: + """ + Tests that at most once delivery is guaranteed. + + If two consumers are listening on the same channel, + the message should be delivered to only one of them. + """ + + channel_name = "same-channel" + loop = asyncio.get_running_loop() + + channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS) + channel_layer_2 = RedisPubSubChannelLayer(hosts=TEST_HOSTS) + future_channel_layer = loop.create_future() + future_channel_layer_2 = loop.create_future() + + async def receive_task( + channel_layer: RedisPubSubChannelLayer, future: asyncio.Future + ) -> None: + message = await channel_layer.receive(channel_name) + future.set_result(message) + + # Ensure that receive_task_2 is scheduled first and accquires the lock + asyncio.create_task(receive_task(channel_layer_2, future_channel_layer_2)) + await asyncio.sleep(1) + asyncio.create_task(receive_task(channel_layer, future_channel_layer)) + await asyncio.sleep(1) + + await channel_layer.send(channel_name, {"type": "test.message", "text": "Hello!"}) + + result = await future_channel_layer_2 + assert result["type"] == "test.message" + assert result["text"] == "Hello!" + + # Channel layer 1 should not receive the message + # as it is already consumed by channel layer 2 + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(1): + await future_channel_layer