Skip to content

first stab at raw API for redis notifications #2527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions src/StackExchange.Redis/ConnectionMultiplexer.ClientSideTracking.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
using Microsoft.Extensions.Logging;
using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;

namespace StackExchange.Redis;

public partial class ConnectionMultiplexer
{
/// <summary>
/// Enable the <a href="https://redis.io/commands/client-tracking/">client tracking</a> feature of redis
/// </summary>
/// <remarks>see also https://redis.io/docs/manual/client-side-caching/</remarks>
/// <param name="keyInvalidated">The callback to be invoked when keys are determined to be invalidated</param>
/// <param name="options">Additional flags to influence the behavior of client tracking</param>
/// <param name="prefixes">Optionally restricts client-side caching notifications for these connections to a subset of key prefixes; this has performance implications (see the PREFIX option in CLIENT TRACKING)</param>
public void EnableServerAssistedClientSideTracking(Func<RedisKey, ValueTask> keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory<RedisKey> prefixes = default)
{
if (_clientSideTracking is not null) ThrowOnceOnly();
if (!prefixes.IsEmpty && (options & ClientTrackingOptions.Broadcast) == 0) ThrowPrefixNeedsBroadcast();
var obj = new ClientSideTrackingState(this, keyInvalidated, options, prefixes);
if (Interlocked.CompareExchange(ref _clientSideTracking, obj, null) is not null) ThrowOnceOnly();

static void ThrowOnceOnly() => throw new InvalidOperationException("The " + nameof(EnableServerAssistedClientSideTracking) + " method can be invoked once-only per multiplexer instance");
static void ThrowPrefixNeedsBroadcast() => throw new ArgumentException("Prefixes can only be specified when " + nameof(ClientTrackingOptions) + "." + nameof(ClientTrackingOptions.Broadcast) + " is used", nameof(prefixes));
}

private ClientSideTrackingState? _clientSideTracking;
internal ClientSideTrackingState? ClientSideTracking => _clientSideTracking;
internal sealed class ClientSideTrackingState
{
public bool IsAlive { get; private set; }
private readonly Func<RedisKey, ValueTask> _keyInvalidated;
public ClientTrackingOptions Options { get; }
public ReadOnlyMemory<RedisKey> Prefixes { get; }

private readonly Channel<RedisKey> _notifications;
private readonly WeakReference<ConnectionMultiplexer> _multiplexer;
#if NETCOREAPP3_1_OR_GREATER
private readonly Action<RedisKey>? _concurrentCallback;
#else
private readonly WaitCallback? _concurrentCallback;
#endif

public ClientSideTrackingState(ConnectionMultiplexer multiplexer, Func<RedisKey, ValueTask> keyInvalidated, ClientTrackingOptions options, ReadOnlyMemory<RedisKey> prefixes)
{
_keyInvalidated = keyInvalidated;
Options = options;
Prefixes = prefixes;
_notifications = Channel.CreateUnbounded<RedisKey>(ChannelOptions);
_ = Task.Run(RunAsync);
IsAlive = true;
_multiplexer = new(multiplexer);

if ((options & ClientTrackingOptions.ConcurrentInvalidation) != 0)
{
_concurrentCallback = OnInvalidate;
}
}

#if !NETCOREAPP3_1_OR_GREATER
private void OnInvalidate(object state) => OnInvalidate((RedisKey)state);
#endif

private void OnInvalidate(RedisKey key)
{
try // not optimized for sync completions
{
var pending = _keyInvalidated(key);
if (pending.IsCompleted)
{ // observe result
pending.GetAwaiter().GetResult();
}
else
{
_ = ObserveAsyncInvalidation(pending);
}
}
catch (Exception ex) // handle sync failure (via immediate throw or faulted ValueTask)
{
OnCallbackError(ex);
}
}

private async Task ObserveAsyncInvalidation(ValueTask pending)
{
try
{
await pending.ConfigureAwait(false);
}
catch (Exception ex)
{
OnCallbackError(ex);
}
}

private ConnectionMultiplexer? Multiplexer => _multiplexer.TryGetTarget(out var multiplexer) ? multiplexer : null;


private void OnCallbackError(Exception error) => Multiplexer?.Logger?.LogError(error, "Client-side tracking invalidation callback failure");

private async Task RunAsync()
{
while (await _notifications.Reader.WaitToReadAsync().ConfigureAwait(false))
{
while (_notifications.Reader.TryRead(out var key))
{
if (_concurrentCallback is not null)
{
#if NETCOREAPP3_1_OR_GREATER
ThreadPool.QueueUserWorkItem(_concurrentCallback, key, preferLocal: false);
#else
// eat the box
ThreadPool.QueueUserWorkItem(_concurrentCallback, key);
#endif
}
else
{
try
{
await _keyInvalidated(key).ConfigureAwait(false);
}
catch (Exception ex)
{
OnCallbackError(ex);
}
}
}
}
}

public void Write(RedisKey key) => _notifications.Writer.TryWrite(key);

public void Shutdown()
{
IsAlive = false;
_notifications.Writer.TryComplete(null);
}

private static readonly UnboundedChannelOptions ChannelOptions = new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = true };


}
}

/// <summary>
/// Additional flags to influence the behavior of client tracking
/// </summary>
[Flags]
public enum ClientTrackingOptions
{
/// <summary>
/// No additional options
/// </summary>
None = 0,
/// <summary>
/// Enable tracking in broadcasting mode. In this mode invalidation messages are reported for all the prefixes specified, regardless of the keys requested by the connection. Instead when the broadcasting mode is not enabled, Redis will track which keys are fetched using read-only commands, and will report invalidation messages only for such keys.
/// </summary>
/// <remarks>This corresponds to CLIENT TRACKING ... BCAST; using <see cref="Broadcast"/> mode consumes less server memory, at the cost of more invalidation messages (i.e. clients are
/// likely to receive invalidation messages for keys that the individual client is not using); this can be partially mitigated by using prefixes</remarks>
Broadcast = 1 << 0,
/// <summary>
/// Send notifications about keys modified by this connection itself.
/// </summary>
/// <remarks>This corresponds to the <b>inverse</b> of CLIENT TRACKING ... NOLOOP; setting <see cref="NotifyForOwnCommands"/> means that your own writes will cause self-notification; this
/// may mean that you discard a locally updated copy of the new value, hence this is disabled by default</remarks>
NotifyForOwnCommands = 1 << 1,

/// <summary>
/// Indicates that the callback specified for key invalidation should be invoked concurrently rather than sequentially
/// </summary>
ConcurrentInvalidation = 1 << 2,

// to think about: OPTIN / OPTOUT ? I'm happy to implement on the basis of OPTIN for now, though
}
6 changes: 6 additions & 0 deletions src/StackExchange.Redis/ConnectionMultiplexer.Events.cs
Original file line number Diff line number Diff line change
@@ -14,6 +14,12 @@ public partial class ConnectionMultiplexer
internal void OnConnectionFailed(EndPoint endpoint, ConnectionType connectionType, ConnectionFailureType failureType, Exception exception, bool reconfigure, string? physicalName)
{
if (_isDisposed) return;

if (connectionType is ConnectionType.Subscription)
{
GetServerEndPoint(endpoint, activate: false)?.OnSubscriberFailed();
}

var handler = ConnectionFailed;
if (handler != null)
{
16 changes: 11 additions & 5 deletions src/StackExchange.Redis/ConnectionMultiplexer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System;
using Microsoft.Extensions.Logging;
using Pipelines.Sockets.Unofficial;
using StackExchange.Redis.Profiling;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
@@ -10,9 +13,6 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Pipelines.Sockets.Unofficial;
using StackExchange.Redis.Profiling;

namespace StackExchange.Redis
{
@@ -355,6 +355,11 @@ internal void CheckMessage(Message message)
{
throw ExceptionFactory.TooManyArgs(message.CommandAndKey, message.ArgCount);
}

if (message.IsClientCaching && ClientSideTracking is null)
{
throw new InvalidOperationException("The " + nameof(CommandFlags.ClientCaching) + " flag can only be used if " + nameof(EnableServerAssistedClientSideTracking) + " has been called");
}
}

internal bool TryResend(int hashSlot, Message message, EndPoint endpoint, bool isMoved)
@@ -2268,7 +2273,7 @@ public async ValueTask DisposeAsync()
public void Close(bool allowCommandsToComplete = true)
{
if (_isDisposed) return;

_clientSideTracking?.Shutdown();
OnClosing(false);
_isDisposed = true;
_profilingSessionProvider = null;
@@ -2295,6 +2300,7 @@ public void Close(bool allowCommandsToComplete = true)
public async Task CloseAsync(bool allowCommandsToComplete = true)
{
_isDisposed = true;
_clientSideTracking?.Shutdown();
using (var tmp = pulse)
{
pulse = null;
5 changes: 4 additions & 1 deletion src/StackExchange.Redis/Enums/CommandFlags.cs
Original file line number Diff line number Diff line change
@@ -81,7 +81,10 @@ public enum CommandFlags
/// </summary>
NoScriptCache = 512,

// 1024: Removed - was used for async timeout checks; never user-specified, so not visible on the public API
/// <summary>
/// Indicates a command that relates to server-assisted client-side caching; this corresponds to CLIENT CACHING YES being issues before the command
/// </summary>
ClientCaching = 1024,

// 2048: Use subscription connection type; never user-specified, so not visible on the public API
}
3 changes: 3 additions & 0 deletions src/StackExchange.Redis/Interfaces/IConnectionMultiplexer.cs
Original file line number Diff line number Diff line change
@@ -281,5 +281,8 @@ public interface IConnectionMultiplexer : IDisposable, IAsyncDisposable
/// <param name="destination">The destination stream to write the export to.</param>
/// <param name="options">The options to use for this export.</param>
void ExportConfiguration(Stream destination, ExportOptions options = ExportOptions.All);

/// <inheritdoc cref="ConnectionMultiplexer.EnableServerAssistedClientSideTracking(Func{RedisKey, ValueTask}, ClientTrackingOptions, ReadOnlyMemory{RedisKey})"/>
void EnableServerAssistedClientSideTracking(Func<RedisKey, ValueTask> keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory<RedisKey> prefixes = default);
}
}
4 changes: 3 additions & 1 deletion src/StackExchange.Redis/Message.cs
Original file line number Diff line number Diff line change
@@ -75,7 +75,8 @@ internal abstract class Message : ICompletable
#pragma warning restore CS0618
| CommandFlags.FireAndForget
| CommandFlags.NoRedirect
| CommandFlags.NoScriptCache;
| CommandFlags.NoScriptCache
| CommandFlags.ClientCaching;
private IResultBox? resultBox;

private ResultProcessor? resultProcessor;
@@ -197,6 +198,7 @@ public bool IsAdmin
}

public bool IsAsking => (Flags & AskingFlag) != 0;
public bool IsClientCaching => (Flags & CommandFlags.ClientCaching) != 0;

internal bool IsScriptUnavailable => (Flags & ScriptUnavailableFlag) != 0;

24 changes: 18 additions & 6 deletions src/StackExchange.Redis/PhysicalBridge.cs
Original file line number Diff line number Diff line change
@@ -23,7 +23,9 @@ internal sealed class PhysicalBridge : IDisposable

private const double ProfileLogSeconds = (1000 /* ms */ * ProfileLogSamples) / 1000.0;

private static readonly Message ReusableAskingCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.ASKING);
private static readonly Message
ReusableAskingCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.ASKING),
ReusableClientCachingYesCommand = Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.CACHING, RedisLiterals.yes);

private readonly long[] profileLog = new long[ProfileLogSamples];

@@ -1494,13 +1496,13 @@ private WriteResult WriteMessageToServerInsideWriteLock(PhysicalConnection conne
}
if (message.IsAsking)
{
var asking = ReusableAskingCommand;
connection.EnqueueInsideWriteLock(asking);
asking.WriteTo(connection);
asking.SetRequestSent();
IncrementOpCount();
RawWriteInternalMessageInsideWriteLock(connection, ReusableAskingCommand);
}
}
if (message.IsClientCaching && connection.EnsureServerAssistedClientSideTrackingInsideWriteLock())
{
RawWriteInternalMessageInsideWriteLock(connection, ReusableClientCachingYesCommand);
}
switch (cmd)
{
case RedisCommand.WATCH:
@@ -1570,6 +1572,15 @@ private WriteResult WriteMessageToServerInsideWriteLock(PhysicalConnection conne
}
}

internal void RawWriteInternalMessageInsideWriteLock(PhysicalConnection connection, Message message)
{
message.SetInternalCall();
connection.EnqueueInsideWriteLock(message);
message.WriteTo(connection);
message.SetRequestSent();
IncrementOpCount();
}

/// <summary>
/// For testing only
/// </summary>
@@ -1583,5 +1594,6 @@ internal void SimulateConnectionFailure(SimulatedFailureType failureType)
}

internal RedisCommand? GetActiveMessage() => Volatile.Read(ref _activeMessage)?.Command;
internal void OnSubscriberFailed() => physical?.OnSubscriberFailed();
}
}
177 changes: 171 additions & 6 deletions src/StackExchange.Redis/PhysicalConnection.cs
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using static StackExchange.Redis.Message;
using System.Threading.Channels;

namespace StackExchange.Redis
{
@@ -1599,10 +1600,10 @@ private void MatchResult(in RawResult result)
Trace("MESSAGE: " + channel);
if (!channel.IsNull)
{
if (TryGetPubSubPayload(items[2], out var payload))
if (TryGetPubSubPayload(items[2], out var payload, out var source))
{
_readStatus = ReadStatus.InvokePubSub;
muxer.OnMessage(channel, channel, payload);
muxer.OnMessage(channel, channel, payload, source);
}
// could be multi-message: https://github.com/StackExchange/StackExchange.Redis/issues/2507
else if (TryGetMultiPubSubPayload(items[2], out var payloads))
@@ -1621,11 +1622,11 @@ private void MatchResult(in RawResult result)
Trace("PMESSAGE: " + channel);
if (!channel.IsNull)
{
if (TryGetPubSubPayload(items[3], out var payload))
if (TryGetPubSubPayload(items[3], out var payload, out var source))
{
var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.PatternMode.Pattern);
_readStatus = ReadStatus.InvokePubSub;
muxer.OnMessage(sub, channel, payload);
muxer.OnMessage(sub, channel, payload, source);
}
else if (TryGetMultiPubSubPayload(items[3], out var payloads))
{
@@ -1661,8 +1662,9 @@ private void MatchResult(in RawResult result)
_readStatus = ReadStatus.MatchResultComplete;
_activeMessage = null;

static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, bool allowArraySingleton = true)
static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, out RawResult source, bool allowArraySingleton = true)
{
source = value;
if (value.IsNull)
{
parsed = RedisValue.Null;
@@ -1676,7 +1678,7 @@ static bool TryGetPubSubPayload(in RawResult value, out RedisValue parsed, bool
parsed = value.AsRedisValue();
return true;
case ResultType.MultiBulk when allowArraySingleton && value.ItemsCount == 1:
return TryGetPubSubPayload(in value[0], out parsed, allowArraySingleton: false);
return TryGetPubSubPayload(in value[0], out parsed, out source, allowArraySingleton: false);
}
parsed = default;
return false;
@@ -2071,5 +2073,168 @@ internal bool HasPendingCallerFacingItems()
if (lockTaken) Monitor.Exit(_writtenAwaitingResponse);
}
}

private int _clientTrackingState = (int)ClientTrackingState.NotInitialized;
private enum ClientTrackingState
{
NotInitialized = 0,
ActiveSingleConnectionPerItemTracking = 1,
ActiveSplitConnectionPerItemTracking = 2,
ActiveSingleConnectionBroadcast = 3,
ActiveSplitConnectionBroadcast = 4,
Broken = 10, // was active, now not
}

private ClientTrackingState GetClientTrackingState() => (ClientTrackingState)Volatile.Read(ref _clientTrackingState);

/// <summary>
/// initializes client caching state and returns True if CLIENT CACHING YES should be sent
/// </summary>
internal bool EnsureServerAssistedClientSideTrackingInsideWriteLock() =>
GetClientTrackingState() switch
{
ClientTrackingState.ActiveSingleConnectionPerItemTracking => true,
ClientTrackingState.ActiveSplitConnectionPerItemTracking => true,
// don't add CLIENT CACHING per-item when in broadcast mode
ClientTrackingState.ActiveSingleConnectionBroadcast => false,
ClientTrackingState.ActiveSplitConnectionBroadcast => false,
// anything else? slow mode
_ => InitializeServerAssistedClientSideTrackingInsideWriteLock()
};

private bool InitializeServerAssistedClientSideTrackingInsideWriteLock()
{
var bridge = BridgeCouldBeNull;
if (bridge is null)
{
return false; // shutting down, be gentle in our nope
}

var config = bridge.Multiplexer.ClientSideTracking;
if (config is not { IsAlive: true })
{
return false; // not enabled (should already have faulted, note), or: already dead
}

ClientTrackingState oldState, newState;
do
{
switch (oldState = GetClientTrackingState())
{
case ClientTrackingState.ActiveSingleConnectionPerItemTracking:
case ClientTrackingState.ActiveSplitConnectionPerItemTracking:
return true; // we shouldn't be here, but: whatever
case ClientTrackingState.ActiveSingleConnectionBroadcast:
case ClientTrackingState.ActiveSplitConnectionBroadcast:
return false; // we shouldn't be here, but: whatever
case ClientTrackingState.Broken:
bridge.RawWriteInternalMessageInsideWriteLock(this, Message.Create(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT, RedisLiterals.TRACKING, RedisLiterals.OFF));
oldState = ClientTrackingState.NotInitialized; // ack that we've reset things
Volatile.Write(ref _clientTrackingState, (int)oldState);
goto case ClientTrackingState.NotInitialized;
case ClientTrackingState.NotInitialized:
// note: this check will need to be removed in RESP3
if (BridgeCouldBeNull?.ServerEndPoint is { SupportsSubscriptions: true } sep
&& sep.GetBridge(ConnectionType.Subscription) is { IsConnected: true } sub)
{
var subId = sub.ConnectionId;
if (subId is not null)
{
// subscribe
bridge.Multiplexer.ExecuteSyncImpl<int>(ReusableSubscribeClientCachingSubscribeMessage, null, sep, 0);
bridge.RawWriteInternalMessageInsideWriteLock(this, new ClientTrackingMessage(config, subId));
newState = (config.Options & ClientTrackingOptions.Broadcast) == 0 ? ClientTrackingState.ActiveSplitConnectionPerItemTracking : ClientTrackingState.ActiveSplitConnectionBroadcast;
break;
}
}
return false; // unable to initialize; connections unavailable or similar
default:
return false; // unknown state
}
} while (Interlocked.CompareExchange(ref _clientTrackingState, (int)newState, (int)oldState) != (int)oldState); // redo from start if fighting with OnSubscriberFailed, which is only for the "Broken" scenario

// we're now in a known state; we only issue CLIENT CACHING YES if we're in per-item tracking mode
return newState is ClientTrackingState.ActiveSingleConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionPerItemTracking;

}

internal void OnSubscriberFailed()
{
// if in split connection mode, then: our notifications have failed and we need to reset
if (GetClientTrackingState() is ClientTrackingState.ActiveSplitConnectionPerItemTracking or ClientTrackingState.ActiveSplitConnectionBroadcast)
{
Volatile.Write(ref _clientTrackingState, (int)ClientTrackingState.Broken);
}
}

private static readonly Message ReusableSubscribeClientCachingSubscribeMessage = Message.Create(
-1, CommandFlags.FireAndForget, RedisCommand.SUBSCRIBE, ConnectionMultiplexer.ClientCachingChannel);

private sealed class ClientTrackingMessage : Message
{
private readonly ConnectionMultiplexer.ClientSideTrackingState _state;
private readonly long? _subId; // will be NULL in RESP3

public ClientTrackingMessage(ConnectionMultiplexer.ClientSideTrackingState state, long? subId) : base(-1, CommandFlags.FireAndForget, RedisCommand.CLIENT)
{
_state = state;
_subId = subId;
}

public override int ArgCount
{
get
{
var count = 3; // TRACKING ON {OPTIN|BCAST}
if (_subId is not null)
{
count += 2; // [REDIRECT client-id]
}
if (!_state.Prefixes.IsEmpty)
{
count += _state.Prefixes.Length + 1; // [PREFIX prefix ...]
}
var options = _state.Options;
if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0)
{
count++; // [NOLOOP]
}
return count;
}
}

protected override void WriteImpl(PhysicalConnection physical)
{
physical.WriteHeader(Command, ArgCount);
physical.WriteBulkString(RedisLiterals.TRACKING);
physical.WriteBulkString(RedisLiterals.ON);
if (_subId is not null)
{
physical.WriteBulkString(RedisLiterals.REDIRECT);
physical.WriteBulkString(_subId.GetValueOrDefault());
}
if (!_state.Prefixes.IsEmpty)
{
physical.WriteBulkString(RedisLiterals.PREFIX);
foreach (ref readonly RedisKey prefix in _state.Prefixes.Span)
{
physical.Write(in prefix);
}
}
var options = _state.Options;
if ((options & ClientTrackingOptions.Broadcast) == 0)
{
physical.WriteBulkString(RedisLiterals.OPTIN);
}
else
{
physical.WriteBulkString(RedisLiterals.BCAST);
}
if ((options & ClientTrackingOptions.NotifyForOwnCommands) == 0)
{
physical.WriteBulkString(RedisLiterals.NOLOOP);
}
}
}
}
}
9 changes: 8 additions & 1 deletion src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@

StackExchange.Redis.ClientTrackingOptions
StackExchange.Redis.ClientTrackingOptions.Broadcast = 1 -> StackExchange.Redis.ClientTrackingOptions
StackExchange.Redis.ClientTrackingOptions.ConcurrentInvalidation = 4 -> StackExchange.Redis.ClientTrackingOptions
StackExchange.Redis.ClientTrackingOptions.None = 0 -> StackExchange.Redis.ClientTrackingOptions
StackExchange.Redis.ClientTrackingOptions.NotifyForOwnCommands = 2 -> StackExchange.Redis.ClientTrackingOptions
StackExchange.Redis.CommandFlags.ClientCaching = 1024 -> StackExchange.Redis.CommandFlags
StackExchange.Redis.ConnectionMultiplexer.EnableServerAssistedClientSideTracking(System.Func<StackExchange.Redis.RedisKey, System.Threading.Tasks.ValueTask>! keyInvalidated, StackExchange.Redis.ClientTrackingOptions options = StackExchange.Redis.ClientTrackingOptions.None, System.ReadOnlyMemory<StackExchange.Redis.RedisKey> prefixes = default(System.ReadOnlyMemory<StackExchange.Redis.RedisKey>)) -> void
StackExchange.Redis.IConnectionMultiplexer.EnableServerAssistedClientSideTracking(System.Func<StackExchange.Redis.RedisKey, System.Threading.Tasks.ValueTask>! keyInvalidated, StackExchange.Redis.ClientTrackingOptions options = StackExchange.Redis.ClientTrackingOptions.None, System.ReadOnlyMemory<StackExchange.Redis.RedisKey> prefixes = default(System.ReadOnlyMemory<StackExchange.Redis.RedisKey>)) -> void
9 changes: 9 additions & 0 deletions src/StackExchange.Redis/RedisLiterals.cs
Original file line number Diff line number Diff line change
@@ -50,12 +50,14 @@ public static readonly RedisValue
AND = "AND",
ANY = "ANY",
ASC = "ASC",
BCAST = "BCAST",
BEFORE = "BEFORE",
BIT = "BIT",
BY = "BY",
BYLEX = "BYLEX",
BYSCORE = "BYSCORE",
BYTE = "BYTE",
CACHING = "CACHING",
CH = "CH",
CHANNELS = "CHANNELS",
COPY = "COPY",
@@ -97,21 +99,27 @@ public static readonly RedisValue
MINMATCHLEN = "MINMATCHLEN",
MODULE = "MODULE",
NODES = "NODES",
NOLOOP = "NOLOOP",
NOSAVE = "NOSAVE",
NOT = "NOT",
NUMPAT = "NUMPAT",
NUMSUB = "NUMSUB",
NX = "NX",
OBJECT = "OBJECT",
OFF = "OFF",
OPTIN = "OPTIN",
OR = "OR",
ON = "ON",
PATTERN = "PATTERN",
PAUSE = "PAUSE",
PERSIST = "PERSIST",
PING = "PING",
PREFIX = "PREFIX",
PURGE = "PURGE",
PX = "PX",
PXAT = "PXAT",
RANK = "RANK",
REDIRECT = "REDIRECT",
REFCOUNT = "REFCOUNT",
REPLACE = "REPLACE",
RESET = "RESET",
@@ -127,6 +135,7 @@ public static readonly RedisValue
SKIPME = "SKIPME",
STATS = "STATS",
STORE = "STORE",
TRACKING = "TRACKING",
TYPE = "TYPE",
WEIGHTS = "WEIGHTS",
WITHMATCHLEN = "WITHMATCHLEN",
16 changes: 11 additions & 5 deletions src/StackExchange.Redis/RedisSubscriber.cs
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@ internal bool GetSubscriberCounts(in RedisChannel channel, out int handlers, out
/// <summary>
/// Handler that executes whenever a message comes in, this doles out messages to any registered handlers.
/// </summary>
internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload)
internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, in RedisValue payload, in RawResult rawPayload)
{
ICompletable? completable = null;
ChannelMessageQueue? queues = null;
@@ -91,22 +91,28 @@ internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, i
{
CompleteAsWorker(completable);
}
if (subscription == ClientCachingChannel)
{
_clientSideTracking?.Write(rawPayload.AsRedisKey());
}
}

internal static readonly RedisChannel ClientCachingChannel = RedisChannel.Literal("__redis__:invalidate");

internal void OnMessage(in RedisChannel subscription, in RedisChannel channel, Sequence<RawResult> payload)
{
if (payload.IsSingleSegment)
{
foreach (var message in payload.FirstSpan)
foreach (ref readonly RawResult message in payload.FirstSpan)
{
OnMessage(subscription, channel, message.AsRedisValue());
OnMessage(subscription, channel, message.AsRedisValue(), in message);
}
}
else
{
foreach (var message in payload)
foreach (ref readonly RawResult message in payload)
{
OnMessage(subscription, channel, message.AsRedisValue());
OnMessage(subscription, channel, message.AsRedisValue(), in message);
}
}
}
2 changes: 2 additions & 0 deletions src/StackExchange.Redis/ServerEndPoint.cs
Original file line number Diff line number Diff line change
@@ -1032,5 +1032,7 @@ internal bool HasPendingCallerFacingItems()
if (interactive?.HasPendingCallerFacingItems() == true) return true;
return subscription?.HasPendingCallerFacingItems() ?? false;
}

internal void OnSubscriberFailed() => interactive?.OnSubscriberFailed();
}
}
98 changes: 98 additions & 0 deletions tests/StackExchange.Redis.Tests/ClientTrackingTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;

namespace StackExchange.Redis.Tests;

/// <summary>
/// Tests for <see href="https://redis.io/commands/client-tracking/"/>.
/// </summary>
[Collection(SharedConnectionFixture.Key)]
public class ClientTrackingTests : TestBase
{
public ClientTrackingTests(ITestOutputHelper output, SharedConnectionFixture fixture) : base(output, fixture) { }

[Fact]
public async Task UseFlagWithoutEnabling()
{
using var conn = Create(shared: false);
var key = Me();
var ex = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await conn.GetDatabase().StringGetAsync(key, CommandFlags.ClientCaching)
);
Assert.Equal("The ClientCaching flag can only be used if EnableServerAssistedClientSideTracking has been called", ex.Message);
}

[Fact]
public void CallEnableTwice()
{
using var conn = Create(shared: false);
conn.EnableServerAssistedClientSideTracking(key => default);
var ex = Assert.Throws<InvalidOperationException>(() => conn.EnableServerAssistedClientSideTracking(key => default));
Assert.Equal("The EnableServerAssistedClientSideTracking method can be invoked once-only per multiplexer instance", ex.Message);
}

[Fact]
public void UsePrefixesWithoutBroadcast()
{
using var conn = Create(shared: false);
var ex = Assert.Throws<ArgumentException>(() => conn.EnableServerAssistedClientSideTracking(key => default, prefixes: new RedisKey[] { "abc" }));
Assert.StartsWith("Prefixes can only be specified when ClientTrackingOptions.Broadcast is used", ex.Message);
Assert.Equal("prefixes", ex.ParamName);
}

[Theory]
[InlineData(ClientTrackingOptions.None)]
[InlineData(ClientTrackingOptions.Broadcast)]
[InlineData(ClientTrackingOptions.NotifyForOwnCommands)]
[InlineData(ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.NotifyForOwnCommands)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)]
public Task GetNotificationFromOwnConnection(ClientTrackingOptions options) => GetNotification(options, false);

[Theory]
[InlineData(ClientTrackingOptions.None)]
[InlineData(ClientTrackingOptions.Broadcast)]
[InlineData(ClientTrackingOptions.NotifyForOwnCommands)]
[InlineData(ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.NotifyForOwnCommands)]
[InlineData(ClientTrackingOptions.ConcurrentInvalidation | ClientTrackingOptions.Broadcast | ClientTrackingOptions.NotifyForOwnCommands)]
public Task GetNotificationFromExternalConnection(ClientTrackingOptions options) => GetNotification(options, true);

private async Task GetNotification(ClientTrackingOptions options, bool externalConnectionMakesChange)
{
bool expectNotification = ((options & ClientTrackingOptions.NotifyForOwnCommands) != 0) || externalConnectionMakesChange;

using var listen = Create(shared: false);
using var send = externalConnectionMakesChange ? Create() : listen;

int value = (new Random().Next() % 1024) + 1024, notifyCount = 0;

var key = Me();
var db = listen.GetDatabase();
db.KeyDelete(key);
db.StringSet(key, value);

listen.EnableServerAssistedClientSideTracking(rkey =>
{
if (rkey == key) Interlocked.Increment(ref notifyCount);
return default;
}, options);

Assert.Equal(value, db.StringGet(key, CommandFlags.ClientCaching));
Assert.Equal(0, Volatile.Read(ref notifyCount));

send.GetDatabase().StringIncrement(key, 5);
await Task.Delay(100); // allow time for the magic to happen

Assert.Equal(expectNotification ? 1 : 0, Volatile.Read(ref notifyCount));
Assert.Equal(value + 5, db.StringGet(key, CommandFlags.ClientCaching));

}
}
Original file line number Diff line number Diff line change
@@ -185,6 +185,8 @@ public void ExportConfiguration(Stream destination, ExportOptions options = Expo
public override string ToString() => _inner.ToString();
long? IInternalConnectionMultiplexer.GetConnectionId(EndPoint endPoint, ConnectionType type)
=> _inner.GetConnectionId(endPoint, type);
public void EnableServerAssistedClientSideTracking(Func<RedisKey, ValueTask> keyInvalidated, ClientTrackingOptions options = ClientTrackingOptions.None, ReadOnlyMemory<RedisKey> prefixes = default)
=> _inner.EnableServerAssistedClientSideTracking(keyInvalidated, options, prefixes);
}

public void Dispose()