diff --git a/src/StackExchange.Redis/Interfaces/ISubscriber.cs b/src/StackExchange.Redis/Interfaces/ISubscriber.cs index e0c509f49..170188550 100644 --- a/src/StackExchange.Redis/Interfaces/ISubscriber.cs +++ b/src/StackExchange.Redis/Interfaces/ISubscriber.cs @@ -61,6 +61,39 @@ public interface ISubscriber : IRedis /// Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None); + /// + /// Subscribe to perform some operation when a message to any primary node is broadcast, without any guarantee of ordered handling. + /// This is most useful when addressing key space notifications. For user controlled pub/sub channels, it should be used with caution and further it is not advised. + /// + /// The channel to subscribe to. + /// The handler to invoke when a message is received on . + /// The command flags to use. + /// + /// See + /// , + /// , + /// . + /// + Task SubscribeAllPrimariesAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None); + + /// + /// Unsubscribe from a specified message channel on all primary nodes. + /// Note: if no handler is specified, the subscription is canceled regardless of the subscribers. + /// If a handler is specified, the subscription is only canceled if this handler is the last handler remaining against the channel. + /// This is used in combination with and as mentioned there, + /// it is intended to use with key space notitications and not advised for user controlled pub/sub channels. + /// + /// The channel that was subscribed to. + /// The handler to no longer invoke when a message is received on . + /// The command flags to use. + /// + /// See + /// , + /// , + /// . + /// + Task UnsubscribeAllPrimariesAsync(RedisChannel channel, Action? handler = null, CommandFlags flags = CommandFlags.None); + /// /// Subscribe to perform some operation when a message to the preferred/active node is broadcast, as a queue that guarantees ordered handling. /// diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt index a24333c8e..13fb1d4d9 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt @@ -1893,4 +1893,5 @@ virtual StackExchange.Redis.RedisResult.Length.get -> int virtual StackExchange.Redis.RedisResult.this[int index].get -> StackExchange.Redis.RedisResult! StackExchange.Redis.ConnectionMultiplexer.AddLibraryNameSuffix(string! suffix) -> void StackExchange.Redis.IConnectionMultiplexer.AddLibraryNameSuffix(string! suffix) -> void - +StackExchange.Redis.ISubscriber.SubscribeAllPrimariesAsync(StackExchange.Redis.RedisChannel channel, System.Action! handler, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! +StackExchange.Redis.ISubscriber.UnsubscribeAllPrimariesAsync(StackExchange.Redis.RedisChannel channel, System.Action? handler = null, StackExchange.Redis.CommandFlags flags = StackExchange.Redis.CommandFlags.None) -> System.Threading.Tasks.Task! diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index ee28f4c56..cebe4253c 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -432,6 +434,86 @@ public Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags) + => SubscribeAllPrimariesAsync(channel, handler, null, flags); + + public Task SubscribeAllPrimariesAsync(RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags) + { + ThrowIfNull(channel); + if (handler == null && queue == null) { return CompletedTask.Default(null); } + + var sub = multiplexer.GetOrAddSubscription(channel, flags); + sub.Add(handler, queue); + return EnsureSubscribedToPrimariesAsync(sub, channel, flags, false); + } + + private Task EnsureSubscribedToPrimariesAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall) + { + if (sub.IsConnected) { return CompletedTask.Default(null); } + + // TODO: Cleanup old hangers here? + sub.SetCurrentServer(null); // we're not appropriately connected, so blank it out for eligible reconnection + var tasks = new List>(); + foreach (var server in multiplexer.GetServerSnapshot()) + { + if (!server.IsReplica) + { + var message = sub.GetMessage(channel, SubscriptionAction.Subscribe, flags, internalCall); + tasks.Add(ExecuteAsync(message, sub.Processor, server)); + } + } + + if (tasks.Count == 0) + { + return CompletedTask.Default(false); + } + + // Create a new task that will collect all results and observe errors + return Task.Run(async () => + { + // Wait for all tasks to complete + var results = await Task.WhenAll(tasks).ObserveErrors(); + return results.All(result => result); + }).ObserveErrors(); + } + + Task ISubscriber.UnsubscribeAllPrimariesAsync(RedisChannel channel, Action? handler, CommandFlags flags) + => UnsubscribeAllPrimariesAsync(channel, handler, null, flags); + + public Task UnsubscribeAllPrimariesAsync(in RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags) + { + ThrowIfNull(channel); + // Unregister the subscription handler/queue, and if that returns true (last handler removed), also disconnect from the server + return UnregisterSubscription(channel, handler, queue, out var sub) + ? UnsubscribeFromPrimariesAsync(sub, channel, flags, asyncState, false) + : CompletedTask.Default(asyncState); + } + + private Task UnsubscribeFromPrimariesAsync(Subscription sub, RedisChannel channel, CommandFlags flags, object? asyncState, bool internalCall) + { + var tasks = new List>(); + foreach (var server in multiplexer.GetServerSnapshot()) + { + if (!server.IsReplica) + { + var message = sub.GetMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall); + tasks.Add(multiplexer.ExecuteAsyncImpl(message, sub.Processor, asyncState, server)); + } + } + + if (tasks.Count == 0) + { + return CompletedTask.Default(false); + } + + // Create a new task that will collect all results and observe errors + return Task.Run(async () => + { + // Wait for all tasks to complete + var results = await Task.WhenAll(tasks).ObserveErrors(); + return results.All(result => result); + }).ObserveErrors(); + } public Task EnsureSubscribedToServerAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall) { if (sub.IsConnected) { return CompletedTask.Default(null); } diff --git a/tests/StackExchange.Redis.Tests/ClusterTests.cs b/tests/StackExchange.Redis.Tests/ClusterTests.cs index 742ce51bb..e23bab0a0 100644 --- a/tests/StackExchange.Redis.Tests/ClusterTests.cs +++ b/tests/StackExchange.Redis.Tests/ClusterTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; @@ -746,4 +747,84 @@ public void ConnectIncludesSubscriber() Assert.Equal(PhysicalBridge.State.ConnectedEstablished, server.SubscriptionConnectionState); } } + + [Fact] + public async void SubscribeAllPrimariesAsync() + { + var receivedNofitications = new ConcurrentDictionary(); + var expectedNofitications = new HashSet(); + Action Accept = (channel, message) => receivedNofitications[$"{message} on channel {channel}"] = 0; + + var redis = Create(keepAlive: 1, connectTimeout: 3000, shared: false, allowAdmin: true); + var db = redis.GetDatabase(); + SwitchKeySpaceNofitications(redis); + + var sub = redis.GetSubscriber(); + var channel = new RedisChannel("__key*@*__:*", RedisChannel.PatternMode.Pattern); + await sub.SubscribeAllPrimariesAsync(channel, Accept); + + for (int i = 1; i < 3; i++) + { + await db.StringSetAsync($"k{i}", i); + expectedNofitications.Add($"set on channel __keyspace@0__:k{i}"); + expectedNofitications.Add($"k{i} on channel __keyevent@0__:set"); + } + + // Wait for notifications to be processed + await Task.Delay(1000).ForAwait(); + foreach (var notification in expectedNofitications) + { + Assert.Contains(notification, receivedNofitications); + } + SwitchKeySpaceNofitications(redis, false); + + // Assert that all expected notifications are contained in received notifications + Assert.True( + expectedNofitications.IsSubsetOf(receivedNofitications.Keys), + $"Expected notifications were not received. Expected: {string.Join(", ", expectedNofitications)}, Received: {string.Join(", ", receivedNofitications)}"); + } + + [Fact] + public async void UnsubscribeAllPrimariesAsync() + { + var receivedNofitications = new ConcurrentDictionary(); + Action Accept = (channel, message) => receivedNofitications[$"{message} on channel {channel}"] = 0; + + var redis = Create(keepAlive: 1, connectTimeout: 3000, shared: false, allowAdmin: true); + var db = redis.GetDatabase(); + SwitchKeySpaceNofitications(redis); + + var sub = redis.GetSubscriber(); + var channel = new RedisChannel("__key*@*__:*", RedisChannel.PatternMode.Pattern); + await sub.SubscribeAllPrimariesAsync(channel, Accept); + + for (int i = 1; i < 3; i++) + { + await db.StringSetAsync($"k{i}", i); + } + + // Wait for notifications to be processed + await Task.Delay(1000).ForAwait(); + Assert.NotEmpty(receivedNofitications); + receivedNofitications.Clear(); + await sub.UnsubscribeAllPrimariesAsync(channel, Accept); + for (int i = 1; i < 3; i++) + { + await db.StringSetAsync($"k{i}", i); + } + await Task.Delay(1000).ForAwait(); + Assert.Empty(receivedNofitications); + + SwitchKeySpaceNofitications(redis, false); + } + + private static void SwitchKeySpaceNofitications(IInternalConnectionMultiplexer redis, bool enable = true) + { + string onOff = enable ? "KEA" : ""; + foreach (var endpoint in redis.GetEndPoints()) + { + var server = redis.GetServer(endpoint); + if (!server.IsReplica) server.ConfigSet("notify-keyspace-events", onOff); + } + } }