From 3403af4638e9cb2c4e40469b1a74f3050378c22a Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Wed, 23 Aug 2023 16:54:57 +0100 Subject: [PATCH 1/3] first stab at raw API for redis notifications --- .../ConnectionMultiplexer.cs | 101 ++++++++++++- src/StackExchange.Redis/Enums/CommandFlags.cs | 5 +- .../Interfaces/IConnectionMultiplexer.cs | 3 + src/StackExchange.Redis/Message.cs | 4 +- src/StackExchange.Redis/PhysicalBridge.cs | 23 ++- src/StackExchange.Redis/PhysicalConnection.cs | 140 +++++++++++++++++- .../PublicAPI/PublicAPI.Unshipped.txt | 8 +- src/StackExchange.Redis/RedisLiterals.cs | 9 ++ src/StackExchange.Redis/RedisSubscriber.cs | 16 +- .../ClientTrackingTests.cs | 73 +++++++++ .../Helpers/SharedConnectionFixture.cs | 2 + 11 files changed, 359 insertions(+), 25 deletions(-) create mode 100644 tests/StackExchange.Redis.Tests/ClientTrackingTests.cs diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.cs b/src/StackExchange.Redis/ConnectionMultiplexer.cs index cc239ad3f..410a5ed34 100644 --- a/src/StackExchange.Redis/ConnectionMultiplexer.cs +++ b/src/StackExchange.Redis/ConnectionMultiplexer.cs @@ -1,4 +1,7 @@ -using System; +using Microsoft.Extensions.Logging; +using Pipelines.Sockets.Unofficial; +using StackExchange.Redis.Profiling; +using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; @@ -9,10 +12,8 @@ using System.Runtime.InteropServices; using System.Text; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Pipelines.Sockets.Unofficial; -using StackExchange.Redis.Profiling; namespace StackExchange.Redis { @@ -355,6 +356,11 @@ internal void CheckMessage(Message message) { throw ExceptionFactory.TooManyArgs(message.CommandAndKey, message.ArgCount); } + + if (message.IsClientCaching && ClientSideTracking is null) + { + throw new InvalidOperationException("The " + nameof(CommandFlags.ClientCaching) + " flag can only be used if " + nameof(EnableServerAssistedClientSideTracking) + " has been called"); + } } internal bool TryResend(int hashSlot, Message message, EndPoint endpoint, bool isMoved) @@ -2268,7 +2274,7 @@ public async ValueTask DisposeAsync() public void Close(bool allowCommandsToComplete = true) { if (_isDisposed) return; - + _clientSideTracking?.Shutdown(); OnClosing(false); _isDisposed = true; _profilingSessionProvider = null; @@ -2295,6 +2301,7 @@ public void Close(bool allowCommandsToComplete = true) public async Task CloseAsync(bool allowCommandsToComplete = true) { _isDisposed = true; + _clientSideTracking?.Shutdown(); using (var tmp = pulse) { pulse = null; @@ -2341,5 +2348,89 @@ private Task[] QuitAllServers() long? IInternalConnectionMultiplexer.GetConnectionId(EndPoint endpoint, ConnectionType type) => GetServerEndPoint(endpoint)?.GetBridge(type)?.ConnectionId; + + /// + /// Enable the client tracking feature of redis + /// + /// see also https://redis.io/docs/manual/client-side-caching/ + /// The callback to be invoked when keys are determined to be invalidated + /// Additional flags to influence the behavior of client tracking + /// Optionally restricts client-side caching notifications for these connections to a subset of key prefixes; this has performance implications (see the PREFIX option in CLIENT TRACKING) + public void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default) + { + if (_clientSideTracking is not null) ThrowOnceOnly(); + var obj = new ClientSideTrackingState(keyInvalidated, options, prefixes); + if (Interlocked.CompareExchange(ref _clientSideTracking, obj, null) is not null) ThrowOnceOnly(); + + static void ThrowOnceOnly() => throw new InvalidOperationException("The " + nameof(EnableServerAssistedClientSideTracking) + " method can be invoked once-only per multiplexer instance"); + } + + private ClientSideTrackingState? _clientSideTracking; + internal ClientSideTrackingState? ClientSideTracking => _clientSideTracking; + internal sealed class ClientSideTrackingState + { + public bool IsAlive { get; private set; } + private readonly Func _keyInvalidated; + public ClientTrackingOptions Options { get; } + public ReadOnlyMemory Prefixes { get; } + + private readonly Channel _notifications; + + public ClientSideTrackingState(Func keyInvalidated, ClientTrackingOptions options, ReadOnlyMemory prefixes) + { + _keyInvalidated = keyInvalidated; + Options = options; + Prefixes = prefixes; + _notifications = Channel.CreateUnbounded(ChannelOptions); + _ = Task.Run(RunAsync); + IsAlive = true; + } + private async Task RunAsync() + { + while (await _notifications.Reader.WaitToReadAsync().ConfigureAwait(false)) + { + while (_notifications.Reader.TryRead(out var key)) + { + await _keyInvalidated(key).ConfigureAwait(false); + } + } + } + + public void Write(RedisKey key) => _notifications.Writer.TryWrite(key); + + public void Shutdown() + { + IsAlive = false; + _notifications.Writer.TryComplete(null); + } + + private static readonly UnboundedChannelOptions ChannelOptions = new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = true }; + + + } + } + + /// + /// Additional flags to influence the behavior of client tracking + /// + [Flags] + public enum ClientTrackingOptions + { + /// + /// No additional options + /// + None = 0, + /// + /// Enable tracking in broadcasting mode. In this mode invalidation messages are reported for all the prefixes specified, regardless of the keys requested by the connection. Instead when the broadcasting mode is not enabled, Redis will track which keys are fetched using read-only commands, and will report invalidation messages only for such keys. + /// + /// This corresponds to CLIENT TRACKING ... BCAST + Broadcast = 1 << 0, + /// + /// Send notifications about keys modified by this connection itself. + /// + /// This corresponds to the inverse of CLIENT TRACKING ... NOLOOP + NotifyForOwnCommands = 1 << 1, + + // to think about: OPTIN / OPTOUT ? I'm happy to implement on the basis of OPTIN for now, though } } diff --git a/src/StackExchange.Redis/Enums/CommandFlags.cs b/src/StackExchange.Redis/Enums/CommandFlags.cs index bc93f328e..ddcec208d 100644 --- a/src/StackExchange.Redis/Enums/CommandFlags.cs +++ b/src/StackExchange.Redis/Enums/CommandFlags.cs @@ -81,7 +81,10 @@ public enum CommandFlags /// NoScriptCache = 512, - // 1024: Removed - was used for async timeout checks; never user-specified, so not visible on the public API + /// + /// Indicates a command that relates to server-assisted client-side caching; this corresponds to CLIENT CACHING YES being issues before the command + /// + ClientCaching = 1024, // 2048: Use subscription connection type; never user-specified, so not visible on the public API } diff --git a/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs b/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs index 58973df68..3340d132f 100644 --- a/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs +++ b/src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs @@ -281,5 +281,8 @@ public interface IConnectionMultiplexer : IDisposable, IAsyncDisposable /// The destination stream to write the export to. /// The options to use for this export. void ExportConfiguration(Stream destination, ExportOptions options = ExportOptions.All); + + /// + void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default); } } diff --git a/src/StackExchange.Redis/Message.cs b/src/StackExchange.Redis/Message.cs index 65df7d0e7..c954b6e2f 100644 --- a/src/StackExchange.Redis/Message.cs +++ b/src/StackExchange.Redis/Message.cs @@ -75,7 +75,8 @@ internal abstract class Message : ICompletable #pragma warning restore CS0618 | CommandFlags.FireAndForget | CommandFlags.NoRedirect - | CommandFlags.NoScriptCache; + | CommandFlags.NoScriptCache + | CommandFlags.ClientCaching; private IResultBox? resultBox; private ResultProcessor? resultProcessor; @@ -197,6 +198,7 @@ public bool IsAdmin } public bool IsAsking => (Flags & AskingFlag) != 0; + public bool IsClientCaching => (Flags & CommandFlags.ClientCaching) != 0; internal bool IsScriptUnavailable => (Flags & ScriptUnavailableFlag) != 0; diff --git a/src/StackExchange.Redis/PhysicalBridge.cs b/src/StackExchange.Redis/PhysicalBridge.cs index 7041cf0af..6da85fbc6 100644 --- a/src/StackExchange.Redis/PhysicalBridge.cs +++ b/src/StackExchange.Redis/PhysicalBridge.cs @@ -23,7 +23,9 @@ internal sealed class PhysicalBridge : IDisposable private const double ProfileLogSeconds = (1000 /* ms */ * ProfileLogSamples) / 1000.0; - private static readonly Message ReusableAskingCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.ASKING); + private static readonly Message + ReusableAskingCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.ASKING), + ReusableClientCachingYesCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.CACHING, RedisLiterals.yes); private readonly long[] profileLog = new long[ProfileLogSamples]; @@ -1494,13 +1496,13 @@ private WriteResult WriteMessageToServerInsideWriteLock(PhysicalConnection conne } if (message.IsAsking) { - var asking = ReusableAskingCommand; - connection.EnqueueInsideWriteLock(asking); - asking.WriteTo(connection); - asking.SetRequestSent(); - IncrementOpCount(); + RawWriteInternalMessageInsideWriteLock(connection, ReusableAskingCommand); } } + if (message.IsClientCaching && connection.EnsureServerAssistedClientSideTrackingInsideWriteLock()) + { + RawWriteInternalMessageInsideWriteLock(connection, ReusableClientCachingYesCommand); + } switch (cmd) { case RedisCommand.WATCH: @@ -1570,6 +1572,15 @@ private WriteResult WriteMessageToServerInsideWriteLock(PhysicalConnection conne } } + internal void RawWriteInternalMessageInsideWriteLock(PhysicalConnection connection, Message message) + { + message.SetInternalCall(); + connection.EnqueueInsideWriteLock(message); + message.WriteTo(connection); + message.SetRequestSent(); + IncrementOpCount(); + } + /// /// For testing only /// diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index eb0787606..8453db220 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -19,6 +19,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using static StackExchange.Redis.Message; +using System.Threading.Channels; namespace StackExchange.Redis { @@ -1599,10 +1600,10 @@ private void MatchResult(in RawResult result) Trace("MESSAGE: " + channel); if (!channel.IsNull) { - if (TryGetPubSubPayload(items[2], out var payload)) + if (TryGetPubSubPayload(items[2], out var payload, out var source)) { _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(channel, channel, payload); + muxer.OnMessage(channel, channel, payload, source); } // could be multi-message: https://github.com/StackExchange/StackExchange.Redis/issues/2507 else if (TryGetMultiPubSubPayload(items[2], out var payloads)) @@ -1621,11 +1622,11 @@ private void MatchResult(in RawResult result) Trace("PMESSAGE: " + channel); if (!channel.IsNull) { - if (TryGetPubSubPayload(items[3], out var payload)) + if (TryGetPubSubPayload(items[3], out var payload, out var source)) { var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Pattern); _readStatus = ReadStatus.InvokePubSub; - muxer.OnMessage(sub, channel, payload); + muxer.OnMessage(sub, channel, payload, source); } else if (TryGetMultiPubSubPayload(items[3], out var payloads)) { @@ -1661,8 +1662,9 @@ private void MatchResult(in RawResult result) _readStatus = ReadStatus.MatchResultComplete; _activeMessage = null; - static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, bool allowArraySingleton = true) + static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, out RawResult source, bool allowArraySingleton = true) { + source = value; if (value.IsNull) { parsed = RedisValue.Null; @@ -1676,7 +1678,7 @@ static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, bool parsed = value.AsRedisValue(); return true; case ResultType.MultiBulk when allowArraySingleton && value.ItemsCount == 1: - return TryGetPubSubPayload(in value[0], out parsed, allowArraySingleton: false); + return TryGetPubSubPayload(in value[0], out parsed, out source, allowArraySingleton: false); } parsed = default; return false; @@ -2071,5 +2073,131 @@ internal bool HasPendingCallerFacingItems() if (lockTaken) Monitor.Exit(_writtenAwaitingResponse); } } + + private ClientTrackingState _clientTrackingState = ClientTrackingState.NotInitialized; + private enum ClientTrackingState + { + NotInitialized = 0, + Active = 1, + Broken = 2, // was active, now not + } + + internal bool EnsureServerAssistedClientSideTrackingInsideWriteLock() => + _clientTrackingState == ClientTrackingState.Active // optimize for this + || InitializeServerAssistedClientSideTrackingInsideWriteLock(); + + private bool InitializeServerAssistedClientSideTrackingInsideWriteLock() + { + var bridge = BridgeCouldBeNull; + if (bridge is null) + { + return false; // shutting down, be gentle in our nope + } + + var config = bridge.Multiplexer.ClientSideTracking; + if (config is not { IsAlive: true }) + { + return false; // not enabled (should alrady have faulted, note), or: already dead + } + + switch (_clientTrackingState) + { + case ClientTrackingState.Active: + return true; // we shouldn't be here, but: whatever + case ClientTrackingState.Broken: + bridge.RawWriteInternalMessageInsideWriteLock(this, Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.TRACKING, RedisLiterals.OFF)); + _clientTrackingState = ClientTrackingState.NotInitialized; + goto case ClientTrackingState.NotInitialized; + case ClientTrackingState.NotInitialized: + // note: this check will need to be removed in RESP3 + if (BridgeCouldBeNull?.ServerEndPoint is { SupportsSubscriptions: true } sep + && sep.GetBridge(ConnectionType.Subscription) is { IsConnected: true } sub) + { + var subId = sub.ConnectionId; + if (subId is not null) + { + // subscribe + bridge.Multiplexer.ExecuteSyncImpl(ReusableSubscribeClientCachingSubscribeMessage, null, sep, 0); + bridge.RawWriteInternalMessageInsideWriteLock(this, new ClientTrackingMessage(config, subId)); + _clientTrackingState = ClientTrackingState.Active; + return true; + } + } + return false; + default: + return false; + } + } + + private static readonly Message ReusableSubscribeClientCachingSubscribeMessage = Message.Create( + -1, CommandFlags.FireAndForget, RedisCommand.SUBSCRIBE, ConnectionMultiplexer.ClientCachingChannel); + + private sealed class ClientTrackingMessage : Message + { + private readonly ConnectionMultiplexer.ClientSideTrackingState _state; + private readonly long? _subId; // will be NULL in RESP3 + + public ClientTrackingMessage(ConnectionMultiplexer.ClientSideTrackingState state, long? subId) : base(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT) + { + _state = state; + _subId = subId; + } + + public override int ArgCount + { + get + { + var count = 3; // TRACKING ON [OPTIN] + if (_subId is not null) + { + count += 2; // [REDIRECT client-id] + } + if (!_state.Prefixes.IsEmpty) + { + count += _state.Prefixes.Length + 1; // [PREFIX prefix ...] + } + var options = _state.Options; + if ((options & ClientTrackingOptions.Broadcast) != 0) + { + count++; // [BCAST] + } + if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0) + { + count++; // [NOLOOP] + } + return count; + } + } + + protected override void WriteImpl(PhysicalConnection physical) + { + physical.WriteHeader(Command, ArgCount); + physical.WriteBulkString(RedisLiterals.TRACKING); + physical.WriteBulkString(RedisLiterals.ON); + if (_subId is not null) + { + physical.WriteBulkString(RedisLiterals.REDIRECT); + physical.WriteBulkString(_subId.GetValueOrDefault()); + } + if (!_state.Prefixes.IsEmpty) + { + physical.WriteBulkString(RedisLiterals.PREFIX); + foreach (ref readonly RedisKey prefix in _state.Prefixes.Span) + { + physical.Write(in prefix); + } + } + var options = _state.Options; + if ((options & ClientTrackingOptions.Broadcast) != 0) + { + physical.WriteBulkString(RedisLiterals.BCAST); + } + physical.WriteBulkString(RedisLiterals.OPTIN); + if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0) + { + physical.WriteBulkString(RedisLiterals.NOLOOP); + } + } + } } } diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt index 5f282702b..b165bb107 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt @@ -1 +1,7 @@ - \ No newline at end of file +StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.Broadcast = 1 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.None = 0 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.NotifyForOwnCommands = 2 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.CommandFlags.ClientCaching = 1024 -> StackExchange.Redis.CommandFlags +StackExchange.Redis.ConnectionMultiplexer.EnableServerAssistedClientSideTracking(System.Func! keyInvalidated, StackExchange.Redis.ClientTrackingOptions options = StackExchange.Redis.ClientTrackingOptions.None, System.ReadOnlyMemory prefixes = default(System.ReadOnlyMemory)) -> void +StackExchange.Redis.IConnectionMultiplexer.EnableServerAssistedClientSideTracking(System.Func! keyInvalidated, StackExchange.Redis.ClientTrackingOptions options = StackExchange.Redis.ClientTrackingOptions.None, System.ReadOnlyMemory prefixes = default(System.ReadOnlyMemory)) -> void \ No newline at end of file diff --git a/src/StackExchange.Redis/RedisLiterals.cs b/src/StackExchange.Redis/RedisLiterals.cs index e926b6da4..f2429b640 100644 --- a/src/StackExchange.Redis/RedisLiterals.cs +++ b/src/StackExchange.Redis/RedisLiterals.cs @@ -50,12 +50,14 @@ public static readonly RedisValue AND = "AND", ANY = "ANY", ASC = "ASC", + BCAST = "BCAST", BEFORE = "BEFORE", BIT = "BIT", BY = "BY", BYLEX = "BYLEX", BYSCORE = "BYSCORE", BYTE = "BYTE", + CACHING = "CACHING", CH = "CH", CHANNELS = "CHANNELS", COPY = "COPY", @@ -97,21 +99,27 @@ public static readonly RedisValue MINMATCHLEN = "MINMATCHLEN", MODULE = "MODULE", NODES = "NODES", + NOLOOP = "NOLOOP", NOSAVE = "NOSAVE", NOT = "NOT", NUMPAT = "NUMPAT", NUMSUB = "NUMSUB", NX = "NX", OBJECT = "OBJECT", + OFF = "OFF", + OPTIN = "OPTIN", OR = "OR", + ON = "ON", PATTERN = "PATTERN", PAUSE = "PAUSE", PERSIST = "PERSIST", PING = "PING", + PREFIX = "PREFIX", PURGE = "PURGE", PX = "PX", PXAT = "PXAT", RANK = "RANK", + REDIRECT = "REDIRECT", REFCOUNT = "REFCOUNT", REPLACE = "REPLACE", RESET = "RESET", @@ -127,6 +135,7 @@ public static readonly RedisValue SKIPME = "SKIPME", STATS = "STATS", STORE = "STORE", + TRACKING = "TRACKING", TYPE = "TYPE", WEIGHTS = "WEIGHTS", WITHMATCHLEN = "WITHMATCHLEN", diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index 5a24a716e..067f5c263 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -75,7 +75,7 @@ internal bool GetSubscriberCounts(in RedisChannel channel, out int handlers, out /// /// Handler that executes whenever a message comes in, this doles out messages to any registered handlers. /// - internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload) + internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload, in RawResult rawPayload) { ICompletable? completable = null; ChannelMessageQueue? queues = null; @@ -91,22 +91,28 @@ internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, i { CompleteAsWorker(completable); } + if (subscription == ClientCachingChannel) + { + _clientSideTracking?.Write(rawPayload.AsRedisKey()); + } } + internal static readonly RedisChannel ClientCachingChannel = RedisChannel.Literal("__redis__:invalidate"); + internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, Sequence payload) { if (payload.IsSingleSegment) { - foreach (var message in payload.FirstSpan) + foreach (ref readonly RawResult message in payload.FirstSpan) { - OnMessage(subscription, channel, message.AsRedisValue()); + OnMessage(subscription, channel, message.AsRedisValue(), in message); } } else { - foreach (var message in payload) + foreach (ref readonly RawResult message in payload) { - OnMessage(subscription, channel, message.AsRedisValue()); + OnMessage(subscription, channel, message.AsRedisValue(), in message); } } } diff --git a/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs b/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs new file mode 100644 index 000000000..561109daf --- /dev/null +++ b/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs @@ -0,0 +1,73 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace StackExchange.Redis.Tests; + +/// +/// Tests for . +/// +[Collection(SharedConnectionFixture.Key)] +public class ClientTrackingTests : TestBase +{ + public ClientTrackingTests(ITestOutputHelper output, SharedConnectionFixture fixture) : base(output, fixture) { } + + [Fact] + public async Task UseFlagWithoutEnabling() + { + using var conn = Create(shared: false); + var key = Me(); + var ex = await Assert.ThrowsAsync( + async () => await conn.GetDatabase().StringGetAsync(key, CommandFlags.ClientCaching) + ); + Assert.Equal("The ClientCaching flag can only be used if EnableServerAssistedClientSideTracking has been called", ex.Message); + } + + [Fact] + public void CallEnableTwice() + { + using var conn = Create(shared: false); + conn.EnableServerAssistedClientSideTracking(key => default); + var ex = Assert.Throws(() => conn.EnableServerAssistedClientSideTracking(key => default)); + Assert.Equal("The EnableServerAssistedClientSideTracking method can be invoked once-only per multiplexer instance", ex.Message); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(false, true)] + public async void GetNotification(bool listenToSelf, bool externalConnectionMakesChange) + { + bool expectNotification = listenToSelf || externalConnectionMakesChange; + + using var listen = Create(shared: false); + using var send = externalConnectionMakesChange ? Create() : listen; + + int value = (new Random().Next() % 1024) + 1024, notifyCount = 0; + + var key = Me(); + var db = listen.GetDatabase(); + db.KeyDelete(key); + db.StringSet(key, value); + + var options = listenToSelf ? ClientTrackingOptions.NotifyForOwnCommands : ClientTrackingOptions.None; + listen.EnableServerAssistedClientSideTracking(rkey => + { + if (rkey == key) Interlocked.Increment(ref notifyCount); + return default; + }, options); + + Assert.Equal(value, db.StringGet(key, CommandFlags.ClientCaching)); + Assert.Equal(0, Volatile.Read(ref notifyCount)); + + send.GetDatabase().StringIncrement(key, 5); + await Task.Delay(100); // allow time for the magic to happen + + Assert.Equal(expectNotification ? 1 : 0, Volatile.Read(ref notifyCount)); + Assert.Equal(value + 5, db.StringGet(key, CommandFlags.ClientCaching)); + + } +} diff --git a/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs b/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs index f61e73e32..78e20a539 100644 --- a/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs +++ b/tests/StackExchange.Redis.Tests/Helpers/SharedConnectionFixture.cs @@ -185,6 +185,8 @@ public void ExportConfiguration(Stream destination, ExportOptions options = Expo public override string ToString() => _inner.ToString(); long? IInternalConnectionMultiplexer.GetConnectionId(EndPoint endPoint, ConnectionType type) => _inner.GetConnectionId(endPoint, type); + public void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default) + => _inner.EnableServerAssistedClientSideTracking(keyInvalidated, options, prefixes); } public void Dispose() From 27380e5b750f2074161605a1e73ff1d2a3f718fa Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Tue, 29 Aug 2023 13:03:01 +0100 Subject: [PATCH 2/3] - fix broadcast nuances - implement concurrent callback --- ...onnectionMultiplexer.ClientSideTracking.cs | 176 ++++++++++++++++++ .../ConnectionMultiplexer.Events.cs | 6 + .../ConnectionMultiplexer.cs | 85 --------- src/StackExchange.Redis/PhysicalBridge.cs | 1 + src/StackExchange.Redis/PhysicalConnection.cs | 60 ++++-- .../PublicAPI/PublicAPI.Unshipped.txt | 1 + src/StackExchange.Redis/ServerEndPoint.cs | 2 + .../ClientTrackingTests.cs | 39 +++- 8 files changed, 261 insertions(+), 109 deletions(-) create mode 100644 src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs b/src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs new file mode 100644 index 000000000..88705057d --- /dev/null +++ b/src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs @@ -0,0 +1,176 @@ +using Microsoft.Extensions.Logging; +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; + +namespace StackExchange.Redis; + +public partial class ConnectionMultiplexer +{ + /// + /// Enable the client tracking feature of redis + /// + /// see also https://redis.io/docs/manual/client-side-caching/ + /// The callback to be invoked when keys are determined to be invalidated + /// Additional flags to influence the behavior of client tracking + /// Optionally restricts client-side caching notifications for these connections to a subset of key prefixes; this has performance implications (see the PREFIX option in CLIENT TRACKING) + public void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default) + { + if (_clientSideTracking is not null) ThrowOnceOnly(); + if (!prefixes.IsEmpty && (options & ClientTrackingOptions.Broadcast) == 0) ThrowPrefixNeedsBroadcast(); + var obj = new ClientSideTrackingState(this, keyInvalidated, options, prefixes); + if (Interlocked.CompareExchange(ref _clientSideTracking, obj, null) is not null) ThrowOnceOnly(); + + static void ThrowOnceOnly() => throw new InvalidOperationException("The " + nameof(EnableServerAssistedClientSideTracking) + " method can be invoked once-only per multiplexer instance"); + static void ThrowPrefixNeedsBroadcast() => throw new ArgumentException("Prefixes can only be specified when " + nameof(ClientTrackingOptions) + "." + nameof(ClientTrackingOptions.Broadcast) + " is used", nameof(prefixes)); + } + + private ClientSideTrackingState? _clientSideTracking; + internal ClientSideTrackingState? ClientSideTracking => _clientSideTracking; + internal sealed class ClientSideTrackingState + { + public bool IsAlive { get; private set; } + private readonly Func _keyInvalidated; + public ClientTrackingOptions Options { get; } + public ReadOnlyMemory Prefixes { get; } + + private readonly Channel _notifications; + private readonly WeakReference _multiplexer; +#if NETCOREAPP3_1_OR_GREATER + private readonly Action? _concurrentCallback; +#else + private readonly WaitCallback? _concurrentCallback; +#endif + + public ClientSideTrackingState(ConnectionMultiplexer multiplexer, Func keyInvalidated, ClientTrackingOptions options, ReadOnlyMemory prefixes) + { + _keyInvalidated = keyInvalidated; + Options = options; + Prefixes = prefixes; + _notifications = Channel.CreateUnbounded(ChannelOptions); + _ = Task.Run(RunAsync); + IsAlive = true; + _multiplexer = new(multiplexer); + + if ((options & ClientTrackingOptions.ConcurrentInvalidation) != 0) + { + _concurrentCallback = OnInvalidate; + } + } + +#if !NETCOREAPP3_1_OR_GREATER + private void OnInvalidate(object state) => OnInvalidate((RedisKey)state); +#endif + + private void OnInvalidate(RedisKey key) + { + try // not optimized for sync completions + { + var pending = _keyInvalidated(key); + if (pending.IsCompleted) + { // observe result + pending.GetAwaiter().GetResult(); + } + else + { + _ = ObserveAsyncInvalidation(pending); + } + } + catch (Exception ex) // handle sync failure (via immediate throw or faulted ValueTask) + { + OnCallbackError(ex); + } + } + + private async Task ObserveAsyncInvalidation(ValueTask pending) + { + try + { + await pending.ConfigureAwait(false); + } + catch (Exception ex) + { + OnCallbackError(ex); + } + } + + private ConnectionMultiplexer? Multiplexer => _multiplexer.TryGetTarget(out var multiplexer) ? multiplexer : null; + + + private void OnCallbackError(Exception error) => Multiplexer?.Logger?.LogError(error, "Client-side tracking invalidation callback failure"); + + private async Task RunAsync() + { + while (await _notifications.Reader.WaitToReadAsync().ConfigureAwait(false)) + { + while (_notifications.Reader.TryRead(out var key)) + { + if (_concurrentCallback is not null) + { +#if NETCOREAPP3_1_OR_GREATER + ThreadPool.QueueUserWorkItem(_concurrentCallback, key, preferLocal: false); +#else + // eat the box + ThreadPool.QueueUserWorkItem(_concurrentCallback, key); +#endif + } + else + { + try + { + await _keyInvalidated(key).ConfigureAwait(false); + } + catch (Exception ex) + { + OnCallbackError(ex); + } + } + } + } + } + + public void Write(RedisKey key) => _notifications.Writer.TryWrite(key); + + public void Shutdown() + { + IsAlive = false; + _notifications.Writer.TryComplete(null); + } + + private static readonly UnboundedChannelOptions ChannelOptions = new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = true }; + + + } +} + +/// +/// Additional flags to influence the behavior of client tracking +/// +[Flags] +public enum ClientTrackingOptions +{ + /// + /// No additional options + /// + None = 0, + /// + /// Enable tracking in broadcasting mode. In this mode invalidation messages are reported for all the prefixes specified, regardless of the keys requested by the connection. Instead when the broadcasting mode is not enabled, Redis will track which keys are fetched using read-only commands, and will report invalidation messages only for such keys. + /// + /// This corresponds to CLIENT TRACKING ... BCAST; using mode consumes less server memory, at the cost of more invalidation messages (i.e. clients are + /// likely to receive invalidation messages for keys that the individual client is not using); this can be partially mitigated by using prefixes + Broadcast = 1 << 0, + /// + /// Send notifications about keys modified by this connection itself. + /// + /// This corresponds to the inverse of CLIENT TRACKING ... NOLOOP; setting means that your own writes will cause self-notification; this + /// may mean that you discard a locally updated copy of the new value, hence this is disabled by default + NotifyForOwnCommands = 1 << 1, + + /// + /// Indicates that the callback specified for key invalidation should be invoked concurrently rather than sequentially + /// + ConcurrentInvalidation = 1 << 2, + + // to think about: OPTIN / OPTOUT ? I'm happy to implement on the basis of OPTIN for now, though +} diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs b/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs index 0a8b95be5..707ba7bf6 100644 --- a/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs +++ b/src/StackExchange.Redis/ConnectionMultiplexer.Events.cs @@ -14,6 +14,12 @@ public partial class ConnectionMultiplexer internal void OnConnectionFailed(EndPoint endpoint, ConnectionType connectionType, ConnectionFailureType failureType, Exception exception, bool reconfigure, string? physicalName) { if (_isDisposed) return; + + if (connectionType is ConnectionType.Subscription) + { + GetServerEndPoint(endpoint, activate: false)?.OnSubscriberFailed(); + } + var handler = ConnectionFailed; if (handler != null) { diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.cs b/src/StackExchange.Redis/ConnectionMultiplexer.cs index 410a5ed34..75060fb01 100644 --- a/src/StackExchange.Redis/ConnectionMultiplexer.cs +++ b/src/StackExchange.Redis/ConnectionMultiplexer.cs @@ -12,7 +12,6 @@ using System.Runtime.InteropServices; using System.Text; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; namespace StackExchange.Redis @@ -2348,89 +2347,5 @@ private Task[] QuitAllServers() long? IInternalConnectionMultiplexer.GetConnectionId(EndPoint endpoint, ConnectionType type) => GetServerEndPoint(endpoint)?.GetBridge(type)?.ConnectionId; - - /// - /// Enable the client tracking feature of redis - /// - /// see also https://redis.io/docs/manual/client-side-caching/ - /// The callback to be invoked when keys are determined to be invalidated - /// Additional flags to influence the behavior of client tracking - /// Optionally restricts client-side caching notifications for these connections to a subset of key prefixes; this has performance implications (see the PREFIX option in CLIENT TRACKING) - public void EnableServerAssistedClientSideTracking(Func keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory prefixes = default) - { - if (_clientSideTracking is not null) ThrowOnceOnly(); - var obj = new ClientSideTrackingState(keyInvalidated, options, prefixes); - if (Interlocked.CompareExchange(ref _clientSideTracking, obj, null) is not null) ThrowOnceOnly(); - - static void ThrowOnceOnly() => throw new InvalidOperationException("The " + nameof(EnableServerAssistedClientSideTracking) + " method can be invoked once-only per multiplexer instance"); - } - - private ClientSideTrackingState? _clientSideTracking; - internal ClientSideTrackingState? ClientSideTracking => _clientSideTracking; - internal sealed class ClientSideTrackingState - { - public bool IsAlive { get; private set; } - private readonly Func _keyInvalidated; - public ClientTrackingOptions Options { get; } - public ReadOnlyMemory Prefixes { get; } - - private readonly Channel _notifications; - - public ClientSideTrackingState(Func keyInvalidated, ClientTrackingOptions options, ReadOnlyMemory prefixes) - { - _keyInvalidated = keyInvalidated; - Options = options; - Prefixes = prefixes; - _notifications = Channel.CreateUnbounded(ChannelOptions); - _ = Task.Run(RunAsync); - IsAlive = true; - } - private async Task RunAsync() - { - while (await _notifications.Reader.WaitToReadAsync().ConfigureAwait(false)) - { - while (_notifications.Reader.TryRead(out var key)) - { - await _keyInvalidated(key).ConfigureAwait(false); - } - } - } - - public void Write(RedisKey key) => _notifications.Writer.TryWrite(key); - - public void Shutdown() - { - IsAlive = false; - _notifications.Writer.TryComplete(null); - } - - private static readonly UnboundedChannelOptions ChannelOptions = new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = true }; - - - } - } - - /// - /// Additional flags to influence the behavior of client tracking - /// - [Flags] - public enum ClientTrackingOptions - { - /// - /// No additional options - /// - None = 0, - /// - /// Enable tracking in broadcasting mode. In this mode invalidation messages are reported for all the prefixes specified, regardless of the keys requested by the connection. Instead when the broadcasting mode is not enabled, Redis will track which keys are fetched using read-only commands, and will report invalidation messages only for such keys. - /// - /// This corresponds to CLIENT TRACKING ... BCAST - Broadcast = 1 << 0, - /// - /// Send notifications about keys modified by this connection itself. - /// - /// This corresponds to the inverse of CLIENT TRACKING ... NOLOOP - NotifyForOwnCommands = 1 << 1, - - // to think about: OPTIN / OPTOUT ? I'm happy to implement on the basis of OPTIN for now, though } } diff --git a/src/StackExchange.Redis/PhysicalBridge.cs b/src/StackExchange.Redis/PhysicalBridge.cs index 6da85fbc6..8394c06e1 100644 --- a/src/StackExchange.Redis/PhysicalBridge.cs +++ b/src/StackExchange.Redis/PhysicalBridge.cs @@ -1594,5 +1594,6 @@ internal void SimulateConnectionFailure(SimulatedFailureType failureType) } internal RedisCommand? GetActiveMessage() => Volatile.Read(ref _activeMessage)?.Command; + internal void OnSubscriberFailed() => physical?.OnSubscriberFailed(); } } diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index 8453db220..eb66f0bc7 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -2074,17 +2074,31 @@ internal bool HasPendingCallerFacingItems() } } - private ClientTrackingState _clientTrackingState = ClientTrackingState.NotInitialized; + private volatile ClientTrackingState _clientTrackingState = ClientTrackingState.NotInitialized; private enum ClientTrackingState { NotInitialized = 0, - Active = 1, - Broken = 2, // was active, now not + ActiveSingleConnectionPerItemTracking = 1, + ActiveSplitConnectionPerItemTracking = 2, + ActiveSingleConnectionBroadcast = 3, + ActiveSplitConnectionBroadcast = 4, + Broken = 10, // was active, now not } + /// + /// initializes client caching state and returns True if CLIENT CACHING YES should be sent + /// internal bool EnsureServerAssistedClientSideTrackingInsideWriteLock() => - _clientTrackingState == ClientTrackingState.Active // optimize for this - || InitializeServerAssistedClientSideTrackingInsideWriteLock(); + _clientTrackingState switch + { + ClientTrackingState.ActiveSingleConnectionPerItemTracking => true, + ClientTrackingState.ActiveSplitConnectionPerItemTracking => true, + // don't add CLIENT CACHING per-item when in broadcast mode + ClientTrackingState.ActiveSingleConnectionBroadcast => false, + ClientTrackingState.ActiveSplitConnectionBroadcast => false, + // anything else? slow mode + _ => InitializeServerAssistedClientSideTrackingInsideWriteLock() + }; private bool InitializeServerAssistedClientSideTrackingInsideWriteLock() { @@ -2097,13 +2111,17 @@ private bool InitializeServerAssistedClientSideTrackingInsideWriteLock() var config = bridge.Multiplexer.ClientSideTracking; if (config is not { IsAlive: true }) { - return false; // not enabled (should alrady have faulted, note), or: already dead + return false; // not enabled (should already have faulted, note), or: already dead } switch (_clientTrackingState) { - case ClientTrackingState.Active: + case ClientTrackingState.ActiveSingleConnectionPerItemTracking: + case ClientTrackingState.ActiveSplitConnectionPerItemTracking: return true; // we shouldn't be here, but: whatever + case ClientTrackingState.ActiveSingleConnectionBroadcast: + case ClientTrackingState.ActiveSplitConnectionBroadcast: + return false; // we shouldn't be here, but: whatever case ClientTrackingState.Broken: bridge.RawWriteInternalMessageInsideWriteLock(this, Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.TRACKING, RedisLiterals.OFF)); _clientTrackingState = ClientTrackingState.NotInitialized; @@ -2119,13 +2137,22 @@ private bool InitializeServerAssistedClientSideTrackingInsideWriteLock() // subscribe bridge.Multiplexer.ExecuteSyncImpl(ReusableSubscribeClientCachingSubscribeMessage, null, sep, 0); bridge.RawWriteInternalMessageInsideWriteLock(this, new ClientTrackingMessage(config, subId)); - _clientTrackingState = ClientTrackingState.Active; + _clientTrackingState = (config.Options & ClientTrackingOptions.Broadcast) == 0 ? ClientTrackingState.ActiveSplitConnectionPerItemTracking : ClientTrackingState.ActiveSplitConnectionBroadcast; return true; } } - return false; + return false; // unable to initialize; connections unavailable or similar default: - return false; + return false; // unknown state + } + } + + internal void OnSubscriberFailed() + { + // if in split connection mode, then: our notifications have failed and we need to reset + if (_clientTrackingState is ClientTrackingState.ActiveSplitConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionBroadcast) + { + _clientTrackingState = ClientTrackingState.Broken; } } @@ -2147,7 +2174,7 @@ public override int ArgCount { get { - var count = 3; // TRACKING ON [OPTIN] + var count = 3; // TRACKING ON {OPTIN|BCAST} if (_subId is not null) { count += 2; // [REDIRECT client-id] @@ -2157,10 +2184,6 @@ public override int ArgCount count += _state.Prefixes.Length + 1; // [PREFIX prefix ...] } var options = _state.Options; - if ((options & ClientTrackingOptions.Broadcast) != 0) - { - count++; // [BCAST] - } if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0) { count++; // [NOLOOP] @@ -2188,11 +2211,14 @@ protected override void WriteImpl(PhysicalConnection physical) } } var options = _state.Options; - if ((options & ClientTrackingOptions.Broadcast) != 0) + if ((options & ClientTrackingOptions.Broadcast) == 0) + { + physical.WriteBulkString(RedisLiterals.OPTIN); + } + else { physical.WriteBulkString(RedisLiterals.BCAST); } - physical.WriteBulkString(RedisLiterals.OPTIN); if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0) { physical.WriteBulkString(RedisLiterals.NOLOOP); diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt index b165bb107..0e2cf28f5 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt @@ -1,5 +1,6 @@ StackExchange.Redis.ClientTrackingOptions StackExchange.Redis.ClientTrackingOptions.Broadcast = 1 -> StackExchange.Redis.ClientTrackingOptions +StackExchange.Redis.ClientTrackingOptions.ConcurrentInvalidation = 4 -> StackExchange.Redis.ClientTrackingOptions StackExchange.Redis.ClientTrackingOptions.None = 0 -> StackExchange.Redis.ClientTrackingOptions StackExchange.Redis.ClientTrackingOptions.NotifyForOwnCommands = 2 -> StackExchange.Redis.ClientTrackingOptions StackExchange.Redis.CommandFlags.ClientCaching = 1024 -> StackExchange.Redis.CommandFlags diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index a90023580..0791d5ab6 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -1032,5 +1032,7 @@ internal bool HasPendingCallerFacingItems() if (interactive?.HasPendingCallerFacingItems() == true) return true; return subscription?.HasPendingCallerFacingItems() ?? false; } + + internal void OnSubscriberFailed() => interactive?.OnSubscriberFailed(); } } diff --git a/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs b/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs index 561109daf..7fc806587 100644 --- a/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs +++ b/tests/StackExchange.Redis.Tests/ClientTrackingTests.cs @@ -34,14 +34,40 @@ public void CallEnableTwice() Assert.Equal("The EnableServerAssistedClientSideTracking method can be invoked once-only per multiplexer instance", ex.Message); } + [Fact] + public void UsePrefixesWithoutBroadcast() + { + using var conn = Create(shared: false); + var ex = Assert.Throws(() => conn.EnableServerAssistedClientSideTracking(key => default, prefixes: new RedisKey[] { "abc" })); + Assert.StartsWith("Prefixes can only be specified when ClientTrackingOptions.Broadcast is used", ex.Message); + Assert.Equal("prefixes", ex.ParamName); + } + + [Theory] + [InlineData(ClientTrackingOptions.None)] + [InlineData(ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + public Task GetNotificationFromOwnConnection(ClientTrackingOptions options) => GetNotification(options, false); + [Theory] - [InlineData(true, true)] - [InlineData(true, false)] - [InlineData(false, false)] - [InlineData(false, true)] - public async void GetNotification(bool listenToSelf, bool externalConnectionMakesChange) + [InlineData(ClientTrackingOptions.None)] + [InlineData(ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.NotifyForOwnCommands)] + [InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)] + public Task GetNotificationFromExternalConnection(ClientTrackingOptions options) => GetNotification(options, true); + + private async Task GetNotification(ClientTrackingOptions options, bool externalConnectionMakesChange) { - bool expectNotification = listenToSelf || externalConnectionMakesChange; + bool expectNotification = ((options & ClientTrackingOptions.NotifyForOwnCommands) != 0) || externalConnectionMakesChange; using var listen = Create(shared: false); using var send = externalConnectionMakesChange ? Create() : listen; @@ -53,7 +79,6 @@ public async void GetNotification(bool listenToSelf, bool externalConnectionMake db.KeyDelete(key); db.StringSet(key, value); - var options = listenToSelf ? ClientTrackingOptions.NotifyForOwnCommands : ClientTrackingOptions.None; listen.EnableServerAssistedClientSideTracking(rkey => { if (rkey == key) Interlocked.Increment(ref notifyCount); From 3cacf8fa0bb294032cc1a3c71050e50cccc793cd Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Tue, 29 Aug 2023 13:22:13 +0100 Subject: [PATCH 3/3] deal with concurrency on OnSubscriberFailed() --- src/StackExchange.Redis/PhysicalConnection.cs | 77 +++++++++++-------- 1 file changed, 44 insertions(+), 33 deletions(-) diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index eb66f0bc7..79926ada1 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -2074,7 +2074,7 @@ internal bool HasPendingCallerFacingItems() } } - private volatile ClientTrackingState _clientTrackingState = ClientTrackingState.NotInitialized; + private int _clientTrackingState = (int)ClientTrackingState.NotInitialized; private enum ClientTrackingState { NotInitialized = 0, @@ -2085,11 +2085,13 @@ private enum ClientTrackingState Broken = 10, // was active, now not } + private ClientTrackingState GetClientTrackingState() => (ClientTrackingState)Volatile.Read(ref _clientTrackingState); + /// /// initializes client caching state and returns True if CLIENT CACHING YES should be sent /// internal bool EnsureServerAssistedClientSideTrackingInsideWriteLock() => - _clientTrackingState switch + GetClientTrackingState() switch { ClientTrackingState.ActiveSingleConnectionPerItemTracking => true, ClientTrackingState.ActiveSplitConnectionPerItemTracking => true, @@ -2114,45 +2116,54 @@ private bool InitializeServerAssistedClientSideTrackingInsideWriteLock() return false; // not enabled (should already have faulted, note), or: already dead } - switch (_clientTrackingState) - { - case ClientTrackingState.ActiveSingleConnectionPerItemTracking: - case ClientTrackingState.ActiveSplitConnectionPerItemTracking: - return true; // we shouldn't be here, but: whatever - case ClientTrackingState.ActiveSingleConnectionBroadcast: - case ClientTrackingState.ActiveSplitConnectionBroadcast: - return false; // we shouldn't be here, but: whatever - case ClientTrackingState.Broken: - bridge.RawWriteInternalMessageInsideWriteLock(this, Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.TRACKING, RedisLiterals.OFF)); - _clientTrackingState = ClientTrackingState.NotInitialized; - goto case ClientTrackingState.NotInitialized; - case ClientTrackingState.NotInitialized: - // note: this check will need to be removed in RESP3 - if (BridgeCouldBeNull?.ServerEndPoint is { SupportsSubscriptions: true } sep - && sep.GetBridge(ConnectionType.Subscription) is { IsConnected: true } sub) - { - var subId = sub.ConnectionId; - if (subId is not null) + ClientTrackingState oldState, newState; + do + { + switch (oldState = GetClientTrackingState()) + { + case ClientTrackingState.ActiveSingleConnectionPerItemTracking: + case ClientTrackingState.ActiveSplitConnectionPerItemTracking: + return true; // we shouldn't be here, but: whatever + case ClientTrackingState.ActiveSingleConnectionBroadcast: + case ClientTrackingState.ActiveSplitConnectionBroadcast: + return false; // we shouldn't be here, but: whatever + case ClientTrackingState.Broken: + bridge.RawWriteInternalMessageInsideWriteLock(this, Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.TRACKING, RedisLiterals.OFF)); + oldState = ClientTrackingState.NotInitialized; // ack that we've reset things + Volatile.Write(ref _clientTrackingState, (int)oldState); + goto case ClientTrackingState.NotInitialized; + case ClientTrackingState.NotInitialized: + // note: this check will need to be removed in RESP3 + if (BridgeCouldBeNull?.ServerEndPoint is { SupportsSubscriptions: true } sep + && sep.GetBridge(ConnectionType.Subscription) is { IsConnected: true } sub) { - // subscribe - bridge.Multiplexer.ExecuteSyncImpl(ReusableSubscribeClientCachingSubscribeMessage, null, sep, 0); - bridge.RawWriteInternalMessageInsideWriteLock(this, new ClientTrackingMessage(config, subId)); - _clientTrackingState = (config.Options & ClientTrackingOptions.Broadcast) == 0 ? ClientTrackingState.ActiveSplitConnectionPerItemTracking : ClientTrackingState.ActiveSplitConnectionBroadcast; - return true; + var subId = sub.ConnectionId; + if (subId is not null) + { + // subscribe + bridge.Multiplexer.ExecuteSyncImpl(ReusableSubscribeClientCachingSubscribeMessage, null, sep, 0); + bridge.RawWriteInternalMessageInsideWriteLock(this, new ClientTrackingMessage(config, subId)); + newState = (config.Options & ClientTrackingOptions.Broadcast) == 0 ? ClientTrackingState.ActiveSplitConnectionPerItemTracking : ClientTrackingState.ActiveSplitConnectionBroadcast; + break; + } } - } - return false; // unable to initialize; connections unavailable or similar - default: - return false; // unknown state - } + return false; // unable to initialize; connections unavailable or similar + default: + return false; // unknown state + } + } while (Interlocked.CompareExchange(ref _clientTrackingState, (int)newState, (int)oldState) != (int)oldState); // redo from start if fighting with OnSubscriberFailed, which is only for the "Broken" scenario + + // we're now in a known state; we only issue CLIENT CACHING YES if we're in per-item tracking mode + return newState is ClientTrackingState.ActiveSingleConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionPerItemTracking; + } internal void OnSubscriberFailed() { // if in split connection mode, then: our notifications have failed and we need to reset - if (_clientTrackingState is ClientTrackingState.ActiveSplitConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionBroadcast) + if (GetClientTrackingState() is ClientTrackingState.ActiveSplitConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionBroadcast) { - _clientTrackingState = ClientTrackingState.Broken; + Volatile.Write(ref _clientTrackingState, (int)ClientTrackingState.Broken); } }