diff --git a/src/StackExchange.Redis/Interfaces/ISubscriber.cs b/src/StackExchange.Redis/Interfaces/ISubscriber.cs
index e0c509f49..170188550 100644
--- a/src/StackExchange.Redis/Interfaces/ISubscriber.cs
+++ b/src/StackExchange.Redis/Interfaces/ISubscriber.cs
@@ -61,6 +61,39 @@ public interface ISubscriber : IRedis
///
Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None);
+ ///
+ /// Subscribe to perform some operation when a message to any primary node is broadcast, without any guarantee of ordered handling.
+ /// This is most useful when addressing key space notifications. For user controlled pub/sub channels, it should be used with caution and further it is not advised.
+ ///
+ /// The channel to subscribe to.
+ /// The handler to invoke when a message is received on .
+ /// The command flags to use.
+ ///
+ /// See
+ /// ,
+ /// ,
+ /// .
+ ///
+ Task SubscribeAllPrimariesAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None);
+
+ ///
+ /// Unsubscribe from a specified message channel on all primary nodes.
+ /// Note: if no handler is specified, the subscription is canceled regardless of the subscribers.
+ /// If a handler is specified, the subscription is only canceled if this handler is the last handler remaining against the channel.
+ /// This is used in combination with and as mentioned there,
+ /// it is intended to use with key space notitications and not advised for user controlled pub/sub channels.
+ ///
+ /// The channel that was subscribed to.
+ /// The handler to no longer invoke when a message is received on .
+ /// The command flags to use.
+ ///
+ /// See
+ /// ,
+ /// ,
+ /// .
+ ///
+ Task UnsubscribeAllPrimariesAsync(RedisChannel channel, Action? handler = null, CommandFlags flags = CommandFlags.None);
+
///
/// Subscribe to perform some operation when a message to the preferred/active node is broadcast, as a queue that guarantees ordered handling.
///
diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt
index a24333c8e..13fb1d4d9 100644
--- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt
+++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt
@@ -1893,4 +1893,5 @@ virtual StackExchange.Redis.RedisResult.Length.get -> int
virtual StackExchange.Redis.RedisResult.this[int index].get -> StackExchange.Redis.RedisResult!
StackExchange.Redis.ConnectionMultiplexer.AddLibraryNameSuffix(string! suffix) -> void
StackExchange.Redis.IConnectionMultiplexer.AddLibraryNameSuffix(string! suffix) -> void
-
+StackExchange.Redis.ISubscriber.SubscribeAllPrimariesAsync(StackExchange.Redis.RedisChannel channel, System.Action! handler, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task!
+StackExchange.Redis.ISubscriber.UnsubscribeAllPrimariesAsync(StackExchange.Redis.RedisChannel channel, System.Action? handler = null, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task!
diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs
index ee28f4c56..cebe4253c 100644
--- a/src/StackExchange.Redis/RedisSubscriber.cs
+++ b/src/StackExchange.Redis/RedisSubscriber.cs
@@ -1,6 +1,8 @@
using System;
using System.Collections.Concurrent;
+using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
+using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
@@ -432,6 +434,86 @@ public Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags)
+ => SubscribeAllPrimariesAsync(channel, handler, null, flags);
+
+ public Task SubscribeAllPrimariesAsync(RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags)
+ {
+ ThrowIfNull(channel);
+ if (handler == null && queue == null) { return CompletedTask.Default(null); }
+
+ var sub = multiplexer.GetOrAddSubscription(channel, flags);
+ sub.Add(handler, queue);
+ return EnsureSubscribedToPrimariesAsync(sub, channel, flags, false);
+ }
+
+ private Task EnsureSubscribedToPrimariesAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall)
+ {
+ if (sub.IsConnected) { return CompletedTask.Default(null); }
+
+ // TODO: Cleanup old hangers here?
+ sub.SetCurrentServer(null); // we're not appropriately connected, so blank it out for eligible reconnection
+ var tasks = new List>();
+ foreach (var server in multiplexer.GetServerSnapshot())
+ {
+ if (!server.IsReplica)
+ {
+ var message = sub.GetMessage(channel, SubscriptionAction.Subscribe, flags, internalCall);
+ tasks.Add(ExecuteAsync(message, sub.Processor, server));
+ }
+ }
+
+ if (tasks.Count == 0)
+ {
+ return CompletedTask.Default(false);
+ }
+
+ // Create a new task that will collect all results and observe errors
+ return Task.Run(async () =>
+ {
+ // Wait for all tasks to complete
+ var results = await Task.WhenAll(tasks).ObserveErrors();
+ return results.All(result => result);
+ }).ObserveErrors();
+ }
+
+ Task ISubscriber.UnsubscribeAllPrimariesAsync(RedisChannel channel, Action? handler, CommandFlags flags)
+ => UnsubscribeAllPrimariesAsync(channel, handler, null, flags);
+
+ public Task UnsubscribeAllPrimariesAsync(in RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags)
+ {
+ ThrowIfNull(channel);
+ // Unregister the subscription handler/queue, and if that returns true (last handler removed), also disconnect from the server
+ return UnregisterSubscription(channel, handler, queue, out var sub)
+ ? UnsubscribeFromPrimariesAsync(sub, channel, flags, asyncState, false)
+ : CompletedTask.Default(asyncState);
+ }
+
+ private Task UnsubscribeFromPrimariesAsync(Subscription sub, RedisChannel channel, CommandFlags flags, object? asyncState, bool internalCall)
+ {
+ var tasks = new List>();
+ foreach (var server in multiplexer.GetServerSnapshot())
+ {
+ if (!server.IsReplica)
+ {
+ var message = sub.GetMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall);
+ tasks.Add(multiplexer.ExecuteAsyncImpl(message, sub.Processor, asyncState, server));
+ }
+ }
+
+ if (tasks.Count == 0)
+ {
+ return CompletedTask.Default(false);
+ }
+
+ // Create a new task that will collect all results and observe errors
+ return Task.Run(async () =>
+ {
+ // Wait for all tasks to complete
+ var results = await Task.WhenAll(tasks).ObserveErrors();
+ return results.All(result => result);
+ }).ObserveErrors();
+ }
public Task EnsureSubscribedToServerAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall)
{
if (sub.IsConnected) { return CompletedTask.Default(null); }
diff --git a/tests/StackExchange.Redis.Tests/ClusterTests.cs b/tests/StackExchange.Redis.Tests/ClusterTests.cs
index 742ce51bb..e23bab0a0 100644
--- a/tests/StackExchange.Redis.Tests/ClusterTests.cs
+++ b/tests/StackExchange.Redis.Tests/ClusterTests.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@@ -746,4 +747,84 @@ public void ConnectIncludesSubscriber()
Assert.Equal(PhysicalBridge.State.ConnectedEstablished, server.SubscriptionConnectionState);
}
}
+
+ [Fact]
+ public async void SubscribeAllPrimariesAsync()
+ {
+ var receivedNofitications = new ConcurrentDictionary();
+ var expectedNofitications = new HashSet();
+ Action Accept = (channel, message) => receivedNofitications[$"{message} on channel {channel}"] = 0;
+
+ var redis = Create(keepAlive: 1, connectTimeout: 3000, shared: false, allowAdmin: true);
+ var db = redis.GetDatabase();
+ SwitchKeySpaceNofitications(redis);
+
+ var sub = redis.GetSubscriber();
+ var channel = new RedisChannel("__key*@*__:*", RedisChannel.PatternMode.Pattern);
+ await sub.SubscribeAllPrimariesAsync(channel, Accept);
+
+ for (int i = 1; i < 3; i++)
+ {
+ await db.StringSetAsync($"k{i}", i);
+ expectedNofitications.Add($"set on channel __keyspace@0__:k{i}");
+ expectedNofitications.Add($"k{i} on channel __keyevent@0__:set");
+ }
+
+ // Wait for notifications to be processed
+ await Task.Delay(1000).ForAwait();
+ foreach (var notification in expectedNofitications)
+ {
+ Assert.Contains(notification, receivedNofitications);
+ }
+ SwitchKeySpaceNofitications(redis, false);
+
+ // Assert that all expected notifications are contained in received notifications
+ Assert.True(
+ expectedNofitications.IsSubsetOf(receivedNofitications.Keys),
+ $"Expected notifications were not received. Expected: {string.Join(", ", expectedNofitications)}, Received: {string.Join(", ", receivedNofitications)}");
+ }
+
+ [Fact]
+ public async void UnsubscribeAllPrimariesAsync()
+ {
+ var receivedNofitications = new ConcurrentDictionary();
+ Action Accept = (channel, message) => receivedNofitications[$"{message} on channel {channel}"] = 0;
+
+ var redis = Create(keepAlive: 1, connectTimeout: 3000, shared: false, allowAdmin: true);
+ var db = redis.GetDatabase();
+ SwitchKeySpaceNofitications(redis);
+
+ var sub = redis.GetSubscriber();
+ var channel = new RedisChannel("__key*@*__:*", RedisChannel.PatternMode.Pattern);
+ await sub.SubscribeAllPrimariesAsync(channel, Accept);
+
+ for (int i = 1; i < 3; i++)
+ {
+ await db.StringSetAsync($"k{i}", i);
+ }
+
+ // Wait for notifications to be processed
+ await Task.Delay(1000).ForAwait();
+ Assert.NotEmpty(receivedNofitications);
+ receivedNofitications.Clear();
+ await sub.UnsubscribeAllPrimariesAsync(channel, Accept);
+ for (int i = 1; i < 3; i++)
+ {
+ await db.StringSetAsync($"k{i}", i);
+ }
+ await Task.Delay(1000).ForAwait();
+ Assert.Empty(receivedNofitications);
+
+ SwitchKeySpaceNofitications(redis, false);
+ }
+
+ private static void SwitchKeySpaceNofitications(IInternalConnectionMultiplexer redis, bool enable = true)
+ {
+ string onOff = enable ? "KEA" : "";
+ foreach (var endpoint in redis.GetEndPoints())
+ {
+ var server = redis.GetServer(endpoint);
+ if (!server.IsReplica) server.ConfigSet("notify-keyspace-events", onOff);
+ }
+ }
}