diff --git a/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs b/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs index 07612031c97..dd4b9ca0688 100644 --- a/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs +++ b/src/MongoDB.Driver/Authentication/MongoDBX509Authenticator.cs @@ -58,7 +58,8 @@ public void Authenticate(IConnection connection, ConnectionDescription descripti try { var protocol = CreateAuthenticateProtocol(); - protocol.Execute(connection, cancellationToken); + // TODO: CSOT: implement operationContext support for Auth. + protocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); } catch (MongoCommandException ex) { @@ -79,7 +80,8 @@ public async Task AuthenticateAsync(IConnection connection, ConnectionDescriptio try { var protocol = CreateAuthenticateProtocol(); - await protocol.ExecuteAsync(connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: implement operationContext support for Auth. + await protocol.ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection).ConfigureAwait(false); } catch (MongoCommandException ex) { diff --git a/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs b/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs index fddb5953b60..d42558ddee6 100644 --- a/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs +++ b/src/MongoDB.Driver/Authentication/SaslAuthenticator.cs @@ -109,7 +109,8 @@ public void Authenticate(IConnection connection, ConnectionDescription descripti try { var protocol = CreateCommandProtocol(command); - result = protocol.Execute(connection, cancellationToken); + // TODO: CSOT: implement operationContext support for Auth. + result = protocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); conversationId ??= result?.GetValue("conversationId").AsInt32; } catch (MongoException ex) @@ -172,7 +173,8 @@ public async Task AuthenticateAsync(IConnection connection, ConnectionDescriptio try { var protocol = CreateCommandProtocol(command); - result = await protocol.ExecuteAsync(connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: implement operationContext support for Auth. + result = await protocol.ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection).ConfigureAwait(false); conversationId ??= result?.GetValue("conversationId").AsInt32; } catch (MongoException ex) diff --git a/src/MongoDB.Driver/Core/Bindings/ChannelChannelSource.cs b/src/MongoDB.Driver/Core/Bindings/ChannelChannelSource.cs index f48a8038428..16734bf1a32 100644 --- a/src/MongoDB.Driver/Core/Bindings/ChannelChannelSource.cs +++ b/src/MongoDB.Driver/Core/Bindings/ChannelChannelSource.cs @@ -26,31 +26,26 @@ internal sealed class ChannelChannelSource : IChannelSource private readonly IChannelHandle _channel; private bool _disposed; private readonly IServer _server; + private readonly TimeSpan _roundTripTime; private readonly ICoreSessionHandle _session; // constructors - public ChannelChannelSource(IServer server, IChannelHandle channel, ICoreSessionHandle session) + public ChannelChannelSource(IServer server, TimeSpan roundTripTime, IChannelHandle channel, ICoreSessionHandle session) { _server = Ensure.IsNotNull(server, nameof(server)); + _roundTripTime = Ensure.IsGreaterThanZero(roundTripTime, nameof(roundTripTime)); _channel = Ensure.IsNotNull(channel, nameof(channel)); _session = Ensure.IsNotNull(session, nameof(session)); } // properties - public IServer Server - { - get { return _server; } - } + public IServer Server => _server; - public ServerDescription ServerDescription - { - get { return _server.Description; } - } + public ServerDescription ServerDescription => _server.Description; - public ICoreSessionHandle Session - { - get { return _session; } - } + public TimeSpan RoundTripTime => _roundTripTime; + + public ICoreSessionHandle Session => _session; // methods public void Dispose() diff --git a/src/MongoDB.Driver/Core/Bindings/ChannelReadBinding.cs b/src/MongoDB.Driver/Core/Bindings/ChannelReadBinding.cs index 63dab353cf4..01faf8dcd39 100644 --- a/src/MongoDB.Driver/Core/Bindings/ChannelReadBinding.cs +++ b/src/MongoDB.Driver/Core/Bindings/ChannelReadBinding.cs @@ -27,25 +27,21 @@ internal sealed class ChannelReadBinding : IReadBinding private bool _disposed; private readonly ReadPreference _readPreference; private readonly IServer _server; + private readonly TimeSpan _roundTripTime; private readonly ICoreSessionHandle _session; - public ChannelReadBinding(IServer server, IChannelHandle channel, ReadPreference readPreference, ICoreSessionHandle session) + public ChannelReadBinding(IServer server, TimeSpan roundTripTime, IChannelHandle channel, ReadPreference readPreference, ICoreSessionHandle session) { _server = Ensure.IsNotNull(server, nameof(server)); + _roundTripTime = Ensure.IsGreaterThanZero(roundTripTime, nameof(roundTripTime)); _channel = Ensure.IsNotNull(channel, nameof(channel)); _readPreference = Ensure.IsNotNull(readPreference, nameof(readPreference)); _session = Ensure.IsNotNull(session, nameof(session)); } - public ReadPreference ReadPreference - { - get { return _readPreference; } - } + public ReadPreference ReadPreference => _readPreference; - public ICoreSessionHandle Session - { - get { return _session; } - } + public ICoreSessionHandle Session => _session; public void Dispose() { @@ -81,7 +77,7 @@ public Task GetReadChannelSourceAsync(OperationContext ope private IChannelSourceHandle GetReadChannelSourceHelper() { - return new ChannelSourceHandle(new ChannelChannelSource(_server, _channel.Fork(), _session.Fork())); + return new ChannelSourceHandle(new ChannelChannelSource(_server, _roundTripTime, _channel.Fork(), _session.Fork())); } private void ThrowIfDisposed() diff --git a/src/MongoDB.Driver/Core/Bindings/ChannelReadWriteBinding.cs b/src/MongoDB.Driver/Core/Bindings/ChannelReadWriteBinding.cs index 17ae75966bc..809fd86a5e1 100644 --- a/src/MongoDB.Driver/Core/Bindings/ChannelReadWriteBinding.cs +++ b/src/MongoDB.Driver/Core/Bindings/ChannelReadWriteBinding.cs @@ -26,24 +26,20 @@ internal sealed class ChannelReadWriteBinding : IReadWriteBinding private readonly IChannelHandle _channel; private bool _disposed; private readonly IServer _server; + private readonly TimeSpan _serverRoundTripTime; private readonly ICoreSessionHandle _session; - public ChannelReadWriteBinding(IServer server, IChannelHandle channel, ICoreSessionHandle session) + public ChannelReadWriteBinding(IServer server, TimeSpan roundTripTime, IChannelHandle channel, ICoreSessionHandle session) { _server = Ensure.IsNotNull(server, nameof(server)); + _serverRoundTripTime = Ensure.IsGreaterThanZero(roundTripTime, nameof(roundTripTime)); _channel = Ensure.IsNotNull(channel, nameof(channel)); _session = Ensure.IsNotNull(session, nameof(session)); } - public ReadPreference ReadPreference - { - get { return ReadPreference.Primary; } - } + public ReadPreference ReadPreference => ReadPreference.Primary; - public ICoreSessionHandle Session - { - get { return _session; } - } + public ICoreSessionHandle Session => _session; public void Dispose() { @@ -121,7 +117,7 @@ public Task GetWriteChannelSourceAsync(OperationContext op private IChannelSourceHandle GetChannelSourceHelper() { - return new ChannelSourceHandle(new ChannelChannelSource(_server, _channel.Fork(), _session.Fork())); + return new ChannelSourceHandle(new ChannelChannelSource(_server, _serverRoundTripTime, _channel.Fork(), _session.Fork())); } private void ThrowIfDisposed() diff --git a/src/MongoDB.Driver/Core/Bindings/ChannelSourceHandle.cs b/src/MongoDB.Driver/Core/Bindings/ChannelSourceHandle.cs index 3b08ff9da33..fa8deacc1b6 100644 --- a/src/MongoDB.Driver/Core/Bindings/ChannelSourceHandle.cs +++ b/src/MongoDB.Driver/Core/Bindings/ChannelSourceHandle.cs @@ -38,20 +38,13 @@ private ChannelSourceHandle(ReferenceCounted reference) } // properties - public IServer Server - { - get { return _reference.Instance.Server; } - } + public IServer Server => _reference.Instance.Server; - public ServerDescription ServerDescription - { - get { return _reference.Instance.ServerDescription; } - } + public ServerDescription ServerDescription => _reference.Instance.ServerDescription; - public ICoreSessionHandle Session - { - get { return _reference.Instance.Session; } - } + public TimeSpan RoundTripTime => _reference.Instance.RoundTripTime; + + public ICoreSessionHandle Session => _reference.Instance.Session; // methods public IChannelHandle GetChannel(OperationContext operationContext) @@ -72,7 +65,6 @@ public void Dispose() { _reference.DecrementReferenceCount(); _disposed = true; - GC.SuppressFinalize(this); } } diff --git a/src/MongoDB.Driver/Core/Bindings/CoreTransaction.cs b/src/MongoDB.Driver/Core/Bindings/CoreTransaction.cs index 53747c8530c..0ffe8e203a8 100644 --- a/src/MongoDB.Driver/Core/Bindings/CoreTransaction.cs +++ b/src/MongoDB.Driver/Core/Bindings/CoreTransaction.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using MongoDB.Bson; using MongoDB.Driver.Core.Servers; @@ -27,6 +28,7 @@ public class CoreTransaction private bool _isEmpty; private IChannelHandle _pinnedChannel = null; private IServer _pinnedServer; + private TimeSpan _pinnedServerRoundTripTime; private BsonDocument _recoveryToken; private CoreTransactionState _state; private readonly long _transactionNumber; @@ -64,10 +66,7 @@ public CoreTransaction(long transactionNumber, TransactionOptions transactionOpt /// public CoreTransactionState State => _state; - internal IChannelHandle PinnedChannel - { - get => _pinnedChannel; - } + internal IChannelHandle PinnedChannel => _pinnedChannel; /// /// Gets or sets pinned server for the current transaction. @@ -76,11 +75,9 @@ internal IChannelHandle PinnedChannel /// /// The pinned server for the current transaction. /// - internal IServer PinnedServer - { - get => _pinnedServer; - set => _pinnedServer = value; - } + internal IServer PinnedServer => _pinnedServer; + + internal TimeSpan PinnedServerRoundTripTime => _pinnedServerRoundTripTime; /// /// Gets the transaction number. @@ -120,6 +117,12 @@ internal void PinChannel(IChannelHandle channel) } } + internal void PinServer(IServer server, TimeSpan roundTripTime) + { + _pinnedServer = server; + _pinnedServerRoundTripTime = roundTripTime; + } + internal void SetState(CoreTransactionState state) { _state = state; @@ -135,6 +138,7 @@ internal void UnpinAll() { _pinnedChannel?.Dispose(); _pinnedChannel = null; + _pinnedServerRoundTripTime = default; _pinnedServer = null; } } diff --git a/src/MongoDB.Driver/Core/Bindings/IBinding.cs b/src/MongoDB.Driver/Core/Bindings/IBinding.cs index 275131043ba..0b2d0456a36 100644 --- a/src/MongoDB.Driver/Core/Bindings/IBinding.cs +++ b/src/MongoDB.Driver/Core/Bindings/IBinding.cs @@ -15,7 +15,6 @@ using System; using System.Collections.Generic; -using System.Threading; using System.Threading.Tasks; using MongoDB.Driver.Core.Servers; diff --git a/src/MongoDB.Driver/Core/Bindings/IChannel.cs b/src/MongoDB.Driver/Core/Bindings/IChannel.cs index 275d6cdebbc..109276dc837 100644 --- a/src/MongoDB.Driver/Core/Bindings/IChannel.cs +++ b/src/MongoDB.Driver/Core/Bindings/IChannel.cs @@ -15,7 +15,6 @@ using System; using System.Collections.Generic; -using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.IO; @@ -31,8 +30,10 @@ internal interface IChannel : IDisposable { IConnectionHandle Connection { get; } ConnectionDescription ConnectionDescription { get; } + TimeSpan RoundTripTimeout { get; } TResult Command( + OperationContext operationContext, ICoreSession session, ReadPreference readPreference, DatabaseNamespace databaseNamespace, @@ -43,10 +44,10 @@ TResult Command( Action postWriteAction, CommandResponseHandling responseHandling, IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken); + MessageEncoderSettings messageEncoderSettings); Task CommandAsync( + OperationContext operationContext, ICoreSession session, ReadPreference readPreference, DatabaseNamespace databaseNamespace, @@ -57,8 +58,7 @@ Task CommandAsync( Action postWriteAction, CommandResponseHandling responseHandling, IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken); + MessageEncoderSettings messageEncoderSettings); } internal interface IChannelHandle : IChannel diff --git a/src/MongoDB.Driver/Core/Bindings/IChannelSource.cs b/src/MongoDB.Driver/Core/Bindings/IChannelSource.cs index c9bd90ec61b..53d91db4028 100644 --- a/src/MongoDB.Driver/Core/Bindings/IChannelSource.cs +++ b/src/MongoDB.Driver/Core/Bindings/IChannelSource.cs @@ -23,6 +23,7 @@ internal interface IChannelSource : IDisposable { IServer Server { get; } ServerDescription ServerDescription { get; } + TimeSpan RoundTripTime { get; } ICoreSessionHandle Session { get; } IChannelHandle GetChannel(OperationContext operationContext); diff --git a/src/MongoDB.Driver/Core/Bindings/ReadPreferenceBinding.cs b/src/MongoDB.Driver/Core/Bindings/ReadPreferenceBinding.cs index 32106f0efcd..650cdf0e9f3 100644 --- a/src/MongoDB.Driver/Core/Bindings/ReadPreferenceBinding.cs +++ b/src/MongoDB.Driver/Core/Bindings/ReadPreferenceBinding.cs @@ -41,15 +41,9 @@ public ReadPreferenceBinding(IClusterInternal cluster, ReadPreference readPrefer _serverSelector = new ReadPreferenceServerSelector(readPreference); } - public ReadPreference ReadPreference - { - get { return _readPreference; } - } + public ReadPreference ReadPreference => _readPreference; - public ICoreSessionHandle Session - { - get { return _session; } - } + public ICoreSessionHandle Session => _session; public IChannelSourceHandle GetReadChannelSource(OperationContext operationContext) { @@ -75,9 +69,9 @@ public async Task GetReadChannelSourceAsync(OperationConte return GetChannelSourceHelper(server); } - private IChannelSourceHandle GetChannelSourceHelper(IServer server) + private IChannelSourceHandle GetChannelSourceHelper((IServer Server, TimeSpan RoundTripTime) server) { - return new ChannelSourceHandle(new ServerChannelSource(server, _session.Fork())); + return new ChannelSourceHandle(new ServerChannelSource(server.Server, server.RoundTripTime, _session.Fork())); } public void Dispose() diff --git a/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs b/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs index c5fbc55cea1..ae5a3028768 100644 --- a/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs +++ b/src/MongoDB.Driver/Core/Bindings/ServerChannelSource.cs @@ -25,30 +25,25 @@ internal sealed class ServerChannelSource : IChannelSource // fields private bool _disposed; private readonly IServer _server; + private readonly TimeSpan _serverRoundTripTime; private readonly ICoreSessionHandle _session; // constructors - public ServerChannelSource(IServer server, ICoreSessionHandle session) + public ServerChannelSource(IServer server, TimeSpan roundTripTime, ICoreSessionHandle session) { _server = Ensure.IsNotNull(server, nameof(server)); + _serverRoundTripTime = Ensure.IsGreaterThanZero(roundTripTime, nameof(roundTripTime)); _session = Ensure.IsNotNull(session, nameof(session)); } // properties - public IServer Server - { - get { return _server; } - } + public IServer Server => _server; - public ServerDescription ServerDescription - { - get { return _server.Description; } - } + public ServerDescription ServerDescription => _server.Description; - public ICoreSessionHandle Session - { - get { return _session; } - } + public TimeSpan RoundTripTime => _serverRoundTripTime; + + public ICoreSessionHandle Session => _session; // methods public void Dispose() @@ -63,13 +58,15 @@ public void Dispose() public IChannelHandle GetChannel(OperationContext operationContext) { ThrowIfDisposed(); - return _server.GetChannel(operationContext); + var connection = _server.GetConnection(operationContext); + return new ServerChannel(_server, connection, _serverRoundTripTime); } - public Task GetChannelAsync(OperationContext operationContext) + public async Task GetChannelAsync(OperationContext operationContext) { ThrowIfDisposed(); - return _server.GetChannelAsync(operationContext); + var connection = await _server.GetConnectionAsync(operationContext).ConfigureAwait(false); + return new ServerChannel(_server, connection, _serverRoundTripTime); } private void ThrowIfDisposed() diff --git a/src/MongoDB.Driver/Core/Bindings/SingleServerReadBinding.cs b/src/MongoDB.Driver/Core/Bindings/SingleServerReadBinding.cs index 04a65fbd4b3..8f891a34ce7 100644 --- a/src/MongoDB.Driver/Core/Bindings/SingleServerReadBinding.cs +++ b/src/MongoDB.Driver/Core/Bindings/SingleServerReadBinding.cs @@ -28,25 +28,21 @@ internal sealed class SingleServerReadBinding : IReadBinding private bool _disposed; private readonly ReadPreference _readPreference; + private readonly TimeSpan _roundTripTime; private readonly IServer _server; private readonly ICoreSessionHandle _session; - public SingleServerReadBinding(IServer server, ReadPreference readPreference, ICoreSessionHandle session) + public SingleServerReadBinding(IServer server, TimeSpan roundTripTime, ReadPreference readPreference, ICoreSessionHandle session) { _server = Ensure.IsNotNull(server, nameof(server)); + _roundTripTime = Ensure.IsGreaterThanZero(roundTripTime, nameof(roundTripTime)); _readPreference = Ensure.IsNotNull(readPreference, nameof(readPreference)); _session = Ensure.IsNotNull(session, nameof(session)); } - public ReadPreference ReadPreference - { - get { return _readPreference; } - } + public ReadPreference ReadPreference => _readPreference; - public ICoreSessionHandle Session - { - get { return _session; } - } + public ICoreSessionHandle Session => _session; public IChannelSourceHandle GetReadChannelSource(OperationContext operationContext) { @@ -90,7 +86,7 @@ private IChannelSourceHandle GetChannelSourceHelper() SpinWait.SpinUntil(() => _server.Description.State == ServerState.Connected, SingleServerSelectionTimeoutMS); } - return new ChannelSourceHandle(new ServerChannelSource(_server, _session.Fork())); + return new ChannelSourceHandle(new ServerChannelSource(_server, _roundTripTime, _session.Fork())); } private void ThrowIfDisposed() diff --git a/src/MongoDB.Driver/Core/Bindings/SingleServerReadWriteBinding.cs b/src/MongoDB.Driver/Core/Bindings/SingleServerReadWriteBinding.cs index 5113baa09c0..2369bae6298 100644 --- a/src/MongoDB.Driver/Core/Bindings/SingleServerReadWriteBinding.cs +++ b/src/MongoDB.Driver/Core/Bindings/SingleServerReadWriteBinding.cs @@ -25,23 +25,19 @@ internal sealed class SingleServerReadWriteBinding : IReadWriteBinding { private bool _disposed; private readonly IServer _server; + private readonly TimeSpan _serverRoundTripTime; private readonly ICoreSessionHandle _session; - public SingleServerReadWriteBinding(IServer server, ICoreSessionHandle session) + public SingleServerReadWriteBinding(IServer server, TimeSpan roundTripTime, ICoreSessionHandle session) { _server = Ensure.IsNotNull(server, nameof(server)); + _serverRoundTripTime = Ensure.IsGreaterThanZero(roundTripTime, nameof(roundTripTime)); _session = Ensure.IsNotNull(session, nameof(session)); } - public ReadPreference ReadPreference - { - get { return ReadPreference.Primary; } - } + public ReadPreference ReadPreference => ReadPreference.Primary; - public ICoreSessionHandle Session - { - get { return _session; } - } + public ICoreSessionHandle Session => _session; public void Dispose() { @@ -118,7 +114,7 @@ public Task GetWriteChannelSourceAsync(OperationContext op private IChannelSourceHandle GetChannelSourceHelper() { - return new ChannelSourceHandle(new ServerChannelSource(_server, _session.Fork())); + return new ChannelSourceHandle(new ServerChannelSource(_server, _serverRoundTripTime, _session.Fork())); } private void ThrowIfDisposed() diff --git a/src/MongoDB.Driver/Core/Bindings/WritableServerBinding.cs b/src/MongoDB.Driver/Core/Bindings/WritableServerBinding.cs index 764bdc0e0ae..efae91bb031 100644 --- a/src/MongoDB.Driver/Core/Bindings/WritableServerBinding.cs +++ b/src/MongoDB.Driver/Core/Bindings/WritableServerBinding.cs @@ -37,25 +37,15 @@ public WritableServerBinding(IClusterInternal cluster, ICoreSessionHandle sessio _session = Ensure.IsNotNull(session, nameof(session)); } - public ReadPreference ReadPreference - { - get { return ReadPreference.Primary; } - } + public ReadPreference ReadPreference => ReadPreference.Primary; - public ICoreSessionHandle Session - { - get { return _session; } - } + public ICoreSessionHandle Session => _session; public IChannelSourceHandle GetReadChannelSource(OperationContext operationContext) - { - return GetReadChannelSource(operationContext, null); - } + => GetReadChannelSource(operationContext, null); public Task GetReadChannelSourceAsync(OperationContext operationContext) - { - return GetReadChannelSourceAsync(operationContext, null); - } + => GetReadChannelSourceAsync(operationContext, null); public IChannelSourceHandle GetReadChannelSource(OperationContext operationContext, IReadOnlyCollection deprioritizedServers) { @@ -139,9 +129,9 @@ public async Task GetWriteChannelSourceAsync(OperationCont return CreateServerChannelSource(server); } - private IChannelSourceHandle CreateServerChannelSource(IServer server) + private IChannelSourceHandle CreateServerChannelSource((IServer Server, TimeSpan RoundTripTime) server) { - return new ChannelSourceHandle(new ServerChannelSource(server, _session.Fork())); + return new ChannelSourceHandle(new ServerChannelSource(server.Server, server.RoundTripTime, _session.Fork())); } public void Dispose() diff --git a/src/MongoDB.Driver/Core/ChannelPinningHelper.cs b/src/MongoDB.Driver/Core/ChannelPinningHelper.cs index 740ffef8fa4..c926cebdf61 100644 --- a/src/MongoDB.Driver/Core/ChannelPinningHelper.cs +++ b/src/MongoDB.Driver/Core/ChannelPinningHelper.cs @@ -32,6 +32,7 @@ public static IReadBindingHandle CreateReadBinding(IClusterInternal cluster, ICo { readBinding = new ChannelReadWriteBinding( session.CurrentTransaction.PinnedServer, + session.CurrentTransaction.PinnedServerRoundTripTime, session.CurrentTransaction.PinnedChannel.Fork(), session); } @@ -57,6 +58,7 @@ public static IReadWriteBindingHandle CreateReadWriteBinding(IClusterInternal cl { readWriteBinding = new ChannelReadWriteBinding( session.CurrentTransaction.PinnedServer, + session.CurrentTransaction.PinnedServerRoundTripTime, session.CurrentTransaction.PinnedChannel.Fork(), session); } @@ -85,12 +87,13 @@ internal static IChannelSourceHandle CreateGetMoreChannelSource(IChannelSourceHa effectiveChannelSource = new ChannelChannelSource( channelSource.Server, + channelSource.RoundTripTime, channel.Fork(), channelSource.Session.Fork()); } else { - effectiveChannelSource = new ServerChannelSource(channelSource.Server, channelSource.Session.Fork()); + effectiveChannelSource = new ServerChannelSource(channelSource.Server, channelSource.RoundTripTime, channelSource.Session.Fork()); } return new ChannelSourceHandle(effectiveChannelSource); @@ -110,7 +113,7 @@ internal static void PinChannellIfRequired( checkOutReasonTracker.SetCheckOutReasonIfNotAlreadySet(CheckOutReason.Transaction); } session.CurrentTransaction.PinChannel(channel.Fork()); - session.CurrentTransaction.PinnedServer = channelSource.Server; + session.CurrentTransaction.PinServer(channelSource.Server, channelSource.RoundTripTime); } } diff --git a/src/MongoDB.Driver/Core/Clusters/Cluster.cs b/src/MongoDB.Driver/Core/Clusters/Cluster.cs index 287a40657aa..a7b937b5c72 100644 --- a/src/MongoDB.Driver/Core/Clusters/Cluster.cs +++ b/src/MongoDB.Driver/Core/Clusters/Cluster.cs @@ -153,13 +153,13 @@ protected void OnDescriptionChanged(ClusterDescription oldDescription, ClusterDe DescriptionChanged?.Invoke(this, new ClusterDescriptionChangedEventArgs(oldDescription, newDescription)); } - public IServer SelectServer(OperationContext operationContext, IServerSelector selector) + public (IServer, TimeSpan) SelectServer(OperationContext operationContext, IServerSelector selector) { Ensure.IsNotNull(selector, nameof(selector)); Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposedOrNotOpen(); - operationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); var expirableClusterDescription = _expirableClusterDescription; IDisposable serverSelectionWaitQueueDisposer = null; (selector, var operationCountSelector, var stopwatch) = BeginServerSelection(expirableClusterDescription.ClusterDescription, selector); @@ -168,16 +168,16 @@ public IServer SelectServer(OperationContext operationContext, IServerSelector s { while (true) { - var result = SelectServer(expirableClusterDescription, selector, operationCountSelector); - if (result != default) + var (server, description) = SelectServer(expirableClusterDescription, selector, operationCountSelector); + if (server != null) { - EndServerSelection(expirableClusterDescription.ClusterDescription, selector, result.ServerDescription, stopwatch); - return result.Server; + EndServerSelection(expirableClusterDescription.ClusterDescription, selector, description, stopwatch); + return (server, description.AverageRoundTripTime); } - serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(operationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); + serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(serverSelectionOperationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); - operationContext.WaitTask(expirableClusterDescription.Expired); + serverSelectionOperationContext.WaitTask(expirableClusterDescription.Expired); expirableClusterDescription = _expirableClusterDescription; } } @@ -191,13 +191,13 @@ public IServer SelectServer(OperationContext operationContext, IServerSelector s } } - public async Task SelectServerAsync(OperationContext operationContext, IServerSelector selector) + public async Task<(IServer, TimeSpan)> SelectServerAsync(OperationContext operationContext, IServerSelector selector) { Ensure.IsNotNull(selector, nameof(selector)); Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposedOrNotOpen(); - operationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(Settings.ServerSelectionTimeout); var expirableClusterDescription = _expirableClusterDescription; IDisposable serverSelectionWaitQueueDisposer = null; (selector, var operationCountSelector, var stopwatch) = BeginServerSelection(expirableClusterDescription.ClusterDescription, selector); @@ -206,16 +206,16 @@ public async Task SelectServerAsync(OperationContext operationContext, { while (true) { - var result = SelectServer(expirableClusterDescription, selector, operationCountSelector); - if (result != default) + var (server, description) = SelectServer(expirableClusterDescription, selector, operationCountSelector); + if (server != null) { - EndServerSelection(expirableClusterDescription.ClusterDescription, selector, result.ServerDescription, stopwatch); - return result.Server; + EndServerSelection(expirableClusterDescription.ClusterDescription, selector, description, stopwatch); + return (server, description.AverageRoundTripTime); } - serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(operationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); + serverSelectionWaitQueueDisposer ??= _serverSelectionWaitQueue.Enter(serverSelectionOperationContext, selector, expirableClusterDescription.ClusterDescription, EventContext.OperationId); - await operationContext.WaitTaskAsync(expirableClusterDescription.Expired).ConfigureAwait(false); + await serverSelectionOperationContext.WaitTaskAsync(expirableClusterDescription.Expired).ConfigureAwait(false); expirableClusterDescription = _expirableClusterDescription; } } diff --git a/src/MongoDB.Driver/Core/Clusters/ICluster.cs b/src/MongoDB.Driver/Core/Clusters/ICluster.cs index ea31d13bc12..41102a1fe3d 100644 --- a/src/MongoDB.Driver/Core/Clusters/ICluster.cs +++ b/src/MongoDB.Driver/Core/Clusters/ICluster.cs @@ -61,8 +61,8 @@ internal interface IClusterInternal : ICluster void Initialize(); - IServer SelectServer(OperationContext operationContext, IServerSelector selector); - Task SelectServerAsync(OperationContext operationContext, IServerSelector selector); + (IServer, TimeSpan) SelectServer(OperationContext operationContext, IServerSelector selector); + Task<(IServer, TimeSpan)> SelectServerAsync(OperationContext operationContext, IServerSelector selector); ICoreSessionHandle StartSession(CoreSessionOptions options = null); } diff --git a/src/MongoDB.Driver/Core/Clusters/IClusterExtensions.cs b/src/MongoDB.Driver/Core/Clusters/IClusterExtensions.cs index e8060a75a07..830f46509a1 100644 --- a/src/MongoDB.Driver/Core/Clusters/IClusterExtensions.cs +++ b/src/MongoDB.Driver/Core/Clusters/IClusterExtensions.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Collections.Generic; using System.Threading.Tasks; using MongoDB.Driver.Core.Bindings; @@ -23,7 +24,7 @@ namespace MongoDB.Driver.Core.Clusters { internal static class IClusterExtensions { - public static IServer SelectServerAndPinIfNeeded( + public static (IServer, TimeSpan) SelectServerAndPinIfNeeded( this IClusterInternal cluster, OperationContext operationContext, ICoreSessionHandle session, @@ -31,7 +32,7 @@ public static IServer SelectServerAndPinIfNeeded( IReadOnlyCollection deprioritizedServers) { var pinnedServer = GetPinnedServerIfValid(cluster, session); - if (pinnedServer != null) + if (pinnedServer != default) { return pinnedServer; } @@ -42,12 +43,12 @@ public static IServer SelectServerAndPinIfNeeded( // Server selection also updates the cluster type, allowing us to determine if the server // should be pinned. - var server = cluster.SelectServer(operationContext, selector); - PinServerIfNeeded(cluster, session, server); - return server; + var (server, serverRoundTripTime) = cluster.SelectServer(operationContext, selector); + PinServerIfNeeded(cluster, session, server, serverRoundTripTime); + return (server, serverRoundTripTime); } - public static async Task SelectServerAndPinIfNeededAsync( + public static async Task<(IServer, TimeSpan)> SelectServerAndPinIfNeededAsync( this IClusterInternal cluster, OperationContext operationContext, ICoreSessionHandle session, @@ -55,7 +56,7 @@ public static async Task SelectServerAndPinIfNeededAsync( IReadOnlyCollection deprioritizedServers) { var pinnedServer = GetPinnedServerIfValid(cluster, session); - if (pinnedServer != null) + if (pinnedServer != default) { return pinnedServer; } @@ -66,32 +67,30 @@ public static async Task SelectServerAndPinIfNeededAsync( // Server selection also updates the cluster type, allowing us to determine if the server // should be pinned. - var server = await cluster.SelectServerAsync(operationContext, selector).ConfigureAwait(false); - PinServerIfNeeded(cluster, session, server); + var (server, serverRoundTripTime) = await cluster.SelectServerAsync(operationContext, selector).ConfigureAwait(false); + PinServerIfNeeded(cluster, session, server, serverRoundTripTime); - return server; + return (server, serverRoundTripTime); } - private static void PinServerIfNeeded(ICluster cluster, ICoreSessionHandle session, IServer server) + private static void PinServerIfNeeded(ICluster cluster, ICoreSessionHandle session, IServer server, TimeSpan serverRoundTripTime) { if (cluster.Description.Type == ClusterType.Sharded && session.IsInTransaction) { - session.CurrentTransaction.PinnedServer = server; + session.CurrentTransaction.PinServer(server, serverRoundTripTime); } } - private static IServer GetPinnedServerIfValid(ICluster cluster, ICoreSessionHandle session) + private static (IServer, TimeSpan) GetPinnedServerIfValid(ICluster cluster, ICoreSessionHandle session) { if (cluster.Description.Type == ClusterType.Sharded && session.IsInTransaction && session.CurrentTransaction.State != CoreTransactionState.Starting) { - return session.CurrentTransaction.PinnedServer; - } - else - { - return null; + return (session.CurrentTransaction.PinnedServer, session.CurrentTransaction.PinnedServerRoundTripTime); } + + return default; } } } diff --git a/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs b/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs index c77d2d45241..3581a45a3a9 100644 --- a/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs +++ b/src/MongoDB.Driver/Core/Clusters/LoadBalancedCluster.cs @@ -170,13 +170,13 @@ public void Initialize() } } - public IServer SelectServer(OperationContext operationContext, IServerSelector selector) + public (IServer, TimeSpan) SelectServer(OperationContext operationContext, IServerSelector selector) { Ensure.IsNotNull(selector, nameof(selector)); Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposed(); - var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); _serverSelectionEventLogger.LogAndPublish(new ClusterSelectingServerEvent( _description, @@ -205,19 +205,20 @@ public IServer SelectServer(OperationContext operationContext, IServerSelector s stopwatch.Elapsed, null, EventContext.OperationName)); + + return (_server, _server.Description.AverageRoundTripTime); } - return _server ?? - throw new InvalidOperationException("The server must be created before usage."); // should not be reached + throw new InvalidOperationException("The server must be created before usage."); // should not be reached } - public async Task SelectServerAsync(OperationContext operationContext, IServerSelector selector) + public async Task<(IServer, TimeSpan)> SelectServerAsync(OperationContext operationContext, IServerSelector selector) { Ensure.IsNotNull(selector, nameof(selector)); Ensure.IsNotNull(operationContext, nameof(operationContext)); ThrowIfDisposed(); - var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); + using var serverSelectionOperationContext = operationContext.WithTimeout(_settings.ServerSelectionTimeout); _serverSelectionEventLogger.LogAndPublish(new ClusterSelectingServerEvent( _description, @@ -245,10 +246,11 @@ public async Task SelectServerAsync(OperationContext operationContext, stopwatch.Elapsed, null, EventContext.OperationName)); + + return (_server, _server.Description.AverageRoundTripTime); } - return _server ?? - throw new InvalidOperationException("The server must be created before usage."); // should not be reached + throw new InvalidOperationException("The server must be created before usage."); // should not be reached } public ICoreSessionHandle StartSession(CoreSessionOptions options = null) diff --git a/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs b/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs index 7419244a46c..44542d88da8 100644 --- a/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs +++ b/src/MongoDB.Driver/Core/Compression/SnappyCompressor.cs @@ -15,7 +15,6 @@ using Snappier; using System.IO; -using System.Threading; using MongoDB.Driver.Core.Misc; namespace MongoDB.Driver.Core.Compression @@ -34,7 +33,7 @@ public void Compress(Stream input, Stream output) { var uncompressedSize = (int)(input.Length - input.Position); var uncompressedBytes = new byte[uncompressedSize]; // does not include uncompressed message headers - input.ReadBytes(uncompressedBytes, offset: 0, count: uncompressedSize, Timeout.InfiniteTimeSpan, CancellationToken.None); + input.ReadBytes(OperationContext.NoTimeout, uncompressedBytes, offset: 0, count: uncompressedSize); var maxCompressedSize = Snappy.GetMaxCompressedLength(uncompressedSize); var compressedBytes = new byte[maxCompressedSize]; var compressedSize = Snappy.Compress(uncompressedBytes, compressedBytes); @@ -50,7 +49,7 @@ public void Decompress(Stream input, Stream output) { var compressedSize = (int)(input.Length - input.Position); var compressedBytes = new byte[compressedSize]; - input.ReadBytes(compressedBytes, offset: 0, count: compressedSize, Timeout.InfiniteTimeSpan, CancellationToken.None); + input.ReadBytes(OperationContext.NoTimeout, compressedBytes, offset: 0, count: compressedSize); var uncompressedSize = Snappy.GetUncompressedLength(compressedBytes); var decompressedBytes = new byte[uncompressedSize]; var decompressedSize = Snappy.Decompress(compressedBytes, decompressedBytes); diff --git a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs index 5ff4a0f0845..da18a88012d 100644 --- a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs +++ b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.Helpers.cs @@ -401,11 +401,11 @@ public void Dispose() } } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { try { - _connection.Open(cancellationToken); + _connection.Open(operationContext); SetEffectiveGenerationIfRequired(_connection.Description); } catch (MongoConnectionException ex) @@ -416,11 +416,11 @@ public void Open(CancellationToken cancellationToken) } } - public async Task OpenAsync(CancellationToken cancellationToken) + public async Task OpenAsync(OperationContext operationContext) { try { - await _connection.OpenAsync(cancellationToken).ConfigureAwait(false); + await _connection.OpenAsync(operationContext).ConfigureAwait(false); SetEffectiveGenerationIfRequired(_connection.Description); } catch (MongoConnectionException ex) @@ -435,11 +435,11 @@ public async Task OpenAsync(CancellationToken cancellationToken) public Task ReauthenticateAsync(CancellationToken cancellationToken) => _connection.ReauthenticateAsync(cancellationToken); - public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { try { - return _connection.ReceiveMessage(responseTo, encoderSelector, messageEncoderSettings, cancellationToken); + return _connection.ReceiveMessage(operationContext, responseTo, encoderSelector, messageEncoderSettings); } catch (MongoConnectionException ex) { @@ -448,11 +448,11 @@ public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector en } } - public async Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { try { - return await _connection.ReceiveMessageAsync(responseTo, encoderSelector, messageEncoderSettings, cancellationToken).ConfigureAwait(false); + return await _connection.ReceiveMessageAsync(operationContext, responseTo, encoderSelector, messageEncoderSettings).ConfigureAwait(false); } catch (MongoConnectionException ex) { @@ -461,11 +461,11 @@ public async Task ReceiveMessageAsync(int responseTo, IMessageE } } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { try { - _connection.SendMessage(message, messageEncoderSettings, cancellationToken); + _connection.SendMessage(operationContext, message, messageEncoderSettings); } catch (MongoConnectionException ex) { @@ -474,11 +474,11 @@ public void SendMessage(RequestMessage message, MessageEncoderSettings messageEn } } - public async Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { try { - await _connection.SendMessageAsync(message, messageEncoderSettings, cancellationToken).ConfigureAwait(false); + await _connection.SendMessageAsync(operationContext, message, messageEncoderSettings).ConfigureAwait(false); } catch (MongoConnectionException ex) { @@ -587,16 +587,16 @@ public IConnectionHandle Fork() return new AcquiredConnection(_connectionPool, _reference); } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { ThrowIfDisposed(); - _reference.Instance.Open(cancellationToken); + _reference.Instance.Open(operationContext); } - public Task OpenAsync(CancellationToken cancellationToken) + public Task OpenAsync(OperationContext operationContext) { ThrowIfDisposed(); - return _reference.Instance.OpenAsync(cancellationToken); + return _reference.Instance.OpenAsync(operationContext); } public void Reauthenticate(CancellationToken cancellationToken) @@ -611,28 +611,28 @@ public Task ReauthenticateAsync(CancellationToken cancellationToken) return _reference.Instance.ReauthenticateAsync(cancellationToken); } - public Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - return _reference.Instance.ReceiveMessageAsync(responseTo, encoderSelector, messageEncoderSettings, cancellationToken); + return _reference.Instance.ReceiveMessageAsync(operationContext, responseTo, encoderSelector, messageEncoderSettings); } - public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - return _reference.Instance.ReceiveMessage(responseTo, encoderSelector, messageEncoderSettings, cancellationToken); + return _reference.Instance.ReceiveMessage(operationContext, responseTo, encoderSelector, messageEncoderSettings); } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - _reference.Instance.SendMessage(message, messageEncoderSettings, cancellationToken); + _reference.Instance.SendMessage(operationContext, message, messageEncoderSettings); } - public Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { ThrowIfDisposed(); - return _reference.Instance.SendMessageAsync(message, messageEncoderSettings, cancellationToken); + return _reference.Instance.SendMessageAsync(operationContext, message, messageEncoderSettings); } public void SetCheckOutReasonIfNotAlreadySet(CheckOutReason reason) @@ -974,8 +974,7 @@ private PooledConnection CreateOpenedInternal(OperationContext operationContext) { var stopwatch = StartCreating(operationContext); - // TODO: CSOT add support of CSOT timeout in connection open code too. - _connection.Open(operationContext.CancellationToken); + _connection.Open(operationContext); FinishCreating(_connection.Description, stopwatch); @@ -986,8 +985,7 @@ private async Task CreateOpenedInternalAsync(OperationContext { var stopwatch = StartCreating(operationContext); - // TODO: CSOT add support of CSOT timeout in connection open code too. - await _connection.OpenAsync(operationContext.CancellationToken).ConfigureAwait(false); + await _connection.OpenAsync(operationContext).ConfigureAwait(false); FinishCreating(_connection.Description, stopwatch); diff --git a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs index 7489d714081..c22fcb2f431 100644 --- a/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs +++ b/src/MongoDB.Driver/Core/ConnectionPools/ExclusiveConnectionPool.cs @@ -141,16 +141,16 @@ public int UsedCount // public methods public IConnectionHandle AcquireConnection(OperationContext operationContext) { - operationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); + using var waitQueueOperationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); using var helper = new AcquireConnectionHelper(this); - return helper.AcquireConnection(operationContext); + return helper.AcquireConnection(waitQueueOperationContext); } public async Task AcquireConnectionAsync(OperationContext operationContext) { - operationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); + using var waitQueueOperationContext = operationContext.WithTimeout(Settings.WaitQueueTimeout); using var helper = new AcquireConnectionHelper(this); - return await helper.AcquireConnectionAsync(operationContext).ConfigureAwait(false); + return await helper.AcquireConnectionAsync(waitQueueOperationContext).ConfigureAwait(false); } public void Clear(bool closeInUseConnections = false) diff --git a/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs b/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs index 210e33cc14c..2c49c91891c 100644 --- a/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs +++ b/src/MongoDB.Driver/Core/Connections/BinaryConnection.cs @@ -62,7 +62,8 @@ internal sealed class BinaryConnection : IConnection private readonly EventLogger _eventLogger; // constructors - public BinaryConnection(ServerId serverId, + public BinaryConnection( + ServerId serverId, EndPoint endPoint, ConnectionSettings settings, IStreamFactory streamFactory, @@ -203,9 +204,9 @@ private void EnsureMessageSizeIsValid(int messageSize) } } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { - ThrowIfCancelledOrDisposed(cancellationToken); + ThrowIfCancelledOrDisposed(operationContext); TaskCompletionSource taskCompletionSource = null; var connecting = false; @@ -225,7 +226,7 @@ public void Open(CancellationToken cancellationToken) { try { - OpenHelper(cancellationToken); + OpenHelper(operationContext); taskCompletionSource.TrySetResult(true); } catch (Exception ex) @@ -240,33 +241,37 @@ public void Open(CancellationToken cancellationToken) } } - public Task OpenAsync(CancellationToken cancellationToken) + public Task OpenAsync(OperationContext operationContext) { - ThrowIfCancelledOrDisposed(cancellationToken); + ThrowIfCancelledOrDisposed(operationContext); lock (_openLock) { if (_state.TryChange(State.Initial, State.Connecting)) { _openedAtUtc = DateTime.UtcNow; - _openTask = OpenHelperAsync(cancellationToken); + _openTask = OpenHelperAsync(operationContext); } return _openTask; } } - private void OpenHelper(CancellationToken cancellationToken) + private void OpenHelper(OperationContext operationContext) { var helper = new OpenConnectionHelper(this); ConnectionDescription handshakeDescription = null; try { helper.OpeningConnection(); - _stream = _streamFactory.CreateStream(_endPoint, cancellationToken); +#pragma warning disable CS0618 // Type or member is obsolete + _stream = _streamFactory.CreateStream(_endPoint, operationContext.CombinedCancellationToken); +#pragma warning restore CS0618 // Type or member is obsolete helper.InitializingConnection(); - _connectionInitializerContext = _connectionInitializer.SendHello(this, cancellationToken); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + _connectionInitializerContext = _connectionInitializer.SendHello(this, operationContext.CancellationToken); handshakeDescription = _connectionInitializerContext.Description; - _connectionInitializerContext = _connectionInitializer.Authenticate(this, _connectionInitializerContext, cancellationToken); + // TODO: CSOT: Implement operation context support for Auth + _connectionInitializerContext = _connectionInitializer.Authenticate(this, _connectionInitializerContext, operationContext.CancellationToken); _description = _connectionInitializerContext.Description; _sendCompressorType = ChooseSendCompressorTypeIfAny(_description); @@ -281,18 +286,22 @@ private void OpenHelper(CancellationToken cancellationToken) } } - private async Task OpenHelperAsync(CancellationToken cancellationToken) + private async Task OpenHelperAsync(OperationContext operationContext) { var helper = new OpenConnectionHelper(this); ConnectionDescription handshakeDescription = null; try { helper.OpeningConnection(); - _stream = await _streamFactory.CreateStreamAsync(_endPoint, cancellationToken).ConfigureAwait(false); +#pragma warning disable CS0618 // Type or member is obsolete + _stream = await _streamFactory.CreateStreamAsync(_endPoint, operationContext.CombinedCancellationToken).ConfigureAwait(false); +#pragma warning restore CS0618 // Type or member is obsolete helper.InitializingConnection(); - _connectionInitializerContext = await _connectionInitializer.SendHelloAsync(this, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + _connectionInitializerContext = await _connectionInitializer.SendHelloAsync(this, operationContext.CancellationToken).ConfigureAwait(false); handshakeDescription = _connectionInitializerContext.Description; - _connectionInitializerContext = await _connectionInitializer.AuthenticateAsync(this, _connectionInitializerContext, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: Implement operation context support for Auth + _connectionInitializerContext = await _connectionInitializer.AuthenticateAsync(this, _connectionInitializerContext, operationContext.CancellationToken).ConfigureAwait(false); _description = _connectionInitializerContext.Description; _sendCompressorType = ChooseSendCompressorTypeIfAny(_description); helper.OpenedConnection(); @@ -326,20 +335,19 @@ private void InvalidateAuthenticator() } } - private IByteBuffer ReceiveBuffer(CancellationToken cancellationToken) + private IByteBuffer ReceiveBuffer(OperationContext operationContext) { try { var messageSizeBytes = new byte[4]; - var readTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.ReadTimeout) : Timeout.InfiniteTimeSpan; - _stream.ReadBytes(messageSizeBytes, 0, 4, readTimeout, cancellationToken); + _stream.ReadBytes(operationContext, messageSizeBytes, 0, 4); var messageSize = BinaryPrimitives.ReadInt32LittleEndian(messageSizeBytes); EnsureMessageSizeIsValid(messageSize); var inputBufferChunkSource = new InputBufferChunkSource(BsonChunkPool.Default); var buffer = ByteBufferFactory.Create(inputBufferChunkSource, messageSize); buffer.Length = messageSize; buffer.SetBytes(0, messageSizeBytes, 0, 4); - _stream.ReadBytes(buffer, 4, messageSize - 4, readTimeout, cancellationToken); + _stream.ReadBytes(operationContext, buffer, 4, messageSize - 4); _lastUsedAtUtc = DateTime.UtcNow; buffer.MakeReadOnly(); return buffer; @@ -352,9 +360,9 @@ private IByteBuffer ReceiveBuffer(CancellationToken cancellationToken) } } - private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellationToken) + private IByteBuffer ReceiveBuffer(OperationContext operationContext, int responseTo) { - using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, cancellationToken)) + using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, operationContext.RemainingTimeout, operationContext.CancellationToken)) { var messageTask = _dropbox.GetMessageAsync(responseTo); try @@ -370,7 +378,7 @@ private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellation { try { - var buffer = ReceiveBuffer(cancellationToken); + var buffer = ReceiveBuffer(operationContext); _dropbox.AddMessage(buffer); } catch (Exception ex) @@ -383,7 +391,7 @@ private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellation return _dropbox.RemoveMessage(responseTo); // also propagates exception if any } - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); } } catch @@ -396,20 +404,19 @@ private IByteBuffer ReceiveBuffer(int responseTo, CancellationToken cancellation } } - private async Task ReceiveBufferAsync(CancellationToken cancellationToken) + private async Task ReceiveBufferAsync(OperationContext operationContext) { try { var messageSizeBytes = new byte[4]; - var readTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.ReadTimeout) : Timeout.InfiniteTimeSpan; - await _stream.ReadBytesAsync(messageSizeBytes, 0, 4, readTimeout, cancellationToken).ConfigureAwait(false); + await _stream.ReadBytesAsync(operationContext, messageSizeBytes, 0, 4).ConfigureAwait(false); var messageSize = BinaryPrimitives.ReadInt32LittleEndian(messageSizeBytes); EnsureMessageSizeIsValid(messageSize); var inputBufferChunkSource = new InputBufferChunkSource(BsonChunkPool.Default); var buffer = ByteBufferFactory.Create(inputBufferChunkSource, messageSize); buffer.Length = messageSize; buffer.SetBytes(0, messageSizeBytes, 0, 4); - await _stream.ReadBytesAsync(buffer, 4, messageSize - 4, readTimeout, cancellationToken).ConfigureAwait(false); + await _stream.ReadBytesAsync(operationContext, buffer, 4, messageSize - 4).ConfigureAwait(false); _lastUsedAtUtc = DateTime.UtcNow; buffer.MakeReadOnly(); return buffer; @@ -422,9 +429,9 @@ private async Task ReceiveBufferAsync(CancellationToken cancellatio } } - private async Task ReceiveBufferAsync(int responseTo, CancellationToken cancellationToken) + private async Task ReceiveBufferAsync(OperationContext operationContext, int responseTo) { - using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, cancellationToken)) + using (var receiveLockRequest = new SemaphoreSlimRequest(_receiveLock, operationContext.RemainingTimeout, operationContext.CancellationToken)) { var messageTask = _dropbox.GetMessageAsync(responseTo); try @@ -435,12 +442,12 @@ private async Task ReceiveBufferAsync(int responseTo, CancellationT return _dropbox.RemoveMessage(responseTo); // also propagates exception if any } - receiveLockRequest.Task.GetAwaiter().GetResult(); // propagate exceptions + await receiveLockRequest.Task.ConfigureAwait(false); // propagate exceptions while (true) { try { - var buffer = await ReceiveBufferAsync(cancellationToken).ConfigureAwait(false); + var buffer = await ReceiveBufferAsync(operationContext).ConfigureAwait(false); _dropbox.AddMessage(buffer); } catch (Exception ex) @@ -453,7 +460,7 @@ private async Task ReceiveBufferAsync(int responseTo, CancellationT return _dropbox.RemoveMessage(responseTo); // also propagates exception if any } - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); } } catch @@ -467,21 +474,21 @@ private async Task ReceiveBufferAsync(int responseTo, CancellationT } public ResponseMessage ReceiveMessage( + OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) + MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(encoderSelector, nameof(encoderSelector)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new ReceiveMessageHelper(this, responseTo, messageEncoderSettings, _compressorSource); try { helper.ReceivingMessage(); - using (var buffer = ReceiveBuffer(responseTo, cancellationToken)) + using (var buffer = ReceiveBuffer(operationContext, responseTo)) { - var message = helper.DecodeMessage(buffer, encoderSelector, cancellationToken); + var message = helper.DecodeMessage(operationContext, buffer, encoderSelector); helper.ReceivedMessage(buffer, message); return message; } @@ -494,22 +501,20 @@ public ResponseMessage ReceiveMessage( } } - public async Task ReceiveMessageAsync( - int responseTo, + public async Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) + MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(encoderSelector, nameof(encoderSelector)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new ReceiveMessageHelper(this, responseTo, messageEncoderSettings, _compressorSource); try { helper.ReceivingMessage(); - using (var buffer = await ReceiveBufferAsync(responseTo, cancellationToken).ConfigureAwait(false)) + using (var buffer = await ReceiveBufferAsync(operationContext, responseTo).ConfigureAwait(false)) { - var message = helper.DecodeMessage(buffer, encoderSelector, cancellationToken); + var message = helper.DecodeMessage(operationContext, buffer, encoderSelector); helper.ReceivedMessage(buffer, message); return message; } @@ -522,9 +527,9 @@ public async Task ReceiveMessageAsync( } } - private void SendBuffer(IByteBuffer buffer, CancellationToken cancellationToken) + private void SendBuffer(OperationContext operationContext, IByteBuffer buffer) { - _sendLock.Wait(cancellationToken); + _sendLock.Wait(operationContext.RemainingTimeout, operationContext.CancellationToken); try { if (_state.Value == State.Failed) @@ -534,8 +539,7 @@ private void SendBuffer(IByteBuffer buffer, CancellationToken cancellationToken) try { - var writeTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.WriteTimeout) : Timeout.InfiniteTimeSpan; - _stream.WriteBytes(buffer, 0, buffer.Length, writeTimeout, cancellationToken); + _stream.WriteBytes(operationContext, buffer, 0, buffer.Length); _lastUsedAtUtc = DateTime.UtcNow; } catch (Exception ex) @@ -551,9 +555,9 @@ private void SendBuffer(IByteBuffer buffer, CancellationToken cancellationToken) } } - private async Task SendBufferAsync(IByteBuffer buffer, CancellationToken cancellationToken) + private async Task SendBufferAsync(OperationContext operationContext, IByteBuffer buffer) { - await _sendLock.WaitAsync(cancellationToken).ConfigureAwait(false); + await _sendLock.WaitAsync(operationContext.RemainingTimeout, operationContext.CancellationToken).ConfigureAwait(false); try { if (_state.Value == State.Failed) @@ -563,8 +567,7 @@ private async Task SendBufferAsync(IByteBuffer buffer, CancellationToken cancell try { - var writeTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.WriteTimeout) : Timeout.InfiniteTimeSpan; - await _stream.WriteBytesAsync(buffer, 0, buffer.Length, writeTimeout, cancellationToken).ConfigureAwait(false); + await _stream.WriteBytesAsync(operationContext, buffer, 0, buffer.Length).ConfigureAwait(false); _lastUsedAtUtc = DateTime.UtcNow; } catch (Exception ex) @@ -580,16 +583,16 @@ private async Task SendBufferAsync(IByteBuffer buffer, CancellationToken cancell } } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(message, nameof(message)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new SendMessageHelper(this, message, messageEncoderSettings); try { helper.EncodingMessage(); - using (var uncompressedBuffer = helper.EncodeMessage(cancellationToken, out var sentMessage)) + using (var uncompressedBuffer = helper.EncodeMessage(operationContext, out var sentMessage)) { helper.SendingMessage(uncompressedBuffer); int sentLength; @@ -597,13 +600,13 @@ public void SendMessage(RequestMessage message, MessageEncoderSettings messageEn { using (var compressedBuffer = CompressMessage(sentMessage, uncompressedBuffer, messageEncoderSettings)) { - SendBuffer(compressedBuffer, cancellationToken); + SendBuffer(operationContext, compressedBuffer); sentLength = compressedBuffer.Length; } } else { - SendBuffer(uncompressedBuffer, cancellationToken); + SendBuffer(operationContext, uncompressedBuffer); sentLength = uncompressedBuffer.Length; } helper.SentMessage(sentLength); @@ -617,16 +620,16 @@ public void SendMessage(RequestMessage message, MessageEncoderSettings messageEn } } - public async Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { Ensure.IsNotNull(message, nameof(message)); - ThrowIfCancelledOrDisposedOrNotOpen(cancellationToken); + ThrowIfCancelledOrDisposedOrNotOpen(operationContext); var helper = new SendMessageHelper(this, message, messageEncoderSettings); try { helper.EncodingMessage(); - using (var uncompressedBuffer = helper.EncodeMessage(cancellationToken, out var sentMessage)) + using (var uncompressedBuffer = helper.EncodeMessage(operationContext, out var sentMessage)) { helper.SendingMessage(uncompressedBuffer); int sentLength; @@ -634,13 +637,13 @@ public async Task SendMessageAsync(RequestMessage message, MessageEncoderSetting { using (var compressedBuffer = CompressMessage(sentMessage, uncompressedBuffer, messageEncoderSettings)) { - await SendBufferAsync(compressedBuffer, cancellationToken).ConfigureAwait(false); + await SendBufferAsync(operationContext, compressedBuffer).ConfigureAwait(false); sentLength = compressedBuffer.Length; } } else { - await SendBufferAsync(uncompressedBuffer, cancellationToken).ConfigureAwait(false); + await SendBufferAsync(operationContext, uncompressedBuffer).ConfigureAwait(false); sentLength = uncompressedBuffer.Length; } helper.SentMessage(sentLength); @@ -717,15 +720,15 @@ private void CompressMessage( compressedMessageEncoder.WriteMessage(compressedMessage); } - private void ThrowIfCancelledOrDisposed(CancellationToken cancellationToken = default) + private void ThrowIfCancelledOrDisposed(OperationContext operationContext) { - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); ThrowIfDisposed(); } - private void ThrowIfCancelledOrDisposedOrNotOpen(CancellationToken cancellationToken) + private void ThrowIfCancelledOrDisposedOrNotOpen(OperationContext operationContext) { - ThrowIfCancelledOrDisposed(cancellationToken); + ThrowIfCancelledOrDisposed(operationContext); if (_state.Value == State.Failed) { throw new MongoConnectionClosedException(_connectionId); @@ -905,9 +908,9 @@ public ReceiveMessageHelper(BinaryConnection connection, int responseTo, Message _messageEncoderSettings = messageEncoderSettings; } - public ResponseMessage DecodeMessage(IByteBuffer buffer, IMessageEncoderSelector encoderSelector, CancellationToken cancellationToken) + public ResponseMessage DecodeMessage(OperationContext operationContext, IByteBuffer buffer, IMessageEncoderSelector encoderSelector) { - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); _stopwatch.Stop(); _networkDuration = _stopwatch.Elapsed; @@ -992,10 +995,10 @@ public SendMessageHelper(BinaryConnection connection, RequestMessage message, Me _commandStopwatch = Stopwatch.StartNew(); } - public IByteBuffer EncodeMessage(CancellationToken cancellationToken, out RequestMessage sentMessage) + public IByteBuffer EncodeMessage(OperationContext operationContext, out RequestMessage sentMessage) { sentMessage = null; - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); var serializationStopwatch = Stopwatch.StartNew(); var outputBufferChunkSource = new OutputBufferChunkSource(BsonChunkPool.Default); @@ -1012,7 +1015,7 @@ public IByteBuffer EncodeMessage(CancellationToken cancellationToken, out Reques // Encoding messages includes serializing the // documents, so encoding message could be expensive // and worthy of us honoring cancellation here. - cancellationToken.ThrowIfCancellationRequested(); + operationContext.ThrowIfTimedOutOrCanceled(); buffer.Length = (int)stream.Length; buffer.MakeReadOnly(); diff --git a/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs b/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs index 851fd96b82f..fdf95dbfe05 100644 --- a/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs +++ b/src/MongoDB.Driver/Core/Connections/ConnectionInitializer.cs @@ -68,7 +68,8 @@ public ConnectionInitializerContext Authenticate(IConnection connection, Connect try { var getLastErrorProtocol = CreateGetLastErrorProtocol(_serverApi); - var getLastErrorResult = getLastErrorProtocol.Execute(connection, cancellationToken); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + var getLastErrorResult = getLastErrorProtocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); description = UpdateConnectionIdWithServerValue(description, getLastErrorResult); } @@ -103,8 +104,9 @@ public async Task AuthenticateAsync(IConnection co try { var getLastErrorProtocol = CreateGetLastErrorProtocol(_serverApi); + // TODO: CSOT: Implement operation context support for MongoDB Handshake var getLastErrorResult = await getLastErrorProtocol - .ExecuteAsync(connection, cancellationToken) + .ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection) .ConfigureAwait(false); description = UpdateConnectionIdWithServerValue(description, getLastErrorResult); diff --git a/src/MongoDB.Driver/Core/Connections/HelloHelper.cs b/src/MongoDB.Driver/Core/Connections/HelloHelper.cs index 70194498f5c..2ebebe12078 100644 --- a/src/MongoDB.Driver/Core/Connections/HelloHelper.cs +++ b/src/MongoDB.Driver/Core/Connections/HelloHelper.cs @@ -90,7 +90,8 @@ internal static HelloResult GetResult( { try { - var helloResultDocument = helloProtocol.Execute(connection, cancellationToken); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + var helloResultDocument = helloProtocol.Execute(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection); return new HelloResult(helloResultDocument); } catch (MongoCommandException ex) when (ex.Code == 11) @@ -109,7 +110,8 @@ internal static async Task GetResultAsync( { try { - var helloResultDocument = await helloProtocol.ExecuteAsync(connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: Implement operation context support for MongoDB Handshake + var helloResultDocument = await helloProtocol.ExecuteAsync(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), connection).ConfigureAwait(false); return new HelloResult(helloResultDocument); } catch (MongoCommandException ex) when (ex.Code == 11) diff --git a/src/MongoDB.Driver/Core/Connections/IConnection.cs b/src/MongoDB.Driver/Core/Connections/IConnection.cs index a82dfc3eda3..5a7af78169f 100644 --- a/src/MongoDB.Driver/Core/Connections/IConnection.cs +++ b/src/MongoDB.Driver/Core/Connections/IConnection.cs @@ -14,7 +14,6 @@ */ using System; -using System.Collections.Generic; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -33,15 +32,16 @@ internal interface IConnection : IDisposable bool IsExpired { get; } ConnectionSettings Settings { get; } + // TODO: CSOT: remove this in scope of MongoDB Handshake void SetReadTimeout(TimeSpan timeout); - void Open(CancellationToken cancellationToken); - Task OpenAsync(CancellationToken cancellationToken); + void Open(OperationContext operationContext); + Task OpenAsync(OperationContext operationContext); void Reauthenticate(CancellationToken cancellationToken); Task ReauthenticateAsync(CancellationToken cancellationToken); - ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); - Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); - void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); - Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken); + ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings); + Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings); + void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings); + Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings); } internal interface IConnectionHandle : IConnection diff --git a/src/MongoDB.Driver/Core/Connections/IStreamFactory.cs b/src/MongoDB.Driver/Core/Connections/IStreamFactory.cs index e8540e391f4..cd378d8cbd1 100644 --- a/src/MongoDB.Driver/Core/Connections/IStreamFactory.cs +++ b/src/MongoDB.Driver/Core/Connections/IStreamFactory.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.IO; using System.Net; using System.Threading; diff --git a/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs b/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs index eed0ee4a21e..6ef047c725a 100644 --- a/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs +++ b/src/MongoDB.Driver/Core/Misc/SemaphoreSlimRequest.cs @@ -39,12 +39,23 @@ public sealed class SemaphoreSlimRequest : IDisposable /// The semaphore. /// The cancellation token. public SemaphoreSlimRequest(SemaphoreSlim semaphore, CancellationToken cancellationToken) + : this(semaphore, Timeout.InfiniteTimeSpan, cancellationToken) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The semaphore. + /// The timeout. + /// The cancellation token. + public SemaphoreSlimRequest(SemaphoreSlim semaphore, TimeSpan timeout, CancellationToken cancellationToken) { _semaphore = Ensure.IsNotNull(semaphore, nameof(semaphore)); _disposeCancellationTokenSource = new CancellationTokenSource(); _linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposeCancellationTokenSource.Token); - _task = semaphore.WaitAsync(_linkedCancellationTokenSource.Token); + _task = semaphore.WaitAsync(timeout, _linkedCancellationTokenSource.Token); } // public properties @@ -56,7 +67,7 @@ public SemaphoreSlimRequest(SemaphoreSlim semaphore, CancellationToken cancellat /// public Task Task => _task; - // public methods + // public methods /// public void Dispose() { diff --git a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs index 1cb4cd5181f..c135a232dec 100644 --- a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs +++ b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs @@ -109,16 +109,20 @@ public static async Task ReadAsync(this Stream stream, byte[] buffer, int o } } - public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static void ReadBytes(this Stream stream, OperationContext operationContext, byte[] buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsOperationTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { - var bytesRead = stream.Read(buffer, offset, count, timeout, cancellationToken); + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; + var bytesRead = stream.Read(buffer, offset, count, timeout, operationContext.CancellationToken); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -128,18 +132,22 @@ public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int } } - public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static void ReadBytes(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsOperationTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToRead = Math.Min(count, backingBytes.Count); - var bytesRead = stream.Read(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, cancellationToken); + var bytesRead = stream.Read(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, operationContext.CancellationToken); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -149,16 +157,20 @@ public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset, } } - public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task ReadBytesAsync(this Stream stream, OperationContext operationContext, byte[] buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsOperationTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { - var bytesRead = await stream.ReadAsync(buffer, offset, count, timeout, cancellationToken).ConfigureAwait(false); + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; + var bytesRead = await stream.ReadAsync(buffer, offset, count, timeout, operationContext.CancellationToken).ConfigureAwait(false); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -168,18 +180,22 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int o } } - public static async Task ReadBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task ReadBytesAsync(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsOperationTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.ReadTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToRead = Math.Min(count, backingBytes.Count); - var bytesRead = await stream.ReadAsync(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, cancellationToken).ConfigureAwait(false); + var bytesRead = await stream.ReadAsync(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, operationContext.CancellationToken).ConfigureAwait(false); if (bytesRead == 0) { throw new EndOfStreamException(); @@ -264,36 +280,43 @@ public static async Task WriteAsync(this Stream stream, byte[] buffer, int offse } } - public static void WriteBytes(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static void WriteBytes(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsOperationTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.WriteTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { - cancellationToken.ThrowIfCancellationRequested(); + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToWrite = Math.Min(count, backingBytes.Count); - stream.Write(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, cancellationToken); + stream.Write(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, operationContext.CancellationToken); offset += bytesToWrite; count -= bytesToWrite; } } - public static async Task WriteBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + public static async Task WriteBytesAsync(this Stream stream, OperationContext operationContext, IByteBuffer buffer, int offset, int count) { Ensure.IsNotNull(stream, nameof(stream)); Ensure.IsNotNull(buffer, nameof(buffer)); Ensure.IsBetween(offset, 0, buffer.Length, nameof(offset)); Ensure.IsBetween(count, 0, buffer.Length - offset, nameof(count)); + var hasOperationTimeout = operationContext.IsOperationTimeoutConfigured(); + var streamTimeout = stream.CanTimeout ? TimeSpan.FromMilliseconds(stream.WriteTimeout) : Timeout.InfiniteTimeSpan; + while (count > 0) { + var timeout = hasOperationTimeout ? operationContext.RemainingTimeout : streamTimeout; var backingBytes = buffer.AccessBackingBytes(offset); var bytesToWrite = Math.Min(count, backingBytes.Count); - await stream.WriteAsync(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, cancellationToken).ConfigureAwait(false); + await stream.WriteAsync(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, operationContext.CancellationToken).ConfigureAwait(false); offset += bytesToWrite; count -= bytesToWrite; } diff --git a/src/MongoDB.Driver/Core/Operations/AggregateToCollectionOperation.cs b/src/MongoDB.Driver/Core/Operations/AggregateToCollectionOperation.cs index 79c684d9752..25e2a45b014 100644 --- a/src/MongoDB.Driver/Core/Operations/AggregateToCollectionOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/AggregateToCollectionOperation.cs @@ -155,7 +155,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (BeginOperation()) using (var channelSource = binding.GetWriteChannelSource(operationContext, mayUseSecondary)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription, mayUseSecondary.EffectiveReadPreference); return operation.Execute(operationContext, channelBinding); @@ -170,7 +170,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (BeginOperation()) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext, mayUseSecondary).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription, mayUseSecondary.EffectiveReadPreference); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs b/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs index bb0b2daa92a..dd5c7e0ba9f 100644 --- a/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs +++ b/src/MongoDB.Driver/Core/Operations/AsyncCursor.cs @@ -219,7 +219,10 @@ private CursorBatch ExecuteGetMoreCommand(IChannelHandle channel, Can BsonDocument result; try { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); result = channel.Command( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -230,8 +233,7 @@ private CursorBatch ExecuteGetMoreCommand(IChannelHandle channel, Can null, // postWriteAction CommandResponseHandling.Return, __getMoreCommandResultSerializer, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); } catch (MongoCommandException ex) when (IsMongoCursorNotFoundException(ex)) { @@ -247,7 +249,10 @@ private async Task> ExecuteGetMoreCommandAsync(IChannelHa BsonDocument result; try { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); result = await channel.CommandAsync( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -258,8 +263,7 @@ private async Task> ExecuteGetMoreCommandAsync(IChannelHa null, // postWriteAction CommandResponseHandling.Return, __getMoreCommandResultSerializer, - _messageEncoderSettings, - cancellationToken).ConfigureAwait(false); + _messageEncoderSettings).ConfigureAwait(false); } catch (MongoCommandException ex) when (IsMongoCursorNotFoundException(ex)) { @@ -271,8 +275,11 @@ private async Task> ExecuteGetMoreCommandAsync(IChannelHa private void ExecuteKillCursorsCommand(IChannelHandle channel, CancellationToken cancellationToken) { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); var command = CreateKillCursorsCommand(); var result = channel.Command( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -283,16 +290,18 @@ private void ExecuteKillCursorsCommand(IChannelHandle channel, CancellationToken null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); ThrowIfKillCursorsCommandFailed(result, channel.ConnectionDescription.ConnectionId); } private async Task ExecuteKillCursorsCommandAsync(IChannelHandle channel, CancellationToken cancellationToken) { + // TODO: CSOT: Implement operation context support for Cursors + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken); var command = CreateKillCursorsCommand(); var result = await channel.CommandAsync( + operationContext, _channelSource.Session, null, // readPreference _collectionNamespace.DatabaseNamespace, @@ -303,8 +312,7 @@ private async Task ExecuteKillCursorsCommandAsync(IChannelHandle channel, Cancel null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - _messageEncoderSettings, - cancellationToken) + _messageEncoderSettings) .ConfigureAwait(false); ThrowIfKillCursorsCommandFailed(result, channel.ConnectionDescription.ConnectionId); diff --git a/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs b/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs index 62eca2d6992..e5f94ebe200 100644 --- a/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs +++ b/src/MongoDB.Driver/Core/Operations/CommandOperationBase.cs @@ -13,7 +13,6 @@ * limitations under the License. */ -using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.IO; @@ -85,11 +84,12 @@ public IBsonSerializer ResultSerializer get { return _resultSerializer; } } - protected TCommandResult ExecuteProtocol(IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference, CancellationToken cancellationToken) + protected TCommandResult ExecuteProtocol(OperationContext operationContext, IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference) { var additionalOptions = GetEffectiveAdditionalOptions(); return channel.Command( + operationContext, session, readPreference, _databaseNamespace, @@ -100,8 +100,7 @@ protected TCommandResult ExecuteProtocol(IChannelHandle channel, ICoreSessionHan null, // postWriteAction, CommandResponseHandling.Return, _resultSerializer, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); } protected TCommandResult ExecuteProtocol( @@ -112,15 +111,16 @@ protected TCommandResult ExecuteProtocol( { using (var channel = channelSource.GetChannel(operationContext)) { - return ExecuteProtocol(channel, session, readPreference, operationContext.CancellationToken); + return ExecuteProtocol(operationContext, channel, session, readPreference); } } - protected Task ExecuteProtocolAsync(IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference, CancellationToken cancellationToken) + protected Task ExecuteProtocolAsync(OperationContext operationContext, IChannelHandle channel, ICoreSessionHandle session, ReadPreference readPreference) { var additionalOptions = GetEffectiveAdditionalOptions(); return channel.CommandAsync( + operationContext, session, readPreference, _databaseNamespace, @@ -131,8 +131,7 @@ protected Task ExecuteProtocolAsync(IChannelHandle channel, ICor null, // postWriteAction, CommandResponseHandling.Return, _resultSerializer, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); } protected async Task ExecuteProtocolAsync( @@ -143,7 +142,7 @@ protected async Task ExecuteProtocolAsync( { using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) { - return await ExecuteProtocolAsync(channel, session, readPreference, operationContext.CancellationToken).ConfigureAwait(false); + return await ExecuteProtocolAsync(operationContext, channel, session, readPreference).ConfigureAwait(false); } } diff --git a/src/MongoDB.Driver/Core/Operations/CreateCollectionOperation.cs b/src/MongoDB.Driver/Core/Operations/CreateCollectionOperation.cs index 31b2f7be994..d4cc0d92923 100644 --- a/src/MongoDB.Driver/Core/Operations/CreateCollectionOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/CreateCollectionOperation.cs @@ -282,7 +282,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (var channel = channelSource.GetChannel(operationContext)) { EnsureServerIsValid(channel.ConnectionDescription.MaxWireVersion); - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); return operation.Execute(operationContext, channelBinding); @@ -299,7 +299,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) { EnsureServerIsValid(channel.ConnectionDescription.MaxWireVersion); - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/CreateIndexesOperation.cs b/src/MongoDB.Driver/Core/Operations/CreateIndexesOperation.cs index f4e071950ad..ef73bbda75d 100644 --- a/src/MongoDB.Driver/Core/Operations/CreateIndexesOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/CreateIndexesOperation.cs @@ -91,7 +91,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (BeginOperation()) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return operation.Execute(operationContext, channelBinding); @@ -103,7 +103,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (BeginOperation()) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/CreateSearchIndexesOperation.cs b/src/MongoDB.Driver/Core/Operations/CreateSearchIndexesOperation.cs index edc8ae04f77..481b6a9f3ec 100644 --- a/src/MongoDB.Driver/Core/Operations/CreateSearchIndexesOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/CreateSearchIndexesOperation.cs @@ -59,7 +59,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (EventContext.BeginOperation("createSearchIndexes")) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); return operation.Execute(operationContext, channelBinding); @@ -72,7 +72,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (EventContext.BeginOperation("createSearchIndexes")) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/CreateViewOperation.cs b/src/MongoDB.Driver/Core/Operations/CreateViewOperation.cs index 64a5d954bc2..18375035997 100644 --- a/src/MongoDB.Driver/Core/Operations/CreateViewOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/CreateViewOperation.cs @@ -92,7 +92,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return operation.Execute(operationContext, channelBinding); @@ -105,7 +105,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/DropCollectionOperation.cs b/src/MongoDB.Driver/Core/Operations/DropCollectionOperation.cs index d667367575a..30d0612d324 100644 --- a/src/MongoDB.Driver/Core/Operations/DropCollectionOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/DropCollectionOperation.cs @@ -102,7 +102,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (BeginOperation()) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); BsonDocument result; @@ -129,7 +129,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (BeginOperation()) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); BsonDocument result; diff --git a/src/MongoDB.Driver/Core/Operations/DropDatabaseOperation.cs b/src/MongoDB.Driver/Core/Operations/DropDatabaseOperation.cs index 39be674d49c..778b2434293 100644 --- a/src/MongoDB.Driver/Core/Operations/DropDatabaseOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/DropDatabaseOperation.cs @@ -71,7 +71,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (BeginOperation()) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); return operation.Execute(operationContext, channelBinding); @@ -85,7 +85,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (BeginOperation()) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/DropIndexOperation.cs b/src/MongoDB.Driver/Core/Operations/DropIndexOperation.cs index ef68ad071c5..27788bae392 100644 --- a/src/MongoDB.Driver/Core/Operations/DropIndexOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/DropIndexOperation.cs @@ -104,7 +104,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (BeginOperation()) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); BsonDocument result; @@ -131,7 +131,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (BeginOperation()) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session); BsonDocument result; diff --git a/src/MongoDB.Driver/Core/Operations/DropSearchIndexOperation.cs b/src/MongoDB.Driver/Core/Operations/DropSearchIndexOperation.cs index aff890be381..5322c69b57e 100644 --- a/src/MongoDB.Driver/Core/Operations/DropSearchIndexOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/DropSearchIndexOperation.cs @@ -69,7 +69,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (EventContext.BeginOperation("dropSearchIndex")) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); @@ -92,7 +92,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (EventContext.BeginOperation("dropSearchIndex")) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); diff --git a/src/MongoDB.Driver/Core/Operations/EndTransactionOperation.cs b/src/MongoDB.Driver/Core/Operations/EndTransactionOperation.cs index 0544b2498c3..d3ca20f66c7 100644 --- a/src/MongoDB.Driver/Core/Operations/EndTransactionOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/EndTransactionOperation.cs @@ -56,7 +56,7 @@ public virtual BsonDocument Execute(OperationContext operationContext, IReadBind using (var channelSource = binding.GetReadChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); return operation.Execute(operationContext, channelBinding); @@ -69,7 +69,7 @@ public virtual async Task ExecuteAsync(OperationContext operationC using (var channelSource = await binding.GetReadChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/FindAndModifyOperationBase.cs b/src/MongoDB.Driver/Core/Operations/FindAndModifyOperationBase.cs index d32198f81af..574e969ae12 100644 --- a/src/MongoDB.Driver/Core/Operations/FindAndModifyOperationBase.cs +++ b/src/MongoDB.Driver/Core/Operations/FindAndModifyOperationBase.cs @@ -122,7 +122,7 @@ public TResult ExecuteAttempt(OperationContext operationContext, RetryableWriteC var channelSource = context.ChannelSource; var channel = context.Channel; - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription, transactionNumber); using (var rawBsonDocument = operation.Execute(operationContext, channelBinding)) @@ -138,7 +138,7 @@ public async Task ExecuteAttemptAsync(OperationContext operationContext var channelSource = context.ChannelSource; var channel = context.Channel; - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription, transactionNumber); using (var rawBsonDocument = await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false)) diff --git a/src/MongoDB.Driver/Core/Operations/GroupOperation.cs b/src/MongoDB.Driver/Core/Operations/GroupOperation.cs index 4c987e09300..d4d24e5bad3 100644 --- a/src/MongoDB.Driver/Core/Operations/GroupOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/GroupOperation.cs @@ -143,7 +143,7 @@ public IEnumerable Execute(OperationContext operationContext, IReadBind Ensure.IsNotNull(binding, nameof(binding)); using (var channelSource = binding.GetReadChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadBinding(channelSource.Server, channel, binding.ReadPreference, binding.Session.Fork())) + using (var channelBinding = new ChannelReadBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.ReadPreference, binding.Session.Fork())) { var operation = CreateOperation(); return operation.Execute(operationContext, channelBinding); @@ -155,7 +155,7 @@ public async Task> ExecuteAsync(OperationContext operationC Ensure.IsNotNull(binding, nameof(binding)); using (var channelSource = await binding.GetReadChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadBinding(channelSource.Server, channel, binding.ReadPreference, binding.Session.Fork())) + using (var channelBinding = new ChannelReadBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.ReadPreference, binding.Session.Fork())) { var operation = CreateOperation(); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/MapReduceOperation.cs b/src/MongoDB.Driver/Core/Operations/MapReduceOperation.cs index a68b2227278..a75555ee89b 100644 --- a/src/MongoDB.Driver/Core/Operations/MapReduceOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/MapReduceOperation.cs @@ -93,7 +93,7 @@ public IAsyncCursor Execute(OperationContext operationContext, IReadBin using (var channelSource = binding.GetReadChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadBinding(channelSource.Server, channel, binding.ReadPreference, binding.Session.Fork())) + using (var channelBinding = new ChannelReadBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.ReadPreference, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); var result = operation.Execute(operationContext, channelBinding); @@ -108,7 +108,7 @@ public async Task> ExecuteAsync(OperationContext operation using (var channelSource = await binding.GetReadChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadBinding(channelSource.Server, channel, binding.ReadPreference, binding.Session.Fork())) + using (var channelBinding = new ChannelReadBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.ReadPreference, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); var result = await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/MapReduceOutputToCollectionOperation.cs b/src/MongoDB.Driver/Core/Operations/MapReduceOutputToCollectionOperation.cs index 99e32b1adcd..fb94367a629 100644 --- a/src/MongoDB.Driver/Core/Operations/MapReduceOutputToCollectionOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/MapReduceOutputToCollectionOperation.cs @@ -175,7 +175,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return operation.Execute(operationContext, channelBinding); @@ -189,7 +189,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs b/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs index 84305df6f25..711d31c37df 100644 --- a/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/ReadCommandOperation.cs @@ -84,12 +84,12 @@ public async Task ExecuteAsync(OperationContext operationContext public TCommandResult ExecuteAttempt(OperationContext operationContext, RetryableReadContext context, int attempt, long? transactionNumber) { - return ExecuteProtocol(context.Channel, context.Binding.Session, context.Binding.ReadPreference, operationContext.CancellationToken); + return ExecuteProtocol(operationContext, context.Channel, context.Binding.Session, context.Binding.ReadPreference); } public Task ExecuteAttemptAsync(OperationContext operationContext, RetryableReadContext context, int attempt, long? transactionNumber) { - return ExecuteProtocolAsync(context.Channel, context.Binding.Session, context.Binding.ReadPreference, operationContext.CancellationToken); + return ExecuteProtocolAsync(operationContext, context.Channel, context.Binding.Session, context.Binding.ReadPreference); } } } diff --git a/src/MongoDB.Driver/Core/Operations/RenameCollectionOperation.cs b/src/MongoDB.Driver/Core/Operations/RenameCollectionOperation.cs index aee3f2fed1c..f00a504e775 100644 --- a/src/MongoDB.Driver/Core/Operations/RenameCollectionOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/RenameCollectionOperation.cs @@ -89,7 +89,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (BeginOperation()) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return operation.Execute(operationContext, channelBinding); @@ -103,7 +103,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (BeginOperation()) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(channelBinding.Session, channel.ConnectionDescription); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs b/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs index bb3ec94e32d..bcb0e72b291 100644 --- a/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs +++ b/src/MongoDB.Driver/Core/Operations/RetryableWriteCommandOperationBase.cs @@ -115,8 +115,8 @@ public virtual Task ExecuteAsync(OperationContext operationContext public BsonDocument ExecuteAttempt(OperationContext operationContext, RetryableWriteContext context, int attempt, long? transactionNumber) { var args = GetCommandArgs(context, attempt, transactionNumber); - // TODO: CSOT implement timeout in Command Execution return context.Channel.Command( + operationContext, context.ChannelSource.Session, ReadPreference.Primary, _databaseNamespace, @@ -127,15 +127,14 @@ public BsonDocument ExecuteAttempt(OperationContext operationContext, RetryableW args.PostWriteAction, args.ResponseHandling, BsonDocumentSerializer.Instance, - args.MessageEncoderSettings, - operationContext.CancellationToken); + args.MessageEncoderSettings); } public Task ExecuteAttemptAsync(OperationContext operationContext, RetryableWriteContext context, int attempt, long? transactionNumber) { var args = GetCommandArgs(context, attempt, transactionNumber); - // TODO: CSOT implement timeout in Command Execution return context.Channel.CommandAsync( + operationContext, context.ChannelSource.Session, ReadPreference.Primary, _databaseNamespace, @@ -146,8 +145,7 @@ public Task ExecuteAttemptAsync(OperationContext operationContext, args.PostWriteAction, args.ResponseHandling, BsonDocumentSerializer.Instance, - args.MessageEncoderSettings, - operationContext.CancellationToken); + args.MessageEncoderSettings); } protected abstract BsonDocument CreateCommand(ICoreSessionHandle session, int attempt, long? transactionNumber); diff --git a/src/MongoDB.Driver/Core/Operations/UpdateSearchIndexOperation.cs b/src/MongoDB.Driver/Core/Operations/UpdateSearchIndexOperation.cs index 09496dd9d3b..e243f717a1e 100644 --- a/src/MongoDB.Driver/Core/Operations/UpdateSearchIndexOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/UpdateSearchIndexOperation.cs @@ -48,7 +48,7 @@ public BsonDocument Execute(OperationContext operationContext, IWriteBinding bin using (EventContext.BeginOperation("updateSearchIndex")) using (var channelSource = binding.GetWriteChannelSource(operationContext)) using (var channel = channelSource.GetChannel(operationContext)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); return operation.Execute(operationContext, channelBinding); @@ -60,7 +60,7 @@ public async Task ExecuteAsync(OperationContext operationContext, using (EventContext.BeginOperation("updateSearchIndex")) using (var channelSource = await binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)) using (var channel = await channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, binding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, binding.Session.Fork())) { var operation = CreateOperation(); return await operation.ExecuteAsync(operationContext, channelBinding).ConfigureAwait(false); diff --git a/src/MongoDB.Driver/Core/Servers/IServer.cs b/src/MongoDB.Driver/Core/Servers/IServer.cs index da1e6f49138..505b6cbe0bd 100644 --- a/src/MongoDB.Driver/Core/Servers/IServer.cs +++ b/src/MongoDB.Driver/Core/Servers/IServer.cs @@ -15,9 +15,9 @@ using System; using System.Net; -using System.Threading; using System.Threading.Tasks; -using MongoDB.Driver.Core.Bindings; +using MongoDB.Driver.Core.Clusters; +using MongoDB.Driver.Core.Connections; namespace MongoDB.Driver.Core.Servers { @@ -25,12 +25,16 @@ internal interface IServer { event EventHandler DescriptionChanged; + IClusterClock ClusterClock { get; } ServerDescription Description { get; } EndPoint EndPoint { get; } ServerId ServerId { get; } + ServerApi ServerApi { get; } - IChannelHandle GetChannel(OperationContext operationContext); - Task GetChannelAsync(OperationContext operationContext); + IConnectionHandle GetConnection(OperationContext operationContext); + Task GetConnectionAsync(OperationContext operationContext); + void ReturnConnection(IConnectionHandle connection); + void HandleChannelException(IConnection connection, Exception exception); } internal interface IClusterableServer : IServer, IDisposable diff --git a/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs b/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs index 23a306a9dcf..656a5f1bb94 100644 --- a/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs +++ b/src/MongoDB.Driver/Core/Servers/RoundTripTimeMonitor.cs @@ -170,8 +170,8 @@ private void InitializeConnection() { // if we are cancelling, it's because the server has // been shut down and we really don't need to wait. - roundTripTimeConnection.Open(_cancellationToken); - _cancellationToken.ThrowIfCancellationRequested(); + var operationContext = new OperationContext(Timeout.InfiniteTimeSpan, _cancellationToken); + roundTripTimeConnection.Open(operationContext); } catch { diff --git a/src/MongoDB.Driver/Core/Servers/Server.cs b/src/MongoDB.Driver/Core/Servers/Server.cs index c7fe1a94b28..76f23550894 100644 --- a/src/MongoDB.Driver/Core/Servers/Server.cs +++ b/src/MongoDB.Driver/Core/Servers/Server.cs @@ -14,15 +14,10 @@ */ using System; -using System.Collections.Generic; using System.Diagnostics; using System.Net; using System.Threading; using System.Threading.Tasks; -using MongoDB.Bson; -using MongoDB.Bson.IO; -using MongoDB.Bson.Serialization; -using MongoDB.Driver.Core.Bindings; using MongoDB.Driver.Core.Clusters; using MongoDB.Driver.Core.Configuration; using MongoDB.Driver.Core.ConnectionPools; @@ -30,9 +25,6 @@ using MongoDB.Driver.Core.Events; using MongoDB.Driver.Core.Logging; using MongoDB.Driver.Core.Misc; -using MongoDB.Driver.Core.WireProtocol; -using MongoDB.Driver.Core.WireProtocol.Messages; -using MongoDB.Driver.Core.WireProtocol.Messages.Encoders; namespace MongoDB.Driver.Core.Servers { @@ -82,6 +74,7 @@ public Server( public abstract ServerDescription Description { get; } public EndPoint EndPoint => _endPoint; public bool IsInitialized => _state.Value != State.Initial; + public ServerApi ServerApi => _serverApi; public ServerId ServerId => _serverId; protected EventLogger EventLogger => _eventLogger; @@ -104,10 +97,37 @@ public void Dispose() } } + public void HandleChannelException(IConnection connection, Exception ex) + { + if (!IsOpen() || ShouldIgnoreException(ex)) + { + return; + } + + ex = GetEffectiveException(ex); + + HandleAfterHandshakeCompletesException(connection, ex); + + bool ShouldIgnoreException(Exception ex) + { + // For most connection exceptions, we are going to immediately + // invalidate the server. However, we aren't going to invalidate + // because of OperationCanceledExceptions. We trust that the + // implementations of connection don't leave themselves in a state + // where they can't be used based on user cancellation. + return ex is OperationCanceledException; + } + + Exception GetEffectiveException(Exception ex) => + ex is AggregateException aggregateException && aggregateException.InnerExceptions.Count == 1 + ? aggregateException.InnerException + : ex; + } + public void HandleExceptionOnOpen(Exception exception) => HandleBeforeHandshakeCompletesException(exception); - public IChannelHandle GetChannel(OperationContext operationContext) + public IConnectionHandle GetConnection(OperationContext operationContext) { ThrowIfNotOpen(); @@ -115,8 +135,7 @@ public IChannelHandle GetChannel(OperationContext operationContext) { Interlocked.Increment(ref _outstandingOperationsCount); - var connection = _connectionPool.AcquireConnection(operationContext); - return new ServerChannel(this, connection); + return _connectionPool.AcquireConnection(operationContext); } catch { @@ -126,15 +145,14 @@ public IChannelHandle GetChannel(OperationContext operationContext) } } - public async Task GetChannelAsync(OperationContext operationContext) + public async Task GetConnectionAsync(OperationContext operationContext) { ThrowIfNotOpen(); try { Interlocked.Increment(ref _outstandingOperationsCount); - var connection = await _connectionPool.AcquireConnectionAsync(operationContext).ConfigureAwait(false); - return new ServerChannel(this, connection); + return await _connectionPool.AcquireConnectionAsync(operationContext).ConfigureAwait(false); } catch { @@ -172,6 +190,11 @@ public void Invalidate(string reasonInvalidated, TopologyVersion responseTopolog public abstract void RequestHeartbeat(); + public void ReturnConnection(IConnectionHandle connection) + { + Interlocked.Decrement(ref _outstandingOperationsCount); + } + // protected methods protected abstract void Invalidate(string reasonInvalidated, bool clearConnectionPool, TopologyVersion responseTopologyDescription); @@ -222,33 +245,6 @@ protected bool ShouldClearConnectionPoolForChannelException(Exception ex, int ma } // private methods - private void HandleChannelException(IConnection connection, Exception ex) - { - if (!IsOpen() || ShouldIgnoreException(ex)) - { - return; - } - - ex = GetEffectiveException(ex); - - HandleAfterHandshakeCompletesException(connection, ex); - - bool ShouldIgnoreException(Exception ex) - { - // For most connection exceptions, we are going to immediately - // invalidate the server. However, we aren't going to invalidate - // because of OperationCanceledExceptions. We trust that the - // implementations of connection don't leave themselves in a state - // where they can't be used based on user cancellation. - return ex is OperationCanceledException; - } - - Exception GetEffectiveException(Exception ex) => - ex is AggregateException aggregateException && aggregateException.InnerExceptions.Count == 1 - ? aggregateException.InnerException - : ex; - } - private bool IsOpen() => _state.Value == State.Open; private void ThrowIfDisposed() @@ -275,172 +271,5 @@ private static class State public const int Open = 1; public const int Disposed = 2; } - - private sealed class ServerChannel : IChannelHandle - { - // fields - private readonly IConnectionHandle _connection; - private readonly Server _server; - - private readonly InterlockedInt32 _state; - private readonly bool _decrementOperationsCount; - - // constructors - public ServerChannel(Server server, IConnectionHandle connection, bool decrementOperationsCount = true) - { - _server = server; - _connection = connection; - - _state = new InterlockedInt32(ChannelState.Initial); - _decrementOperationsCount = decrementOperationsCount; - } - - // properties - public IConnectionHandle Connection => _connection; - - public ConnectionDescription ConnectionDescription - { - get { return _connection.Description; } - } - - // methods - public TResult Command( - ICoreSession session, - ReadPreference readPreference, - DatabaseNamespace databaseNamespace, - BsonDocument command, - IEnumerable commandPayloads, - IElementNameValidator commandValidator, - BsonDocument additionalOptions, - Action postWriteAction, - CommandResponseHandling responseHandling, - IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) - { - var protocol = new CommandWireProtocol( - CreateClusterClockAdvancingCoreSession(session), - readPreference, - databaseNamespace, - command, - commandPayloads, - commandValidator, - additionalOptions, - postWriteAction, - responseHandling, - resultSerializer, - messageEncoderSettings, - _server._serverApi); - - return ExecuteProtocol(protocol, session, cancellationToken); - } - - public Task CommandAsync( - ICoreSession session, - ReadPreference readPreference, - DatabaseNamespace databaseNamespace, - BsonDocument command, - IEnumerable commandPayloads, - IElementNameValidator commandValidator, - BsonDocument additionalOptions, - Action postWriteAction, - CommandResponseHandling responseHandling, - IBsonSerializer resultSerializer, - MessageEncoderSettings messageEncoderSettings, - CancellationToken cancellationToken) - { - var protocol = new CommandWireProtocol( - CreateClusterClockAdvancingCoreSession(session), - readPreference, - databaseNamespace, - command, - commandPayloads, - commandValidator, - additionalOptions, - postWriteAction, - responseHandling, - resultSerializer, - messageEncoderSettings, - _server._serverApi); - - return ExecuteProtocolAsync(protocol, session, cancellationToken); - } - - public void Dispose() - { - if (_state.TryChange(ChannelState.Initial, ChannelState.Disposed)) - { - if (_decrementOperationsCount) - { - Interlocked.Decrement(ref _server._outstandingOperationsCount); - } - - _connection.Dispose(); - } - } - - private ICoreSession CreateClusterClockAdvancingCoreSession(ICoreSession session) - { - return new ClusterClockAdvancingCoreSession(session, _server.ClusterClock); - } - - private TResult ExecuteProtocol(IWireProtocol protocol, ICoreSession session, CancellationToken cancellationToken) - { - try - { - return protocol.Execute(_connection, cancellationToken); - } - catch (Exception ex) - { - MarkSessionDirtyIfNeeded(session, ex); - _server.HandleChannelException(_connection, ex); - throw; - } - } - - private async Task ExecuteProtocolAsync(IWireProtocol protocol, ICoreSession session, CancellationToken cancellationToken) - { - try - { - return await protocol.ExecuteAsync(_connection, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - MarkSessionDirtyIfNeeded(session, ex); - _server.HandleChannelException(_connection, ex); - throw; - } - } - - public IChannelHandle Fork() - { - ThrowIfDisposed(); - - return new ServerChannel(_server, _connection.Fork(), false); - } - - private void MarkSessionDirtyIfNeeded(ICoreSession session, Exception ex) - { - if (ex is MongoConnectionException) - { - session.MarkDirty(); - } - } - - private void ThrowIfDisposed() - { - if (_state.Value == ChannelState.Disposed) - { - throw new ObjectDisposedException(GetType().Name); - } - } - - // nested types - private static class ChannelState - { - public const int Initial = 0; - public const int Disposed = 1; - } - } } } diff --git a/src/MongoDB.Driver/Core/Servers/ServerChannel.cs b/src/MongoDB.Driver/Core/Servers/ServerChannel.cs new file mode 100644 index 00000000000..22577cac374 --- /dev/null +++ b/src/MongoDB.Driver/Core/Servers/ServerChannel.cs @@ -0,0 +1,201 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Driver.Core.Bindings; +using MongoDB.Driver.Core.Connections; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Core.WireProtocol; +using MongoDB.Driver.Core.WireProtocol.Messages; +using MongoDB.Driver.Core.WireProtocol.Messages.Encoders; + +namespace MongoDB.Driver.Core.Servers +{ + internal sealed class ServerChannel : IChannelHandle + { + // fields + private readonly IConnectionHandle _connection; + private readonly IServer _server; + private readonly TimeSpan _roundTripTime; + private readonly InterlockedInt32 _state; + private readonly bool _ownConnection; + + // constructors + public ServerChannel(IServer server, IConnectionHandle connection, TimeSpan roundTripTime, bool ownConnection = true) + { + _server = server; + _connection = connection; + _roundTripTime = roundTripTime; + _state = new InterlockedInt32(ChannelState.Initial); + _ownConnection = ownConnection; + } + + // properties + public IConnectionHandle Connection => _connection; + + public ConnectionDescription ConnectionDescription => _connection.Description; + + public TimeSpan RoundTripTimeout => _roundTripTime; + + // methods + public TResult Command( + OperationContext operationContext, + ICoreSession session, + ReadPreference readPreference, + DatabaseNamespace databaseNamespace, + BsonDocument command, + IEnumerable commandPayloads, + IElementNameValidator commandValidator, + BsonDocument additionalOptions, + Action postWriteAction, + CommandResponseHandling responseHandling, + IBsonSerializer resultSerializer, + MessageEncoderSettings messageEncoderSettings) + { + var protocol = new CommandWireProtocol( + CreateClusterClockAdvancingCoreSession(session), + readPreference, + databaseNamespace, + command, + commandPayloads, + commandValidator, + additionalOptions, + postWriteAction, + responseHandling, + resultSerializer, + messageEncoderSettings, + _server.ServerApi, + _roundTripTime); + + return ExecuteProtocol(operationContext, protocol, session); + } + + public Task CommandAsync( + OperationContext operationContext, + ICoreSession session, + ReadPreference readPreference, + DatabaseNamespace databaseNamespace, + BsonDocument command, + IEnumerable commandPayloads, + IElementNameValidator commandValidator, + BsonDocument additionalOptions, + Action postWriteAction, + CommandResponseHandling responseHandling, + IBsonSerializer resultSerializer, + MessageEncoderSettings messageEncoderSettings) + { + var protocol = new CommandWireProtocol( + CreateClusterClockAdvancingCoreSession(session), + readPreference, + databaseNamespace, + command, + commandPayloads, + commandValidator, + additionalOptions, + postWriteAction, + responseHandling, + resultSerializer, + messageEncoderSettings, + _server.ServerApi, + _roundTripTime); + + return ExecuteProtocolAsync(operationContext, protocol, session); + } + + public void Dispose() + { + if (_state.TryChange(ChannelState.Initial, ChannelState.Disposed)) + { + if (_ownConnection) + { + _server.ReturnConnection(_connection); + } + + _connection.Dispose(); + } + } + + private ICoreSession CreateClusterClockAdvancingCoreSession(ICoreSession session) + { + return new ClusterClockAdvancingCoreSession(session, _server.ClusterClock); + } + + private TResult ExecuteProtocol(OperationContext operationContext, IWireProtocol protocol, ICoreSession session) + { + try + { + return protocol.Execute(operationContext, _connection); + } + catch (Exception ex) + { + MarkSessionDirtyIfNeeded(session, ex); + _server.HandleChannelException(_connection, ex); + throw; + } + } + + private async Task ExecuteProtocolAsync(OperationContext operationContext, IWireProtocol protocol, ICoreSession session) + { + try + { + return await protocol.ExecuteAsync(operationContext, _connection).ConfigureAwait(false); + } + catch (Exception ex) + { + MarkSessionDirtyIfNeeded(session, ex); + _server.HandleChannelException(_connection, ex); + throw; + } + } + + public IChannelHandle Fork() + { + ThrowIfDisposed(); + + return new ServerChannel(_server, _connection.Fork(), _roundTripTime, false); + } + + private void MarkSessionDirtyIfNeeded(ICoreSession session, Exception ex) + { + if (ex is MongoConnectionException) + { + session.MarkDirty(); + } + } + + private void ThrowIfDisposed() + { + if (_state.Value == ChannelState.Disposed) + { + throw new ObjectDisposedException(GetType().Name); + } + } + + // nested types + private static class ChannelState + { + public const int Initial = 0; + public const int Disposed = 1; + } + } +} + + + diff --git a/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs b/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs index ca043166746..e2e7c64163a 100644 --- a/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs +++ b/src/MongoDB.Driver/Core/Servers/ServerMonitor.cs @@ -216,7 +216,8 @@ private IConnection InitializeConnection(CancellationToken cancellationToken) // { // if we are cancelling, it's because the server has // been shut down and we really don't need to wait. - connection.Open(cancellationToken); + // TODO: CSOT: Implement operation context support for Server Discovery and Monitoring + connection.Open(new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken)); _eventLoggerSdam.LogAndPublish(new ServerHeartbeatSucceededEvent(connection.ConnectionId, stopwatch.Elapsed, false, connection.Description.HelloResult.Wrapped)); } diff --git a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs index 159fe57c70c..ffcb5a90163 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs @@ -49,6 +49,7 @@ internal sealed class CommandUsingCommandMessageWireProtocol : I private readonly CommandResponseHandling _responseHandling; private readonly IBsonSerializer _resultSerializer; private readonly ServerApi _serverApi; + private readonly TimeSpan _serverRoundTripTime; private readonly ICoreSession _session; // streamable fields private bool _moreToCome = false; // MoreToCome from the previous response @@ -67,7 +68,8 @@ public CommandUsingCommandMessageWireProtocol( IBsonSerializer resultSerializer, MessageEncoderSettings messageEncoderSettings, Action postWriteAction, - ServerApi serverApi) + ServerApi serverApi, + TimeSpan serverRoundTripTime) { if (responseHandling != CommandResponseHandling.Return && responseHandling != CommandResponseHandling.NoResponseExpected && @@ -88,6 +90,7 @@ public CommandUsingCommandMessageWireProtocol( _messageEncoderSettings = messageEncoderSettings; _postWriteAction = postWriteAction; // can be null _serverApi = serverApi; // can be null + _serverRoundTripTime = serverRoundTripTime; if (messageEncoderSettings != null) { @@ -100,7 +103,7 @@ public CommandUsingCommandMessageWireProtocol( public bool MoreToCome => _moreToCome; // public methods - public TCommandResult Execute(IConnection connection, CancellationToken cancellationToken) + public TCommandResult Execute(OperationContext operationContext, IConnection connection) { try { @@ -113,19 +116,21 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella } else { - message = CreateCommandMessage(connection.Description); - message = AutoEncryptFieldsIfNecessary(message, connection, cancellationToken); + message = CreateCommandMessage(operationContext, connection.Description); + // TODO: CSOT: Propagate operationContext into Encryption + message = AutoEncryptFieldsIfNecessary(message, connection, operationContext.CancellationToken); responseTo = message.WrappedMessage.RequestId; } try { - return SendMessageAndProcessResponse(message, responseTo, connection, cancellationToken); + return SendMessageAndProcessResponse(operationContext, message, responseTo, connection); } catch (MongoCommandException commandException) when (RetryabilityHelper.IsReauthenticationRequested(commandException, _command)) { - connection.Reauthenticate(cancellationToken); - return SendMessageAndProcessResponse(message, responseTo, connection, cancellationToken); + // TODO: CSOT: support operationContext in auth + connection.Reauthenticate(operationContext.CancellationToken); + return SendMessageAndProcessResponse(operationContext, message, responseTo, connection); } } catch (Exception exception) @@ -137,7 +142,7 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella } } - public async Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken) + public async Task ExecuteAsync(OperationContext operationContext, IConnection connection) { try { @@ -150,19 +155,21 @@ public async Task ExecuteAsync(IConnection connection, Cancellat } else { - message = CreateCommandMessage(connection.Description); - message = await AutoEncryptFieldsIfNecessaryAsync(message, connection, cancellationToken).ConfigureAwait(false); + message = CreateCommandMessage(operationContext, connection.Description); + // TODO: CSOT: Propagate operationContext into Encryption + message = await AutoEncryptFieldsIfNecessaryAsync(message, connection, operationContext.CancellationToken).ConfigureAwait(false); responseTo = message.WrappedMessage.RequestId; } try { - return await SendMessageAndProcessResponseAsync(message, responseTo, connection, cancellationToken).ConfigureAwait(false); + return await SendMessageAndProcessResponseAsync(operationContext, message, responseTo, connection).ConfigureAwait(false); } catch (MongoCommandException commandException) when (RetryabilityHelper.IsReauthenticationRequested(commandException, _command)) { - await connection.ReauthenticateAsync(cancellationToken).ConfigureAwait(false); - return await SendMessageAndProcessResponseAsync(message, responseTo, connection, cancellationToken).ConfigureAwait(false); + // TODO: CSOT: support operationContext in auth + await connection.ReauthenticateAsync(operationContext.CancellationToken).ConfigureAwait(false); + return await SendMessageAndProcessResponseAsync(operationContext, message, responseTo, connection).ConfigureAwait(false); } } catch (Exception exception) @@ -253,11 +260,11 @@ private async Task AutoEncryptFieldsIfNecessaryAsync(Comm } } - private CommandRequestMessage CreateCommandMessage(ConnectionDescription connectionDescription) + private CommandRequestMessage CreateCommandMessage(OperationContext operationContext, ConnectionDescription connectionDescription) { var requestId = RequestMessage.GetNextRequestId(); var responseTo = 0; - var sections = CreateSections(connectionDescription); + var sections = CreateSections(operationContext, connectionDescription); var moreToComeRequest = _responseHandling == CommandResponseHandling.NoResponseExpected; @@ -270,9 +277,9 @@ private CommandRequestMessage CreateCommandMessage(ConnectionDescription connect return new CommandRequestMessage(wrappedMessage); } - private IEnumerable CreateSections(ConnectionDescription connectionDescription) + private IEnumerable CreateSections(OperationContext operationContext, ConnectionDescription connectionDescription) { - var type0Section = CreateType0Section(connectionDescription); + var type0Section = CreateType0Section(operationContext, connectionDescription); if (_commandPayloads == null) { return new[] { type0Section }; @@ -283,7 +290,7 @@ private IEnumerable CreateSections(ConnectionDescription } } - private Type0CommandMessageSection CreateType0Section(ConnectionDescription connectionDescription) + private Type0CommandMessageSection CreateType0Section(OperationContext operationContext, ConnectionDescription connectionDescription) { var extraElements = new List(); @@ -369,6 +376,12 @@ private Type0CommandMessageSection CreateType0Section(ConnectionDe } } + if (operationContext.IsOperationTimeoutConfigured()) + { + var serverTimeout = operationContext.RemainingTimeout - _serverRoundTripTime; + AddIfNotAlreadyAdded("maxTimeMS", (int)serverTimeout.TotalMilliseconds); + } + var elementAppendingSerializer = new ElementAppendingSerializer(BsonDocumentSerializer.Instance, extraElements); return new Type0CommandMessageSection(_command, elementAppendingSerializer); @@ -526,14 +539,15 @@ private void SaveResponseInfo(CommandResponseMessage response) _moreToCome = response.WrappedMessage.MoreToCome; } - private TCommandResult SendMessageAndProcessResponse(CommandRequestMessage message, int responseTo, IConnection connection, CancellationToken cancellationToken) + private TCommandResult SendMessageAndProcessResponse(OperationContext operationContext, CommandRequestMessage message, int responseTo, IConnection connection) { var responseExpected = true; if (message != null) { try { - connection.SendMessage(message, _messageEncoderSettings, cancellationToken); + ThrowIfRemainingTimeoutLessThenRoundTripTime(operationContext); + connection.SendMessage(operationContext, message, _messageEncoderSettings); } finally { @@ -549,8 +563,9 @@ private TCommandResult SendMessageAndProcessResponse(CommandRequestMessage messa if (responseExpected) { var encoderSelector = new CommandResponseMessageEncoderSelector(); - var response = (CommandResponseMessage)connection.ReceiveMessage(responseTo, encoderSelector, _messageEncoderSettings, cancellationToken); - response = AutoDecryptFieldsIfNecessary(response, cancellationToken); + var response = (CommandResponseMessage)connection.ReceiveMessage(operationContext, responseTo, encoderSelector, _messageEncoderSettings); + // TODO: CSOT: Propagate operationContext into Encryption + response = AutoDecryptFieldsIfNecessary(response, operationContext.CancellationToken); var result = ProcessResponse(connection.ConnectionId, response.WrappedMessage); SaveResponseInfo(response); return result; @@ -561,14 +576,15 @@ private TCommandResult SendMessageAndProcessResponse(CommandRequestMessage messa } } - private async Task SendMessageAndProcessResponseAsync(CommandRequestMessage message, int responseTo, IConnection connection, CancellationToken cancellationToken) + private async Task SendMessageAndProcessResponseAsync(OperationContext operationContext, CommandRequestMessage message, int responseTo, IConnection connection) { var responseExpected = true; if (message != null) { try { - await connection.SendMessageAsync(message, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); + ThrowIfRemainingTimeoutLessThenRoundTripTime(operationContext); + await connection.SendMessageAsync(operationContext, message, _messageEncoderSettings).ConfigureAwait(false); } finally { @@ -583,8 +599,9 @@ private async Task SendMessageAndProcessResponseAsync(CommandReq if (responseExpected) { var encoderSelector = new CommandResponseMessageEncoderSelector(); - var response = (CommandResponseMessage)await connection.ReceiveMessageAsync(responseTo, encoderSelector, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); - response = await AutoDecryptFieldsIfNecessaryAsync(response, cancellationToken).ConfigureAwait(false); + var response = (CommandResponseMessage)await connection.ReceiveMessageAsync(operationContext, responseTo, encoderSelector, _messageEncoderSettings).ConfigureAwait(false); + // TODO: CSOT: Propagate operationContext into Encryption + response = await AutoDecryptFieldsIfNecessaryAsync(response, operationContext.CancellationToken).ConfigureAwait(false); var result = ProcessResponse(connection.ConnectionId, response.WrappedMessage); SaveResponseInfo(response); return result; @@ -608,6 +625,16 @@ private bool ShouldAddTransientTransactionError(MongoException exception) return false; } + private void ThrowIfRemainingTimeoutLessThenRoundTripTime(OperationContext operationContext) + { + if (operationContext.RemainingTimeout == Timeout.InfiniteTimeSpan || operationContext.RemainingTimeout > _serverRoundTripTime) + { + return; + } + + throw new TimeoutException(); + } + private MongoException WrapNotSupportedRetryableWriteException(MongoCommandException exception) { const string friendlyErrorMessage = diff --git a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs index 6d18ada9747..fe8d408b002 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/CommandUsingQueryMessageWireProtocol.cs @@ -116,11 +116,11 @@ private QueryMessage CreateMessage(ConnectionDescription connectionDescription, #pragma warning restore 618 } - public TCommandResult Execute(IConnection connection, CancellationToken cancellationToken) + public TCommandResult Execute(OperationContext operationContext, IConnection connection) { bool messageContainsSessionId; var message = CreateMessage(connection.Description, out messageContainsSessionId); - connection.SendMessage(message, _messageEncoderSettings, cancellationToken); + connection.SendMessage(operationContext, message, _messageEncoderSettings); if (messageContainsSessionId) { _session.WasUsed(); @@ -129,20 +129,20 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella switch (message.ResponseHandling) { case CommandResponseHandling.Ignore: - IgnoreResponse(connection, message, cancellationToken); + IgnoreResponse(operationContext, connection, message); return default(TCommandResult); default: var encoderSelector = new ReplyMessageEncoderSelector(RawBsonDocumentSerializer.Instance); - var reply = connection.ReceiveMessage(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken); + var reply = connection.ReceiveMessage(operationContext, message.RequestId, encoderSelector, _messageEncoderSettings); return ProcessReply(connection.ConnectionId, (ReplyMessage)reply); } } - public async Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken) + public async Task ExecuteAsync(OperationContext operationContext, IConnection connection) { bool messageContainsSessionId; var message = CreateMessage(connection.Description, out messageContainsSessionId); - await connection.SendMessageAsync(message, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); + await connection.SendMessageAsync(operationContext, message, _messageEncoderSettings).ConfigureAwait(false); if (messageContainsSessionId) { _session.WasUsed(); @@ -151,11 +151,11 @@ public async Task ExecuteAsync(IConnection connection, Cancellat switch (message.ResponseHandling) { case CommandResponseHandling.Ignore: - IgnoreResponse(connection, message, cancellationToken); + IgnoreResponse(operationContext, connection, message); return default(TCommandResult); default: var encoderSelector = new ReplyMessageEncoderSelector(RawBsonDocumentSerializer.Instance); - var reply = await connection.ReceiveMessageAsync(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken).ConfigureAwait(false); + var reply = await connection.ReceiveMessageAsync(operationContext, message.RequestId, encoderSelector, _messageEncoderSettings).ConfigureAwait(false); return ProcessReply(connection.ConnectionId, (ReplyMessage)reply); } } @@ -230,10 +230,10 @@ private IBsonSerializer CreateSizeLimitingPayloadSerializer(Type1CommandMessageS return (IBsonSerializer)constructorInfo.Invoke(new object[] { itemSerializer, itemElementNameValidator, maxBatchCount, maxItemSize, maxBatchSize }); } - private void IgnoreResponse(IConnection connection, QueryMessage message, CancellationToken cancellationToken) + private void IgnoreResponse(OperationContext operationContext, IConnection connection, QueryMessage message) { var encoderSelector = new ReplyMessageEncoderSelector(IgnoredReplySerializer.Instance); - connection.ReceiveMessageAsync(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken).IgnoreExceptions(); + connection.ReceiveMessageAsync(operationContext, message.RequestId, encoderSelector, _messageEncoderSettings).IgnoreExceptions(); } [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2202:Do not dispose objects multiple times")] diff --git a/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs index 3844f13aec0..48e1bea9887 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/CommandWireProtocol.cs @@ -16,7 +16,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Bson.IO; @@ -45,6 +44,7 @@ internal sealed class CommandWireProtocol : IWireProtocol _resultSerializer; private readonly ServerApi _serverApi; + private readonly TimeSpan _serverRoundTripTime; private readonly ICoreSession _session; // constructors @@ -86,7 +86,8 @@ public CommandWireProtocol( commandResponseHandling, resultSerializer, messageEncoderSettings, - serverApi) + serverApi, + serverRoundTripTime: TimeSpan.Zero) { } @@ -102,7 +103,8 @@ public CommandWireProtocol( CommandResponseHandling responseHandling, IBsonSerializer resultSerializer, MessageEncoderSettings messageEncoderSettings, - ServerApi serverApi) + ServerApi serverApi, + TimeSpan serverRoundTripTime) { if (responseHandling != CommandResponseHandling.Return && responseHandling != CommandResponseHandling.NoResponseExpected && @@ -123,22 +125,23 @@ public CommandWireProtocol( _messageEncoderSettings = messageEncoderSettings; _postWriteAction = postWriteAction; // can be null _serverApi = serverApi; // can be null + _serverRoundTripTime = serverRoundTripTime; } // public properties public bool MoreToCome => _cachedWireProtocol?.MoreToCome ?? false; // public methods - public TCommandResult Execute(IConnection connection, CancellationToken cancellationToken) + public TCommandResult Execute(OperationContext operationContext, IConnection connection) { var supportedProtocol = CreateSupportedWireProtocol(connection); - return supportedProtocol.Execute(connection, cancellationToken); + return supportedProtocol.Execute(operationContext, connection); } - public Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken) + public Task ExecuteAsync(OperationContext operationContext, IConnection connection) { var supportedProtocol = CreateSupportedWireProtocol(connection); - return supportedProtocol.ExecuteAsync(connection, cancellationToken); + return supportedProtocol.ExecuteAsync(operationContext, connection); } // private methods @@ -156,7 +159,8 @@ private IWireProtocol CreateCommandUsingCommandMessageWireProtoc _resultSerializer, _messageEncoderSettings, _postWriteAction, - _serverApi); + _serverApi, + _serverRoundTripTime); } private IWireProtocol CreateCommandUsingQueryMessageWireProtocol() diff --git a/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs b/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs index 025e7dd5acd..dee26e7dd87 100644 --- a/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs +++ b/src/MongoDB.Driver/Core/WireProtocol/IWireProtocol.cs @@ -13,7 +13,6 @@ * limitations under the License. */ -using System.Threading; using System.Threading.Tasks; using MongoDB.Driver.Core.Connections; @@ -22,14 +21,14 @@ namespace MongoDB.Driver.Core.WireProtocol internal interface IWireProtocol { bool MoreToCome { get; } - void Execute(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); - Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); + void Execute(OperationContext operationContext, IConnection connection); + Task ExecuteAsync(OperationContext operationContext, IConnection connection); } internal interface IWireProtocol { bool MoreToCome { get; } - TResult Execute(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); - Task ExecuteAsync(IConnection connection, CancellationToken cancellationToken = default(CancellationToken)); + TResult Execute(OperationContext operationContext, IConnection connection); + Task ExecuteAsync(OperationContext operationContext, IConnection connection); } } diff --git a/src/MongoDB.Driver/GridFS/GridFSBucket.cs b/src/MongoDB.Driver/GridFS/GridFSBucket.cs index 7bd029044b8..df0c27ae924 100644 --- a/src/MongoDB.Driver/GridFS/GridFSBucket.cs +++ b/src/MongoDB.Driver/GridFS/GridFSBucket.cs @@ -1003,8 +1003,8 @@ private IReadBindingHandle GetSingleServerReadBinding(OperationContext operation { var readPreference = _options.ReadPreference ?? _database.Settings.ReadPreference; var selector = new ReadPreferenceServerSelector(readPreference); - var server = _cluster.SelectServer(operationContext, selector); - var binding = new SingleServerReadBinding(server, readPreference, NoCoreSession.NewHandle()); + var (server, serverRoundTripTime) = _cluster.SelectServer(operationContext, selector); + var binding = new SingleServerReadBinding(server, serverRoundTripTime, readPreference, NoCoreSession.NewHandle()); return new ReadBindingHandle(binding); } @@ -1012,24 +1012,24 @@ private async Task GetSingleServerReadBindingAsync(Operation { var readPreference = _options.ReadPreference ?? _database.Settings.ReadPreference; var selector = new ReadPreferenceServerSelector(readPreference); - var server = await _cluster.SelectServerAsync(operationContext, selector).ConfigureAwait(false); - var binding = new SingleServerReadBinding(server, readPreference, NoCoreSession.NewHandle()); + var (server, serverRoundTripTime) = await _cluster.SelectServerAsync(operationContext, selector).ConfigureAwait(false); + var binding = new SingleServerReadBinding(server, serverRoundTripTime, readPreference, NoCoreSession.NewHandle()); return new ReadBindingHandle(binding); } private IReadWriteBindingHandle GetSingleServerReadWriteBinding(OperationContext operationContext) { var selector = WritableServerSelector.Instance; - var server = _cluster.SelectServer(operationContext, selector); - var binding = new SingleServerReadWriteBinding(server, NoCoreSession.NewHandle()); + var (server, serverRoundTripTime) = _cluster.SelectServer(operationContext, selector); + var binding = new SingleServerReadWriteBinding(server, serverRoundTripTime, NoCoreSession.NewHandle()); return new ReadWriteBindingHandle(binding); } private async Task GetSingleServerReadWriteBindingAsync(OperationContext operationContext) { var selector = WritableServerSelector.Instance; - var server = await _cluster.SelectServerAsync(operationContext, selector).ConfigureAwait(false); - var binding = new SingleServerReadWriteBinding(server, NoCoreSession.NewHandle()); + var (server, serverRoundTripTime) = await _cluster.SelectServerAsync(operationContext, selector).ConfigureAwait(false); + var binding = new SingleServerReadWriteBinding(server, serverRoundTripTime, NoCoreSession.NewHandle()); return new ReadWriteBindingHandle(binding); } diff --git a/src/MongoDB.Driver/OperationContext.cs b/src/MongoDB.Driver/OperationContext.cs index c0ccd67919f..5e9f4c5164e 100644 --- a/src/MongoDB.Driver/OperationContext.cs +++ b/src/MongoDB.Driver/OperationContext.cs @@ -21,11 +21,14 @@ namespace MongoDB.Driver { - internal sealed class OperationContext + internal sealed class OperationContext : IDisposable { // TODO: this static field is temporary here and will be removed in a future PRs in scope of CSOT. public static readonly OperationContext NoTimeout = new(System.Threading.Timeout.InfiniteTimeSpan, CancellationToken.None); + private CancellationTokenSource _remainingTimeoutCancellationTokenSource; + private CancellationTokenSource _combinedCancellationTokenSource; + public OperationContext(TimeSpan timeout, CancellationToken cancellationToken) : this(Stopwatch.StartNew(), timeout, cancellationToken) { @@ -61,21 +64,39 @@ public TimeSpan RemainingTimeout } } + [Obsolete("Do not use this property, unless it's needed to avoid breaking changes in public API")] + public CancellationToken CombinedCancellationToken + { + get + { + if (_combinedCancellationTokenSource != null) + { + return _combinedCancellationTokenSource.Token; + } + + if (RemainingTimeout == System.Threading.Timeout.InfiniteTimeSpan) + { + return CancellationToken; + } + + _remainingTimeoutCancellationTokenSource = new CancellationTokenSource(RemainingTimeout); + _combinedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken, _remainingTimeoutCancellationTokenSource.Token); + return _combinedCancellationTokenSource.Token; + } + } private Stopwatch Stopwatch { get; } public TimeSpan Timeout { get; } - public bool IsTimedOut() + public void Dispose() { - var remainingTimeout = RemainingTimeout; - if (remainingTimeout == System.Threading.Timeout.InfiniteTimeSpan) - { - return false; - } - - return remainingTimeout == TimeSpan.Zero; + _remainingTimeoutCancellationTokenSource?.Dispose(); + _combinedCancellationTokenSource?.Dispose(); } + public bool IsTimedOut() + => RemainingTimeout == TimeSpan.Zero; + public void ThrowIfTimedOutOrCanceled() { CancellationToken.ThrowIfCancellationRequested(); @@ -94,7 +115,7 @@ public void WaitTask(Task task) } var timeout = RemainingTimeout; - if (timeout != System.Threading.Timeout.InfiniteTimeSpan && timeout < TimeSpan.Zero) + if (timeout == TimeSpan.Zero) { throw new TimeoutException(); } @@ -127,7 +148,7 @@ public async Task WaitTaskAsync(Task task) } var timeout = RemainingTimeout; - if (timeout != System.Threading.Timeout.InfiniteTimeSpan && timeout < TimeSpan.Zero) + if (timeout == TimeSpan.Zero) { throw new TimeoutException(); } @@ -159,7 +180,7 @@ public OperationContext WithTimeout(TimeSpan timeout) return new OperationContext(timeout, CancellationToken) { - ParentContext = this + ParentContext = this, }; } } diff --git a/src/MongoDB.Driver/OperationContextExtensions.cs b/src/MongoDB.Driver/OperationContextExtensions.cs new file mode 100644 index 00000000000..e177bfbf235 --- /dev/null +++ b/src/MongoDB.Driver/OperationContextExtensions.cs @@ -0,0 +1,32 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Threading; + +namespace MongoDB.Driver +{ + internal static class OperationContextExtensions + { + public static bool IsOperationTimeoutConfigured(this OperationContext operationContext) + { + var rootContext = operationContext.GetRootOperationContext(); + return rootContext.Timeout != Timeout.InfiniteTimeSpan; + } + + private static OperationContext GetRootOperationContext(this OperationContext operationContext) + => operationContext.ParentContext == null ? operationContext : GetRootOperationContext(operationContext.ParentContext); + } +} + diff --git a/src/MongoDB.Driver/OperationExecutor.cs b/src/MongoDB.Driver/OperationExecutor.cs index 7025097de70..84ddcf287bd 100644 --- a/src/MongoDB.Driver/OperationExecutor.cs +++ b/src/MongoDB.Driver/OperationExecutor.cs @@ -50,7 +50,7 @@ public TResult ExecuteReadOperation( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); var readPreference = options.GetEffectiveReadPreference(session); using var binding = CreateReadBinding(session, readPreference, allowChannelPinning); return operation.Execute(operationContext, binding); @@ -68,7 +68,7 @@ public async Task ExecuteReadOperationAsync( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); var readPreference = options.GetEffectiveReadPreference(session); using var binding = CreateReadBinding(session, readPreference, allowChannelPinning); return await operation.ExecuteAsync(operationContext, binding).ConfigureAwait(false); @@ -86,7 +86,7 @@ public TResult ExecuteWriteOperation( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); using var binding = CreateReadWriteBinding(session, allowChannelPinning); return operation.Execute(operationContext, binding); } @@ -103,7 +103,7 @@ public async Task ExecuteWriteOperationAsync( Ensure.IsNotNull(session, nameof(session)); ThrowIfDisposed(); - var operationContext = options.ToOperationContext(cancellationToken); + using var operationContext = options.ToOperationContext(cancellationToken); using var binding = CreateReadWriteBinding(session, allowChannelPinning); return await operation.ExecuteAsync(operationContext, binding).ConfigureAwait(false); } diff --git a/tests/MongoDB.Driver.TestHelpers/Core/FailPoint.cs b/tests/MongoDB.Driver.TestHelpers/Core/FailPoint.cs index 0122eb8447b..41806068f01 100644 --- a/tests/MongoDB.Driver.TestHelpers/Core/FailPoint.cs +++ b/tests/MongoDB.Driver.TestHelpers/Core/FailPoint.cs @@ -44,12 +44,12 @@ internal sealed class FailPoint : IDisposable public static FailPoint Configure(IClusterInternal cluster, ICoreSessionHandle session, BsonDocument command, bool? withAsync = null) { var server = GetWriteableServer(cluster); - return FailPoint.Configure(server, session, command, withAsync); + return FailPoint.Configure(server.Server, server.RoundTripTime, session, command, withAsync); } - public static FailPoint Configure(IServer server, ICoreSessionHandle session, BsonDocument command, bool? withAsync = null) + public static FailPoint Configure(IServer server, TimeSpan serverRoundTripTime, ICoreSessionHandle session, BsonDocument command, bool? withAsync = null) { - var binding = new SingleServerReadWriteBinding(server, session.Fork()); + var binding = new SingleServerReadWriteBinding(server, serverRoundTripTime, session.Fork()); if (withAsync.HasValue) { MakeFailPointApplicationNameTestableIfConfigured(command, withAsync.Value); @@ -85,7 +85,7 @@ public static FailPoint ConfigureAlwaysOn(IClusterInternal cluster, ICoreSession public static string DecorateApplicationName(string applicationName, bool async) => $"{applicationName}{ApplicationNameTestableSuffix}{async}"; // private static methods - private static IServer GetWriteableServer(IClusterInternal cluster) + private static (IServer Server, TimeSpan RoundTripTime) GetWriteableServer(IClusterInternal cluster) { var selector = WritableServerSelector.Instance; return cluster.SelectServer(OperationContext.NoTimeout, selector); diff --git a/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs b/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs index 6ee9c5ad81a..f8298aaa62d 100644 --- a/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs +++ b/tests/MongoDB.Driver.TestHelpers/Core/MockConnection.cs @@ -184,7 +184,7 @@ public List GetSentMessages() return _sentMessages; } - public void Open(CancellationToken cancellationToken) + public void Open(OperationContext operationContext) { _openingEventHandler?.Invoke(new ConnectionOpeningEvent(_connectionId, _connectionSettings, null)); @@ -196,7 +196,7 @@ public void Open(CancellationToken cancellationToken) _openedEventHandler?.Invoke(new ConnectionOpenedEvent(_connectionId, _connectionSettings, TimeSpan.FromTicks(1), null)); } - public Task OpenAsync(CancellationToken cancellationToken) + public Task OpenAsync(OperationContext operationContext) { _openingEventHandler?.Invoke(new ConnectionOpeningEvent(_connectionId, _connectionSettings, null)); @@ -220,24 +220,24 @@ public async Task ReauthenticateAsync(CancellationToken cancellationToken) await _replyActions.Dequeue().GetEffectiveMessageAsync().ConfigureAwait(false); } - public ResponseMessage ReceiveMessage(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public ResponseMessage ReceiveMessage(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { var action = _replyActions.Dequeue(); return (ResponseMessage)action.GetEffectiveMessage(); } - public async Task ReceiveMessageAsync(int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public async Task ReceiveMessageAsync(OperationContext operationContext, int responseTo, IMessageEncoderSelector encoderSelector, MessageEncoderSettings messageEncoderSettings) { var action = _replyActions.Dequeue(); return (ResponseMessage)await action.GetEffectiveMessageAsync().ConfigureAwait(false); } - public void SendMessage(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public void SendMessage(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { _sentMessages.Add(message); } - public Task SendMessageAsync(RequestMessage message, MessageEncoderSettings messageEncoderSettings, CancellationToken cancellationToken) + public Task SendMessageAsync(OperationContext operationContext, RequestMessage message, MessageEncoderSettings messageEncoderSettings) { _sentMessages.Add(message); return Task.CompletedTask; diff --git a/tests/MongoDB.Driver.Tests/AuthenticationTests.cs b/tests/MongoDB.Driver.Tests/AuthenticationTests.cs index 0ad423e16eb..ae0b8840e45 100644 --- a/tests/MongoDB.Driver.Tests/AuthenticationTests.cs +++ b/tests/MongoDB.Driver.Tests/AuthenticationTests.cs @@ -16,7 +16,6 @@ using System; using System.Linq; using System.Security.Cryptography.X509Certificates; -using System.Threading; using FluentAssertions; using MongoDB.Bson; using MongoDB.Driver.Core.Clusters.ServerSelectors; @@ -338,9 +337,9 @@ private void AssertAuthenticationSucceeds( speculativeAuthenticatationShouldSucceedIfPossible) { var serverSelector = new ReadPreferenceServerSelector(settings.ReadPreference); - var server = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, serverSelector); - var channel = server.GetChannel(OperationContext.NoTimeout); - var helloResult = channel.ConnectionDescription.HelloResult; + var (server, _) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, serverSelector); + var channel = server.GetConnection(OperationContext.NoTimeout); + var helloResult = channel.Description.HelloResult; helloResult.SpeculativeAuthenticate.Should().NotBeNull(); } } diff --git a/tests/MongoDB.Driver.Tests/ClusterTests.cs b/tests/MongoDB.Driver.Tests/ClusterTests.cs index e1a453acefa..b849b6c697a 100644 --- a/tests/MongoDB.Driver.Tests/ClusterTests.cs +++ b/tests/MongoDB.Driver.Tests/ClusterTests.cs @@ -25,6 +25,7 @@ using MongoDB.Driver.Core.Bindings; using MongoDB.Driver.Core.Clusters; using MongoDB.Driver.Core.Clusters.ServerSelectors; +using MongoDB.Driver.Core.Connections; using MongoDB.Driver.Core.Events; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.Servers; @@ -87,21 +88,20 @@ public void SelectServer_loadbalancing_prose_test([Values(false, true)] bool asy var eventCapturer = CreateEventCapturer(); using (var client = CreateMongoClient(eventCapturer, applicationName)) { - var slowServer = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); - var fastServer = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new DelegateServerSelector((_, servers) => servers.Where(s => s.ServerId != slowServer.ServerId))); - - using var failPoint = FailPoint.Configure(slowServer, NoCoreSession.NewHandle(), failCommand, async); + var (slowServer, slowServerRtt) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + var (fastServer, _) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new DelegateServerSelector((_, servers) => servers.Where(s => s.ServerId != slowServer.ServerId))); + using var failPoint = FailPoint.Configure(slowServer, slowServerRtt, NoCoreSession.NewHandle(), failCommand, async); var database = client.GetDatabase(_databaseName); CreateCollection(); var collection = database.GetCollection(_collectionName); // warm up connections - var channels = new ConcurrentBag(); + var channels = new ConcurrentBag(); ThreadingUtilities.ExecuteOnNewThreads(threadsCount, i => { - channels.Add(slowServer.GetChannel(OperationContext.NoTimeout)); - channels.Add(fastServer.GetChannel(OperationContext.NoTimeout)); + channels.Add(slowServer.GetConnection(OperationContext.NoTimeout)); + channels.Add(fastServer.GetConnection(OperationContext.NoTimeout)); }); foreach (var channel in channels) diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelChannelSourceTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelChannelSourceTests.cs index 650f074169d..357006fd72f 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelChannelSourceTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelChannelSourceTests.cs @@ -14,9 +14,10 @@ */ using System; -using System.Reflection; +using System.Collections.Generic; using System.Threading.Tasks; using FluentAssertions; +using MongoDB.Bson.TestHelpers; using MongoDB.TestHelpers.XunitExtensions; using MongoDB.Driver.Core.Servers; using Moq; @@ -30,14 +31,16 @@ public class ChannelChannelSourceTests public void constructor_should_initialize_instance() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromSeconds(42); var channel = new Mock().Object; var session = new Mock().Object; - var result = new ChannelChannelSource(server, channel, session); + var result = new ChannelChannelSource(server, roundTripTime, channel, session); result._channel().Should().BeSameAs(channel); result._disposed().Should().BeFalse(); result.Server.Should().BeSameAs(server); + result.RoundTripTime.Should().Be(roundTripTime); result.Session.Should().BeSameAs(session); } @@ -47,19 +50,39 @@ public void constructor_should_throw_when_server_is_null() var channel = new Mock().Object; var session = new Mock().Object; - var exception = Record.Exception(() => new ChannelChannelSource(null, channel, session)); + var exception = Record.Exception(() => new ChannelChannelSource(null, TimeSpan.FromSeconds(42), channel, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("server"); } + [Theory] + [MemberData(nameof(InvalidRoundTripCases))] + public void constructor_should_throw_when_round_trip_time_is_invalid(TimeSpan roundTripTime) + { + var server = new Mock().Object; + var channel = new Mock().Object; + var session = new Mock().Object; + + var exception = Record.Exception(() => new ChannelChannelSource(server, roundTripTime, channel, session)); + + var e = exception.Should().BeOfType().Subject; + e.ParamName.Should().Be("roundTripTime"); + } + + public static IEnumerable InvalidRoundTripCases = + [ + [TimeSpan.Zero], + [TimeSpan.FromMilliseconds(-5)] + ]; + [Fact] public void constructor_should_throw_when_channel_is_null() { var server = new Mock().Object; var session = new Mock().Object; - var exception = Record.Exception(() => new ChannelChannelSource(server, null, session)); + var exception = Record.Exception(() => new ChannelChannelSource(server, TimeSpan.FromSeconds(42), null, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("channel"); @@ -71,7 +94,7 @@ public void constructor_should_throw_when_session_is_null() var server = new Mock().Object; var channel = new Mock().Object; - var exception = Record.Exception(() => new ChannelChannelSource(server, channel, null)); + var exception = Record.Exception(() => new ChannelChannelSource(server, TimeSpan.FromSeconds(42), channel, null)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("session"); @@ -183,6 +206,7 @@ private ChannelChannelSource CreateSubject(IServer server = null, IChannelHandle { return new ChannelChannelSource( server ?? new Mock().Object, + TimeSpan.FromSeconds(42), channel ?? new Mock().Object, session ?? new Mock().Object); } @@ -191,15 +215,9 @@ private ChannelChannelSource CreateSubject(IServer server = null, IChannelHandle internal static class ChannelChannelSourceReflector { public static IChannelHandle _channel(this ChannelChannelSource obj) - { - var fieldInfo = typeof(ChannelChannelSource).GetField("_channel", BindingFlags.NonPublic | BindingFlags.Instance); - return (IChannelHandle)fieldInfo.GetValue(obj); - } + => (IChannelHandle)Reflector.GetFieldValue(obj, "_channel"); public static bool _disposed(this ChannelChannelSource obj) - { - var fieldInfo = typeof(ChannelChannelSource).GetField("_disposed", BindingFlags.NonPublic | BindingFlags.Instance); - return (bool)fieldInfo.GetValue(obj); - } + => (bool)Reflector.GetFieldValue(obj, "_disposed"); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadBindingTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadBindingTests.cs index 1d6f317e611..0a3a67310d3 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadBindingTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadBindingTests.cs @@ -14,9 +14,10 @@ */ using System; -using System.Reflection; +using System.Collections.Generic; using System.Threading.Tasks; using FluentAssertions; +using MongoDB.Bson.TestHelpers; using MongoDB.TestHelpers.XunitExtensions; using MongoDB.Driver.Core.Servers; using Moq; @@ -30,11 +31,12 @@ public class ChannelReadBindingTests public void constructor_should_initialize_instance() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var channel = new Mock().Object; var readPreference = ReadPreference.Primary; var session = new Mock().Object; - var result = new ChannelReadBinding(server, channel, readPreference, session); + var result = new ChannelReadBinding(server, roundTripTime, channel, readPreference, session); result._channel().Should().BeSameAs(channel); result._disposed().Should().BeFalse(); @@ -46,24 +48,47 @@ public void constructor_should_initialize_instance() [Fact] public void constructor_should_throw_when_server_is_null() { + var roundTripTime = TimeSpan.FromMilliseconds(42); var channel = new Mock().Object; var readPreference = ReadPreference.Primary; var session = new Mock().Object; - var exception = Record.Exception(() => new ChannelReadBinding(null, channel, readPreference, session)); + var exception = Record.Exception(() => new ChannelReadBinding(null, roundTripTime, channel, readPreference, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("server"); } + [Theory] + [MemberData(nameof(InvalidRoundTripCases))] + public void constructor_should_throw_when_round_trip_time_is_invalid(TimeSpan roundTripTime) + { + var server = new Mock().Object; + var channel = new Mock().Object; + var readPreference = ReadPreference.Primary; + var session = new Mock().Object; + + var exception = Record.Exception(() => new ChannelReadBinding(server, roundTripTime, channel, readPreference, session)); + + var e = exception.Should().BeOfType().Subject; + e.ParamName.Should().Be("roundTripTime"); + } + + public static IEnumerable InvalidRoundTripCases = + [ + [TimeSpan.Zero], + [TimeSpan.FromMilliseconds(-5)] + ]; + [Fact] public void constructor_should_throw_when_channel_is_null() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var readPreference = ReadPreference.Primary; var session = new Mock().Object; - var exception = Record.Exception(() => new ChannelReadBinding(server, null, readPreference, session)); + var exception = Record.Exception(() => new ChannelReadBinding(server, roundTripTime, null, readPreference, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("channel"); @@ -73,10 +98,11 @@ public void constructor_should_throw_when_channel_is_null() public void constructor_should_throw_when_readPreference_is_null() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var channel = new Mock().Object; var session = new Mock().Object; - var exception = Record.Exception(() => new ChannelReadBinding(server, channel, null, session)); + var exception = Record.Exception(() => new ChannelReadBinding(server, roundTripTime, channel, null, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("readPreference"); @@ -86,10 +112,11 @@ public void constructor_should_throw_when_readPreference_is_null() public void constructor_should_throw_when_session_is_null() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var channel = new Mock().Object; var readPreference = ReadPreference.Primary; - var exception = Record.Exception(() => new ChannelReadBinding(server, channel, readPreference, null)); + var exception = Record.Exception(() => new ChannelReadBinding(server, roundTripTime, channel, readPreference, null)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("session"); @@ -150,9 +177,11 @@ public void Dispose_can_be_called_more_than_once() public async Task GetReadChannelSource_should_return_expected_result( [Values(false, true)] bool async) { + var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(5); var mockChannel = new Mock(); var mockSession = new Mock(); - var subject = CreateSubject(channel: mockChannel.Object, session: mockSession.Object); + var subject = CreateSubject(server, roundTripTime, channel: mockChannel.Object, session: mockSession.Object); var forkedChannel = new Mock().Object; var forkedSession = new Mock().Object; @@ -168,6 +197,8 @@ await subject.GetReadChannelSourceAsync(OperationContext.NoTimeout) : var newSource = referenceCounted.Instance.Should().BeOfType().Subject; newSource._channel().Should().Be(forkedChannel); newSource.Session.Should().Be(forkedSession); + newSource.Server.Should().Be(server); + newSource.RoundTripTime.Should().Be(roundTripTime); } [Theory] @@ -175,7 +206,9 @@ await subject.GetReadChannelSourceAsync(OperationContext.NoTimeout) : public async Task GetReadChannelSource_should_throw_when_disposed( [Values(false, true)] bool async) { - var subject = CreateDisposedSubject(); + var subject = CreateSubject(); + subject.Dispose(); + var exception = async ? await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationContext.NoTimeout)) : Record.Exception(() => subject.GetReadChannelSource(OperationContext.NoTimeout)); @@ -185,17 +218,11 @@ await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationCon } // private methods - private ChannelReadBinding CreateDisposedSubject() - { - var subject = CreateSubject(); - subject.Dispose(); - return subject; - } - - private ChannelReadBinding CreateSubject(IServer server = null, IChannelHandle channel = null, ReadPreference readPreference = null, ICoreSessionHandle session = null) + private ChannelReadBinding CreateSubject(IServer server = null, TimeSpan? roundTripTime = null, IChannelHandle channel = null, ReadPreference readPreference = null, ICoreSessionHandle session = null) { return new ChannelReadBinding( server ?? new Mock().Object, + roundTripTime ?? TimeSpan.FromMilliseconds(42), channel ?? new Mock().Object, readPreference ?? ReadPreference.Primary, session ?? new Mock().Object); @@ -205,21 +232,12 @@ private ChannelReadBinding CreateSubject(IServer server = null, IChannelHandle c internal static class ChannelReadBindingReflector { public static IChannelHandle _channel(this ChannelReadBinding obj) - { - var fieldInfo = typeof(ChannelReadBinding).GetField("_channel", BindingFlags.NonPublic | BindingFlags.Instance); - return (IChannelHandle)fieldInfo.GetValue(obj); - } + => (IChannelHandle)Reflector.GetFieldValue(obj, "_channel"); public static bool _disposed(this ChannelReadBinding obj) - { - var fieldInfo = typeof(ChannelReadBinding).GetField("_disposed", BindingFlags.NonPublic | BindingFlags.Instance); - return (bool)fieldInfo.GetValue(obj); - } + => (bool)Reflector.GetFieldValue(obj, "_disposed"); public static IServer _server(this ChannelReadBinding obj) - { - var fieldInfo = typeof(ChannelReadBinding).GetField("_server", BindingFlags.NonPublic | BindingFlags.Instance); - return (IServer)fieldInfo.GetValue(obj); - } + => (IServer)Reflector.GetFieldValue(obj, "_server"); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadWriteBindingTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadWriteBindingTests.cs index acc8fdfa4ad..dd0588baec0 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadWriteBindingTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/ChannelReadWriteBindingTests.cs @@ -14,9 +14,10 @@ */ using System; -using System.Reflection; +using System.Collections.Generic; using System.Threading.Tasks; using FluentAssertions; +using MongoDB.Bson.TestHelpers; using MongoDB.TestHelpers.XunitExtensions; using MongoDB.Driver.Core.Servers; using Moq; @@ -30,10 +31,11 @@ public class ChannelReadWriteBindingTests public void constructor_should_initialize_instance() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var channel = new Mock().Object; var session = new Mock().Object; - var result = new ChannelReadWriteBinding(server, channel, session); + var result = new ChannelReadWriteBinding(server, roundTripTime, channel, session); result._channel().Should().BeSameAs(channel); result._disposed().Should().BeFalse(); @@ -44,22 +46,44 @@ public void constructor_should_initialize_instance() [Fact] public void constructor_should_throw_when_server_is_null() { + var roundTripTime = TimeSpan.FromMilliseconds(42); var channel = new Mock().Object; var session = new Mock().Object; - var exception = Record.Exception(() => new ChannelReadWriteBinding(null, channel, session)); + var exception = Record.Exception(() => new ChannelReadWriteBinding(null, roundTripTime, channel, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("server"); } + [Theory] + [MemberData(nameof(InvalidRoundTripCases))] + public void constructor_should_throw_when_roundTripTime_is_invalid(TimeSpan roundTripTime) + { + var server = new Mock().Object; + var channel = new Mock().Object; + var session = new Mock().Object; + + var exception = Record.Exception(() => new ChannelReadWriteBinding(server, roundTripTime, channel, session)); + + var e = exception.Should().BeOfType().Subject; + e.ParamName.Should().Be("roundTripTime"); + } + + public static IEnumerable InvalidRoundTripCases = + [ + [TimeSpan.Zero], + [TimeSpan.FromMilliseconds(-5)] + ]; + [Fact] public void constructor_should_throw_when_channel_is_null() { + var roundTripTime = TimeSpan.FromMilliseconds(42); var server = new Mock().Object; var session = new Mock().Object; - var exception = Record.Exception(() => new ChannelReadWriteBinding(server, null, session)); + var exception = Record.Exception(() => new ChannelReadWriteBinding(server, roundTripTime, null, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("channel"); @@ -68,10 +92,11 @@ public void constructor_should_throw_when_channel_is_null() [Fact] public void constructor_should_throw_when_session_is_null() { + var roundTripTime = TimeSpan.FromMilliseconds(42); var server = new Mock().Object; var channel = new Mock().Object; - var exception = Record.Exception(() => new ChannelReadWriteBinding(server, channel, null)); + var exception = Record.Exception(() => new ChannelReadWriteBinding(server, roundTripTime, channel, null)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("session"); @@ -156,9 +181,11 @@ await subject.GetReadChannelSourceAsync(OperationContext.NoTimeout) : public async Task GetWriteChannelSource_should_return_expected_result( [Values(false, true)] bool async) { + var server = Mock.Of(); + var roundTripTime = TimeSpan.FromMilliseconds(5); var mockChannel = new Mock(); var mockSession = new Mock(); - var subject = CreateSubject(channel: mockChannel.Object, session: mockSession.Object); + var subject = CreateSubject(server, roundTripTime, mockChannel.Object, mockSession.Object); var forkedChannel = new Mock().Object; var forkedSession = new Mock().Object; @@ -174,6 +201,8 @@ await subject.GetWriteChannelSourceAsync(OperationContext.NoTimeout) : var newSource = referenceCounted.Instance.Should().BeOfType().Subject; newSource._channel().Should().Be(forkedChannel); newSource.Session.Should().Be(forkedSession); + newSource.Server.Should().Be(server); + newSource.RoundTripTime.Should().Be(roundTripTime); } [Theory] @@ -212,10 +241,11 @@ private ChannelReadWriteBinding CreateDisposedSubject() return subject; } - private ChannelReadWriteBinding CreateSubject(IServer server = null, IChannelHandle channel = null, ICoreSessionHandle session = null) + private ChannelReadWriteBinding CreateSubject(IServer server = null, TimeSpan? roundTripTime = null, IChannelHandle channel = null, ICoreSessionHandle session = null) { return new ChannelReadWriteBinding( server ?? new Mock().Object, + roundTripTime ?? TimeSpan.FromSeconds(42), channel ?? new Mock().Object, session ?? new Mock().Object); } @@ -224,21 +254,12 @@ private ChannelReadWriteBinding CreateSubject(IServer server = null, IChannelHan internal static class ChannelReadWriteBindingReflector { public static IChannelHandle _channel(this ChannelReadWriteBinding obj) - { - var fieldInfo = typeof(ChannelReadWriteBinding).GetField("_channel", BindingFlags.NonPublic | BindingFlags.Instance); - return (IChannelHandle)fieldInfo.GetValue(obj); - } + => (IChannelHandle)Reflector.GetFieldValue(obj, "_channel"); public static bool _disposed(this ChannelReadWriteBinding obj) - { - var fieldInfo = typeof(ChannelReadWriteBinding).GetField("_disposed", BindingFlags.NonPublic | BindingFlags.Instance); - return (bool)fieldInfo.GetValue(obj); - } + => (bool)Reflector.GetFieldValue(obj, "_disposed"); public static IServer _server(this ChannelReadWriteBinding obj) - { - var fieldInfo = typeof(ChannelReadWriteBinding).GetField("_server", BindingFlags.NonPublic | BindingFlags.Instance); - return (IServer)fieldInfo.GetValue(obj); - } + => (IServer)Reflector.GetFieldValue(obj, "_server"); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/ReadPreferenceBindingTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/ReadPreferenceBindingTests.cs index c3486f158e5..ad7e79f61a6 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/ReadPreferenceBindingTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/ReadPreferenceBindingTests.cs @@ -15,9 +15,9 @@ using System; using System.Net; -using System.Reflection; using System.Threading.Tasks; using FluentAssertions; +using MongoDB.Bson.TestHelpers; using MongoDB.Driver.Core.Clusters; using MongoDB.Driver.Core.Clusters.ServerSelectors; using MongoDB.Driver.Core.Misc; @@ -95,7 +95,7 @@ public async Task GetReadChannelSource_should_use_a_read_preference_server_selec bool async) { var subject = new ReadPreferenceBinding(_mockCluster.Object, ReadPreference.Primary, NoCoreSession.NewHandle()); - var selectedServer = new Mock().Object; + var selectedServer = (new Mock().Object, TimeSpan.FromMilliseconds(42)); var clusterId = new ClusterId(); var endPoint = new DnsEndPoint("localhost", 27017); @@ -134,7 +134,7 @@ public async Task GetReadChannelSource_should_fork_the_session( { var mockSession = new Mock(); var subject = new ReadPreferenceBinding(_mockCluster.Object, ReadPreference.Primary, mockSession.Object); - var selectedServer = new Mock().Object; + var selectedServer = (new Mock().Object, TimeSpan.FromMilliseconds(42)); _mockCluster.Setup(m => m.SelectServer(It.IsAny(), It.IsAny())).Returns(selectedServer); _mockCluster.Setup(m => m.SelectServerAsync(It.IsAny(), It.IsAny())).Returns(Task.FromResult(selectedServer)); var forkedSession = new Mock().Object; @@ -187,9 +187,6 @@ public void Dispose_should_call_dispose_on_the_session() internal static class ReadPreferenceBindingReflector { public static IClusterInternal _cluster(this ReadPreferenceBinding obj) - { - var fieldInfo = typeof(ReadPreferenceBinding).GetField("_cluster", BindingFlags.NonPublic | BindingFlags.Instance); - return (IClusterInternal)fieldInfo.GetValue(obj); - } + => (IClusterInternal)Reflector.GetFieldValue(obj, "_cluster"); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs index 62bd5cfd535..48aa1d81f4e 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/ServerChannelSourceTests.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Threading.Tasks; using FluentAssertions; using MongoDB.Driver.Core.Clusters; @@ -27,27 +28,43 @@ namespace MongoDB.Driver.Core.Bindings { public class ServerChannelSourceTests { - private Mock _mockServer; - - public ServerChannelSourceTests() - { - _mockServer = new Mock(); - } - [Fact] public void Constructor_should_throw_when_server_is_null() { + var roundTripTime = TimeSpan.FromMilliseconds(42); var session = new Mock().Object; - var exception = Record.Exception(() => new ServerChannelSource(null, session)); + var exception = Record.Exception(() => new ServerChannelSource(null, roundTripTime, session)); exception.Should().BeOfType() .Subject.ParamName.Should().Be("server"); } + [Theory] + [MemberData(nameof(InvalidRoundTripCases))] + public void Constructor_should_throw_when_roundTripTime_is_invalid(TimeSpan roundTripTime) + { + var server = Mock.Of(); + var session = Mock.Of(); + + var exception = Record.Exception(() => new ServerChannelSource(server, roundTripTime, session)); + + exception.Should().BeOfType() + .Subject.ParamName.Should().Be("roundTripTime"); + } + + public static IEnumerable InvalidRoundTripCases = + [ + [TimeSpan.Zero], + [TimeSpan.FromMilliseconds(-5)] + ]; + [Fact] public void Constructor_should_throw_when_session_is_null() { - var exception = Record.Exception(() => new ServerChannelSource(_mockServer.Object, null)); + var server = Mock.Of(); + var roundTripTime = TimeSpan.FromMilliseconds(42); + + var exception = Record.Exception(() => new ServerChannelSource(server, roundTripTime, null)); exception.Should().BeOfType() .Subject.ParamName.Should().Be("session"); @@ -56,13 +73,14 @@ public void Constructor_should_throw_when_session_is_null() [Fact] public void ServerDescription_should_return_description_of_server() { - var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); - var desc = ServerDescriptionHelper.Disconnected(new ClusterId()); - _mockServer.SetupGet(s => s.Description).Returns(desc); + var serverMock = new Mock(); + serverMock.SetupGet(s => s.Description).Returns(desc); + var roundTripTime = TimeSpan.FromMilliseconds(42); + var session = new Mock().Object; + var subject = new ServerChannelSource(serverMock.Object, roundTripTime, session); var result = subject.ServerDescription; result.Should().BeSameAs(desc); @@ -72,7 +90,7 @@ public void ServerDescription_should_return_description_of_server() public void Session_should_return_expected_result() { var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); + var subject = new ServerChannelSource(Mock.Of(), TimeSpan.FromMilliseconds(42), session); var result = subject.Session; @@ -86,7 +104,7 @@ public async Task GetChannel_should_throw_if_disposed( bool async) { var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); + var subject = new ServerChannelSource(Mock.Of(), TimeSpan.FromMilliseconds(42), session); subject.Dispose(); var exception = async ? @@ -102,20 +120,21 @@ public async Task GetChannel_should_get_connection_from_server( [Values(false, true)] bool async) { + var serverMock = new Mock(); var session = new Mock().Object; - var subject = new ServerChannelSource(_mockServer.Object, session); + var subject = new ServerChannelSource(serverMock.Object, TimeSpan.FromMilliseconds(42), session); if (async) { await subject.GetChannelAsync(OperationContext.NoTimeout); - _mockServer.Verify(s => s.GetChannelAsync(It.IsAny()), Times.Once); + serverMock.Verify(s => s.GetConnectionAsync(It.IsAny()), Times.Once); } else { subject.GetChannel(OperationContext.NoTimeout); - _mockServer.Verify(s => s.GetChannel(It.IsAny()), Times.Once); + serverMock.Verify(s => s.GetConnection(It.IsAny()), Times.Once); } } @@ -123,7 +142,7 @@ public async Task GetChannel_should_get_connection_from_server( public void Dispose_should_dispose_session() { var mockSession = new Mock(); - var subject = new ServerChannelSource(_mockServer.Object, mockSession.Object); + var subject = new ServerChannelSource(Mock.Of(), TimeSpan.FromMilliseconds(42), mockSession.Object); subject.Dispose(); diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadBindingTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadBindingTests.cs index ab0fae6858c..55302d2f82d 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadBindingTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadBindingTests.cs @@ -14,11 +14,11 @@ */ using System; +using System.Collections.Generic; using System.Net; -using System.Reflection; -using System.Threading; using System.Threading.Tasks; using FluentAssertions; +using MongoDB.Bson.TestHelpers; using MongoDB.TestHelpers.XunitExtensions; using MongoDB.Driver.Core.Servers; using Moq; @@ -32,10 +32,11 @@ public class SingleServerReadBindingTests public void constructor_should_initialize_instance() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var readPreference = ReadPreference.Primary; var session = new Mock().Object; - var result = new SingleServerReadBinding(server, readPreference, session); + var result = new SingleServerReadBinding(server, roundTripTime, readPreference, session); result._disposed().Should().BeFalse(); result.ReadPreference.Should().BeSameAs(readPreference); @@ -46,22 +47,44 @@ public void constructor_should_initialize_instance() [Fact] public void constructor_should_throw_when_server_is_null() { + var roundTripTime = TimeSpan.FromMilliseconds(42); var readPreference = ReadPreference.Primary; var session = new Mock().Object; - var exception = Record.Exception(() => new SingleServerReadBinding(null, readPreference, session)); + var exception = Record.Exception(() => new SingleServerReadBinding(null, roundTripTime, readPreference, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("server"); } + [Theory] + [MemberData(nameof(InvalidRoundTripCases))] + public void constructor_should_throw_when_roundTripTime_is_invalid(TimeSpan roundTripTime) + { + var server = new Mock().Object; + var readPreference = ReadPreference.Primary; + var session = new Mock().Object; + + var exception = Record.Exception(() => new SingleServerReadBinding(server, roundTripTime, readPreference, session)); + + var e = exception.Should().BeOfType().Subject; + e.ParamName.Should().Be("roundTripTime"); + } + + public static IEnumerable InvalidRoundTripCases = + [ + [TimeSpan.Zero], + [TimeSpan.FromMilliseconds(-5)] + ]; + [Fact] public void constructor_should_throw_when_readPreference_is_null() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var session = new Mock().Object; - var exception = Record.Exception(() => new SingleServerReadBinding(server, null, session)); + var exception = Record.Exception(() => new SingleServerReadBinding(server, roundTripTime, null, session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("readPreference"); @@ -71,9 +94,10 @@ public void constructor_should_throw_when_readPreference_is_null() public void constructor_should_throw_when_session_is_null() { var server = new Mock().Object; + var roundTripTime = TimeSpan.FromMilliseconds(42); var readPreference = ReadPreference.Primary; - var exception = Record.Exception(() => new SingleServerReadBinding(server, readPreference, null)); + var exception = Record.Exception(() => new SingleServerReadBinding(server, roundTripTime, readPreference, null)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("session"); @@ -130,8 +154,10 @@ public void Dispose_can_be_called_more_than_once() public async Task GetReadChannelSource_should_return_expected_result( [Values(false, true)] bool async) { + var server = CreateMockServer(); + var roundTripTime = TimeSpan.FromMilliseconds(42); var mockSession = new Mock(); - var subject = CreateSubject(session: mockSession.Object); + var subject = CreateSubject(server.Object, roundTripTime, session: mockSession.Object); var forkedSession = new Mock().Object; mockSession.Setup(m => m.Fork()).Returns(forkedSession); @@ -143,6 +169,8 @@ await subject.GetReadChannelSourceAsync(OperationContext.NoTimeout) : var referenceCounted = newHandle._reference(); var source = referenceCounted.Instance.Should().BeOfType().Subject; source.Session.Should().BeSameAs(forkedSession); + source.Server.Should().Be(server.Object); + source.RoundTripTime.Should().Be(roundTripTime); } [Theory] @@ -150,7 +178,8 @@ await subject.GetReadChannelSourceAsync(OperationContext.NoTimeout) : public async Task GetReadChannelSource_should_throw_when_disposed( [Values(false, true)] bool async) { - var subject = CreateDisposedSubject(); + var subject = CreateSubject(); + subject.Dispose(); var exception = async ? await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationContext.NoTimeout)) : @@ -161,17 +190,11 @@ await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationCon } // private methods - private SingleServerReadBinding CreateDisposedSubject() - { - var subject = CreateSubject(); - subject.Dispose(); - return subject; - } - - private SingleServerReadBinding CreateSubject(IServer server = null, ReadPreference readPreference = null, ICoreSessionHandle session = null) + private SingleServerReadBinding CreateSubject(IServer server = null, TimeSpan? roundTripTime = null, ReadPreference readPreference = null, ICoreSessionHandle session = null) { return new SingleServerReadBinding( server ?? CreateMockServer().Object, + roundTripTime ?? TimeSpan.FromMilliseconds(5), readPreference ?? ReadPreference.Primary, session ?? new Mock().Object); } @@ -191,15 +214,9 @@ private Mock CreateMockServer() internal static class SingleServerReadBindingReflector { public static bool _disposed(this SingleServerReadBinding obj) - { - var fieldInfo = typeof(SingleServerReadBinding).GetField("_disposed", BindingFlags.NonPublic | BindingFlags.Instance); - return (bool)fieldInfo.GetValue(obj); - } + => (bool)Reflector.GetFieldValue(obj, "_disposed"); public static IServer _server(this SingleServerReadBinding obj) - { - var fieldInfo = typeof(SingleServerReadBinding).GetField("_server", BindingFlags.NonPublic | BindingFlags.Instance); - return (IServer)fieldInfo.GetValue(obj); - } + => (IServer)Reflector.GetFieldValue(obj, "_server"); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadWriteBindingTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadWriteBindingTests.cs index 2efd25ca017..a4bc6e6e16b 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadWriteBindingTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/SingleServerReadWriteBindingTests.cs @@ -14,9 +14,10 @@ */ using System; -using System.Reflection; +using System.Collections.Generic; using System.Threading.Tasks; using FluentAssertions; +using MongoDB.Bson.TestHelpers; using MongoDB.TestHelpers.XunitExtensions; using MongoDB.Driver.Core.Servers; using Moq; @@ -32,7 +33,7 @@ public void constructor_should_initialize_instance() var server = new Mock().Object; var session = new Mock().Object; - var result = new SingleServerReadWriteBinding(server, session); + var result = new SingleServerReadWriteBinding(server, TimeSpan.FromMilliseconds(42), session); result._disposed().Should().BeFalse(); result._server().Should().BeSameAs(server); @@ -44,18 +45,37 @@ public void constructor_should_throw_when_server_is_null() { var session = new Mock().Object; - var exception = Record.Exception(() => new SingleServerReadWriteBinding(null, session)); + var exception = Record.Exception(() => new SingleServerReadWriteBinding(null, TimeSpan.FromMilliseconds(42), session)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("server"); } + [Theory] + [MemberData(nameof(InvalidRoundTripCases))] + public void constructor_should_throw_when_roundTripTime_is_invalid(TimeSpan roundTripTime) + { + var server = new Mock().Object; + var session = new Mock().Object; + + var exception = Record.Exception(() => new SingleServerReadWriteBinding(server, roundTripTime, session)); + + var e = exception.Should().BeOfType().Subject; + e.ParamName.Should().Be("roundTripTime"); + } + + public static IEnumerable InvalidRoundTripCases = + [ + [TimeSpan.Zero], + [TimeSpan.FromMilliseconds(-5)] + ]; + [Fact] public void constructor_should_throw_when_session_is_null() { var server = new Mock().Object; - var exception = Record.Exception(() => new SingleServerReadWriteBinding(server, null)); + var exception = Record.Exception(() => new SingleServerReadWriteBinding(server, TimeSpan.FromMilliseconds(42), null)); var e = exception.Should().BeOfType().Subject; e.ParamName.Should().Be("session"); @@ -111,8 +131,10 @@ public void Dispose_can_be_called_more_than_once() public async Task GetReadChannelSource_should_return_expected_result( [Values(false, true)] bool async) { + var server = Mock.Of(); + var roundTripTime = TimeSpan.FromMilliseconds(5); var mockSession = new Mock(); - var subject = CreateSubject(session: mockSession.Object); + var subject = CreateSubject(server, roundTripTime, mockSession.Object); var forkedSession = new Mock().Object; mockSession.Setup(m => m.Fork()).Returns(forkedSession); @@ -123,6 +145,8 @@ await subject.GetReadChannelSourceAsync(OperationContext.NoTimeout) : var newHandle = result.Should().BeOfType().Subject; var referenceCounted = newHandle._reference(); var source = referenceCounted.Instance.Should().BeOfType().Subject; + source.Server.Should().Be(server); + source.RoundTripTime.Should().Be(roundTripTime); source.Session.Should().BeSameAs(forkedSession); } @@ -131,7 +155,9 @@ await subject.GetReadChannelSourceAsync(OperationContext.NoTimeout) : public async Task GetReadChannelSource_should_throw_when_disposed( [Values(false, true)] bool async) { - var subject = CreateDisposedSubject(); + var subject = CreateSubject(); + subject.Dispose(); + var exception = async ? await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationContext.NoTimeout)) : Record.Exception(() => subject.GetReadChannelSource(OperationContext.NoTimeout)); @@ -145,8 +171,10 @@ await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationCon public async Task GetWriteChannelSource_should_return_expected_result( [Values(false, true)] bool async) { + var server = Mock.Of(); + var roundTripTime = TimeSpan.FromMilliseconds(5); var mockSession = new Mock(); - var subject = CreateSubject(session: mockSession.Object); + var subject = CreateSubject(server, roundTripTime, mockSession.Object); var forkedSession = new Mock().Object; mockSession.Setup(m => m.Fork()).Returns(forkedSession); @@ -157,6 +185,8 @@ await subject.GetWriteChannelSourceAsync(OperationContext.NoTimeout) : var newHandle = result.Should().BeOfType().Subject; var referenceCounted = newHandle._reference(); var source = referenceCounted.Instance.Should().BeOfType().Subject; + source.Server.Should().Be(server); + source.RoundTripTime.Should().Be(roundTripTime); source.Session.Should().BeSameAs(forkedSession); } @@ -165,7 +195,9 @@ await subject.GetWriteChannelSourceAsync(OperationContext.NoTimeout) : public async Task GetWriteChannelSource_should_throw_when_disposed( [Values(false, true)] bool async) { - var subject = CreateDisposedSubject(); + var subject = CreateSubject(); + subject.Dispose(); + var exception = async ? await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationContext.NoTimeout)) : Record.Exception(() => subject.GetReadChannelSource(OperationContext.NoTimeout)); @@ -175,17 +207,11 @@ await Record.ExceptionAsync(() => subject.GetReadChannelSourceAsync(OperationCon } // private methods - private SingleServerReadWriteBinding CreateDisposedSubject() - { - var subject = CreateSubject(); - subject.Dispose(); - return subject; - } - - private SingleServerReadWriteBinding CreateSubject(IServer server = null, ICoreSessionHandle session = null) + private SingleServerReadWriteBinding CreateSubject(IServer server = null, TimeSpan? roundTripTime = null, ICoreSessionHandle session = null) { return new SingleServerReadWriteBinding( server ?? new Mock().Object, + roundTripTime ?? TimeSpan.FromMilliseconds(42), session ?? new Mock().Object); } } @@ -193,15 +219,9 @@ private SingleServerReadWriteBinding CreateSubject(IServer server = null, ICoreS internal static class SingleServerReadWriteBindingReflector { public static bool _disposed(this SingleServerReadWriteBinding obj) - { - var fieldInfo = typeof(SingleServerReadWriteBinding).GetField("_disposed", BindingFlags.NonPublic | BindingFlags.Instance); - return (bool)fieldInfo.GetValue(obj); - } + => (bool)Reflector.GetFieldValue(obj, "_disposed"); public static IServer _server(this SingleServerReadWriteBinding obj) - { - var fieldInfo = typeof(SingleServerReadWriteBinding).GetField("_server", BindingFlags.NonPublic | BindingFlags.Instance); - return (IServer)fieldInfo.GetValue(obj); - } + => (IServer)Reflector.GetFieldValue(obj, "_server"); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Bindings/WritableServerBindingTests.cs b/tests/MongoDB.Driver.Tests/Core/Bindings/WritableServerBindingTests.cs index f16dc755894..0cecececc0e 100644 --- a/tests/MongoDB.Driver.Tests/Core/Bindings/WritableServerBindingTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Bindings/WritableServerBindingTests.cs @@ -18,6 +18,7 @@ using System.Reflection; using System.Threading.Tasks; using FluentAssertions; +using MongoDB.Bson.TestHelpers; using MongoDB.Driver.Core.Clusters; using MongoDB.Driver.Core.Clusters.ServerSelectors; using MongoDB.Driver.Core.Servers; @@ -97,7 +98,7 @@ public async Task GetReadChannelSource_should_use_a_writable_server_selector_to_ bool async) { var subject = new WritableServerBinding(_mockCluster.Object, NoCoreSession.NewHandle()); - var selectedServer = new Mock().Object; + var selectedServer = (new Mock().Object, TimeSpan.FromMilliseconds(42)); var clusterId = new ClusterId(); var endPoint = new DnsEndPoint("localhost", 27017); @@ -154,7 +155,7 @@ public async Task GetWriteChannelSourceAsync_should_use_a_writable_server_select bool async) { var subject = new WritableServerBinding(_mockCluster.Object, NoCoreSession.NewHandle()); - var selectedServer = new Mock().Object; + var selectedServer = (new Mock().Object, TimeSpan.FromMilliseconds(42)); var clusterId = new ClusterId(); var endPoint = new DnsEndPoint("localhost", 27017); @@ -194,7 +195,7 @@ public async Task GetWriteChannelSource_should_use_a_composite_server_selector_t bool async) { var subject = new WritableServerBinding(_mockCluster.Object, NoCoreSession.NewHandle()); - var selectedServer = new Mock().Object; + var selectedServer = (new Mock().Object, TimeSpan.FromMilliseconds(42)); var clusterId = new ClusterId(); var endPoint = new DnsEndPoint("localhost", 27017); @@ -235,7 +236,7 @@ public async Task GetWriteChannelSource_with_mayUseSecondary_should_pass_mayUseS bool async) { var subject = new WritableServerBinding(_mockCluster.Object, NoCoreSession.NewHandle()); - var selectedServer = new Mock().Object; + var selectedServer = (new Mock().Object, TimeSpan.FromMilliseconds(42)); var clusterId = new ClusterId(); var endPoint = new DnsEndPoint("localhost", 27017); @@ -288,9 +289,6 @@ public void Dispose_should_call_dispose_on_owned_resources() internal static class WritableServerBindingReflector { public static IClusterInternal _cluster(this WritableServerBinding obj) - { - var fieldInfo = typeof(WritableServerBinding).GetField("_cluster", BindingFlags.NonPublic | BindingFlags.Instance); - return (IClusterInternal)fieldInfo.GetValue(obj); - } + => (IClusterInternal)Reflector.GetFieldValue(obj, "_cluster"); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Clusters/ClusterTests.cs b/tests/MongoDB.Driver.Tests/Core/Clusters/ClusterTests.cs index 7a3244b229d..68fe29ce044 100644 --- a/tests/MongoDB.Driver.Tests/Core/Clusters/ClusterTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Clusters/ClusterTests.cs @@ -212,13 +212,13 @@ public async Task SelectServer_should_return_second_server_if_first_cannot_be_fo var selector = new DelegateServerSelector((c, s) => s); - var result = async ? + var (server, _) = async ? await subject.SelectServerAsync(OperationContext.NoTimeout, selector) : subject.SelectServer(OperationContext.NoTimeout, selector); - result.Should().NotBeNull(); - result.EndPoint.Should().Be(endPoint2); + server.Should().NotBeNull(); + server.EndPoint.Should().Be(endPoint2); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -376,13 +376,13 @@ public async Task SelectServer_should_ignore_deprioritized_servers_if_cluster_is { _capturedEvents.Clear(); - var result = async ? + var (server, _) = async ? await subject.SelectServerAsync(OperationContext.NoTimeout, selector) : subject.SelectServer(OperationContext.NoTimeout, selector); - result.Should().NotBeNull(); + server.Should().NotBeNull(); - deprioritizedServers.Should().NotContain(d => d.EndPoint == result.Description.EndPoint); + deprioritizedServers.Should().NotContain(d => d.EndPoint == server.Description.EndPoint); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -410,13 +410,13 @@ public async Task SelectServer_should_return_deprioritized_servers_if_no_other_s var selector = new PriorityServerSelector(deprioritizedServers); _capturedEvents.Clear(); - var result = async ? + var (server, _) = async ? await subject.SelectServerAsync(OperationContext.NoTimeout, selector) : subject.SelectServer(OperationContext.NoTimeout, selector); - result.Should().NotBeNull(); + server.Should().NotBeNull(); - deprioritizedServers.Should().Contain(d => d.EndPoint == result.Description.EndPoint); + deprioritizedServers.Should().Contain(d => d.EndPoint == server.Description.EndPoint); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -491,11 +491,11 @@ public async Task SelectServer_should_apply_both_pre_and_post_server_selectors( ServerDescriptionHelper.Connected(subject.Description.ClusterId, new DnsEndPoint("localhost", 27020))); _capturedEvents.Clear(); - var result = async ? + var (server, _) = async ? await subject.SelectServerAsync(OperationContext.NoTimeout, middleSelector) : subject.SelectServer(OperationContext.NoTimeout, middleSelector); - ((DnsEndPoint)result.EndPoint).Port.Should().Be(27020); + ((DnsEndPoint)server.EndPoint).Port.Should().Be(27020); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Any().Should().BeFalse(); @@ -526,7 +526,7 @@ public async Task SelectServer_should_call_custom_selector( if (withEligibleServers) { var selector = new DelegateServerSelector((c, s) => s); - var selectedServer = async ? + var (selectedServer, _) = async ? await subject.SelectServerAsync(OperationContext.NoTimeout, selector): subject.SelectServer(OperationContext.NoTimeout, selector); diff --git a/tests/MongoDB.Driver.Tests/Core/Clusters/LoadBalancedClusterTests.cs b/tests/MongoDB.Driver.Tests/Core/Clusters/LoadBalancedClusterTests.cs index 2c14c12f751..e4980c610ac 100644 --- a/tests/MongoDB.Driver.Tests/Core/Clusters/LoadBalancedClusterTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Clusters/LoadBalancedClusterTests.cs @@ -327,11 +327,11 @@ public async Task SelectServer_should_return_expected_server( PublishDescription(_endPoint); - var result = async ? + var (server, _) = async ? await subject.SelectServerAsync(OperationContext.NoTimeout, Mock.Of()) : subject.SelectServer(OperationContext.NoTimeout, Mock.Of()); - result.EndPoint.Should().Be(_endPoint); + server.EndPoint.Should().Be(_endPoint); } } diff --git a/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs b/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs index e467384ef70..7da9b2f2627 100644 --- a/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/ConnectionPools/ExclusiveConnectionPoolTests.cs @@ -312,10 +312,10 @@ public async Task AcquireConnection_should_invoke_error_handling_before_releasin .Setup(c => c.Settings) .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Throws(exception); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Throws(exception); return connectionMock.Object; @@ -582,7 +582,7 @@ public void AcquireConnection_should_timeout_when_non_sufficient_reused_connecti .Setup(c => c.Settings) .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { if (establishingCount.CurrentCount > 0) @@ -593,7 +593,7 @@ public void AcquireConnection_should_timeout_when_non_sufficient_reused_connecti blockEstablishmentEvent.Wait(); }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { if (establishingCount.CurrentCount > 0) @@ -756,14 +756,14 @@ public void Acquire_and_release_connection_stress_test( .Setup(c => c.Settings) .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { var sleepMS = random.Next(minEstablishingTime, maxEstablishingTime); Thread.Sleep(sleepMS); }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(async () => { var sleepMS = random.Next(minEstablishingTime, maxEstablishingTime); @@ -970,7 +970,7 @@ public void In_use_marker_should_work_as_expected( var mockConnection = new Mock(); mockConnection.SetupGet(c => c.ConnectionId).Returns(new ConnectionId(serverId, ci)); mockConnection - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { if (minPoolSize == 0 || ci == 2) // ignore connection 1 created in minPoolSize logic @@ -984,7 +984,7 @@ public void In_use_marker_should_work_as_expected( }); mockConnection - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(async () => { if (minPoolSize == 0 || ci == 2) // ignore connection 1 created in minPoolSize logic @@ -1076,7 +1076,7 @@ public void Maintenance_should_call_connection_dispose_when_connection_authentic var authenticationException = new MongoAuthenticationException(connectionId, "test message"); var authenticationFailedConnection = new Mock(); authenticationFailedConnection - .Setup(c => c.Open(It.IsAny())) // an authentication exception is thrown from _connectionInitializer.InitializeConnection + .Setup(c => c.Open(It.IsAny())) // an authentication exception is thrown from _connectionInitializer.InitializeConnection // that in turn is called from OpenAsync .Throws(authenticationException); authenticationFailedConnection.SetupGet(c => c.ConnectionId).Returns(connectionId); @@ -1166,7 +1166,7 @@ public void MaxConnecting_queue_should_be_cleared_on_pool_clear( .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { allEstablishing.Signal(); @@ -1174,7 +1174,7 @@ public void MaxConnecting_queue_should_be_cleared_on_pool_clear( }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { allEstablishing.Signal(); @@ -1424,7 +1424,7 @@ public void WaitQueue_should_throw_when_full( .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { allAcquiringCountdownEvent.Signal(); @@ -1432,7 +1432,7 @@ public void WaitQueue_should_throw_when_full( }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { allAcquiringCountdownEvent.Signal(); @@ -1516,7 +1516,7 @@ public void WaitQueue_should_be_cleared_on_pool_clear( .Returns(new ConnectionSettings()); connectionMock - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { allEstablishing.Signal(); @@ -1524,7 +1524,7 @@ public void WaitQueue_should_be_cleared_on_pool_clear( }); connectionMock - .Setup(c => c.OpenAsync(It.IsAny())) + .Setup(c => c.OpenAsync(It.IsAny())) .Returns(() => { allEstablishing.Signal(); diff --git a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs index f39bb0a97a5..4a5f5fc3da6 100644 --- a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs @@ -103,7 +103,7 @@ public void Dispose_should_raise_the_correct_events() [Theory] [ParameterAttributeData] - public void Open_should_always_create_description_if_handshake_was_successful([Values(false, true)] bool async) + public async Task Open_should_always_create_description_if_handshake_was_successful([Values(false, true)] bool async) { var serviceId = ObjectId.GenerateNewId(); var connectionDescription = new ConnectionDescription( @@ -124,15 +124,9 @@ public void Open_should_always_create_description_if_handshake_was_successful([V .Setup(i => i.AuthenticateAsync(It.IsAny(), It.IsAny(), CancellationToken.None)) .ThrowsAsync(socketException); - Exception exception; - if (async) - { - exception = Record.Exception(() => _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult()); - } - else - { - exception = Record.Exception(() => _subject.Open(CancellationToken.None)); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.OpenAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.Open(OperationContext.NoTimeout)); _subject.Description.Should().Be(connectionDescription); var ex = exception.Should().BeOfType().Subject; @@ -185,11 +179,11 @@ public async Task Open_should_create_authenticators_only_once( if (async) { - await subject.OpenAsync(CancellationToken.None); + await subject.OpenAsync(OperationContext.NoTimeout); } else { - subject.Open(CancellationToken.None); + subject.Open(OperationContext.NoTimeout); } authenticatorFactoryMock.Verify(f => f.Create(), Times.Once()); @@ -206,52 +200,37 @@ ResponseMessage CreateResponseMessage() [Theory] [ParameterAttributeData] - public void Open_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( + public async Task Open_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( [Values(false, true)] bool async) { _subject.Dispose(); - Action act; - if (async) - { - act = () => _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.Open(CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.OpenAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.Open(OperationContext.NoTimeout)); - act.ShouldThrow(); + exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void Open_should_raise_the_correct_events_upon_failure( + public async Task Open_should_raise_the_correct_events_upon_failure( [Values(false, true)] bool async) { - Action act; - if (async) - { - var result = new TaskCompletionSource(); - result.SetException(new SocketException()); - _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), It.IsAny())) - .Returns(result.Task); - - act = () => _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - _mockConnectionInitializer.Setup(i => i.SendHello(It.IsAny(), It.IsAny())) - .Throws(); + _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), It.IsAny())) + .Throws(); + _mockConnectionInitializer.Setup(i => i.SendHello(It.IsAny(), It.IsAny())) + .Throws(); - act = () => _subject.Open(CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.OpenAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.Open(OperationContext.NoTimeout)); - act.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception.InnerException.Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -260,17 +239,17 @@ public void Open_should_raise_the_correct_events_upon_failure( [Theory] [ParameterAttributeData] - public void Open_should_setup_the_description( + public async Task Open_should_setup_the_description( [Values(false, true)] bool async) { if (async) { - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); + await _subject.OpenAsync(OperationContext.NoTimeout); } else { - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); } _subject.Description.Should().NotBeNull(); @@ -290,32 +269,27 @@ public void Open_should_not_complete_the_second_call_until_the_first_is_complete { var task1IsBlocked = false; var completionSource = new TaskCompletionSource(); - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(() => { task1IsBlocked = true; return completionSource.Task.GetAwaiter().GetResult(); }); - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .Returns(() => { task1IsBlocked = true; return completionSource.Task; }); - - Task openTask1; - if (async1) - { + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) + .Returns(() => + { + task1IsBlocked = true; + return completionSource.Task.GetAwaiter().GetResult(); + }); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) + .Returns(() => + { + task1IsBlocked = true; + return completionSource.Task; + }); - openTask1 = _subject.OpenAsync(CancellationToken.None); - } - else - { - openTask1 = Task.Run(() => _subject.Open(CancellationToken.None)); - } + var openTask1 = async1 ? + _subject.OpenAsync(OperationContext.NoTimeout) : + Task.Run(() => _subject.Open(OperationContext.NoTimeout)); SpinWait.SpinUntil(() => task1IsBlocked, TimeSpan.FromSeconds(5)).Should().BeTrue(); - Task openTask2; - if (async2) - { - openTask2 = _subject.OpenAsync(CancellationToken.None); - } - else - { - openTask2 = Task.Run(() => _subject.Open(CancellationToken.None)); - } + var openTask2 = async2 ? + _subject.OpenAsync(OperationContext.NoTimeout) : + Task.Run(() => _subject.Open(OperationContext.NoTimeout)); openTask1.IsCompleted.Should().BeFalse(); openTask2.IsCompleted.Should().BeFalse(); @@ -340,11 +314,11 @@ public async Task Reauthentication_should_use_the_same_auth_context_as_in_initia if (async) { - await _subject.OpenAsync(CancellationToken.None); + await _subject.OpenAsync(OperationContext.NoTimeout); } else { - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); } _subject._connectionInitializerContext().Should().Be(_connectionInitializerContextAfterAuthentication); @@ -365,7 +339,7 @@ public async Task Reauthentication_should_use_the_same_auth_context_as_in_initia [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_a_FormatException_when_message_is_an_invalid_size( + public async Task ReceiveMessage_should_throw_a_FormatException_when_message_is_an_invalid_size( [Values(-1, 48000001)] int length, [Values(false, true)] @@ -376,27 +350,15 @@ public void ReceiveMessage_should_throw_a_FormatException_when_message_is_an_inv var bytes = BitConverter.GetBytes(length); stream.Write(bytes, 0, bytes.Length); stream.Seek(0, SeekOrigin.Begin); + + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())).Returns(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - Exception exception; - if (async) - { - _mockStreamFactory - .Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .ReturnsAsync(stream); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - exception = Record - .Exception(() => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None) - .GetAwaiter() - .GetResult()); - } - else - { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - exception = Record.Exception(() => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); exception.Should().BeOfType(); var e = exception.InnerException.Should().BeOfType().Subject; @@ -406,71 +368,52 @@ public void ReceiveMessage_should_throw_a_FormatException_when_message_is_an_inv [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_an_ArgumentNullException_when_the_encoderSelector_is_null( + public async Task ReceiveMessage_should_throw_an_ArgumentNullException_when_the_encoderSelector_is_null( [Values(false, true)] bool async) { - IMessageEncoderSelector encoderSelector = null; - - Action act; - if (async) - { - act = () => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, null, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, null, _messageEncoderSettings)); - act.ShouldThrow(); + exception.Should().BeOfType().Subject + .ParamName.Should().Be("encoderSelector"); } [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( + public async Task ReceiveMessage_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed( [Values(false, true)] bool async) { var encoderSelector = new Mock().Object; _subject.Dispose(); - Action act; - if (async) - { - act = () => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); - act.ShouldThrow(); + exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_an_InvalidOperationException_if_the_connection_is_not_open( + public async Task ReceiveMessage_should_throw_an_InvalidOperationException_if_the_connection_is_not_open( [Values(false, true)] bool async) { var encoderSelector = new Mock().Object; - Action act; - if (async) - { - act = () => _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act = () => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception = async ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); - act.ShouldThrow(); + exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( + public async Task ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( [Values(false, true)] bool async) { @@ -479,27 +422,18 @@ public void ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( var messageToReceive = MessageHelper.BuildReply(new BsonDocument(), BsonDocumentSerializer.Instance, responseTo: 10); MessageHelper.WriteResponsesToStream(stream, messageToReceive); - var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - - ResponseMessage received; - if (async) - { - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .Returns(Task.FromResult(stream)); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - _capturedEvents.Clear(); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) + .ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) + .Returns(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); + _capturedEvents.Clear(); - received = _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - _capturedEvents.Clear(); + var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - received = _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var received = async ? + await _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings) : + _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings); var expected = MessageHelper.TranslateMessagesToBsonDocuments(new[] { messageToReceive }); var actual = MessageHelper.TranslateMessagesToBsonDocuments(new[] { received }); @@ -514,40 +448,31 @@ public void ReceiveMessage_should_complete_when_reply_is_already_on_the_stream( [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_complete_when_reply_is_not_already_on_the_stream( + public async Task ReceiveMessage_should_complete_when_reply_is_not_already_on_the_stream( [Values(false, true)] bool async) { using (var stream = new BlockingMemoryStream()) { - var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - - Task receiveMessageTask; - if (async) - { - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .Returns(Task.FromResult(stream)); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - _capturedEvents.Clear(); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) + .ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) + .Returns(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); + _capturedEvents.Clear(); - receiveMessageTask = _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - _capturedEvents.Clear(); + var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - receiveMessageTask = Task.Run(() => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var receiveMessageTask = async ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); receiveMessageTask.IsCompleted.Should().BeFalse(); var messageToReceive = MessageHelper.BuildReply(new BsonDocument(), BsonDocumentSerializer.Instance, responseTo: 10); MessageHelper.WriteResponsesToStream(stream, messageToReceive); - var received = receiveMessageTask.GetAwaiter().GetResult(); + var received = await receiveMessageTask; var expected = MessageHelper.TranslateMessagesToBsonDocuments(new[] { messageToReceive }); var actual = MessageHelper.TranslateMessagesToBsonDocuments(new[] { received }); @@ -562,7 +487,7 @@ public void ReceiveMessage_should_complete_when_reply_is_not_already_on_the_stre [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_handle_out_of_order_replies( + public async Task ReceiveMessage_should_handle_out_of_order_replies( [Values(false, true)] bool async1, [Values(false, true)] @@ -570,32 +495,19 @@ public void ReceiveMessage_should_handle_out_of_order_replies( { using (var stream = new BlockingMemoryStream()) { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(stream); + await _subject.OpenAsync(OperationContext.NoTimeout); _capturedEvents.Clear(); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - Task receivedTask10; - if (async1) - { - receivedTask10 = _subject.ReceiveMessageAsync(10, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - receivedTask10 = Task.Run(() => _subject.ReceiveMessage(10, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var receivedTask10 = async1 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 10, encoderSelector, _messageEncoderSettings)); - Task receivedTask11; - if (async2) - { - receivedTask11 = _subject.ReceiveMessageAsync(11, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - receivedTask11 = Task.Run(() => _subject.ReceiveMessage(11, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var receivedTask11 = async2 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 11, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 11, encoderSelector, _messageEncoderSettings)); SpinWait.SpinUntil(() => _capturedEvents.Count >= 2, TimeSpan.FromSeconds(5)).Should().BeTrue(); @@ -603,8 +515,8 @@ public void ReceiveMessage_should_handle_out_of_order_replies( var messageToReceive11 = MessageHelper.BuildReply(new BsonDocument("_id", 11), BsonDocumentSerializer.Instance, responseTo: 11); MessageHelper.WriteResponsesToStream(stream, messageToReceive11, messageToReceive10); // out of order - var received10 = receivedTask10.GetAwaiter().GetResult(); - var received11 = receivedTask11.GetAwaiter().GetResult(); + var received10 = await receivedTask10; + var received11 = await receivedTask11; var expected = MessageHelper.TranslateMessagesToBsonDocuments(new[] { messageToReceive10, messageToReceive11 }); var actual = MessageHelper.TranslateMessagesToBsonDocuments(new[] { received10, received11 }); @@ -645,9 +557,9 @@ public async Task ReceiveMessage_should_not_produce_unobserved_task_exceptions_o tcs.SetException(new SocketException()); SetupStreamRead(mockStream, tcs); - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); - var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, CancellationToken.None)); + var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); exception.Should().BeOfType(); GC.Collect(); // Collects the unobserved tasks @@ -681,14 +593,14 @@ public async Task ReceiveMessageAsync_should_not_produce_unobserved_task_excepti var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); _mockStreamFactory - .Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) + .Setup(f => f.CreateStream(_endPoint, It.IsAny())) .Returns(mockStream.Object); var tcs = new TaskCompletionSource(); SetupStreamRead(mockStream, tcs, 50); - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); - var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, CancellationToken.None)); + var exception = await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); exception.Should().BeOfType(); exception.InnerException.Should().BeOfType(); @@ -711,7 +623,7 @@ public async Task ReceiveMessageAsync_should_not_produce_unobserved_task_excepti [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( + public async Task ReceiveMessage_should_throw_network_exception_to_all_awaiters( [Values(false, true)] bool async1, [Values(false, true)] @@ -722,46 +634,35 @@ public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( { var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())) .Returns(mockStream.Object); var readTcs = new TaskCompletionSource(); SetupStreamRead(mockStream, readTcs, readTimeoutMs: Timeout.Infinite); - _subject.Open(CancellationToken.None); + _subject.Open(OperationContext.NoTimeout); _capturedEvents.Clear(); - Task task1; - if (async1) - { - task1 = _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, It.IsAny()); - } - else - { - task1 = Task.Run(() => _subject.ReceiveMessage(1, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var task1 = async1 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); - Task task2; - if (async2) - { - task2 = _subject.ReceiveMessageAsync(2, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } - else - { - task2 = Task.Run(() => _subject.ReceiveMessage(2, encoderSelector, _messageEncoderSettings, CancellationToken.None)); - } + var task2 = async2 ? + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings) : + Task.Run(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings)); SpinWait.SpinUntil(() => _capturedEvents.Count >= 2, TimeSpan.FromSeconds(5)).Should().BeTrue(); readTcs.SetException(new SocketException()); - Func act1 = () => task1; - act1.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + var exception1 = await Record.ExceptionAsync(() => task1); + var exception2 = await Record.ExceptionAsync(() => task2); - Func act2 = () => task2; - act2.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception1.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception1.InnerException.Should().BeOfType(); + + exception2.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception2.InnerException.Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -774,7 +675,7 @@ public void ReceiveMessage_should_throw_network_exception_to_all_awaiters( [Theory] [ParameterAttributeData] - public void ReceiveMessage_should_throw_MongoConnectionClosedException_when_connection_has_failed( + public async Task ReceiveMessage_should_throw_MongoConnectionClosedException_when_connection_has_failed( [Values(false, true)] bool async1, [Values(false, true)] @@ -783,42 +684,29 @@ public void ReceiveMessage_should_throw_MongoConnectionClosedException_when_conn var mockStream = new Mock(); using (mockStream.Object) { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(mockStream.Object); + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(mockStream.Object); var readTcs = new TaskCompletionSource(); readTcs.SetException(new SocketException()); SetupStreamRead(mockStream, readTcs); - _subject.Open(CancellationToken.None); + await _subject.OpenAsync(OperationContext.NoTimeout); _capturedEvents.Clear(); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - Action act1; - if (async1) - { - act1 = () => _subject.ReceiveMessageAsync(1, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act1 = () => _subject.ReceiveMessage(1, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception1 = async1 ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 1, encoderSelector, _messageEncoderSettings)); - Action act2; - if (async2) - { - act2 = () => _subject.ReceiveMessageAsync(2, encoderSelector, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); - } - else - { - act2 = () => _subject.ReceiveMessage(2, encoderSelector, _messageEncoderSettings, CancellationToken.None); - } + var exception2 = async2 ? + await Record.ExceptionAsync(() => _subject.ReceiveMessageAsync(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings)) : + Record.Exception(() => _subject.ReceiveMessage(OperationContext.NoTimeout, 2, encoderSelector, _messageEncoderSettings)); - act1.ShouldThrow() - .WithInnerException() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception1.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); + exception1.InnerException.Should().BeOfType(); - act2.ShouldThrow() - .And.ConnectionId.Should().Be(_subject.ConnectionId); + exception2.Should().BeOfType().Subject + .ConnectionId.Should().Be(_subject.ConnectionId); _capturedEvents.Next().Should().BeOfType(); _capturedEvents.Next().Should().BeOfType(); @@ -834,8 +722,8 @@ public async Task SendMessage_should_throw_an_ArgumentNullException_if_message_i bool async) { var exception = async ? - await Record.ExceptionAsync(() => _subject.SendMessageAsync(null, _messageEncoderSettings, CancellationToken.None)) : - Record.Exception(() => _subject.SendMessage(null, _messageEncoderSettings, CancellationToken.None)); + await Record.ExceptionAsync(() => _subject.SendMessageAsync(OperationContext.NoTimeout, null, _messageEncoderSettings)) : + Record.Exception(() => _subject.SendMessage(OperationContext.NoTimeout, null, _messageEncoderSettings)); exception.Should().BeOfType(); } @@ -850,8 +738,8 @@ public async Task SendMessage_should_throw_an_ObjectDisposedException_if_the_con _subject.Dispose(); var exception = async ? - await Record.ExceptionAsync(() => _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None)) : - Record.Exception(() => _subject.SendMessage(message, _messageEncoderSettings, CancellationToken.None)); + await Record.ExceptionAsync(() => _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings)) : + Record.Exception(() => _subject.SendMessage(OperationContext.NoTimeout, message, _messageEncoderSettings)); exception.Should().BeOfType(); } @@ -865,39 +753,34 @@ public async Task SendMessage_should_throw_an_InvalidOperationException_if_the_c var message = MessageHelper.BuildQuery(); var exception = async ? - await Record.ExceptionAsync(() => _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None)) : - Record.Exception(() => _subject.SendMessage(message, _messageEncoderSettings, CancellationToken.None)); + await Record.ExceptionAsync(() => _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings)) : + Record.Exception(() => _subject.SendMessage(OperationContext.NoTimeout, message, _messageEncoderSettings)); exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public void SendMessage_should_put_the_message_on_the_stream_and_raise_the_correct_events( + public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the_correct_events( [Values(false, true)] bool async) { using (var stream = new MemoryStream()) { + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())).ReturnsAsync(stream); + _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, It.IsAny())).Returns(stream); + _subject.OpenAsync(OperationContext.NoTimeout).GetAwaiter().GetResult(); + _capturedEvents.Clear(); + var message = MessageHelper.BuildQuery(query: new BsonDocument("x", 1)); if (async) { - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) - .ReturnsAsync(stream); - _subject.OpenAsync(CancellationToken.None).GetAwaiter().GetResult(); - _capturedEvents.Clear(); - - _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None).GetAwaiter().GetResult(); + await _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings); } else { - _mockStreamFactory.Setup(f => f.CreateStream(_endPoint, CancellationToken.None)) - .Returns(stream); - _subject.Open(CancellationToken.None); - _capturedEvents.Clear(); - - _subject.SendMessage(message, _messageEncoderSettings, CancellationToken.None); + _subject.SendMessage(OperationContext.NoTimeout, message, _messageEncoderSettings); } var expectedRequests = MessageHelper.TranslateMessagesToBsonDocuments(new[] { message }); diff --git a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs index a01da9bb54a..4ebb258ee9f 100644 --- a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnection_CommandEventTests.cs @@ -87,9 +87,9 @@ public BinaryConnection_CommandEventTests(ITestOutputHelper output) : base(outpu new HelloResult(new BsonDocument { { "maxWireVersion", WireVersion.Server36 } })); _mockConnectionInitializer = new Mock(); - _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), CancellationToken.None)) + _mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny(), It.IsAny())) .Returns(() => Task.FromResult(new ConnectionInitializerContext(connectionDescriptionFunc(), null))); - _mockConnectionInitializer.Setup(i => i.AuthenticateAsync(It.IsAny(), It.IsAny(), CancellationToken.None)) + _mockConnectionInitializer.Setup(i => i.AuthenticateAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(() => Task.FromResult(new ConnectionInitializerContext(connectionDescriptionFunc(), null))); _subject = new BinaryConnection( @@ -102,9 +102,9 @@ public BinaryConnection_CommandEventTests(ITestOutputHelper output) : base(outpu LoggerFactory); _stream = new BlockingMemoryStream(); - _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, CancellationToken.None)) + _mockStreamFactory.Setup(f => f.CreateStreamAsync(_endPoint, It.IsAny())) .Returns(Task.FromResult(_stream)); - _subject.OpenAsync(CancellationToken.None).Wait(); + _subject.OpenAsync(OperationContext.NoTimeout).Wait(); _capturedEvents.Clear(); _operationIdDisposer = EventContext.BeginOperation(); @@ -484,14 +484,14 @@ public void Should_process_a_failed_query() private void SendMessage(RequestMessage message) { - _subject.SendMessageAsync(message, _messageEncoderSettings, CancellationToken.None).Wait(); + _subject.SendMessageAsync(OperationContext.NoTimeout, message, _messageEncoderSettings).Wait(); } private void ReceiveMessage(ReplyMessage message) { MessageHelper.WriteResponsesToStream(_stream, message); var encoderSelector = new ReplyMessageEncoderSelector(BsonDocumentSerializer.Instance); - _subject.ReceiveMessageAsync(message.ResponseTo, encoderSelector, _messageEncoderSettings, CancellationToken.None).Wait(); + _subject.ReceiveMessageAsync(OperationContext.NoTimeout, message.ResponseTo, encoderSelector, _messageEncoderSettings).Wait(); } } } diff --git a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs index 8be118b9db5..f5317db9f5e 100644 --- a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3173Tests.cs @@ -75,7 +75,7 @@ public void Ensure_command_network_error_before_handshake_is_correctly_handled([ // The next hello or legacy hello response will be delayed because the waiting in the mock.Callbacks cluster.Initialize(); - var selectedServer = cluster.SelectServer(OperationContext.NoTimeout, CreateWritableServerAndEndPointSelector(__endPoint1)); + var (selectedServer, _) = cluster.SelectServer(OperationContext.NoTimeout, CreateWritableServerAndEndPointSelector(__endPoint1)); initialSelectedEndpoint = selectedServer.EndPoint; initialSelectedEndpoint.Should().Be(__endPoint1); @@ -86,11 +86,11 @@ public void Ensure_command_network_error_before_handshake_is_correctly_handled([ Exception exception; if (async) { - exception = Record.Exception(() => selectedServer.GetChannelAsync(OperationContext.NoTimeout).GetAwaiter().GetResult()); + exception = Record.Exception(() => selectedServer.GetConnectionAsync(OperationContext.NoTimeout).GetAwaiter().GetResult()); } else { - exception = Record.Exception(() => selectedServer.GetChannel(OperationContext.NoTimeout)); + exception = Record.Exception(() => selectedServer.GetConnection(OperationContext.NoTimeout)); } var e = exception.Should().BeOfType().Subject; @@ -107,7 +107,7 @@ public void Ensure_command_network_error_before_handshake_is_correctly_handled([ } // ensure that a new server can be selected - selectedServer = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + (selectedServer, _) = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); // ensure that the selected server is not the same as the initial selectedServer.EndPoint.Should().Be(__endPoint2); @@ -353,7 +353,7 @@ void SetupFailedConnection(Mock mockFaultyConnection) () => WaitForTaskOrTimeout(hasClusterBeenDisposed.Task, TimeSpan.FromMinutes(1), "cluster dispose") }); mockFaultyConnection - .Setup(c => c.Open(It.IsAny())) + .Setup(c => c.Open(It.IsAny())) .Callback(() => { var responseAction = faultyConnectionResponses.Dequeue(); @@ -361,7 +361,7 @@ void SetupFailedConnection(Mock mockFaultyConnection) }); mockFaultyConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(() => { WaitForTaskOrTimeout( @@ -374,13 +374,13 @@ void SetupFailedConnection(Mock mockFaultyConnection) void SetupHealthyConnection(Mock mockHealthyConnection) { - mockHealthyConnection.Setup(c => c.Open(It.IsAny())); // no action is required - mockHealthyConnection.Setup(c => c.OpenAsync(It.IsAny())).Returns(Task.FromResult(true)); // no action is required + mockHealthyConnection.Setup(c => c.Open(It.IsAny())); // no action is required + mockHealthyConnection.Setup(c => c.OpenAsync(It.IsAny())).Returns(Task.FromResult(true)); // no action is required mockHealthyConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(commandResponseAction); mockConnection - .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(commandResponseAction); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs index aa944ed0ca6..6ea4a9f5176 100644 --- a/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Jira/CSharp3302Tests.cs @@ -142,13 +142,13 @@ public async Task Ensure_no_deadlock_after_primary_update() server.DescriptionChanged += ProcessServerDescriptionChanged; } - var selectedServer = cluster.SelectServer(OperationContext.NoTimeout, CreateWritableServerAndEndPointSelector(__endPoint1)); + var (selectedServer, _) = cluster.SelectServer(OperationContext.NoTimeout, CreateWritableServerAndEndPointSelector(__endPoint1)); initialSelectedEndpoint = selectedServer.EndPoint; initialSelectedEndpoint.Should().Be(__endPoint1); // Change primary currentPrimaries.Add(__serverId2); - selectedServer = cluster.SelectServer(OperationContext.NoTimeout, CreateWritableServerAndEndPointSelector(__endPoint2)); + (selectedServer, _) = cluster.SelectServer(OperationContext.NoTimeout, CreateWritableServerAndEndPointSelector(__endPoint2)); selectedServer.EndPoint.Should().Be(__endPoint2); // Ensure stalling happened @@ -303,9 +303,9 @@ private void SetupServerMonitorConnection( .SetupGet(c => c.Description) .Returns(GetConnectionDescription); - mockConnection.Setup(c => c.Open(It.IsAny())); // no action is required + mockConnection.Setup(c => c.Open(It.IsAny())); // no action is required mockConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(GetHelloResponse); ResponseMessage GetHelloResponse() diff --git a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs index bf17fa539c6..4057b9d5bf0 100644 --- a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs @@ -14,10 +14,7 @@ */ using System; -using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Text; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -40,7 +37,7 @@ public async Task ReadBytesAsync_with_byte_array_should_have_expected_effect_for var stream = new MemoryStream(bytes); var destination = new byte[2]; - await stream.ReadBytesAsync(destination, 0, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, count); destination.Should().Equal(expectedBytes); } @@ -54,7 +51,7 @@ public async Task ReadBytesAsync_with_byte_array_should_have_expected_effect_for var stream = new MemoryStream(bytes); var destination = new byte[3]; - await stream.ReadBytesAsync(destination, offset, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 1); destination.Should().Equal(expectedBytes); } @@ -80,7 +77,7 @@ public async Task ReadBytesAsync_with_byte_array_should_have_expected_effect_for }); var destination = new byte[3]; - await mockStream.Object.ReadBytesAsync(destination, 0, 3, Timeout.InfiniteTimeSpan, CancellationToken.None); + await mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 3); destination.Should().Equal(bytes); } @@ -92,7 +89,7 @@ public void ReadBytesAsync_with_byte_array_should_throw_when_end_of_stream_is_re var destination = new byte[1]; mockStream.Setup(s => s.ReadAsync(destination, 0, 1, It.IsAny())).Returns(Task.FromResult(0)); - Func action = () => mockStream.Object.ReadBytesAsync(destination, 0, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 1); action.ShouldThrow(); } @@ -103,7 +100,7 @@ public void ReadBytesAsync_with_byte_array_should_throw_when_buffer_is_null() var stream = new Mock().Object; byte[] destination = null; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0); action.ShouldThrow().And.ParamName.Should().Be("buffer"); } @@ -117,7 +114,7 @@ public void ReadBytesAsync_with_byte_array_should_throw_when_count_is_invalid(in var stream = new Mock().Object; var destination = new byte[2]; - Func action = () => stream.ReadBytesAsync(destination, offset, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, count); action.ShouldThrow().And.ParamName.Should().Be("count"); } @@ -131,7 +128,7 @@ public void ReadBytesAsync_with_byte_array_should_throw_when_offset_is_invalid( var stream = new Mock().Object; var destination = new byte[2]; - Func action = () => stream.ReadBytesAsync(destination, offset, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 0); action.ShouldThrow().And.ParamName.Should().Be("offset"); } @@ -142,7 +139,7 @@ public void ReadBytesAsync_with_byte_array_should_throw_when_stream_is_null() Stream stream = null; var destination = new byte[0]; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0); action.ShouldThrow().And.ParamName.Should().Be("stream"); } @@ -157,7 +154,7 @@ public async Task ReadBytesAsync_with_byte_buffer_should_have_expected_effect_fo var stream = new MemoryStream(bytes); var destination = new ByteArrayBuffer(new byte[2]); - await stream.ReadBytesAsync(destination, 0, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, count); destination.AccessBackingBytes(0).Array.Should().Equal(expectedBytes); } @@ -171,7 +168,7 @@ public async Task ReadBytesAsync_with_byte_buffer_should_have_expected_effect_fo var stream = new MemoryStream(bytes); var destination = new ByteArrayBuffer(new byte[3]); - await stream.ReadBytesAsync(destination, offset, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + await stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 1); destination.AccessBackingBytes(0).Array.Should().Equal(expectedBytes); } @@ -197,7 +194,7 @@ public async Task ReadBytesAsync_with_byte_buffer_should_have_expected_effect_fo return Task.FromResult(length); }); - await mockStream.Object.ReadBytesAsync(destination, 0, 3, Timeout.InfiniteTimeSpan, CancellationToken.None); + await mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 3); destination.AccessBackingBytes(0).Array.Should().Equal(bytes); } @@ -209,7 +206,7 @@ public void ReadBytesAsync_with_byte_buffer_should_throw_when_end_of_stream_is_r var destination = CreateMockByteBuffer(1).Object; mockStream.Setup(s => s.ReadAsync(It.IsAny(), 0, 1, It.IsAny())).Returns(Task.FromResult(0)); - Func action = () => mockStream.Object.ReadBytesAsync(destination, 0, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => mockStream.Object.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 1); action.ShouldThrow(); } @@ -220,7 +217,7 @@ public void ReadBytesAsync_with_byte_buffer_should_throw_when_buffer_is_null() var stream = new Mock().Object; IByteBuffer destination = null; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0); action.ShouldThrow().And.ParamName.Should().Be("buffer"); } @@ -234,7 +231,7 @@ public void ReadBytesAsync_with_byte_buffer_should_throw_when_count_is_invalid(i var stream = new Mock().Object; var destination = CreateMockByteBuffer(2).Object; - Func action = () => stream.ReadBytesAsync(destination, offset, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, count); action.ShouldThrow().And.ParamName.Should().Be("count"); } @@ -248,7 +245,7 @@ public void ReadBytesAsync_with_byte_buffer_should_throw_when_offset_is_invalid( var stream = new Mock().Object; var destination = CreateMockByteBuffer(2).Object; - Func action = () => stream.ReadBytesAsync(destination, offset, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, offset, 0); action.ShouldThrow().And.ParamName.Should().Be("offset"); } @@ -259,7 +256,7 @@ public void ReadBytesAsync_with_byte_buffer_should_throw_when_stream_is_null() Stream stream = null; var destination = new Mock().Object; - Func action = () => stream.ReadBytesAsync(destination, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.ReadBytesAsync(OperationContext.NoTimeout, destination, 0, 0); action.ShouldThrow().And.ParamName.Should().Be("stream"); } @@ -273,7 +270,7 @@ public async Task WriteBytesAsync_should_have_expected_effect_for_count(int coun var stream = new MemoryStream(); var source = new ByteArrayBuffer(new byte[] { 1, 2 }); - await stream.WriteBytesAsync(source, 0, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + await stream.WriteBytesAsync(OperationContext.NoTimeout, source, 0, count); stream.ToArray().Should().Equal(expectedBytes); } @@ -286,7 +283,7 @@ public async Task WriteBytesAsync_should_have_expected_effect_for_offset(int off var stream = new MemoryStream(); var source = new ByteArrayBuffer(new byte[] { 1, 2, 3 }); - await stream.WriteBytesAsync(source, offset, 1, Timeout.InfiniteTimeSpan, CancellationToken.None); + await stream.WriteBytesAsync(OperationContext.NoTimeout, source, offset, 1); stream.ToArray().Should().Equal(expectedBytes); } @@ -310,7 +307,7 @@ public async Task WriteBytesAsync_should_have_expected_effect_for_partial_writes return new ArraySegment(bytes, position, length); }); - await stream.WriteBytesAsync(mockSource.Object, 0, 3, Timeout.InfiniteTimeSpan, CancellationToken.None); + await stream.WriteBytesAsync(OperationContext.NoTimeout, mockSource.Object, 0, 3); stream.ToArray().Should().Equal(bytes); } @@ -320,7 +317,7 @@ public void WriteBytesAsync_should_throw_when_buffer_is_null() { var stream = new Mock().Object; - Func action = () => stream.WriteBytesAsync(null, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.WriteBytesAsync(OperationContext.NoTimeout, null, 0, 0); action.ShouldThrow().And.ParamName.Should().Be("buffer"); } @@ -334,7 +331,7 @@ public void WriteBytesAsync_should_throw_when_count_is_invalid(int offset, int c var stream = new Mock().Object; var source = CreateMockByteBuffer(2).Object; - Func action = () => stream.WriteBytesAsync(source, offset, count, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.WriteBytesAsync(OperationContext.NoTimeout, source, offset, count); action.ShouldThrow().And.ParamName.Should().Be("count"); } @@ -348,7 +345,7 @@ public void WriteBytesAsync_should_throw_when_offset_is_invalid( var stream = new Mock().Object; var destination = CreateMockByteBuffer(2).Object; - Func action = () => stream.WriteBytesAsync(destination, offset, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.WriteBytesAsync(OperationContext.NoTimeout, destination, offset, 0); action.ShouldThrow().And.ParamName.Should().Be("offset"); } @@ -359,7 +356,7 @@ public void WriteBytesAsync_should_throw_when_stream_is_null() Stream stream = null; var source = new Mock().Object; - Func action = () => stream.WriteBytesAsync(source, 0, 0, Timeout.InfiniteTimeSpan, CancellationToken.None); + Func action = () => stream.WriteBytesAsync(OperationContext.NoTimeout, source, 0, 0); action.ShouldThrow().And.ParamName.Should().Be("stream"); } diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs index a1b8943a276..3728762b355 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/AsyncCursorTests.cs @@ -432,6 +432,7 @@ public void GetMore_should_use_same_session( mockChannelSource.Setup(m => m.GetChannelAsync(It.IsAny())).Returns(Task.FromResult(channel)); mockChannel .Setup(m => m.CommandAsync( + It.IsAny(), session, null, databaseNamespace, @@ -442,8 +443,7 @@ public void GetMore_should_use_same_session( null, CommandResponseHandling.Return, It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .Callback(() => sameSessionWasUsed = true) .Returns(Task.FromResult(secondBatch)); @@ -454,6 +454,7 @@ public void GetMore_should_use_same_session( mockChannelSource.Setup(m => m.GetChannel(It.IsAny())).Returns(channel); mockChannel .Setup(m => m.Command( + It.IsAny(), session, null, databaseNamespace, @@ -464,8 +465,7 @@ public void GetMore_should_use_same_session( null, CommandResponseHandling.Return, It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .Callback(() => sameSessionWasUsed = true) .Returns(secondBatch); @@ -543,6 +543,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock c.CommandAsync( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -553,8 +554,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .ReturnsAsync(() => { var bsonDocument = commandResultFunc(); @@ -570,6 +570,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock c.Command( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -580,8 +581,7 @@ private void SetupChannelMocks(Mock mockChannelSource, Mock>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny())) + It.IsAny())) .Returns(() => { var bsonDocument = commandResultFunc(); @@ -596,6 +596,7 @@ private void VerifyHowManyTimesKillCursorsCommandWasCalled(Mock { mockChannelHandle.Verify( s => s.CommandAsync( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -606,16 +607,14 @@ private void VerifyHowManyTimesKillCursorsCommandWasCalled(Mock It.IsAny>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny()), + It.IsAny()), times); - - } else { mockChannelHandle.Verify( s => s.Command( + It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), @@ -626,8 +625,7 @@ private void VerifyHowManyTimesKillCursorsCommandWasCalled(Mock It.IsAny>(), It.IsAny(), It.IsAny>(), - It.IsAny(), - It.IsAny()), + It.IsAny()), times); } } @@ -694,6 +692,7 @@ private IReadOnlyList GetFirstBatchUsingFindCommand(IChannelHandle { "batchSize", batchSize } }; var result = channel.Command( + new OperationContext(Timeout.InfiniteTimeSpan, cancellationToken), _session, ReadPreference.Primary, _databaseNamespace, @@ -704,8 +703,7 @@ private IReadOnlyList GetFirstBatchUsingFindCommand(IChannelHandle null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - _messageEncoderSettings, - cancellationToken); + _messageEncoderSettings); var cursor = result["cursor"].AsBsonDocument; var firstBatch = cursor["firstBatch"].AsBsonArray.Select(i => i.AsBsonDocument).ToList(); cursorId = cursor["id"].ToInt64(); diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/BulkMixedWriteOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/BulkMixedWriteOperationTests.cs index c14714121a3..2e8ce56777b 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/BulkMixedWriteOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/BulkMixedWriteOperationTests.cs @@ -16,7 +16,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; @@ -1378,7 +1377,7 @@ public void Execute_unacknowledged_with_an_error_in_the_first_batch_and_ordered_ using (var readWriteBinding = CreateReadWriteBinding(useImplicitSession: true)) using (var channelSource = readWriteBinding.GetWriteChannelSource(OperationContext.NoTimeout)) using (var channel = channelSource.GetChannel(OperationContext.NoTimeout)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, readWriteBinding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, readWriteBinding.Session.Fork())) { var result = ExecuteOperation(subject, channelBinding, async); @@ -1425,7 +1424,7 @@ public void Execute_unacknowledged_with_an_error_in_the_first_batch_and_ordered_ using (var readWriteBinding = CreateReadWriteBinding(useImplicitSession: true)) using (var channelSource = readWriteBinding.GetWriteChannelSource(OperationContext.NoTimeout)) using (var channel = channelSource.GetChannel(OperationContext.NoTimeout)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, readWriteBinding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, readWriteBinding.Session.Fork())) { var result = ExecuteOperation(subject, channelBinding, async); result.ProcessedRequests.Should().HaveCount(5); @@ -1466,7 +1465,7 @@ public void Execute_unacknowledged_with_an_error_in_the_second_batch_and_ordered using (var readWriteBinding = CreateReadWriteBinding(useImplicitSession: true)) using (var channelSource = readWriteBinding.GetWriteChannelSource(OperationContext.NoTimeout)) using (var channel = channelSource.GetChannel(OperationContext.NoTimeout)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, readWriteBinding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, readWriteBinding.Session.Fork())) { var result = ExecuteOperation(subject, channelBinding, async); result.ProcessedRequests.Should().HaveCount(5); @@ -1507,7 +1506,7 @@ public void Execute_unacknowledged_with_an_error_in_the_second_batch_and_ordered using (var readWriteBinding = CreateReadWriteBinding(useImplicitSession: true)) using (var channelSource = readWriteBinding.GetWriteChannelSource(OperationContext.NoTimeout)) using (var channel = channelSource.GetChannel(OperationContext.NoTimeout)) - using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channel, readWriteBinding.Session.Fork())) + using (var channelBinding = new ChannelReadWriteBinding(channelSource.Server, channelSource.RoundTripTime, channel, readWriteBinding.Session.Fork())) { var result = ExecuteOperation(subject, channelBinding, async); result.ProcessedRequests.Should().HaveCount(4); diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/FindOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/FindOperationTests.cs index e468f7bca46..50b5c8e493e 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/FindOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/FindOperationTests.cs @@ -755,6 +755,7 @@ public void CreateCursor_should_use_ns_field_instead_of_namespace_passed_in_cons var mockChannelSource = new Mock(); mockChannelSource.Setup(x => x.Server).Returns(mockServer.Object); mockChannelSource.Setup(x => x.Session).Returns(mockSession.Object); + mockChannelSource.Setup(x => x.RoundTripTime).Returns(TimeSpan.FromSeconds(42)); var cursor = subject.CreateCursor(mockChannelSource.Object, Mock.Of(), commandResult); diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs index f866f268877..88a26354b9c 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs @@ -14,7 +14,6 @@ */ using System.Net; -using System.Threading; using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; @@ -96,6 +95,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -106,14 +106,14 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -124,8 +124,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -150,6 +149,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -160,14 +160,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -178,8 +178,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -203,6 +202,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -213,14 +213,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -231,8 +231,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -257,6 +256,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_readPr { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -267,14 +267,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_readPr null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, readPreference, subject.DatabaseNamespace, @@ -285,8 +285,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_readPr null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs index bc44255fa2c..082943172f5 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs @@ -14,7 +14,6 @@ */ using System.Net; -using System.Threading; using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; @@ -70,6 +69,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -80,14 +80,14 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -98,8 +98,7 @@ public void Execute_should_call_channel_Command_with_unwrapped_command_when_wrap null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -122,6 +121,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), It.IsAny(), It.IsAny(), subject.DatabaseNamespace, @@ -132,14 +132,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), It.IsAny(), It.IsAny(), subject.DatabaseNamespace, @@ -150,8 +150,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_additi null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } @@ -174,6 +173,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen { mockChannel.Verify( c => c.CommandAsync( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -184,14 +184,14 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } else { mockChannel.Verify( c => c.Command( + It.IsAny(), binding.Session, ReadPreference.Primary, subject.DatabaseNamespace, @@ -202,8 +202,7 @@ public void Execute_should_call_channel_Command_with_wrapped_command_when_commen null, // postWriteAction CommandResponseHandling.Return, subject.ResultSerializer, - subject.MessageEncoderSettings, - It.IsAny()), + subject.MessageEncoderSettings), Times.Once); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs b/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs index 16709b65bf3..07295032e52 100644 --- a/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Servers/LoadBalancedServerTests.cs @@ -75,28 +75,6 @@ public LoadBalancedTests(ITestOutputHelper output) : base(output) _connectionId = new ConnectionId(_subject.ServerId); } - [Theory] - [ParameterAttributeData] - public async Task ChannelFork_should_not_affect_operations_count([Values(false, true)] bool async) - { - IClusterableServer server = SetupServer(false, false); - - var channel = async ? - await server.GetChannelAsync(OperationContext.NoTimeout) : - server.GetChannel(OperationContext.NoTimeout); - - server.OutstandingOperationsCount.Should().Be(1); - - var forkedChannel = channel.Fork(); - server.OutstandingOperationsCount.Should().Be(1); - - forkedChannel.Dispose(); - server.OutstandingOperationsCount.Should().Be(1); - - channel.Dispose(); - server.OutstandingOperationsCount.Should().Be(0); - } - [Fact] public void Constructor_should_not_throw_when_serverApi_is_null() { @@ -167,7 +145,7 @@ public void Dispose_should_dispose_the_server() [Theory] [ParameterAttributeData] - public async Task GetChannel_should_clear_connection_pool_when_opening_connection_throws_MongoAuthenticationException( + public async Task GetConnection_should_clear_connection_pool_when_opening_connection_throws_MongoAuthenticationException( [Values(false, true)] bool async) { var connectionId = new ConnectionId(new ServerId(_clusterId, _endPoint)); @@ -204,8 +182,8 @@ public async Task GetChannel_should_clear_connection_pool_when_opening_connectio server.Initialize(); var exception = async ? - await Record.ExceptionAsync(() => server.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => server.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => server.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => server.GetConnection(OperationContext.NoTimeout)); exception.Should().BeOfType(); mockConnectionPool.Verify(p => p.Clear(It.IsAny()), Times.Once()); @@ -213,28 +191,28 @@ await Record.ExceptionAsync(() => server.GetChannelAsync(OperationContext.NoTime [Theory] [ParameterAttributeData] - public async Task GetChannel_should_get_a_connection([Values(false, true)] bool async) + public async Task GetConnection_should_get_a_connection([Values(false, true)] bool async) { _subject.Initialize(); var channel = async ? - await _subject.GetChannelAsync(OperationContext.NoTimeout) : - _subject.GetChannel(OperationContext.NoTimeout); + await _subject.GetConnectionAsync(OperationContext.NoTimeout) : + _subject.GetConnection(OperationContext.NoTimeout); channel.Should().NotBeNull(); } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_not_increase_operations_count_on_exception( + public async Task GetConnection_should_not_increase_operations_count_on_exception( [Values(false, true)] bool async, [Values(false, true)] bool connectionOpenException) { IClusterableServer server = SetupServer(connectionOpenException, !connectionOpenException); var exception = async ? - await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => _subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => _subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.GetConnection(OperationContext.NoTimeout)); exception.Should().NotBeNull(); server.OutstandingOperationsCount.Should().Be(0); @@ -242,58 +220,58 @@ await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTi [Theory] [ParameterAttributeData] - public async Task GetChannel_should_set_operations_count_correctly( + public async Task GetConnection_should_set_operations_count_correctly( [Values(false, true)] bool async, [Values(0, 1, 2, 10)] int operationsCount) { IClusterableServer server = SetupServer(false, false); - var channels = new List(); + var connections = new List(); for (int i = 0; i < operationsCount; i++) { - var channel = async ? - await server.GetChannelAsync(OperationContext.NoTimeout) : - server.GetChannel(OperationContext.NoTimeout); - channels.Add(channel); + var connection = async ? + await server.GetConnectionAsync(OperationContext.NoTimeout) : + server.GetConnection(OperationContext.NoTimeout); + connections.Add(connection); } server.OutstandingOperationsCount.Should().Be(operationsCount); - foreach (var channel in channels) + foreach (var connection in connections) { - channel.Dispose(); + server.ReturnConnection(connection); server.OutstandingOperationsCount.Should().Be(--operationsCount); } } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_throw_when_not_initialized( + public async Task GetConnection_should_throw_when_not_initialized( [Values(false, true)] bool async) { var exception = async ? - await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => _subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => _subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.GetConnection(OperationContext.NoTimeout)); exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_throw_when_disposed([Values(false, true)] bool async) + public async Task GetConnection_should_throw_when_disposed([Values(false, true)] bool async) { _subject.Dispose(); var exception = async ? - await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => _subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => _subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.GetConnection(OperationContext.NoTimeout)); exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_not_update_topology_and_clear_connection_pool_on_MongoConnectionException( + public async Task GetConnection_should_not_update_topology_and_clear_connection_pool_on_MongoConnectionException( [Values("TimedOutSocketException", "NetworkUnreachableSocketException")] string errorType, [Values(false, true)] bool async) { @@ -305,8 +283,8 @@ public async Task GetChannel_should_not_update_topology_and_clear_connection_poo var openConnectionException = new MongoConnectionException(connectionId, "Oops", new IOException("Cry", innerMostException)); var mockConnection = new Mock(); mockConnection.Setup(c => c.ConnectionId).Returns(connectionId); - mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); - mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); + mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); + mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); var connectionFactory = new Mock(); connectionFactory.Setup(cf => cf.CreateConnection(serverId, _endPoint)).Returns(mockConnection.Object); @@ -324,8 +302,8 @@ public async Task GetChannel_should_not_update_topology_and_clear_connection_poo subject.Initialize(); var exception = async ? - await Record.ExceptionAsync(() => subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => subject.GetConnection(OperationContext.NoTimeout)); exception.Should().Be(openConnectionException); subject.Description.Type.Should().Be(ServerType.LoadBalanced); diff --git a/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs b/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs index 701fe502cdd..1fa4a5d2588 100644 --- a/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Servers/RoundTripTimeMonitorTests.cs @@ -107,7 +107,7 @@ public void Round_trip_time_monitor_should_work_as_expected() }); mockConnection - .SetupSequence(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .SetupSequence(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns( () => { @@ -281,7 +281,7 @@ private ConnectionDescription CreateConnectionDescription() private RoundTripTimeMonitor CreateSubject(TimeSpan frequency, Mock mockConnection) { mockConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(() => CreateResponseMessage()); var mockConnectionFactory = new Mock(); diff --git a/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs b/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs index cd564ab9773..29d19ba5c05 100644 --- a/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Servers/ServerTests.cs @@ -97,28 +97,6 @@ protected override void DisposeInternal() _subject.Dispose(); } - [Theory] - [ParameterAttributeData] - public async Task ChannelFork_should_not_affect_operations_count([Values(false, true)] bool async) - { - IClusterableServer server = SetupServer(false, false); - - var channel = async ? - await server.GetChannelAsync(OperationContext.NoTimeout) : - server.GetChannel(OperationContext.NoTimeout); - - server.OutstandingOperationsCount.Should().Be(1); - - var forkedChannel = channel.Fork(); - server.OutstandingOperationsCount.Should().Be(1); - - forkedChannel.Dispose(); - server.OutstandingOperationsCount.Should().Be(1); - - channel.Dispose(); - server.OutstandingOperationsCount.Should().Be(0); - } - [Fact] public void Constructor_should_not_throw_when_serverApi_is_null() { @@ -200,7 +178,7 @@ public void Dispose_should_dispose_the_server() [Theory] [ParameterAttributeData] - public async Task GetChannel_should_clear_connection_pool_when_opening_connection_throws_MongoAuthenticationException( + public async Task GetConnection_should_clear_connection_pool_when_opening_connection_throws_MongoAuthenticationException( [Values(false, true)] bool async) { var connectionId = new ConnectionId(new ServerId(_clusterId, _endPoint)); @@ -236,8 +214,8 @@ public async Task GetChannel_should_clear_connection_pool_when_opening_connectio server.Initialize(); var exception = async ? - await Record.ExceptionAsync(() => server.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => server.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => server.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => server.GetConnection(OperationContext.NoTimeout)); exception.Should().BeOfType(); mockConnectionPool.Verify(p => p.Clear(It.IsAny()), Times.Once()); @@ -245,30 +223,30 @@ await Record.ExceptionAsync(() => server.GetChannelAsync(OperationContext.NoTime [Theory] [ParameterAttributeData] - public async Task GetChannel_should_get_a_connection( + public async Task GetConnection_should_get_a_connection( [Values(false, true)] bool async) { _subject.Initialize(); - var channel = async ? - await _subject.GetChannelAsync(OperationContext.NoTimeout) : - _subject.GetChannel(OperationContext.NoTimeout); + var connection = async ? + await _subject.GetConnectionAsync(OperationContext.NoTimeout) : + _subject.GetConnection(OperationContext.NoTimeout); - channel.Should().NotBeNull(); + connection.Should().NotBeNull(); } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_not_increase_operations_count_on_exception( + public async Task GetConnection_should_not_increase_operations_count_on_exception( [Values(false, true)] bool async, [Values(false, true)] bool connectionOpenException) { IClusterableServer server = SetupServer(connectionOpenException, !connectionOpenException); var exception = async ? - await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => _subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => _subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.GetConnection(OperationContext.NoTimeout)); exception.Should().NotBeNull(); server.OutstandingOperationsCount.Should().Be(0); @@ -276,60 +254,60 @@ await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTi [Theory] [ParameterAttributeData] - public async Task GetChannel_should_set_operations_count_correctly( + public async Task GetConnection_should_set_operations_count_correctly( [Values(false, true)] bool async, [Values(0, 1, 2, 10)] int operationsCount) { IClusterableServer server = SetupServer(false, false); - var channels = new List(); + var connections = new List(); for (int i = 0; i < operationsCount; i++) { - var channel = async ? - await server.GetChannelAsync(OperationContext.NoTimeout) : - server.GetChannel(OperationContext.NoTimeout); - channels.Add(channel); + var connection = async ? + await server.GetConnectionAsync(OperationContext.NoTimeout) : + server.GetConnection(OperationContext.NoTimeout); + connections.Add(connection); } server.OutstandingOperationsCount.Should().Be(operationsCount); - foreach (var channel in channels) + foreach (var connection in connections) { - channel.Dispose(); + server.ReturnConnection(connection); server.OutstandingOperationsCount.Should().Be(--operationsCount); } } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_throw_when_not_initialized( + public async Task GetConnection_should_throw_when_not_initialized( [Values(false, true)] bool async) { var exception = async ? - await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => _subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => _subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.GetConnection(OperationContext.NoTimeout)); exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_throw_when_disposed( + public async Task GetConnection_should_throw_when_disposed( [Values(false, true)] bool async) { _subject.Dispose(); var exception = async ? - await Record.ExceptionAsync(() => _subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => _subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => _subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => _subject.GetConnection(OperationContext.NoTimeout)); exception.Should().BeOfType(); } [Theory] [ParameterAttributeData] - public async Task GetChannel_should_update_topology_and_clear_connection_pool_on_network_error_or_timeout( + public async Task GetConnection_should_update_topology_and_clear_connection_pool_on_network_error_or_timeout( [Values("TimedOutSocketException", "NetworkUnreachableSocketException")] string errorType, [Values(false, true)] bool async) { @@ -340,8 +318,8 @@ public async Task GetChannel_should_update_topology_and_clear_connection_pool_on var openConnectionException = new MongoConnectionException(connectionId, "Oops", new IOException("Cry", innerMostException)); var mockConnection = new Mock(); mockConnection.Setup(c => c.ConnectionId).Returns(connectionId); - mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); - mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); + mockConnection.Setup(c => c.Open(It.IsAny())).Throws(openConnectionException); + mockConnection.Setup(c => c.OpenAsync(It.IsAny())).ThrowsAsync(openConnectionException); var connectionFactory = new Mock(); connectionFactory.Setup(f => f.ConnectionSettings).Returns(() => new ConnectionSettings()); @@ -368,8 +346,8 @@ public async Task GetChannel_should_update_topology_and_clear_connection_pool_on connectionPool.SetReady(); var exception = async ? - await Record.ExceptionAsync(() => subject.GetChannelAsync(OperationContext.NoTimeout)) : - Record.Exception(() => subject.GetChannel(OperationContext.NoTimeout)); + await Record.ExceptionAsync(() => subject.GetConnectionAsync(OperationContext.NoTimeout)) : + Record.Exception(() => subject.GetConnection(OperationContext.NoTimeout)); exception.Should().Be(openConnectionException); subject.Description.Type.Should().Be(ServerType.Unknown); @@ -853,8 +831,9 @@ public void Command_should_send_the_greater_of_the_session_and_cluster_cluster_t using (var cluster = CoreTestConfiguration.CreateCluster(b => b.Subscribe(eventCapturer))) using (var session = cluster.StartSession()) { - var server = (Server)cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); - using (var channel = server.GetChannel(OperationContext.NoTimeout)) + var (server, roundTripTime) = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + using (var channelSource = new ServerChannelSource(server, roundTripTime, session)) + using (var channel = channelSource.GetChannel(OperationContext.NoTimeout)) { session.AdvanceClusterTime(sessionClusterTime); server.ClusterClock.AdvanceClusterTime(clusterClusterTime); @@ -863,6 +842,7 @@ public void Command_should_send_the_greater_of_the_session_and_cluster_cluster_t try { channel.Command( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -873,8 +853,7 @@ public void Command_should_send_the_greater_of_the_session_and_cluster_cluster_t null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); + new MessageEncoderSettings()); } catch (MongoCommandException ex) { @@ -900,11 +879,13 @@ public void Command_should_update_the_session_and_cluster_cluster_times() using (var cluster = CoreTestConfiguration.CreateCluster(b => b.Subscribe(eventCapturer))) using (var session = cluster.StartSession()) { - var server = (Server)cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); - using (var channel = server.GetChannel(OperationContext.NoTimeout)) + var (server, roundTripTime) = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + using (var channelSource = new ServerChannelSource(server, roundTripTime, session)) + using (var channel = channelSource.GetChannel(OperationContext.NoTimeout)) { var command = BsonDocument.Parse("{ ping : 1 }"); channel.Command( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -915,15 +896,14 @@ public void Command_should_update_the_session_and_cluster_cluster_times() null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); - } + new MessageEncoderSettings()); - var commandSucceededEvent = eventCapturer.Next().Should().BeOfType().Subject; - var actualReply = commandSucceededEvent.Reply; - var actualClusterTime = actualReply["$clusterTime"].AsBsonDocument; - session.ClusterTime.Should().Be(actualClusterTime); - server.ClusterClock.ClusterTime.Should().Be(actualClusterTime); + var commandSucceededEvent = eventCapturer.Next().Should().BeOfType().Subject; + var actualReply = commandSucceededEvent.Reply; + var actualClusterTime = actualReply["$clusterTime"].AsBsonDocument; + session.ClusterTime.Should().Be(actualClusterTime); + server.ClusterClock.ClusterTime.Should().Be(actualClusterTime); + } } } @@ -943,14 +923,16 @@ public async Task Command_should_use_serverApi([Values(false, true)] bool async) using (var cluster = CoreTestConfiguration.CreateCluster(builder)) using (var session = cluster.StartSession()) { - var server = (Server)cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); - using (var channel = server.GetChannel(OperationContext.NoTimeout)) + var (server, roundTripTime) = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + using (var channelSource = new ServerChannelSource(server, roundTripTime, session)) + using (var channel = channelSource.GetChannel(OperationContext.NoTimeout)) { var command = BsonDocument.Parse("{ ping : 1 }"); if (async) { await channel .CommandAsync( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -961,12 +943,12 @@ await channel null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); + new MessageEncoderSettings()); } else { channel.Command( + OperationContext.NoTimeout, session, ReadPreference.Primary, DatabaseNamespace.Admin, @@ -977,8 +959,7 @@ await channel null, // postWriteAction CommandResponseHandling.Return, BsonDocumentSerializer.Instance, - new MessageEncoderSettings(), - It.IsAny()); + new MessageEncoderSettings()); } } } diff --git a/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs b/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs index 9eef407362b..2e989358f7e 100644 --- a/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/WireProtocol/CommandWriteProtocolTests.cs @@ -71,13 +71,14 @@ public void Execute_should_use_cached_IWireProtocol_if_available([Values(false, responseHandling, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); var commandResponse = MessageHelper.BuildCommandResponse(CreateRawBsonDocument(new BsonDocument("ok", 1))); var connectionId = SetupConnection(mockConnection); - var result = subject.Execute(mockConnection.Object, CancellationToken.None); + var result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); var cachedWireProtocol = subject._cachedWireProtocol(); cachedWireProtocol.Should().NotBeNull(); @@ -91,7 +92,7 @@ public void Execute_should_use_cached_IWireProtocol_if_available([Values(false, subject._responseHandling(CommandResponseHandling.Ignore); // will trigger the exception if the CommandUsingCommandMessageWireProtocol ctor will be called result = null; - var exception = Record.Exception(() => { result = subject.Execute(mockConnection.Object, CancellationToken.None); }); + var exception = Record.Exception(() => { result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); }); if (withSameConnection) { @@ -118,7 +119,7 @@ ConnectionId SetupConnection(Mock connection, ConnectionId id = nul } connection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None)) + .Setup(c => c.ReceiveMessage(OperationContext.NoTimeout, It.IsAny(), It.IsAny(), messageEncoderSettings)) .Returns(commandResponse); connection.SetupGet(c => c.ConnectionId).Returns(id); connection @@ -133,7 +134,7 @@ ConnectionId SetupConnection(Mock connection, ConnectionId id = nul [Theory] [ParameterAttributeData] - public void Execute_should_use_serverApi_with_getMoreCommand( + public async Task Execute_should_use_serverApi_with_getMoreCommand( [Values(false, true)] bool useServerApi, [Values(false, true)] bool async) { @@ -155,15 +156,16 @@ public void Execute_should_use_serverApi_with_getMoreCommand( CommandResponseHandling.Return, BsonDocumentSerializer.Instance, new MessageEncoderSettings(), - serverApi); + serverApi, + TimeSpan.FromMilliseconds(42)); if (async) { - subject.ExecuteAsync(connection, CancellationToken.None).GetAwaiter().GetResult(); + await subject.ExecuteAsync(OperationContext.NoTimeout, connection); } else { - subject.Execute(connection, CancellationToken.None); + subject.Execute(OperationContext.NoTimeout, connection); } SpinWait.SpinUntil(() => connection.GetSentMessages().Count >= 1, TimeSpan.FromSeconds(4)).Should().BeTrue(); @@ -177,7 +179,7 @@ public void Execute_should_use_serverApi_with_getMoreCommand( [Theory] [ParameterAttributeData] - public void Execute_should_use_serverApi_in_transaction( + public async Task Execute_should_use_serverApi_in_transaction( [Values(false, true)] bool useServerApi, [Values(false, true)] bool async) { @@ -199,15 +201,16 @@ public void Execute_should_use_serverApi_in_transaction( CommandResponseHandling.Return, BsonDocumentSerializer.Instance, new MessageEncoderSettings(), - serverApi); + serverApi, + TimeSpan.FromMilliseconds(42)); if (async) { - subject.ExecuteAsync(connection, CancellationToken.None).GetAwaiter().GetResult(); + await subject.ExecuteAsync(OperationContext.NoTimeout, connection); } else { - subject.Execute(connection, CancellationToken.None); + subject.Execute(OperationContext.NoTimeout, connection); } SpinWait.SpinUntil(() => connection.GetSentMessages().Count >= 1, TimeSpan.FromSeconds(4)).Should().BeTrue(); @@ -247,17 +250,18 @@ public void Execute_should_wait_for_response_when_CommandResponseHandling_is_Ret CommandResponseHandling.Return, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); var commandResponse = MessageHelper.BuildReply(CreateRawBsonDocument(new BsonDocument("ok", 1))); mockConnection - .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None)) + .Setup(c => c.ReceiveMessage(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings)) .Returns(commandResponse); - var result = subject.Execute(mockConnection.Object, CancellationToken.None); + var result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); result.Should().Be("{ok: 1}"); } @@ -277,21 +281,22 @@ public void Execute_should_not_wait_for_response_when_CommandResponseHandling_is CommandResponseHandling.NoResponseExpected, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); - var result = subject.Execute(mockConnection.Object, CancellationToken.None); + var result = subject.Execute(OperationContext.NoTimeout, mockConnection.Object); result.Should().BeNull(); mockConnection.Verify( - c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None), + c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings), Times.Once); } [Fact] - public void ExecuteAsync_should_wait_for_response_when_CommandResponseHandling_is_Return() + public async Task ExecuteAsync_should_wait_for_response_when_CommandResponseHandling_is_Return() { var messageEncoderSettings = new MessageEncoderSettings(); var subject = new CommandWireProtocol( @@ -306,22 +311,23 @@ public void ExecuteAsync_should_wait_for_response_when_CommandResponseHandling_i CommandResponseHandling.Return, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); var commandResponse = MessageHelper.BuildReply(CreateRawBsonDocument(new BsonDocument("ok", 1))); mockConnection - .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None)) + .Setup(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings)) .Returns(Task.FromResult(commandResponse)); - var result = subject.ExecuteAsync(mockConnection.Object, CancellationToken.None).GetAwaiter().GetResult(); + var result = await subject.ExecuteAsync(OperationContext.NoTimeout, mockConnection.Object); result.Should().Be("{ok: 1}"); } [Fact] - public void ExecuteAsync_should_not_wait_for_response_when_CommandResponseHandling_is_NoResponseExpected() + public async Task ExecuteAsync_should_not_wait_for_response_when_CommandResponseHandling_is_NoResponseExpected() { var messageEncoderSettings = new MessageEncoderSettings(); var subject = new CommandWireProtocol( @@ -336,15 +342,16 @@ public void ExecuteAsync_should_not_wait_for_response_when_CommandResponseHandli CommandResponseHandling.NoResponseExpected, BsonDocumentSerializer.Instance, messageEncoderSettings, - null); // serverApi + null, // serverApi + TimeSpan.FromMilliseconds(42)); var mockConnection = new Mock(); mockConnection.Setup(c => c.Settings).Returns(() => new ConnectionSettings()); - var result = subject.ExecuteAsync(mockConnection.Object, CancellationToken.None).GetAwaiter().GetResult(); + var result = await subject.ExecuteAsync(OperationContext.NoTimeout, mockConnection.Object); result.Should().BeNull(); - mockConnection.Verify(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), messageEncoderSettings, CancellationToken.None), Times.Once); + mockConnection.Verify(c => c.ReceiveMessageAsync(It.IsAny(), It.IsAny(), It.IsAny(), messageEncoderSettings), Times.Once); } // private methods diff --git a/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs b/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs index fb09f34e818..4c134535153 100644 --- a/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs +++ b/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs @@ -136,16 +136,16 @@ public async Task CreateEncryptedCollection_should_handle_generated_key_when_sec mockCluster.SetupGet(c => c.Description).Returns(clusterDescription); var mockServer = new Mock(); mockServer.SetupGet(s => s.Description).Returns(serverDescription); - var channel = Mock.Of(c => c.ConnectionDescription == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); - mockServer.Setup(s => s.GetChannel(It.IsAny())).Returns(channel); - mockServer.Setup(s => s.GetChannelAsync(It.IsAny())).ReturnsAsync(channel); + var connection = Mock.Of(c => c.Description == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); + mockServer.Setup(s => s.GetConnection(It.IsAny())).Returns(connection); + mockServer.Setup(s => s.GetConnectionAsync(It.IsAny())).ReturnsAsync(connection); mockCluster .Setup(m => m.SelectServer(It.IsAny(), It.IsAny())) - .Returns(mockServer.Object); + .Returns((mockServer.Object, TimeSpan.FromMilliseconds(42))); mockCluster .Setup(m => m.SelectServerAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(mockServer.Object); + .ReturnsAsync((mockServer.Object, TimeSpan.FromMilliseconds(42))); var database = Mock.Of(d => d.DatabaseNamespace == new DatabaseNamespace("db") && @@ -225,16 +225,16 @@ public async Task CreateEncryptedCollection_should_handle_various_encryptedField mockCluster.SetupGet(c => c.Description).Returns(clusterDescription); var mockServer = new Mock(); mockServer.SetupGet(s => s.Description).Returns(serverDescription); - var channel = Mock.Of(c => c.ConnectionDescription == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); - mockServer.Setup(s => s.GetChannel(It.IsAny())).Returns(channel); - mockServer.Setup(s => s.GetChannelAsync(It.IsAny())).ReturnsAsync(channel); + var connection = Mock.Of(c => c.Description == new ConnectionDescription(new ConnectionId(serverId), new HelloResult(new BsonDocument("maxWireVersion", serverDescription.WireVersionRange.Max)))); + mockServer.Setup(s => s.GetConnection(It.IsAny())).Returns(connection); + mockServer.Setup(s => s.GetConnectionAsync(It.IsAny())).ReturnsAsync(connection); mockCluster .Setup(m => m.SelectServer(It.IsAny(), It.IsAny())) - .Returns(mockServer.Object); + .Returns((mockServer.Object, TimeSpan.FromMilliseconds(42))); mockCluster .Setup(m => m.SelectServerAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(mockServer.Object); + .ReturnsAsync((mockServer.Object, TimeSpan.FromMilliseconds(42))); var database = Mock.Of(d => d.DatabaseNamespace == new DatabaseNamespace("db") && diff --git a/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs b/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs index 85fde463079..ec8c2f68978 100644 --- a/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs +++ b/tests/MongoDB.Driver.Tests/IJsonDrivenTestRunner.cs @@ -27,10 +27,8 @@ namespace MongoDB.Driver.Tests internal interface IJsonDrivenTestRunner { IClusterInternal FailPointCluster { get; } - IServer FailPointServer { get; } - - void ConfigureFailPoint(IServer server, ICoreSessionHandle session, BsonDocument failCommand); - Task ConfigureFailPointAsync(IServer server, ICoreSessionHandle session, BsonDocument failCommand); + void ConfigureFailPoint(IServer server, TimeSpan serverRoundTripTime, ICoreSessionHandle session, BsonDocument failCommand); + Task ConfigureFailPointAsync(IServer server, TimeSpan serverRoundTripTime, ICoreSessionHandle session, BsonDocument failCommand); } internal sealed class JsonDrivenTestRunner : IJsonDrivenTestRunner, IDisposable @@ -49,17 +47,15 @@ public IClusterInternal FailPointCluster } } - public IServer FailPointServer => null; - - public void ConfigureFailPoint(IServer server, ICoreSessionHandle session, BsonDocument failCommand) + public void ConfigureFailPoint(IServer server, TimeSpan serverRoundTripTime, ICoreSessionHandle session, BsonDocument failCommand) { - var failPoint = FailPoint.Configure(server, session, failCommand, withAsync: false); + var failPoint = FailPoint.Configure(server, serverRoundTripTime, session, failCommand, withAsync: false); _disposables.Add(failPoint); } - public async Task ConfigureFailPointAsync(IServer server, ICoreSessionHandle session, BsonDocument failCommand) + public async Task ConfigureFailPointAsync(IServer server, TimeSpan serverRoundTripTime, ICoreSessionHandle session, BsonDocument failCommand) { - var failPoint = await Task.Run(() => FailPoint.Configure(server, session, failCommand, withAsync: true)).ConfigureAwait(false); + var failPoint = await Task.Run(() => FailPoint.Configure(server, serverRoundTripTime, session, failCommand, withAsync: true)).ConfigureAwait(false); _disposables.Add(failPoint); } diff --git a/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs b/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs index 389449d592d..4d8e85ee651 100644 --- a/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs +++ b/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenConfigureFailPointTest.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -34,34 +35,24 @@ public JsonDrivenConfigureFailPointTest(IJsonDrivenTestRunner testRunner, Dictio protected override void CallMethod(CancellationToken cancellationToken) { - var server = GetFailPointServer(); - TestRunner.ConfigureFailPoint(server, NoCoreSession.NewHandle(), _failCommand); + var (server, serverRoundTripTime) = GetFailPointServer(); + TestRunner.ConfigureFailPoint(server, serverRoundTripTime, NoCoreSession.NewHandle(), _failCommand); } protected override async Task CallMethodAsync(CancellationToken cancellationToken) { - var server = await GetFailPointServerAsync().ConfigureAwait(false); - await TestRunner.ConfigureFailPointAsync(server, NoCoreSession.NewHandle(), _failCommand).ConfigureAwait(false); + var (server, serverRoundTripTime) = await GetFailPointServerAsync().ConfigureAwait(false); + await TestRunner.ConfigureFailPointAsync(server, serverRoundTripTime, NoCoreSession.NewHandle(), _failCommand).ConfigureAwait(false); } - protected virtual IServer GetFailPointServer() + protected virtual (IServer Server, TimeSpan RoundTripTime) GetFailPointServer() { - if (TestRunner.FailPointServer != null) - { - return TestRunner.FailPointServer; - } - var cluster = TestRunner.FailPointCluster; return cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); } - protected async virtual Task GetFailPointServerAsync() + protected async virtual Task<(IServer Server, TimeSpan RoundTripTime)> GetFailPointServerAsync() { - if (TestRunner.FailPointServer != null) - { - return TestRunner.FailPointServer; - } - var cluster = TestRunner.FailPointCluster; return await cluster.SelectServerAsync(OperationContext.NoTimeout, WritableServerSelector.Instance).ConfigureAwait(false); } diff --git a/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenTargetedFailPointTest.cs b/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenTargetedFailPointTest.cs index 8a34e83638d..8ab031cb94e 100644 --- a/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenTargetedFailPointTest.cs +++ b/tests/MongoDB.Driver.Tests/JsonDrivenTests/JsonDrivenTargetedFailPointTest.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Collections.Generic; using System.Net; using System.Threading.Tasks; @@ -29,14 +30,14 @@ public JsonDrivenTargetedFailPointTest(IJsonDrivenTestRunner testRunner, Diction { } - protected override IServer GetFailPointServer() + protected override (IServer Server, TimeSpan RoundTripTime) GetFailPointServer() { var pinnedServerEndpoint = GetPinnedServerEndpointAndAssertNotNull(); var pinnedServerSelector = CreateServerSelector(pinnedServerEndpoint); return TestRunner.FailPointCluster.SelectServer(OperationContext.NoTimeout, pinnedServerSelector); } - protected async override Task GetFailPointServerAsync() + protected async override Task<(IServer Server, TimeSpan RoundTripTime)> GetFailPointServerAsync() { var pinnedServerEndpoint = GetPinnedServerEndpointAndAssertNotNull(); var pinnedServerSelector = CreateServerSelector(pinnedServerEndpoint); diff --git a/tests/MongoDB.Driver.Tests/Specifications/Runner/MongoClientJsonDrivenTestRunnerBase.cs b/tests/MongoDB.Driver.Tests/Specifications/Runner/MongoClientJsonDrivenTestRunnerBase.cs index 8ad80cafbca..a3c3ee8a203 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/Runner/MongoClientJsonDrivenTestRunnerBase.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/Runner/MongoClientJsonDrivenTestRunnerBase.cs @@ -70,6 +70,7 @@ public abstract class MongoClientJsonDrivenTestRunnerBase : LoggableTestClass private IDictionary _objectMap = null; private protected IServer _failPointServer = null; + private protected TimeSpan? _failPointRoundTripTime = null; protected BsonDocument LastKnownClusterTime { get; set; } @@ -451,22 +452,22 @@ private protected FailPoint ConfigureFailPoint(BsonDocument test, IMongoClient c var settings = client.Settings.Clone(); ConfigureClientSettings(settings, test); - if (settings.DirectConnection == true) + if (settings.DirectConnection) { var serverAddress = EndPointHelper.Parse(settings.Server.ToString()); var selector = new EndPointServerSelector(serverAddress); - _failPointServer = cluster.SelectServer(OperationContext.NoTimeout, selector); + (_failPointServer, _failPointRoundTripTime) = cluster.SelectServer(OperationContext.NoTimeout, selector); } else { - _failPointServer = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + (_failPointServer, _failPointRoundTripTime) = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); } var session = NoCoreSession.NewHandle(); var command = failPoint.AsBsonDocument; - return FailPoint.Configure(_failPointServer, session, command, _async); + return FailPoint.Configure(_failPointServer, _failPointRoundTripTime.Value, session, command, _async); } return null; diff --git a/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs b/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs index cad223d7c2d..07e95151e43 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/connection-monitoring-and-pooling/ConnectionMonitoringAndPoolingTestRunner.cs @@ -83,12 +83,6 @@ private static class Schema public readonly static string ignore = nameof(ignore); public readonly static string async = nameof(async); - public static class Operations - { - public const string runOn = nameof(runOn); - public readonly static string failPoint = nameof(failPoint); - } - public static class Intergration { public readonly static string runOn = nameof(runOn); @@ -101,12 +95,6 @@ public static class Styles public readonly static string integration = nameof(integration); } - public sealed class FailPoint - { - public readonly static string appName = nameof(appName); - public readonly static string data = nameof(data); - } - public readonly static string[] AllFields = new[] { _path, @@ -671,7 +659,7 @@ private void ParseSettings( connectionIdLocalValueProvider: connectionIdProvider)) .Subscribe(eventCapturer)); - var server = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + var (server, _) = cluster.SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); connectionPool = server._connectionPool(); if (test.TryGetValue(Schema.Intergration.failPoint, out var failPointDocument)) @@ -729,8 +717,8 @@ o is ServerHeartbeatSucceededEvent || eventCapturer.WaitForOrThrowIfTimeout(events => events.Any(e => e is ConnectionPoolClearedEvent), TimeSpan.FromMilliseconds(500)); } - var failPointServer = CoreTestConfiguration.Cluster.SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(server.EndPoint)); - failPoint = FailPoint.Configure(failPointServer, NoCoreSession.NewHandle(), failPointDocument.AsBsonDocument, withAsync: async); + var (failPointServer, failPointServerRoundTripTime) = CoreTestConfiguration.Cluster.SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(server.EndPoint)); + failPoint = FailPoint.Configure(failPointServer, failPointServerRoundTripTime, NoCoreSession.NewHandle(), failPointDocument.AsBsonDocument, withAsync: async); if (resetPool) { @@ -745,33 +733,6 @@ o is ServerHeartbeatSucceededEvent || return (connectionPool, failPoint, cluster, eventsFilter); } - private IConnectionPool SetupConnectionPoolMock(BsonDocument test, IEventSubscriber eventSubscriber) - { - var endPoint = new DnsEndPoint("localhost", 27017); - var serverId = new ServerId(new ClusterId(), endPoint); - ParseSettings(test, out var connectionPoolSettings, out var connectionSettings); - - var connectionFactory = new Mock(); - var exceptionHandler = new Mock(); - connectionFactory.Setup(f => f.ConnectionSettings).Returns(() => new ConnectionSettings()); - connectionFactory - .Setup(c => c.CreateConnection(serverId, endPoint)) - .Returns(() => - { - var connection = new MockConnection(serverId, connectionSettings, eventSubscriber); - return connection; - }); - var connectionPool = new ExclusiveConnectionPool( - serverId, - endPoint, - connectionPoolSettings, - connectionFactory.Object, - exceptionHandler.Object, - eventSubscriber.ToEventLogger()); - - return connectionPool; - } - private void Start(BsonDocument operation, ConcurrentDictionary tasks) { var startTarget = operation.GetValue("target").ToString(); diff --git a/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs index 8117f67bcc5..960d24bd773 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/mongodb-handshake/MongoDbHandshakeProseTests.cs @@ -79,11 +79,11 @@ public async Task DriverAcceptsArbitraryAuthMechanism([Values(false, true)] bool if (async) { - await subject.OpenAsync(CancellationToken.None); + await subject.OpenAsync(OperationContext.NoTimeout); } else { - subject.Open(CancellationToken.None); + subject.Open(OperationContext.NoTimeout); } subject._state().Should().Be(3); // 3 - open. diff --git a/tests/MongoDB.Driver.Tests/Specifications/retryable-reads/RetryableReadsProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/retryable-reads/RetryableReadsProseTests.cs index e95cf102944..f4cd2f9b4e8 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/retryable-reads/RetryableReadsProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/retryable-reads/RetryableReadsProseTests.cs @@ -79,8 +79,8 @@ public async Task PoolClearedError_read_retryablity_test([Values(true, false)] b .Capture() .CaptureCommandEvents("find"); - var failpointServer = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, failPointSelector); - using var failPoint = FailPoint.Configure(failpointServer, NoCoreSession.NewHandle(), failPointCommand); + var (failpointServer, roundTripTime) = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, failPointSelector); + using var failPoint = FailPoint.Configure(failpointServer, roundTripTime, NoCoreSession.NewHandle(), failPointCommand); using var client = CreateClient(settings, eventCapturer, heartbeatInterval); var database = client.GetDatabase(DriverTestConfiguration.DatabaseNamespace.DatabaseName); @@ -146,11 +146,11 @@ public void Sharded_cluster_retryable_reads_are_retried_on_different_mongos_if_a }, useMultipleShardRouters: true); - var failPointServer1 = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); - var failPointServer2 = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[1].EndPoint)); + var (failPointServer1, roundTripTime1) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); + var (failPointServer2, roundTripTime2) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[1].EndPoint)); - using var failPoint1 = FailPoint.Configure(failPointServer1, NoCoreSession.NewHandle(), failPointCommand); - using var failPoint2 = FailPoint.Configure(failPointServer2, NoCoreSession.NewHandle(), failPointCommand); + using var failPoint1 = FailPoint.Configure(failPointServer1, roundTripTime1, NoCoreSession.NewHandle(), failPointCommand); + using var failPoint2 = FailPoint.Configure(failPointServer2, roundTripTime2, NoCoreSession.NewHandle(), failPointCommand); var database = client.GetDatabase(DriverTestConfiguration.DatabaseNamespace.DatabaseName); var collection = database.GetCollection(DriverTestConfiguration.CollectionNamespace.CollectionName); @@ -196,9 +196,9 @@ public void Sharded_cluster_retryable_reads_are_retried_on_same_mongos_if_no_oth }, useMultipleShardRouters: false); - var failPointServer = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); + var (failPointServer, roundTripTime) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); - using var failPoint = FailPoint.Configure(failPointServer, NoCoreSession.NewHandle(), failPointCommand); + using var failPoint = FailPoint.Configure(failPointServer, roundTripTime, NoCoreSession.NewHandle(), failPointCommand); var database = client.GetDatabase(DriverTestConfiguration.DatabaseNamespace.DatabaseName); var collection = database.GetCollection(DriverTestConfiguration.CollectionNamespace.CollectionName); diff --git a/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/PoolClearRetryability.cs b/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/PoolClearRetryability.cs index 78949576f42..193b5188fc2 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/PoolClearRetryability.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/PoolClearRetryability.cs @@ -82,8 +82,8 @@ public async Task PoolClearedError_write_retryablity_test([Values(false, true)] .Capture() .CaptureCommandEvents("insert"); - var failpointServer = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, failPointSelector); - using var failPoint = FailPoint.Configure(failpointServer, NoCoreSession.NewHandle(), failPointCommand); + var (failpointServer, roundTripTime) = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, failPointSelector); + using var failPoint = FailPoint.Configure(failpointServer, roundTripTime, NoCoreSession.NewHandle(), failPointCommand); using var client = CreateClient(settings, eventCapturer, heartbeatInterval); var database = client.GetDatabase(DriverTestConfiguration.DatabaseNamespace.DatabaseName); diff --git a/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/RetryWriteOnOtherMongos.cs b/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/RetryWriteOnOtherMongos.cs index c7424f38880..54fa6f9cbd3 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/RetryWriteOnOtherMongos.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/retryable-writes/prose-tests/RetryWriteOnOtherMongos.cs @@ -61,11 +61,11 @@ public void Sharded_cluster_retryable_writes_are_retried_on_different_mongos_if_ }, useMultipleShardRouters: true); - var failPointServer1 = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); - var failPointServer2 = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[1].EndPoint)); + var (failPointServer1, roundTripTime1) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); + var (failPointServer2, roundTripTime2) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[1].EndPoint)); - using var failPoint1 = FailPoint.Configure(failPointServer1, NoCoreSession.NewHandle(), failPointCommand); - using var failPoint2 = FailPoint.Configure(failPointServer2, NoCoreSession.NewHandle(), failPointCommand); + using var failPoint1 = FailPoint.Configure(failPointServer1, roundTripTime1, NoCoreSession.NewHandle(), failPointCommand); + using var failPoint2 = FailPoint.Configure(failPointServer2, roundTripTime2, NoCoreSession.NewHandle(), failPointCommand); var database = client.GetDatabase(DriverTestConfiguration.DatabaseNamespace.DatabaseName); var collection = database.GetCollection(DriverTestConfiguration.CollectionNamespace.CollectionName); @@ -112,9 +112,9 @@ public void Sharded_cluster_retryable_writes_are_retried_on_same_mongo_if_no_oth }, useMultipleShardRouters: false); - var failPointServer = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); + var (failPointServer, roundTripTime) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(client.Cluster.Description.Servers[0].EndPoint)); - using var failPoint = FailPoint.Configure(failPointServer, NoCoreSession.NewHandle(), failPointCommand); + using var failPoint = FailPoint.Configure(failPointServer, roundTripTime, NoCoreSession.NewHandle(), failPointCommand); var database = client.GetDatabase(DriverTestConfiguration.DatabaseNamespace.DatabaseName); var collection = database.GetCollection(DriverTestConfiguration.CollectionNamespace.CollectionName); diff --git a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs index 4d61e9c7b60..6803978739e 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs @@ -165,8 +165,8 @@ public void Monitor_sleep_at_least_minHeartbeatFrequencyMS_between_checks() settings.ApplicationName = appName; settings.ServerSelectionTimeout = TimeSpan.FromSeconds(5); - var server = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(new DnsEndPoint(serverAddress.Host, serverAddress.Port))); - using var failPoint = FailPoint.Configure(server, NoCoreSession.NewHandle(), failPointCommand); + var (server, serverRoundTripTime) = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(new DnsEndPoint(serverAddress.Host, serverAddress.Port))); + using var failPoint = FailPoint.Configure(server, serverRoundTripTime, NoCoreSession.NewHandle(), failPointCommand); using var client = DriverTestConfiguration.CreateMongoClient(settings); var database = client.GetDatabase(DriverTestConfiguration.DatabaseNamespace.DatabaseName); @@ -220,7 +220,7 @@ public void RoundTimeTrip_test() { // Note that the Server Description Equality rule means that ServerDescriptionChangedEvents will not be published. // So we use reflection to obtain the latest RTT instead. - var server = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); + var (server, _) = client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, WritableServerSelector.Instance); var roundTripTimeMonitor = server._monitor()._roundTripTimeMonitor(); var expectedRoundTripTime = TimeSpan.FromMilliseconds(250); var timeout = TimeSpan.FromSeconds(30); // should not be reached without a driver bug @@ -273,8 +273,8 @@ public void ConnectionPool_cleared_on_failed_hello() eventsWaitTimeout); eventCapturer.Clear(); - var failpointServer = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(new DnsEndPoint(serverAddress.Host, serverAddress.Port))); - using var failPoint = FailPoint.Configure(failpointServer, NoCoreSession.NewHandle(), failPointCommand); + var (failpointServer, serverRoundTripTime) = DriverTestConfiguration.Client.GetClusterInternal().SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(new DnsEndPoint(serverAddress.Host, serverAddress.Port))); + using var failPoint = FailPoint.Configure(failpointServer, serverRoundTripTime, NoCoreSession.NewHandle(), failPointCommand); eventCapturer.WaitForEventOrThrowIfTimeout(eventsWaitTimeout); var events = eventCapturer.Events diff --git a/tests/MongoDB.Driver.Tests/Specifications/server-selection/InWindowTestRunner.cs b/tests/MongoDB.Driver.Tests/Specifications/server-selection/InWindowTestRunner.cs index a61e7d6658a..88f66b16d95 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/server-selection/InWindowTestRunner.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/server-selection/InWindowTestRunner.cs @@ -81,7 +81,7 @@ public void RunTestDefinition(JsonDrivenTestCase testCase) for (int i = 0; i < testData.iterations; i++) { - var selectedServer = testData.async + var (selectedServer, _) = testData.async ? cluster.SelectServerAsync(OperationContext.NoTimeout, readPreferenceSelector).GetAwaiter().GetResult() : cluster.SelectServer(OperationContext.NoTimeout, readPreferenceSelector); diff --git a/tests/MongoDB.Driver.Tests/UnifiedTestOperations/UnifiedTargetedFailPointOperation.cs b/tests/MongoDB.Driver.Tests/UnifiedTestOperations/UnifiedTargetedFailPointOperation.cs index 97463a719b0..890296da610 100644 --- a/tests/MongoDB.Driver.Tests/UnifiedTestOperations/UnifiedTargetedFailPointOperation.cs +++ b/tests/MongoDB.Driver.Tests/UnifiedTestOperations/UnifiedTargetedFailPointOperation.cs @@ -52,11 +52,11 @@ public void Execute() _entityMap.RegisterForDispose(client); var cluster = client.GetClusterInternal(); - var server = cluster.SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(pinnedServer)); + var (server, roundTripTime) = cluster.SelectServer(OperationContext.NoTimeout, new EndPointServerSelector(pinnedServer)); var session = NoCoreSession.NewHandle(); - var failPoint = FailPoint.Configure(server, session, _failPointCommand, withAsync: _async); + var failPoint = FailPoint.Configure(server, roundTripTime, session, _failPointCommand, withAsync: _async); _entityMap.RegisterForDispose(failPoint); } }