diff --git a/src/client/executor.rs b/src/client/executor.rs index 54a5ae17b..b171b8d8f 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -36,8 +36,12 @@ impl Client { /// /// Server selection will performed using the criteria specified on the operation, if any, and /// an implicit session will be created if the operation and write concern are compatible with - /// sessions. - pub(crate) async fn execute_operation(&self, op: T) -> Result { + /// sessions and an explicit session is not provided. + pub(crate) async fn execute_operation( + &self, + op: T, + session: impl Into>, + ) -> Result { // TODO RUST-9: allow unacknowledged write concerns if !op.is_acknowledged() { return Err(ErrorKind::ArgumentError { @@ -45,9 +49,14 @@ impl Client { } .into()); } - let mut implicit_session = self.start_implicit_session(&op).await?; - self.execute_operation_with_retry(op, implicit_session.as_mut()) - .await + match session.into() { + Some(session) => self.execute_operation_with_retry(op, Some(session)).await, + None => { + let mut implicit_session = self.start_implicit_session(&op).await?; + self.execute_operation_with_retry(op, implicit_session.as_mut()) + .await + } + } } /// Execute the given operation, returning the implicit session created for it if one was. @@ -63,16 +72,6 @@ impl Client { .map(|result| (result, implicit_session)) } - /// Execute the given operation with the given session. - /// Server selection will performed using the criteria specified on the operation, if any. - pub(crate) async fn execute_operation_with_session( - &self, - op: T, - session: &mut ClientSession, - ) -> Result { - self.execute_operation_with_retry(op, Some(session)).await - } - /// Selects a server and executes the given operation on it, optionally using a provided /// session. Retries the operation upon failure if retryability is supported. async fn execute_operation_with_retry( @@ -324,7 +323,7 @@ impl Client { SessionSupportStatus::Supported { logical_session_timeout, } if op.supports_sessions() && op.is_acknowledged() => Ok(Some( - self.start_implicit_session_with_timeout(logical_session_timeout) + self.start_session_with_timeout(logical_session_timeout, None, true) .await, )), _ => Ok(None), @@ -334,7 +333,7 @@ impl Client { /// Gets whether the topology supports sessions, and if so, returns the topology's logical /// session timeout. If it has yet to be determined if the topology supports sessions, this /// method will perform a server selection that will force that determination to be made. - async fn get_session_support_status(&self) -> Result { + pub(crate) async fn get_session_support_status(&self) -> Result { let initial_status = self.inner.topology.session_support_status().await; // Need to guarantee that we're connected to at least one server that can determine if diff --git a/src/client/mod.rs b/src/client/mod.rs index f25bd0f4f..7ab59c705 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,7 +1,7 @@ pub mod auth; mod executor; pub mod options; -mod session; +pub mod session; use std::{sync::Arc, time::Duration}; @@ -23,10 +23,12 @@ use crate::{ ListDatabasesOptions, ReadPreference, SelectionCriteria, + SessionOptions, }, sdam::{SelectedServer, SessionSupportStatus, Topology}, + ClientSession, }; -pub(crate) use session::{ClientSession, ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS}; +pub(crate) use session::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS}; use session::{ServerSession, ServerSessionPool}; const DEFAULT_SERVER_SELECTION_TIMEOUT: Duration = Duration::from_secs(30); @@ -161,7 +163,7 @@ impl Client { options: impl Into>, ) -> Result> { let op = ListDatabases::new(filter.into(), false, options.into()); - self.execute_operation(op).await + self.execute_operation(op, None).await } /// Gets the names of the databases present in the cluster the Client is connected to. @@ -171,7 +173,7 @@ impl Client { options: impl Into>, ) -> Result> { let op = ListDatabases::new(filter.into(), true, options.into()); - match self.execute_operation(op).await { + match self.execute_operation(op, None).await { Ok(databases) => databases .into_iter() .map(|doc| { @@ -189,6 +191,18 @@ impl Client { } } + /// Starts a new `ClientSession`. + pub async fn start_session(&self, options: Option) -> Result { + match self.get_session_support_status().await? { + SessionSupportStatus::Supported { + logical_session_timeout, + } => Ok(self + .start_session_with_timeout(logical_session_timeout, options, false) + .await), + _ => Err(ErrorKind::SessionsNotSupported.into()), + } + } + /// Check in a server session to the server session pool. /// If the session is expired or dirty, or the topology no longer supports sessions, the session /// will be discarded. @@ -210,16 +224,20 @@ impl Client { /// This method will attempt to re-use server sessions from the pool which are not about to /// expire according to the provided logical session timeout. If no such sessions are /// available, a new one will be created. - pub(crate) async fn start_implicit_session_with_timeout( + pub(crate) async fn start_session_with_timeout( &self, logical_session_timeout: Duration, + options: Option, + is_implicit: bool, ) -> ClientSession { - ClientSession::new_implicit( + ClientSession::new( self.inner .session_pool .check_out(logical_session_timeout) .await, self.clone(), + options, + is_implicit, ) } diff --git a/src/client/options/mod.rs b/src/client/options/mod.rs index b79db6b66..26e2b10a5 100644 --- a/src/client/options/mod.rs +++ b/src/client/options/mod.rs @@ -2057,3 +2057,8 @@ mod tests { ); } } + +/// Contains the options that can be used to create a new +/// [`ClientSession`](../struct.ClientSession.html). +#[derive(Clone, Debug, Deserialize, TypedBuilder)] +pub struct SessionOptions {} diff --git a/src/client/session/cluster_time.rs b/src/client/session/cluster_time.rs index 980bba410..899a886fa 100644 --- a/src/client/session/cluster_time.rs +++ b/src/client/session/cluster_time.rs @@ -10,7 +10,7 @@ use crate::bson::{Document, Timestamp}; #[derive(Debug, Deserialize, Clone, Serialize, Derivative)] #[derivative(PartialEq, Eq)] #[serde(rename_all = "camelCase")] -pub(crate) struct ClusterTime { +pub struct ClusterTime { cluster_time: Timestamp, #[derivative(PartialEq = "ignore")] diff --git a/src/client/session/mod.rs b/src/client/session/mod.rs index 141475b71..d2a3ab971 100644 --- a/src/client/session/mod.rs +++ b/src/client/session/mod.rs @@ -13,6 +13,7 @@ use uuid::Uuid; use crate::{ bson::{doc, spec::BinarySubtype, Binary, Bson, Document}, + options::SessionOptions, Client, RUNTIME, }; @@ -28,29 +29,44 @@ lazy_static! { }; } -/// Session to be used with client operations. This acts as a handle to a server session. -/// This keeps the details of how server sessions are pooled opaque to users. +/// A MongoDB client session. This struct represents a logical session used for ordering sequential +/// operations. To create a `ClientSession`, call `start_session` on a `Client`. +/// +/// `ClientSession` instances are not thread safe or fork safe. They can only be used by one thread +/// or process at a time. #[derive(Debug)] -pub(crate) struct ClientSession { +pub struct ClientSession { cluster_time: Option, server_session: ServerSession, client: Client, is_implicit: bool, + options: Option, } impl ClientSession { /// Creates a new `ClientSession` wrapping the provided server session. - pub(crate) fn new_implicit(server_session: ServerSession, client: Client) -> Self { + pub(crate) fn new( + server_session: ServerSession, + client: Client, + options: Option, + is_implicit: bool, + ) -> Self { Self { client, server_session, cluster_time: None, - is_implicit: true, + is_implicit, + options, } } + /// The client used to create this session. + pub fn client(&self) -> Client { + self.client.clone() + } + /// The id of this session. - pub(crate) fn id(&self) -> &Document { + pub fn id(&self) -> &Document { &self.server_session.id } @@ -61,13 +77,18 @@ impl ClientSession { /// The highest seen cluster time this session has seen so far. /// This will be `None` if this session has not been used in an operation yet. - pub(crate) fn cluster_time(&self) -> Option<&ClusterTime> { + pub fn cluster_time(&self) -> Option<&ClusterTime> { self.cluster_time.as_ref() } + /// The options used to create this session. + pub fn options(&self) -> Option<&SessionOptions> { + self.options.as_ref() + } + /// Set the cluster time to the provided one if it is greater than this session's highest seen /// cluster time or if this session's cluster time is `None`. - pub(crate) fn advance_cluster_time(&mut self, to: &ClusterTime) { + pub fn advance_cluster_time(&mut self, to: &ClusterTime) { if self.cluster_time().map(|ct| ct < to).unwrap_or(true) { self.cluster_time = Some(to.clone()); } @@ -89,6 +110,12 @@ impl ClientSession { self.server_session.txn_number += 1; self.server_session.txn_number } + + /// Whether this session is dirty. + #[cfg(test)] + pub(crate) fn is_dirty(&self) -> bool { + self.server_session.dirty + } } impl Drop for ClientSession { @@ -109,7 +136,7 @@ impl Drop for ClientSession { /// Client side abstraction of a server session. These are pooled and may be associated with /// multiple `ClientSession`s over the course of their lifetime. -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct ServerSession { /// The id of the server session to which this corresponds. id: Document, diff --git a/src/client/session/test.rs b/src/client/session/test.rs index 0caadd3d7..c9437f9b4 100644 --- a/src/client/session/test.rs +++ b/src/client/session/test.rs @@ -202,10 +202,8 @@ async fn pool_is_lifo() { return; } - let timeout = Duration::from_secs(60 * 60); - - let a = client.start_implicit_session_with_timeout(timeout).await; - let b = client.start_implicit_session_with_timeout(timeout).await; + let a = client.start_session(None).await.unwrap(); + let b = client.start_session(None).await.unwrap(); let a_id = a.id().clone(); let b_id = b.id().clone(); @@ -218,10 +216,10 @@ async fn pool_is_lifo() { drop(b); RUNTIME.delay_for(Duration::from_millis(250)).await; - let s1 = client.start_implicit_session_with_timeout(timeout).await; + let s1 = client.start_session(None).await.unwrap(); assert_eq!(s1.id(), &b_id); - let s2 = client.start_implicit_session_with_timeout(timeout).await; + let s2 = client.start_session(None).await.unwrap(); assert_eq!(s2.id(), &a_id); } diff --git a/src/cmap/conn/command.rs b/src/cmap/conn/command.rs index 301fc2d59..769326f86 100644 --- a/src/cmap/conn/command.rs +++ b/src/cmap/conn/command.rs @@ -4,10 +4,11 @@ use super::wire::Message; use crate::{ bson::{Bson, Document}, bson_util, - client::{options::ServerApi, ClientSession, ClusterTime}, + client::{options::ServerApi, ClusterTime}, error::{CommandError, ErrorKind, Result}, options::StreamAddress, selection_criteria::ReadPreference, + ClientSession, }; /// `Command` is a driver side abstraction of a server command containing all the information diff --git a/src/coll/mod.rs b/src/coll/mod.rs index 15418b253..1862389b2 100644 --- a/src/coll/mod.rs +++ b/src/coll/mod.rs @@ -15,6 +15,7 @@ use self::options::*; use crate::{ bson::{doc, ser, to_document, Bson, Document}, bson_util, + client::session::ClientSession, concern::{ReadConcern, WriteConcern}, error::{convert_bulk_errors, BulkWriteError, BulkWriteFailure, ErrorKind, Result}, operation::{ @@ -34,6 +35,7 @@ use crate::{ Client, Cursor, Database, + SessionCursor, }; /// Maximum size in bytes of an insert batch. @@ -178,13 +180,31 @@ where self.inner.write_concern.as_ref() } - /// Drops the collection, deleting all data and indexes stored in it. - pub async fn drop(&self, options: impl Into>) -> Result<()> { + async fn drop_common( + &self, + options: impl Into>, + session: impl Into>, + ) -> Result<()> { let mut options = options.into(); resolve_options!(self, options, [write_concern]); let drop = DropCollection::new(self.namespace(), options); - self.client().execute_operation(drop).await + self.client().execute_operation(drop, session.into()).await + } + + /// Drops the collection, deleting all data and indexes stored in it. + pub async fn drop(&self, options: impl Into>) -> Result<()> { + self.drop_common(options, None).await + } + + /// Drops the collection, deleting all data and indexes stored in it using the provided + /// `ClientSession`. + pub async fn drop_with_session( + &self, + options: impl Into>, + session: &mut ClientSession, + ) -> Result<()> { + self.drop_common(options, session).await } /// Runs an aggregation operation. @@ -211,6 +231,31 @@ where .map(|(spec, session)| Cursor::new(client.clone(), spec, session)) } + /// Runs an aggregation operation using the provided `ClientSession`. + /// + /// See the documentation [here](https://docs.mongodb.com/manual/aggregation/) for more + /// information on aggregations. + pub async fn aggregate_with_session( + &self, + pipeline: impl IntoIterator, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + let mut options = options.into(); + resolve_options!( + self, + options, + [read_concern, write_concern, selection_criteria] + ); + + let aggregate = Aggregate::new(self.namespace(), pipeline, options); + let client = self.client(); + client + .execute_operation(aggregate, session) + .await + .map(|result| SessionCursor::new(client.clone(), result)) + } + /// Estimates the number of documents in the collection using collection metadata. pub async fn estimated_document_count( &self, @@ -220,7 +265,19 @@ where resolve_options!(self, options, [read_concern, selection_criteria]); let op = Count::new(self.namespace(), options); - self.client().execute_operation(op).await + self.client().execute_operation(op, None).await + } + + async fn count_documents_common( + &self, + filter: impl Into>, + options: impl Into>, + session: impl Into>, + ) -> Result { + let options = options.into(); + let filter = filter.into(); + let op = CountDocuments::new(self.namespace(), filter, options); + self.client().execute_operation(op, session).await } /// Gets the number of documents matching `filter`. @@ -232,10 +289,33 @@ where filter: impl Into>, options: impl Into>, ) -> Result { - let options = options.into(); - let filter = filter.into(); - let op = CountDocuments::new(self.namespace(), filter, options); - self.client().execute_operation(op).await + self.count_documents_common(filter, options, None).await + } + + /// Gets the number of documents matching `filter` using the provided `ClientSession`. + /// + /// Note that using [`Collection::estimated_document_count`](#method.estimated_document_count) + /// is recommended instead of this method is most cases. + pub async fn count_documents_with_session( + &self, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.count_documents_common(filter, options, session).await + } + + async fn delete_many_common( + &self, + query: Document, + options: impl Into>, + session: impl Into>, + ) -> Result { + let mut options = options.into(); + resolve_options!(self, options, [write_concern]); + + let delete = Delete::new(self.namespace(), query, None, options); + self.client().execute_operation(delete, session).await } /// Deletes all documents stored in the collection matching `query`. @@ -243,12 +323,32 @@ where &self, query: Document, options: impl Into>, + ) -> Result { + self.delete_many_common(query, options, None).await + } + + /// Deletes all documents stored in the collection matching `query` using the provided + /// `ClientSession`. + pub async fn delete_many_with_session( + &self, + query: Document, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.delete_many_common(query, options, session).await + } + + async fn delete_one_common( + &self, + query: Document, + options: impl Into>, + session: impl Into>, ) -> Result { let mut options = options.into(); resolve_options!(self, options, [write_concern]); - let delete = Delete::new(self.namespace(), query, None, options); - self.client().execute_operation(delete).await + let delete = Delete::new(self.namespace(), query, Some(1), options); + self.client().execute_operation(delete, session).await } /// Deletes up to one document found matching `query`. @@ -262,19 +362,30 @@ where query: Document, options: impl Into>, ) -> Result { - let mut options = options.into(); - resolve_options!(self, options, [write_concern]); + self.delete_one_common(query, options, None).await + } - let delete = Delete::new(self.namespace(), query, Some(1), options); - self.client().execute_operation(delete).await + /// Deletes up to one document found matching `query` using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn delete_one_with_session( + &self, + query: Document, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.delete_one_common(query, options, session).await } - /// Finds the distinct values of the field specified by `field_name` across the collection. - pub async fn distinct( + async fn distinct_common( &self, field_name: &str, filter: impl Into>, options: impl Into>, + session: impl Into>, ) -> Result> { let mut options = options.into(); resolve_options!(self, options, [read_concern, selection_criteria]); @@ -285,7 +396,31 @@ where filter.into(), options, ); - self.client().execute_operation(op).await + self.client().execute_operation(op, session).await + } + + /// Finds the distinct values of the field specified by `field_name` across the collection. + pub async fn distinct( + &self, + field_name: &str, + filter: impl Into>, + options: impl Into>, + ) -> Result> { + self.distinct_common(field_name, filter, options, None) + .await + } + + /// Finds the distinct values of the field specified by `field_name` across the collection using + /// the provided `ClientSession`. + pub async fn distinct_with_session( + &self, + field_name: &str, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + self.distinct_common(field_name, filter, options, session) + .await } /// Finds the documents in the collection matching `filter`. @@ -303,21 +438,68 @@ where .map(|(result, session)| Cursor::new(client.clone(), result, session)) } + /// Finds the documents in the collection matching `filter` using the provided `ClientSession`. + pub async fn find_with_session( + &self, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + let find = Find::new(self.namespace(), filter.into(), options.into()); + let client = self.client(); + + client + .execute_operation(find, session) + .await + .map(|result| SessionCursor::new(client.clone(), result)) + } + /// Finds a single document in the collection matching `filter`. pub async fn find_one( &self, filter: impl Into>, options: impl Into>, ) -> Result> { - let mut options: FindOptions = options + let options: FindOptions = options .into() .map(Into::into) .unwrap_or_else(Default::default); - options.limit = Some(-1); let mut cursor = self.find(filter, Some(options)).await?; cursor.next().await.transpose() } + /// Finds a single document in the collection matching `filter` using the provided + /// `ClientSession`. + pub async fn find_one_with_session( + &self, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + let options: FindOptions = options + .into() + .map(Into::into) + .unwrap_or_else(Default::default); + let mut cursor = self + .find_with_session(filter, Some(options), session) + .await?; + let mut cursor = cursor.with_session(session); + cursor.next().await.transpose() + } + + async fn find_one_and_delete_common( + &self, + filter: Document, + options: impl Into>, + session: impl Into>, + ) -> Result> { + let mut options = options.into(); + resolve_options!(self, options, [write_concern]); + + let op = FindAndModify::::with_delete(self.namespace(), filter, options); + self.client().execute_operation(op, session).await + } + /// Atomically finds up to one document in the collection matching `filter` and deletes it. /// /// This operation will retry once upon failure if the connection and encountered error support @@ -329,11 +511,40 @@ where filter: Document, options: impl Into>, ) -> Result> { + self.find_one_and_delete_common(filter, options, None).await + } + + /// Atomically finds up to one document in the collection matching `filter` and deletes it using + /// the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn find_one_and_delete_with_session( + &self, + filter: Document, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + self.find_one_and_delete_common(filter, options, session) + .await + } + + async fn find_one_and_replace_common( + &self, + filter: Document, + replacement: T, + options: impl Into>, + session: impl Into>, + ) -> Result> { + let replacement = to_document(&replacement)?; + let mut options = options.into(); resolve_options!(self, options, [write_concern]); - let op = FindAndModify::::with_delete(self.namespace(), filter, options); - self.client().execute_operation(op).await + let op = FindAndModify::::with_replace(self.namespace(), filter, replacement, options)?; + self.client().execute_operation(op, session).await } /// Atomically finds up to one document in the collection matching `filter` and replaces it with @@ -349,13 +560,41 @@ where replacement: T, options: impl Into>, ) -> Result> { - let replacement = to_document(&replacement)?; + self.find_one_and_replace_common(filter, replacement, options, None) + .await + } + /// Atomically finds up to one document in the collection matching `filter` and replaces it with + /// `replacement` using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn find_one_and_replace_with_session( + &self, + filter: Document, + replacement: T, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + self.find_one_and_replace_common(filter, replacement, options, session) + .await + } + + async fn find_one_and_update_common( + &self, + filter: Document, + update: impl Into, + options: impl Into>, + session: impl Into>, + ) -> Result> { + let update = update.into(); let mut options = options.into(); resolve_options!(self, options, [write_concern]); - let op = FindAndModify::::with_replace(self.namespace(), filter, replacement, options)?; - self.client().execute_operation(op).await + let op = FindAndModify::::with_update(self.namespace(), filter, update, options)?; + self.client().execute_operation(op, session).await } /// Atomically finds up to one document in the collection matching `filter` and updates it. @@ -373,24 +612,35 @@ where update: impl Into, options: impl Into>, ) -> Result> { - let update = update.into(); - let mut options = options.into(); - resolve_options!(self, options, [write_concern]); - - let op = FindAndModify::::with_update(self.namespace(), filter, update, options)?; - self.client().execute_operation(op).await + self.find_one_and_update_common(filter, update, options, None) + .await } - /// Inserts the data in `docs` into the collection. + /// Atomically finds up to one document in the collection matching `filter` and updates it using + /// the provided `ClientSession`. Both `Document` and `Vec` implement + /// `Into`, so either can be passed in place of constructing the enum + /// case. Note: pipeline updates are only supported in MongoDB 4.2+. /// /// This operation will retry once upon failure if the connection and encountered error support /// retryability. See the documentation /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on /// retryable writes. - pub async fn insert_many( + pub async fn find_one_and_update_with_session( + &self, + filter: Document, + update: impl Into, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + self.find_one_and_update_common(filter, update, options, session) + .await + } + + async fn insert_many_common( &self, docs: impl IntoIterator, options: impl Into>, + mut session: Option<&mut ClientSession>, ) -> Result { let docs: ser::Result> = docs .into_iter() @@ -425,7 +675,11 @@ where n_attempted += current_batch_size; let insert = Insert::new(self.namespace(), current_batch, options.clone()); - match self.client().execute_operation(insert).await { + match self + .client() + .execute_operation(insert, session.as_deref_mut()) + .await + { Ok(result) => { if cumulative_failure.is_none() { let cumulative_result = @@ -472,16 +726,40 @@ where } } - /// Inserts `doc` into the collection. + /// Inserts the data in `docs` into the collection. /// /// This operation will retry once upon failure if the connection and encountered error support /// retryability. See the documentation /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on /// retryable writes. - pub async fn insert_one( + pub async fn insert_many( + &self, + docs: impl IntoIterator, + options: impl Into>, + ) -> Result { + self.insert_many_common(docs, options, None).await + } + + /// Inserts the data in `docs` into the collection using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn insert_many_with_session( + &self, + docs: impl IntoIterator, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.insert_many_common(docs, options, Some(session)).await + } + + async fn insert_one_common( &self, doc: T, options: impl Into>, + session: impl Into>, ) -> Result { let doc = to_document(&doc)?; @@ -494,23 +772,47 @@ where options.map(InsertManyOptions::from_insert_one_options), ); self.client() - .execute_operation(insert) + .execute_operation(insert, session) .await .map(InsertOneResult::from_insert_many_result) .map_err(convert_bulk_errors) } - /// Replaces up to one document matching `query` in the collection with `replacement`. + /// Inserts `doc` into the collection. /// /// This operation will retry once upon failure if the connection and encountered error support /// retryability. See the documentation /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on /// retryable writes. - pub async fn replace_one( + pub async fn insert_one( + &self, + doc: T, + options: impl Into>, + ) -> Result { + self.insert_one_common(doc, options, None).await + } + + /// Inserts `doc` into the collection using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn insert_one_with_session( + &self, + doc: T, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.insert_one_common(doc, options, session).await + } + + async fn replace_one_common( &self, query: Document, replacement: T, options: impl Into>, + session: impl Into>, ) -> Result { let replacement = to_document(&replacement)?; bson_util::replacement_document_check(&replacement)?; @@ -525,20 +827,49 @@ where false, options.map(UpdateOptions::from_replace_options), ); - self.client().execute_operation(update).await + self.client().execute_operation(update, session).await } - /// Updates all documents matching `query` in the collection. + /// Replaces up to one document matching `query` in the collection with `replacement`. /// - /// Both `Document` and `Vec` implement `Into`, so either can be - /// passed in place of constructing the enum case. Note: pipeline updates are only supported - /// in MongoDB 4.2+. See the official MongoDB - /// [documentation](https://docs.mongodb.com/manual/reference/command/update/#behavior) for more information on specifying updates. - pub async fn update_many( + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn replace_one( + &self, + query: Document, + replacement: T, + options: impl Into>, + ) -> Result { + self.replace_one_common(query, replacement, options, None) + .await + } + + /// Replaces up to one document matching `query` in the collection with `replacement` using the + /// provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn replace_one_with_session( + &self, + query: Document, + replacement: T, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.replace_one_common(query, replacement, options, session) + .await + } + + async fn update_many_common( &self, query: Document, update: impl Into, options: impl Into>, + session: impl Into>, ) -> Result { let update = update.into(); let mut options = options.into(); @@ -550,7 +881,53 @@ where resolve_options!(self, options, [write_concern]); let update = Update::new(self.namespace(), query, update, true, options); - self.client().execute_operation(update).await + self.client().execute_operation(update, session).await + } + + /// Updates all documents matching `query` in the collection. + /// + /// Both `Document` and `Vec` implement `Into`, so either can be + /// passed in place of constructing the enum case. Note: pipeline updates are only supported + /// in MongoDB 4.2+. See the official MongoDB + /// [documentation](https://docs.mongodb.com/manual/reference/command/update/#behavior) for more information on specifying updates. + pub async fn update_many( + &self, + query: Document, + update: impl Into, + options: impl Into>, + ) -> Result { + self.update_many_common(query, update, options, None).await + } + + /// Updates all documents matching `query` in the collection using the provided `ClientSession`. + /// + /// Both `Document` and `Vec` implement `Into`, so either can be + /// passed in place of constructing the enum case. Note: pipeline updates are only supported + /// in MongoDB 4.2+. See the official MongoDB + /// [documentation](https://docs.mongodb.com/manual/reference/command/update/#behavior) for more information on specifying updates. + pub async fn update_many_with_session( + &self, + query: Document, + update: impl Into, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.update_many_common(query, update, options, session) + .await + } + + async fn update_one_common( + &self, + query: Document, + update: impl Into, + options: impl Into>, + session: impl Into>, + ) -> Result { + let mut options = options.into(); + resolve_options!(self, options, [write_concern]); + + let update = Update::new(self.namespace(), query, update.into(), false, options); + self.client().execute_operation(update, session).await } /// Updates up to one document matching `query` in the collection. @@ -570,11 +947,30 @@ where update: impl Into, options: impl Into>, ) -> Result { - let mut options = options.into(); - resolve_options!(self, options, [write_concern]); + self.update_one_common(query, update, options, None).await + } - let update = Update::new(self.namespace(), query, update.into(), false, options); - self.client().execute_operation(update).await + /// Updates up to one document matching `query` in the collection using the provided + /// `ClientSession`. + /// + /// Both `Document` and `Vec` implement `Into`, so either can be + /// passed in place of constructing the enum case. Note: pipeline updates are only supported + /// in MongoDB 4.2+. See the official MongoDB + /// [documentation](https://docs.mongodb.com/manual/reference/command/update/#behavior) for more information on specifying updates. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub async fn update_one_with_session( + &self, + query: Document, + update: impl Into, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.update_one_common(query, update, options, session) + .await } /// Kill the server side cursor that id corresponds to. diff --git a/src/coll/options.rs b/src/coll/options.rs index 4e2300bea..d90bde2fb 100644 --- a/src/coll/options.rs +++ b/src/coll/options.rs @@ -800,7 +800,7 @@ impl From for FindOptions { skip: options.skip, batch_size: None, cursor_type: None, - limit: None, + limit: Some(-1), max_await_time: None, no_cursor_timeout: None, sort: options.sort, diff --git a/src/cursor/mod.rs b/src/cursor/mod.rs index 10ce58019..d019d0b6d 100644 --- a/src/cursor/mod.rs +++ b/src/cursor/mod.rs @@ -1,7 +1,5 @@ mod common; -// TODO: RUST-52 use this -#[allow(dead_code)] -mod session; +pub(crate) mod session; use std::{ pin::Pin, @@ -13,11 +11,11 @@ use serde::de::DeserializeOwned; use crate::{ bson::{from_document, Document}, - client::ClientSession, error::{Error, Result}, operation::GetMore, results::GetMoreResult, Client, + ClientSession, RUNTIME, }; pub(crate) use common::{CursorInformation, CursorSpecification}; @@ -211,14 +209,8 @@ impl GetMoreProvider for ImplicitSessionGetMoreProvider { Self::Idle(mut session) => { let future = Box::pin(async move { let get_more = GetMore::new(info); - let get_more_result = match session { - Some(ref mut session) => { - client - .execute_operation_with_session(get_more, session) - .await - } - None => client.execute_operation(get_more).await, - }; + let get_more_result = + client.execute_operation(get_more, session.as_mut()).await; ImplicitSessionGetMoreResult { get_more_result, session, diff --git a/src/cursor/session.rs b/src/cursor/session.rs index bc70c84d1..eddf170ab 100644 --- a/src/cursor/session.rs +++ b/src/cursor/session.rs @@ -1,30 +1,61 @@ -use std::collections::VecDeque; +use std::{ + collections::VecDeque, + pin::Pin, + task::{Context, Poll}, +}; -use futures::future::BoxFuture; +use futures::{future::BoxFuture, Stream}; +use serde::de::DeserializeOwned; use super::common::{CursorInformation, GenericCursor, GetMoreProvider, GetMoreProviderResult}; use crate::{ - bson::Document, - client::ClientSession, + bson::{from_document, Document}, cursor::CursorSpecification, error::{Error, Result}, operation::GetMore, results::GetMoreResult, Client, + ClientSession, RUNTIME, }; -/// A cursor that was started with a session and must be iterated using one. +/// A `SessionCursor` is a cursor that was created with a `ClientSession` and must be iterated using +/// one. To iterate, retrieve a `SessionCursorHandle` using `SessionCursor::with_session`: +/// +/// ```rust +/// # use futures::stream::StreamExt; +/// # use mongodb::{Client, error::Result, ClientSession, SessionCursor}; +/// # +/// # async fn do_stuff() -> Result<()> { +/// # let client = Client::with_uri_str("mongodb://example.com").await?; +/// # let mut session = client.start_session(None).await?; +/// # let coll = client.database("foo").collection("bar"); +/// # let mut cursor = coll.find_with_session(None, None, &mut session).await?; +/// # +/// while let Some(doc) = cursor.with_session(&mut session).next().await { +/// println!("{}", doc?) +/// } +/// # +/// # Ok(()) +/// # } +/// ``` #[derive(Debug)] -pub(crate) struct SessionCursor { +pub struct SessionCursor +where + T: DeserializeOwned + Unpin, +{ exhausted: bool, client: Client, info: CursorInformation, buffer: VecDeque, + _phantom: std::marker::PhantomData, } -impl SessionCursor { - fn new(client: Client, spec: CursorSpecification) -> Self { +impl SessionCursor +where + T: DeserializeOwned + Unpin, +{ + pub(crate) fn new(client: Client, spec: CursorSpecification) -> Self { let exhausted = spec.id() == 0; Self { @@ -32,13 +63,16 @@ impl SessionCursor { client, info: spec.info, buffer: spec.initial_buffer, + _phantom: Default::default(), } } - fn with_session<'session>( + /// Retrieves a `SessionCursorHandle` to iterate this cursor. The session provided must be the + /// same session used to create the cursor. + pub fn with_session<'session>( &mut self, session: &'session mut ClientSession, - ) -> SessionCursorHandle<'_, 'session> { + ) -> SessionCursorHandle<'_, 'session, T> { let get_more_provider = ExplicitSessionGetMoreProvider::new(session); // Pass the buffer into this cursor handle for iteration. @@ -58,7 +92,10 @@ impl SessionCursor { } } -impl Drop for SessionCursor { +impl Drop for SessionCursor +where + T: DeserializeOwned + Unpin, +{ fn drop(&mut self) { if self.exhausted { return; @@ -79,15 +116,38 @@ impl Drop for SessionCursor { type ExplicitSessionCursor<'session> = GenericCursor>; /// A handle that borrows a `ClientSession` temporarily for executing getMores or iterating through -/// the current buffer. +/// the current buffer of a `SessionCursor`. /// -/// This updates the buffer of the parent cursor when dropped. -struct SessionCursorHandle<'cursor, 'session> { - session_cursor: &'cursor mut SessionCursor, +/// This updates the buffer of the parent `SessionCursor` when dropped. +pub struct SessionCursorHandle<'cursor, 'session, T = Document> +where + T: DeserializeOwned + Unpin, +{ + session_cursor: &'cursor mut SessionCursor, generic_cursor: ExplicitSessionCursor<'session>, } -impl<'cursor, 'session> Drop for SessionCursorHandle<'cursor, 'session> { +impl<'cursor, 'session, T> Stream for SessionCursorHandle<'cursor, 'session, T> +where + T: DeserializeOwned + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let next = Pin::new(&mut self.generic_cursor).poll_next(cx); + match next { + Poll::Ready(opt) => Poll::Ready( + opt.map(|result| result.and_then(|doc| from_document(doc).map_err(Into::into))), + ), + Poll::Pending => Poll::Pending, + } + } +} + +impl<'cursor, 'session, T> Drop for SessionCursorHandle<'cursor, 'session, T> +where + T: DeserializeOwned + Unpin, +{ fn drop(&mut self) { // Update the parent cursor's state based on any iteration performed on this handle. self.session_cursor.buffer = self.generic_cursor.take_buffer(); @@ -137,7 +197,7 @@ impl<'session> GetMoreProvider for ExplicitSessionGetMoreProvider<'session> { let future = Box::pin(async move { let get_more = GetMore::new(info); let get_more_result = client - .execute_operation_with_session(get_more, session.reference) + .execute_operation(get_more, Some(&mut *session.reference)) .await; ExecutionResult { get_more_result, diff --git a/src/db/mod.rs b/src/db/mod.rs index c02686cf5..b768be653 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -9,7 +9,7 @@ use crate::{ bson::{Bson, Document}, concern::{ReadConcern, WriteConcern}, cursor::Cursor, - error::{ErrorKind, Result}, + error::{Error, ErrorKind, Result}, operation::{Aggregate, Create, DropDatabase, ListCollections, RunCommand}, options::{ AggregateOptions, @@ -21,8 +21,10 @@ use crate::{ }, selection_criteria::SelectionCriteria, Client, + ClientSession, Collection, Namespace, + SessionCursor, }; /// `Database` is the client-side abstraction of a MongoDB database. It can be used to perform @@ -176,13 +178,33 @@ impl Database { Collection::new(self.clone(), name, Some(options)) } - /// Drops the database, deleting all data, collections, and indexes stored in it. - pub async fn drop(&self, options: impl Into>) -> Result<()> { + async fn drop_common( + &self, + options: impl Into>, + session: impl Into>, + ) -> Result<()> { let mut options = options.into(); resolve_options!(self, options, [write_concern]); let drop_database = DropDatabase::new(self.name().to_string(), options); - self.client().execute_operation(drop_database).await + self.client() + .execute_operation(drop_database, session) + .await + } + + /// Drops the database, deleting all data, collections, and indexes stored in it. + pub async fn drop(&self, options: impl Into>) -> Result<()> { + self.drop_common(options, None).await + } + + /// Drops the database, deleting all data, collections, and indexes stored in it using the + /// provided `ClientSession`. + pub async fn drop_with_session( + &self, + options: impl Into>, + session: &mut ClientSession, + ) -> Result<()> { + self.drop_common(options, session).await } /// Gets information about each of the collections in the database. The cursor will yield a @@ -204,19 +226,31 @@ impl Database { .map(|(spec, session)| Cursor::new(self.client().clone(), spec, session)) } - /// Gets the names of the collections in the database. - pub async fn list_collection_names( + /// Gets information about each of the collections in the database using the provided + /// `ClientSession`. The cursor will yield a document pertaining to each collection in the + /// database. + pub async fn list_collections_with_session( &self, filter: impl Into>, - ) -> Result> { - let list_collections = - ListCollections::new(self.name().to_string(), filter.into(), true, None); - let cursor: Cursor = self - .client() - .execute_cursor_operation(list_collections) + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + let list_collections = ListCollections::new( + self.name().to_string(), + filter.into(), + false, + options.into(), + ); + self.client() + .execute_operation(list_collections, session) .await - .map(|(spec, session)| Cursor::new(self.client().clone(), spec, session))?; + .map(|spec| SessionCursor::new(self.client().clone(), spec)) + } + async fn list_collection_names_common( + &self, + cursor: impl TryStreamExt, + ) -> Result> { cursor .and_then(|doc| match doc.get("name").and_then(Bson::as_str) { Some(name) => futures::future::ok(name.into()), @@ -232,14 +266,45 @@ impl Database { .await } - /// Creates a new collection in the database with the given `name` and `options`. - /// - /// Note that MongoDB creates collections implicitly when data is inserted, so this method is - /// not needed if no special options are required. - pub async fn create_collection( + /// Gets the names of the collections in the database. + pub async fn list_collection_names( + &self, + filter: impl Into>, + ) -> Result> { + let list_collections = + ListCollections::new(self.name().to_string(), filter.into(), true, None); + let cursor: Cursor = self + .client() + .execute_cursor_operation(list_collections) + .await + .map(|(spec, session)| Cursor::new(self.client().clone(), spec, session))?; + + self.list_collection_names_common(cursor).await + } + + /// Gets the names of the collections in the database using the provided `ClientSession`. + pub async fn list_collection_names_with_session( + &self, + filter: impl Into>, + session: &mut ClientSession, + ) -> Result> { + let list_collections = + ListCollections::new(self.name().to_string(), filter.into(), true, None); + let mut cursor: SessionCursor = self + .client() + .execute_operation(list_collections, Some(&mut *session)) + .await + .map(|spec| SessionCursor::new(self.client().clone(), spec))?; + + self.list_collection_names_common(cursor.with_session(session)) + .await + } + + async fn create_collection_common( &self, name: &str, options: impl Into>, + session: impl Into>, ) -> Result<()> { let mut options = options.into(); resolve_options!(self, options, [write_concern]); @@ -251,7 +316,43 @@ impl Database { }, options, ); - self.client().execute_operation(create).await + self.client().execute_operation(create, session).await + } + + /// Creates a new collection in the database with the given `name` and `options`. + /// + /// Note that MongoDB creates collections implicitly when data is inserted, so this method is + /// not needed if no special options are required. + pub async fn create_collection( + &self, + name: &str, + options: impl Into>, + ) -> Result<()> { + self.create_collection_common(name, options, None).await + } + + /// Creates a new collection in the database with the given `name` and `options` using the + /// provided `ClientSession`. + /// + /// Note that MongoDB creates collections implicitly when data is inserted, so this method is + /// not needed if no special options are required. + pub async fn create_collection_with_session( + &self, + name: &str, + options: impl Into>, + session: &mut ClientSession, + ) -> Result<()> { + self.create_collection_common(name, options, session).await + } + + async fn run_command_common( + &self, + command: Document, + selection_criteria: impl Into>, + session: impl Into>, + ) -> Result { + let operation = RunCommand::new(self.name().into(), command, selection_criteria.into())?; + self.client().execute_operation(operation, session).await } /// Runs a database-level command. @@ -264,8 +365,23 @@ impl Database { command: Document, selection_criteria: impl Into>, ) -> Result { - let operation = RunCommand::new(self.name().into(), command, selection_criteria.into())?; - self.client().execute_operation(operation).await + self.run_command_common(command, selection_criteria, None) + .await + } + + /// Runs a database-level command using the provided `ClientSession`. + /// + /// Note that no inspection is done on `doc`, so the command will not use the database's default + /// read concern or write concern. If specific read concern or write concern is desired, it must + /// be specified manually. + pub async fn run_command_with_session( + &self, + command: Document, + selection_criteria: impl Into>, + session: &mut ClientSession, + ) -> Result { + self.run_command_common(command, selection_criteria, session) + .await } /// Runs an aggregation operation. @@ -291,4 +407,29 @@ impl Database { .await .map(|(spec, session)| Cursor::new(client.clone(), spec, session)) } + + /// Runs an aggregation operation with the provided `ClientSession`. + /// + /// See the documentation [here](https://docs.mongodb.com/manual/aggregation/) for more + /// information on aggregations. + pub async fn aggregate_with_session( + &self, + pipeline: impl IntoIterator, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + let mut options = options.into(); + resolve_options!( + self, + options, + [read_concern, write_concern, selection_criteria] + ); + + let aggregate = Aggregate::new(self.name().to_string(), pipeline, options); + let client = self.client(); + client + .execute_operation(aggregate, session) + .await + .map(|spec| SessionCursor::new(client.clone(), spec)) + } } diff --git a/src/db/options.rs b/src/db/options.rs index 9f38765d3..e06c8c05b 100644 --- a/src/db/options.rs +++ b/src/db/options.rs @@ -12,7 +12,7 @@ use crate::{ /// These are the valid options for creating a [`Database`](../struct.Database.html) with /// [`Client::database_with_options`](../struct.Client.html#method.database_with_options). -#[derive(Clone, Debug, Default, TypedBuilder)] +#[derive(Clone, Debug, Default, Deserialize, TypedBuilder)] #[non_exhaustive] pub struct DatabaseOptions { /// The default read preference for operations. diff --git a/src/error.rs b/src/error.rs index 670031859..3c77c9b6b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -321,6 +321,10 @@ pub enum ErrorKind { #[non_exhaustive] SrvLookupError { message: String }, + /// The Client does not support sessions. + #[error("Attempted to start a session on a deployment that does not support sessions")] + SessionsNotSupported, + #[error("{0}")] RustlsConfig(#[from] rustls::TLSError), diff --git a/src/lib.rs b/src/lib.rs index 66ce94fab..4556e6823 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,7 +72,6 @@ #![warn(missing_docs)] #![warn(missing_crate_level_docs)] - #![cfg_attr( feature = "cargo-clippy", allow( @@ -135,7 +134,7 @@ define_if_single_runtime_enabled! { pub use crate::{ client::Client, coll::Collection, - cursor::Cursor, + cursor::{Cursor, session::{SessionCursor, SessionCursorHandle}}, db::Database, }; @@ -143,10 +142,12 @@ define_if_single_runtime_enabled! { pub(crate) use crate::{ client::Client, coll::Collection, - cursor::Cursor, + cursor::{Cursor, session::{SessionCursor, SessionCursorHandle}}, db::Database, }; + pub use client::session::ClientSession; + pub use coll::Namespace; } diff --git a/src/operation/insert/mod.rs b/src/operation/insert/mod.rs index dff0ebbe7..a5f0c60e5 100644 --- a/src/operation/insert/mod.rs +++ b/src/operation/insert/mod.rs @@ -6,12 +6,12 @@ use std::collections::HashMap; use crate::{ bson::{doc, Document}, bson_util, - client::ClientSession, cmap::{Command, CommandResponse, StreamDescription}, error::{ErrorKind, Result}, operation::{append_options, Operation, Retryability, WriteResponseBody}, options::{InsertManyOptions, WriteConcern}, results::InsertManyResult, + ClientSession, Namespace, }; diff --git a/src/sdam/description/topology/mod.rs b/src/sdam/description/topology/mod.rs index 94a859b9d..dc2e9cb8b 100644 --- a/src/sdam/description/topology/mod.rs +++ b/src/sdam/description/topology/mod.rs @@ -288,16 +288,6 @@ impl TopologyDescription { return; } - if server_description.server_type == ServerType::Standalone { - self.session_support_status = SessionSupportStatus::Unsupported { - logical_session_timeout: server_description - .logical_session_timeout() - .ok() - .flatten(), - }; - return; - } - match server_description.logical_session_timeout().ok().flatten() { Some(timeout) => match self.session_support_status { SessionSupportStatus::Supported { diff --git a/src/sync/client.rs b/src/sync/client.rs index bc3a4b320..c5ad8e9c7 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -3,8 +3,15 @@ use crate::{ bson::Document, concern::{ReadConcern, WriteConcern}, error::Result, - options::{ClientOptions, DatabaseOptions, ListDatabasesOptions, SelectionCriteria}, + options::{ + ClientOptions, + DatabaseOptions, + ListDatabasesOptions, + SelectionCriteria, + SessionOptions, + }, Client as AsyncClient, + ClientSession, RUNTIME, }; @@ -127,4 +134,9 @@ impl Client { .list_database_names(filter.into(), options.into()), ) } + + /// Starts a new `ClientSession`. + pub fn start_session(&self, options: Option) -> Result { + RUNTIME.block_on(self.async_client.start_session(options)) + } } diff --git a/src/sync/coll.rs b/src/sync/coll.rs index d2f35f269..df19a0067 100644 --- a/src/sync/coll.rs +++ b/src/sync/coll.rs @@ -5,7 +5,7 @@ use std::{ use serde::{de::DeserializeOwned, Serialize}; -use super::Cursor; +use super::{Cursor, SessionCursor}; use crate::{ bson::{Bson, Document}, error::Result, @@ -31,6 +31,7 @@ use crate::{ WriteConcern, }, results::{DeleteResult, InsertManyResult, InsertOneResult, UpdateResult}, + ClientSession, Collection as AsyncCollection, Namespace, RUNTIME, @@ -131,6 +132,19 @@ where RUNTIME.block_on(self.async_collection.drop(options.into())) } + /// Drops the collection, deleting all data, users, and indexes stored in it using the provided + /// `ClientSession`. + pub fn drop_with_session( + &self, + options: impl Into>, + session: &mut ClientSession, + ) -> Result<()> { + RUNTIME.block_on( + self.async_collection + .drop_with_session(options.into(), session), + ) + } + /// Runs an aggregation operation. /// /// See the documentation [here](https://docs.mongodb.com/manual/aggregation/) for more @@ -146,6 +160,26 @@ where .map(Cursor::new) } + /// Runs an aggregation operation using the provided `ClientSession`. + /// + /// See the documentation [here](https://docs.mongodb.com/manual/aggregation/) for more + /// information on aggregations. + pub fn aggregate_with_session( + &self, + pipeline: impl IntoIterator, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + let pipeline: Vec = pipeline.into_iter().collect(); + RUNTIME + .block_on(self.async_collection.aggregate_with_session( + pipeline, + options.into(), + session, + )) + .map(SessionCursor::new) + } + /// Estimates the number of documents in the collection using collection metadata. pub fn estimated_document_count( &self, @@ -172,6 +206,23 @@ where ) } + /// Gets the number of documents matching `filter` using the provided `ClientSession`. + /// + /// Note that using [`Collection::estimated_document_count`](#method.estimated_document_count) + /// is recommended instead of this method is most cases. + pub fn count_documents_with_session( + &self, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_collection.count_documents_with_session( + filter.into(), + options.into(), + session, + )) + } + /// Deletes all documents stored in the collection matching `query`. pub fn delete_many( &self, @@ -181,6 +232,21 @@ where RUNTIME.block_on(self.async_collection.delete_many(query, options.into())) } + /// Deletes all documents stored in the collection matching `query` using the provided + /// `ClientSession`. + pub fn delete_many_with_session( + &self, + query: Document, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_collection.delete_many_with_session( + query, + options.into(), + session, + )) + } + /// Deletes up to one document found matching `query`. /// /// This operation will retry once upon failure if the connection and encountered error support @@ -195,6 +261,25 @@ where RUNTIME.block_on(self.async_collection.delete_one(query, options.into())) } + /// Deletes up to one document found matching `query` using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn delete_one_with_session( + &self, + query: Document, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_collection.delete_one_with_session( + query, + options.into(), + session, + )) + } + /// Finds the distinct values of the field specified by `field_name` across the collection. pub fn distinct( &self, @@ -208,6 +293,23 @@ where ) } + /// Finds the distinct values of the field specified by `field_name` across the collection using + /// the provided `ClientSession`. + pub fn distinct_with_session( + &self, + field_name: &str, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + RUNTIME.block_on(self.async_collection.distinct_with_session( + field_name, + filter.into(), + options.into(), + session, + )) + } + /// Finds the documents in the collection matching `filter`. pub fn find( &self, @@ -219,6 +321,22 @@ where .map(Cursor::new) } + /// Finds the documents in the collection matching `filter` using the provided `ClientSession`. + pub fn find_with_session( + &self, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + RUNTIME + .block_on(self.async_collection.find_with_session( + filter.into(), + options.into(), + session, + )) + .map(SessionCursor::new) + } + /// Finds a single document in the collection matching `filter`. pub fn find_one( &self, @@ -231,6 +349,21 @@ where ) } + /// Finds a single document in the collection matching `filter` using the provided + /// `ClientSession`. + pub fn find_one_with_session( + &self, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + RUNTIME.block_on(self.async_collection.find_one_with_session( + filter.into(), + options.into(), + session, + )) + } + /// Atomically finds up to one document in the collection matching `filter` and deletes it. /// /// This operation will retry once upon failure if the connection and encountered error support @@ -248,6 +381,26 @@ where ) } + /// Atomically finds up to one document in the collection matching `filter` and deletes it using + /// the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn find_one_and_delete_with_session( + &self, + filter: Document, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + RUNTIME.block_on(self.async_collection.find_one_and_delete_with_session( + filter, + options.into(), + session, + )) + } + /// Atomically finds up to one document in the collection matching `filter` and replaces it with /// `replacement`. /// @@ -268,6 +421,28 @@ where )) } + /// Atomically finds up to one document in the collection matching `filter` and replaces it with + /// `replacement` using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn find_one_and_replace_with_session( + &self, + filter: Document, + replacement: T, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + RUNTIME.block_on(self.async_collection.find_one_and_replace_with_session( + filter, + replacement, + options.into(), + session, + )) + } + /// Atomically finds up to one document in the collection matching `filter` and updates it. /// Both `Document` and `Vec` implement `Into`, so either can be /// passed in place of constructing the enum case. Note: pipeline updates are only supported @@ -290,6 +465,30 @@ where )) } + /// Atomically finds up to one document in the collection matching `filter` and updates it using + /// the provided `ClientSession`. Both `Document` and `Vec` implement + /// `Into`, so either can be passed in place of constructing the enum + /// case. Note: pipeline updates are only supported in MongoDB 4.2+. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn find_one_and_update_with_session( + &self, + filter: Document, + update: impl Into, + options: impl Into>, + session: &mut ClientSession, + ) -> Result> { + RUNTIME.block_on(self.async_collection.find_one_and_update_with_session( + filter, + update.into(), + options.into(), + session, + )) + } + /// Inserts the documents in `docs` into the collection. /// /// This operation will retry once upon failure if the connection and encountered error support @@ -305,6 +504,26 @@ where RUNTIME.block_on(self.async_collection.insert_many(docs, options.into())) } + /// Inserts the documents in `docs` into the collection using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn insert_many_with_session( + &self, + docs: impl IntoIterator, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + let docs: Vec = docs.into_iter().collect(); + RUNTIME.block_on(self.async_collection.insert_many_with_session( + docs, + options.into(), + session, + )) + } + /// Inserts `doc` into the collection. /// /// This operation will retry once upon failure if the connection and encountered error support @@ -319,6 +538,25 @@ where RUNTIME.block_on(self.async_collection.insert_one(doc, options.into())) } + /// Inserts `doc` into the collection using the provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn insert_one_with_session( + &self, + doc: T, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_collection.insert_one_with_session( + doc, + options.into(), + session, + )) + } + /// Replaces up to one document matching `query` in the collection with `replacement`. /// /// This operation will retry once upon failure if the connection and encountered error support @@ -337,6 +575,28 @@ where ) } + /// Replaces up to one document matching `query` in the collection with `replacement` using the + /// provided `ClientSession`. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn replace_one_with_session( + &self, + query: Document, + replacement: T, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_collection.replace_one_with_session( + query, + replacement, + options.into(), + session, + )) + } + /// Updates all documents matching `query` in the collection. /// /// Both `Document` and `Vec` implement `Into`, so either can be @@ -355,6 +615,27 @@ where ) } + /// Updates all documents matching `query` in the collection using the provided `ClientSession`. + /// + /// Both `Document` and `Vec` implement `Into`, so either can be + /// passed in place of constructing the enum case. Note: pipeline updates are only supported + /// in MongoDB 4.2+. See the official MongoDB + /// [documentation](https://docs.mongodb.com/manual/reference/command/update/#behavior) for more information on specifying updates. + pub fn update_many_with_session( + &self, + query: Document, + update: impl Into, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_collection.update_many_with_session( + query, + update.into(), + options.into(), + session, + )) + } + /// Updates up to one document matching `query` in the collection. /// /// Both `Document` and `Vec` implement `Into`, so either can be @@ -377,4 +658,31 @@ where .update_one(query, update.into(), options.into()), ) } + + /// Updates up to one document matching `query` in the collection using the provided + /// `ClientSession`. + /// + /// Both `Document` and `Vec` implement `Into`, so either can be + /// passed in place of constructing the enum case. Note: pipeline updates are only supported + /// in MongoDB 4.2+. See the official MongoDB + /// [documentation](https://docs.mongodb.com/manual/reference/command/update/#behavior) for more information on specifying updates. + /// + /// This operation will retry once upon failure if the connection and encountered error support + /// retryability. See the documentation + /// [here](https://docs.mongodb.com/manual/core/retryable-writes/) for more information on + /// retryable writes. + pub fn update_one_with_session( + &self, + query: Document, + update: impl Into, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_collection.update_one_with_session( + query, + update.into(), + options.into(), + session, + )) + } } diff --git a/src/sync/cursor.rs b/src/sync/cursor.rs index d511d7fa0..5e1218df1 100644 --- a/src/sync/cursor.rs +++ b/src/sync/cursor.rs @@ -1,7 +1,15 @@ use futures::StreamExt; use serde::de::DeserializeOwned; -use crate::{bson::Document, error::Result, Cursor as AsyncCursor, RUNTIME}; +use crate::{ + bson::Document, + error::Result, + ClientSession, + Cursor as AsyncCursor, + SessionCursor as AsyncSessionCursor, + SessionCursorHandle as AsyncSessionCursorHandle, + RUNTIME, +}; /// A `Cursor` streams the result of a query. When a query is made, a `Cursor` will be returned with /// the first batch of results from the server; the documents will be returned as the `Cursor` is @@ -87,3 +95,72 @@ where RUNTIME.block_on(self.async_cursor.next()) } } + +/// A `SessionCursor` is a cursor that was created with a `ClientSession` must be iterated using +/// one. To iterate, retrieve a `SessionCursorHandle` using `SessionCursor::with_session`: +/// +/// ```rust +/// # use mongodb::{sync::Client, error::Result}; +/// # +/// # fn do_stuff() -> Result<()> { +/// # let client = Client::with_uri_str("mongodb://example.com")?; +/// # let mut session = client.start_session(None)?; +/// # let coll = client.database("foo").collection("bar"); +/// # let mut cursor = coll.find_with_session(None, None, &mut session)?; +/// # +/// for doc in cursor.with_session(&mut session) { +/// println!("{}", doc?) +/// } +/// # +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct SessionCursor +where + T: DeserializeOwned + Unpin + Send, +{ + async_cursor: AsyncSessionCursor, +} + +impl SessionCursor +where + T: DeserializeOwned + Unpin + Send, +{ + pub(crate) fn new(async_cursor: AsyncSessionCursor) -> Self { + Self { async_cursor } + } + + /// Retrieves a `SessionCursorHandle` to iterate this cursor. The session provided must be the + /// same session used to create the cursor. + pub fn with_session<'session>( + &mut self, + session: &'session mut ClientSession, + ) -> SessionCursorHandle<'_, 'session, T> { + SessionCursorHandle { + async_handle: self.async_cursor.with_session(session), + } + } +} + +/// A handle that borrows a `ClientSession` temporarily for executing getMores or iterating through +/// the current buffer of a `SessionCursor`. +/// +/// This updates the buffer of the parent `SessionCursor` when dropped. +pub struct SessionCursorHandle<'cursor, 'session, T = Document> +where + T: DeserializeOwned + Unpin + Send, +{ + async_handle: AsyncSessionCursorHandle<'cursor, 'session, T>, +} + +impl Iterator for SessionCursorHandle<'_, '_, T> +where + T: DeserializeOwned + Unpin + Send, +{ + type Item = Result; + + fn next(&mut self) -> Option { + RUNTIME.block_on(self.async_handle.next()) + } +} diff --git a/src/sync/db.rs b/src/sync/db.rs index c67705980..336f00deb 100644 --- a/src/sync/db.rs +++ b/src/sync/db.rs @@ -5,9 +5,10 @@ use std::{ use serde::{de::DeserializeOwned, Serialize}; -use super::{Collection, Cursor}; +use super::{Collection, Cursor, SessionCursor}; use crate::{ bson::Document, + client::session::ClientSession, error::Result, options::{ AggregateOptions, @@ -142,6 +143,19 @@ impl Database { RUNTIME.block_on(self.async_database.drop(options.into())) } + /// Drops the database, deleting all data, collections, users, and indexes stored in it using + /// the provided `ClientSession`. + pub fn drop_with_session( + &self, + options: impl Into>, + session: &mut ClientSession, + ) -> Result<()> { + RUNTIME.block_on( + self.async_database + .drop_with_session(options.into(), session), + ) + } + /// Gets information about each of the collections in the database. The cursor will yield a /// document pertaining to each collection in the database. pub fn list_collections( @@ -157,6 +171,24 @@ impl Database { .map(Cursor::new) } + /// Gets information about each of the collections in the database using the provided + /// `ClientSession`. The cursor will yield a document pertaining to each collection in the + /// database. + pub fn list_collections_with_session( + &self, + filter: impl Into>, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME + .block_on(self.async_database.list_collections_with_session( + filter.into(), + options.into(), + session, + )) + .map(SessionCursor::new) + } + /// Gets the names of the collections in the database. pub fn list_collection_names( &self, @@ -165,6 +197,18 @@ impl Database { RUNTIME.block_on(self.async_database.list_collection_names(filter.into())) } + /// Gets the names of the collections in the database using the provided `ClientSession`. + pub fn list_collection_names_with_session( + &self, + filter: impl Into>, + session: &mut ClientSession, + ) -> Result> { + RUNTIME.block_on( + self.async_database + .list_collection_names_with_session(filter.into(), session), + ) + } + /// Creates a new collection in the database with the given `name` and `options`. /// /// Note that MongoDB creates collections implicitly when data is inserted, so this method is @@ -177,6 +221,24 @@ impl Database { RUNTIME.block_on(self.async_database.create_collection(name, options.into())) } + /// Creates a new collection in the database with the given `name` and `options` using the + /// provided `ClientSession`. + /// + /// Note that MongoDB creates collections implicitly when data is inserted, so this method is + /// not needed if no special options are required. + pub fn create_collection_with_session( + &self, + name: &str, + options: impl Into>, + session: &mut ClientSession, + ) -> Result<()> { + RUNTIME.block_on(self.async_database.create_collection_with_session( + name, + options.into(), + session, + )) + } + /// Runs a database-level command. /// /// Note that no inspection is done on `doc`, so the command will not use the database's default @@ -193,6 +255,24 @@ impl Database { ) } + /// Runs a database-level command using the provided `ClientSession`. + /// + /// Note that no inspection is done on `doc`, so the command will not use the database's default + /// read concern or write concern. If specific read concern or write concern is desired, it must + /// be specified manually. + pub fn run_command_with_session( + &self, + command: Document, + selection_criteria: impl Into>, + session: &mut ClientSession, + ) -> Result { + RUNTIME.block_on(self.async_database.run_command_with_session( + command, + selection_criteria.into(), + session, + )) + } + /// Runs an aggregation operation. /// /// See the documentation [here](https://docs.mongodb.com/manual/aggregation/) for more @@ -207,4 +287,23 @@ impl Database { .block_on(self.async_database.aggregate(pipeline, options.into())) .map(Cursor::new) } + + /// Runs an aggregation operation using the provided `ClientSession`. + /// + /// See the documentation [here](https://docs.mongodb.com/manual/aggregation/) for more + /// information on aggregations. + pub fn aggregate_with_session( + &self, + pipeline: impl IntoIterator, + options: impl Into>, + session: &mut ClientSession, + ) -> Result { + let pipeline: Vec = pipeline.into_iter().collect(); + RUNTIME + .block_on( + self.async_database + .aggregate_with_session(pipeline, options.into(), session), + ) + .map(SessionCursor::new) + } } diff --git a/src/sync/mod.rs b/src/sync/mod.rs index ca6a4c2dc..fd7350d91 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -10,5 +10,5 @@ mod test; pub use client::Client; pub use coll::Collection; -pub use cursor::Cursor; +pub use cursor::{Cursor, SessionCursor}; pub use db::Database; diff --git a/src/test/mod.rs b/src/test/mod.rs index c1e67e570..0be6d661b 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -10,15 +10,7 @@ mod spec; mod util; pub(crate) use self::{ - spec::{ - run_spec_test, - run_v2_test, - AnyTestOperation, - OperationObject, - RunOn, - TestEvent, - Topology, - }, + spec::{run_spec_test, RunOn, Topology}, util::{ assert_matches, CmapEvent, diff --git a/src/test/spec/crud_v2.rs b/src/test/spec/crud_v2.rs index fd0c7b85f..3b1de7e97 100644 --- a/src/test/spec/crud_v2.rs +++ b/src/test/spec/crud_v2.rs @@ -1,6 +1,8 @@ use tokio::sync::RwLockWriteGuard; -use crate::test::{run_spec_test, run_v2_test, LOCK}; +use crate::test::{run_spec_test, LOCK}; + +use super::run_v2_test; #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] diff --git a/src/test/spec/mod.rs b/src/test/spec/mod.rs index d2b114f70..49fb022c6 100644 --- a/src/test/spec/mod.rs +++ b/src/test/spec/mod.rs @@ -11,8 +11,9 @@ mod ocsp; mod read_write_concern; mod retryable_reads; mod retryable_writes; -mod runner; +mod sessions; mod unified_runner; +mod v2_runner; use std::{ convert::TryFrom, @@ -23,16 +24,8 @@ use std::{ }; pub use self::{ - runner::{ - run_v2_test, - AnyTestOperation, - OperationObject, - RunOn, - TestData, - TestEvent, - TestFile, - }, unified_runner::Topology, + v2_runner::{operation::Operation, run_v2_test, test_file::RunOn}, }; use serde::de::DeserializeOwned; diff --git a/src/test/spec/retryable_reads.rs b/src/test/spec/retryable_reads.rs index 7ff37e61a..83bd26292 100644 --- a/src/test/spec/retryable_reads.rs +++ b/src/test/spec/retryable_reads.rs @@ -1,6 +1,8 @@ use tokio::sync::RwLockWriteGuard; -use crate::test::{run_spec_test, run_v2_test, LOCK}; +use crate::test::{run_spec_test, LOCK}; + +use super::run_v2_test; #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::test)] diff --git a/src/test/spec/retryable_writes/mod.rs b/src/test/spec/retryable_writes/mod.rs index 2f3cc5f4b..8eae651c3 100644 --- a/src/test/spec/retryable_writes/mod.rs +++ b/src/test/spec/retryable_writes/mod.rs @@ -16,7 +16,7 @@ use crate::{ test::{ assert_matches, run_spec_test, - util::get_db_name, + util::get_default_name, EventClient, TestClient, CLIENT_OPTIONS, @@ -54,7 +54,7 @@ async fn run_spec_tests() { } } - let db_name = get_db_name(&test_case.description); + let db_name = get_default_name(&test_case.description); let coll_name = "coll"; let write_concern = WriteConcern::builder().w(Acknowledgment::Majority).build(); @@ -81,14 +81,10 @@ async fn run_spec_tests() { .unwrap(); } - let result = client - .run_collection_operation( - &test_case.operation, - &db_name, - &coll_name, - Some(options.clone()), - ) - .await; + let coll = client + .database(&db_name) + .collection_with_options(&coll_name, options.clone()); + let result = test_case.operation.execute_on_collection(&coll, None).await; if let Some(error) = test_case.outcome.error { assert_eq!( diff --git a/src/test/spec/retryable_writes/test_file.rs b/src/test/spec/retryable_writes/test_file.rs index f8c42ee56..17335607b 100644 --- a/src/test/spec/retryable_writes/test_file.rs +++ b/src/test/spec/retryable_writes/test_file.rs @@ -1,6 +1,6 @@ use serde::Deserialize; -use super::super::{AnyTestOperation, RunOn}; +use super::super::{Operation, RunOn}; use crate::{ bson::{Bson, Document}, options::ClientOptions, @@ -21,7 +21,7 @@ pub struct TestCase { pub client_options: Option, pub use_multiple_mongoses: Option, pub fail_point: Option, - pub operation: AnyTestOperation, + pub operation: Operation, pub outcome: Outcome, } diff --git a/src/test/spec/runner/mod.rs b/src/test/spec/runner/mod.rs deleted file mode 100644 index a83b39ddc..000000000 --- a/src/test/spec/runner/mod.rs +++ /dev/null @@ -1,226 +0,0 @@ -mod operation; -mod test_event; -mod test_file; - -use std::time::Duration; - -use crate::{ - bson::doc, - concern::{Acknowledgment, WriteConcern}, - operation::RunCommand, - options::CollectionOptions, - test::{ - assert_matches, - util::{get_db_name, EventClient}, - CLIENT_OPTIONS, - }, -}; - -pub use self::{ - operation::AnyTestOperation, - test_event::TestEvent, - test_file::{OperationObject, RunOn, TestCase, TestData, TestFile}, -}; - -const SKIPPED_OPERATIONS: &[&str] = &[ - "bulkWrite", - "count", - "download", - "download_by_name", - "listCollectionObjects", - "listDatabaseObjects", - "listIndexNames", - "listIndexes", - "mapReduce", - "watch", -]; - -pub async fn run_v2_test(test_file: TestFile) { - for test_case in test_file.tests { - let has_skipped_op = test_case - .operations - .iter() - .any(|op| SKIPPED_OPERATIONS.contains(&op.name.as_str())); - if has_skipped_op { - continue; - } - - if let Some(skip_reason) = test_case.skip_reason { - println!("Skipping {}: {}", test_case.description, skip_reason); - continue; - } - - println!("executing {}", test_case.description); - - let options = test_case.client_options.map(|mut opts| { - opts.hosts = CLIENT_OPTIONS.hosts.clone(); - opts - }); - let client = EventClient::with_additional_options( - options, - Some(Duration::from_millis(50)), - test_case.use_multiple_mongoses, - None, - true, - ) - .await; - - if let Some(ref run_on) = test_file.run_on { - let can_run_on = run_on.iter().any(|run_on| run_on.can_run_on(&client)); - if !can_run_on { - println!("Skipping {}", test_case.description); - continue; - } - } - - let db_name = match test_file.database_name { - Some(ref db_name) => db_name.clone(), - None => get_db_name(&test_case.description), - }; - - let coll_name = match test_file.collection_name { - Some(ref coll_name) => coll_name.clone(), - None => "coll".to_string(), - }; - - if test_case - .description - .contains("Aggregate with $listLocalSessions") - { - // TODO DRIVERS-1230: This test currently fails on 3.6 standalones because the session - // does not attach to the server ping. When the driver is updated to send implicit - // sessions to standalones, this test should be unskipped. - let req = semver::VersionReq::parse("<= 3.6").unwrap(); - if req.matches(&client.server_version.as_ref().unwrap()) && client.is_standalone() { - continue; - } - start_session(&client, &db_name).await; - } - - if let Some(ref data) = test_file.data { - match data { - TestData::Single(data) => { - if !data.is_empty() { - let coll = if client.is_replica_set() || client.is_sharded() { - let write_concern = - WriteConcern::builder().w(Acknowledgment::Majority).build(); - let options = CollectionOptions::builder() - .write_concern(write_concern) - .build(); - client - .init_db_and_coll_with_options(&db_name, &coll_name, options) - .await - } else { - client.init_db_and_coll(&db_name, &coll_name).await - }; - coll.insert_many(data.clone(), None) - .await - .expect(&test_case.description); - } - } - TestData::Many(_) => panic!("{}: invalid data format", &test_case.description), - } - } - - if let Some(ref fail_point) = test_case.fail_point { - client - .database("admin") - .run_command(fail_point.clone(), None) - .await - .unwrap(); - } - - let mut events: Vec = Vec::new(); - for operation in test_case.operations { - let result = match operation.object { - Some(OperationObject::Client) => client.run_client_operation(&operation).await, - Some(OperationObject::Database) => { - client.run_database_operation(&operation, &db_name).await - } - Some(OperationObject::Collection) | None => { - client - .run_collection_operation( - &operation, - &db_name, - &coll_name, - operation.collection_options.clone(), - ) - .await - } - Some(OperationObject::GridfsBucket) => { - panic!("unsupported operation: {}", operation.name) - } - }; - let mut operation_events: Vec = client - .get_command_started_events(operation.command_names()) - .into_iter() - .map(Into::into) - .collect(); - - if let Some(error) = operation.error { - assert_eq!( - result.is_err(), - error, - "{}: expected error: {}, got {:?}", - test_case.description, - error, - result - ); - } - - if let Some(expected_result) = operation.result { - let description = &test_case.description; - let result = result - .unwrap() - .unwrap_or_else(|| panic!("{:?}: operation should succeed", description)); - assert_matches(&result, &expected_result, Some(description)); - } - - events.append(&mut operation_events); - } - - if let Some(expectations) = test_case.expectations { - assert!( - events.len() >= expectations.len(), - "{}", - test_case.description - ); - for (actual_event, expected_event) in events.iter().zip(expectations.iter()) { - assert_matches( - actual_event, - expected_event, - Some(test_case.description.as_str()), - ); - } - } - - if let Some(outcome) = test_case.outcome { - assert!(outcome.matches_actual(db_name, coll_name, &client).await); - } - - if test_case.fail_point.is_some() { - client - .database("admin") - .run_command( - doc! { - "configureFailPoint": "failCommand", - "mode": "off" - }, - None, - ) - .await - .unwrap(); - } - } - - async fn start_session(client: &EventClient, db_name: &str) { - let mut session = client - .start_implicit_session_with_timeout(Duration::from_secs(60 * 60)) - .await; - let op = RunCommand::new(db_name.to_string(), doc! { "ping": 1 }, None).unwrap(); - client - .execute_operation_with_session(op, &mut session) - .await - .unwrap(); - } -} diff --git a/src/test/spec/runner/operation.rs b/src/test/spec/runner/operation.rs deleted file mode 100644 index 15a5cb199..000000000 --- a/src/test/spec/runner/operation.rs +++ /dev/null @@ -1,845 +0,0 @@ -use std::{collections::HashMap, fmt::Debug, ops::Deref}; - -use async_trait::async_trait; -use futures::stream::TryStreamExt; -use serde::{ - de::{self, Deserializer}, - Deserialize, -}; - -use crate::{ - bson::{doc, Bson, Deserializer as BsonDeserializer, Document}, - error::Result, - options::{ - AggregateOptions, - CollectionOptions, - CountOptions, - DeleteOptions, - DistinctOptions, - EstimatedDocumentCountOptions, - FindOneAndDeleteOptions, - FindOneAndReplaceOptions, - FindOneAndUpdateOptions, - FindOneOptions, - FindOptions, - InsertManyOptions, - InsertOneOptions, - ListCollectionsOptions, - ListDatabasesOptions, - ReplaceOptions, - UpdateModifications, - UpdateOptions, - }, - test::{util::EventClient, OperationObject}, - Collection, - Database, -}; - -#[async_trait] -pub trait TestOperation: Debug { - /// The command names to monitor as part of this test. - fn command_names(&self) -> &[&str]; - - // The linked issue causes a warning that cannot be suppressed when providing a default - // implementation for these functions. - // - async fn execute_on_collection(&self, collection: &Collection) -> Result>; - - async fn execute_on_client(&self, client: &EventClient) -> Result>; - - async fn execute_on_database(&self, database: &Database) -> Result>; -} - -#[derive(Debug)] -pub struct AnyTestOperation { - operation: Box, - pub name: String, - pub object: Option, - pub result: Option, - pub error: Option, - pub collection_options: Option, -} - -impl<'de> Deserialize<'de> for AnyTestOperation { - fn deserialize>(deserializer: D) -> std::result::Result { - #[derive(Debug, Deserialize)] - #[serde(rename_all = "camelCase")] - struct OperationDefinition { - name: String, - #[serde(default = "default_arguments")] - arguments: Bson, - object: Option, - result: Option, - error: Option, - collection_options: Option, - } - - let definition = OperationDefinition::deserialize(deserializer)?; - let boxed_op = match definition.name.as_str() { - "insertOne" => InsertOne::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "insertMany" => InsertMany::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "updateOne" => UpdateOne::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "updateMany" => UpdateMany::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "deleteMany" => DeleteMany::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "deleteOne" => DeleteOne::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "find" => Find::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "aggregate" => Aggregate::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "distinct" => Distinct::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "countDocuments" => { - CountDocuments::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "estimatedDocumentCount" => { - EstimatedDocumentCount::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "findOne" => FindOne::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "listDatabases" => { - ListDatabases::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "listDatabaseNames" => { - ListDatabaseNames::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "listCollections" => { - ListCollections::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "listCollectionNames" => { - ListCollectionNames::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "replaceOne" => ReplaceOne::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box), - "findOneAndUpdate" => { - FindOneAndUpdate::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "findOneAndReplace" => { - FindOneAndReplace::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - "findOneAndDelete" => { - FindOneAndDelete::deserialize(BsonDeserializer::new(definition.arguments)) - .map(|op| Box::new(op) as Box) - } - _ => Ok(Box::new(UnimplementedOperation) as Box), - } - .map_err(|e| de::Error::custom(format!("{}", e)))?; - - Ok(AnyTestOperation { - operation: boxed_op, - name: definition.name, - object: definition.object, - result: definition.result, - error: definition.error, - collection_options: definition.collection_options, - }) - } -} - -fn default_arguments() -> Bson { - Bson::Document(doc! {}) -} - -impl Deref for AnyTestOperation { - type Target = Box; - - fn deref(&self) -> &Box { - &self.operation - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct DeleteMany { - filter: Document, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for DeleteMany { - fn command_names(&self) -> &[&str] { - &["delete"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .delete_many(self.filter.clone(), self.options.clone()) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct DeleteOne { - filter: Document, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for DeleteOne { - fn command_names(&self) -> &[&str] { - &["delete"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .delete_one(self.filter.clone(), self.options.clone()) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Default, Deserialize)] -pub(super) struct Find { - filter: Option, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for Find { - fn command_names(&self) -> &[&str] { - &["find", "getMore"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let cursor = collection - .find(self.filter.clone(), self.options.clone()) - .await?; - let result = cursor.try_collect::>().await?; - Ok(Some(Bson::from(result))) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct InsertMany { - documents: Vec, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for InsertMany { - fn command_names(&self) -> &[&str] { - &["insert"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .insert_many(self.documents.clone(), self.options.clone()) - .await?; - let ids: HashMap = result - .inserted_ids - .into_iter() - .map(|(k, v)| (k.to_string(), v)) - .collect(); - let ids = bson::to_bson(&ids)?; - Ok(Some(Bson::from(doc! { "insertedIds": ids }))) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct InsertOne { - document: Document, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for InsertOne { - fn command_names(&self) -> &[&str] { - &["insert"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .insert_one(self.document.clone(), self.options.clone()) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct UpdateMany { - filter: Document, - update: UpdateModifications, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for UpdateMany { - fn command_names(&self) -> &[&str] { - &["update"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .update_many( - self.filter.clone(), - self.update.clone(), - self.options.clone(), - ) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct UpdateOne { - filter: Document, - update: UpdateModifications, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for UpdateOne { - fn command_names(&self) -> &[&str] { - &["update"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .update_one( - self.filter.clone(), - self.update.clone(), - self.options.clone(), - ) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(super) struct Aggregate { - pipeline: Vec, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for Aggregate { - fn command_names(&self) -> &[&str] { - &["aggregate"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let cursor = collection - .aggregate(self.pipeline.clone(), self.options.clone()) - .await?; - let result = cursor.try_collect::>().await?; - Ok(Some(Bson::from(result))) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, database: &Database) -> Result> { - let cursor = database - .aggregate(self.pipeline.clone(), self.options.clone()) - .await?; - let result = cursor.try_collect::>().await?; - Ok(Some(Bson::from(result))) - } -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(super) struct Distinct { - field_name: String, - filter: Option, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for Distinct { - fn command_names(&self) -> &[&str] { - &["distinct"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .distinct(&self.field_name, self.filter.clone(), self.options.clone()) - .await?; - Ok(Some(Bson::Array(result))) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct CountDocuments { - filter: Document, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for CountDocuments { - fn command_names(&self) -> &[&str] { - &["aggregate"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .count_documents(self.filter.clone(), self.options.clone()) - .await?; - Ok(Some(Bson::from(result))) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct EstimatedDocumentCount { - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for EstimatedDocumentCount { - fn command_names(&self) -> &[&str] { - &["count"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .estimated_document_count(self.options.clone()) - .await?; - Ok(Some(Bson::from(result))) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Default, Deserialize)] -pub(super) struct FindOne { - filter: Option, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for FindOne { - fn command_names(&self) -> &[&str] { - &["find"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .find_one(self.filter.clone(), self.options.clone()) - .await?; - match result { - Some(result) => Ok(Some(Bson::from(result))), - None => Ok(None), - } - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct ListDatabases { - filter: Option, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for ListDatabases { - fn command_names(&self) -> &[&str] { - &["listDatabases"] - } - - async fn execute_on_collection(&self, _collection: &Collection) -> Result> { - unimplemented!() - } - - async fn execute_on_client(&self, client: &EventClient) -> Result> { - let result = client - .list_databases(self.filter.clone(), self.options.clone()) - .await?; - let result: Vec = result.iter().map(Bson::from).collect(); - Ok(Some(Bson::Array(result))) - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct ListDatabaseNames { - filter: Option, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for ListDatabaseNames { - fn command_names(&self) -> &[&str] { - &["listDatabases"] - } - - async fn execute_on_collection(&self, _collection: &Collection) -> Result> { - unimplemented!() - } - - async fn execute_on_client(&self, client: &EventClient) -> Result> { - let result = client - .list_database_names(self.filter.clone(), self.options.clone()) - .await?; - let result: Vec = result.iter().map(|s| Bson::String(s.to_string())).collect(); - Ok(Some(Bson::Array(result))) - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct ListCollections { - filter: Option, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for ListCollections { - fn command_names(&self) -> &[&str] { - &["listCollections"] - } - - async fn execute_on_collection(&self, _collection: &Collection) -> Result> { - unimplemented!() - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, database: &Database) -> Result> { - let cursor = database - .list_collections(self.filter.clone(), self.options.clone()) - .await?; - let result = cursor.try_collect::>().await?; - Ok(Some(Bson::from(result))) - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct ListCollectionNames { - filter: Option, -} - -#[async_trait] -impl TestOperation for ListCollectionNames { - fn command_names(&self) -> &[&str] { - &["listCollections"] - } - - async fn execute_on_collection(&self, _collection: &Collection) -> Result> { - unimplemented!() - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, database: &Database) -> Result> { - let result = database.list_collection_names(self.filter.clone()).await?; - let result: Vec = result.iter().map(|s| Bson::String(s.to_string())).collect(); - Ok(Some(Bson::from(result))) - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct ReplaceOne { - filter: Document, - replacement: Document, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for ReplaceOne { - fn command_names(&self) -> &[&str] { - &["update"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .replace_one( - self.filter.clone(), - self.replacement.clone(), - self.options.clone(), - ) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct FindOneAndUpdate { - filter: Document, - update: UpdateModifications, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for FindOneAndUpdate { - fn command_names(&self) -> &[&str] { - &["findAndModify"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .find_one_and_update( - self.filter.clone(), - self.update.clone(), - self.options.clone(), - ) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct FindOneAndReplace { - filter: Document, - replacement: Document, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for FindOneAndReplace { - fn command_names(&self) -> &[&str] { - &["findAndModify"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .find_one_and_replace( - self.filter.clone(), - self.replacement.clone(), - self.options.clone(), - ) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct FindOneAndDelete { - filter: Document, - #[serde(flatten)] - options: Option, -} - -#[async_trait] -impl TestOperation for FindOneAndDelete { - fn command_names(&self) -> &[&str] { - &["findAndModify"] - } - - async fn execute_on_collection(&self, collection: &Collection) -> Result> { - let result = collection - .find_one_and_delete(self.filter.clone(), self.options.clone()) - .await?; - let result = bson::to_bson(&result)?; - Ok(Some(result)) - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -#[derive(Debug, Deserialize)] -pub(super) struct UnimplementedOperation; - -#[async_trait] -impl TestOperation for UnimplementedOperation { - fn command_names(&self) -> &[&str] { - unimplemented!() - } - - async fn execute_on_collection(&self, _collection: &Collection) -> Result> { - unimplemented!() - } - - async fn execute_on_client(&self, _client: &EventClient) -> Result> { - unimplemented!() - } - - async fn execute_on_database(&self, _database: &Database) -> Result> { - unimplemented!() - } -} - -impl EventClient { - pub async fn run_database_operation( - &self, - operation: &AnyTestOperation, - database_name: &str, - ) -> Result> { - operation - .execute_on_database(&self.database(database_name)) - .await - } - - pub async fn run_collection_operation( - &self, - operation: &AnyTestOperation, - db_name: &str, - coll_name: &str, - collection_options: Option, - ) -> Result> { - let coll = match collection_options { - Some(options) => self.get_coll_with_options(&db_name, &coll_name, options), - None => self.get_coll(&db_name, &coll_name), - }; - operation.execute_on_collection(&coll).await - } - - pub async fn run_client_operation(&self, operation: &AnyTestOperation) -> Result> { - operation.execute_on_client(self).await - } -} diff --git a/src/test/spec/runner/test_event.rs b/src/test/spec/runner/test_event.rs deleted file mode 100644 index 39b3afb11..000000000 --- a/src/test/spec/runner/test_event.rs +++ /dev/null @@ -1,51 +0,0 @@ -use crate::{bson::Document, event::command::CommandStartedEvent, test::Matchable}; -use serde::Deserialize; - -#[derive(Debug, Deserialize, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum TestEvent { - CommandStartedEvent { - command_name: Option, - database_name: Option, - command: Document, - }, -} - -impl Matchable for TestEvent { - fn content_matches(&self, actual: &TestEvent) -> bool { - match (self, actual) { - ( - TestEvent::CommandStartedEvent { - command_name: actual_command_name, - database_name: actual_database_name, - command: actual_command, - }, - TestEvent::CommandStartedEvent { - command_name: expected_command_name, - database_name: expected_database_name, - command: expected_command, - }, - ) => { - if expected_command_name.is_some() && actual_command_name != expected_command_name { - return false; - } - if expected_database_name.is_some() - && actual_database_name != expected_database_name - { - return false; - } - actual_command.matches(expected_command) - } - } - } -} - -impl From for TestEvent { - fn from(event: CommandStartedEvent) -> Self { - TestEvent::CommandStartedEvent { - command_name: Some(event.command_name), - database_name: Some(event.db), - command: event.command, - } - } -} diff --git a/src/test/spec/sessions.rs b/src/test/spec/sessions.rs new file mode 100644 index 000000000..1a8ad3219 --- /dev/null +++ b/src/test/spec/sessions.rs @@ -0,0 +1,12 @@ +use tokio::sync::RwLockWriteGuard; + +use crate::test::{run_spec_test, LOCK}; + +use super::run_v2_test; + +#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn run() { + let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; + run_spec_test(&["sessions"], run_v2_test).await; +} diff --git a/src/test/spec/v2_runner/mod.rs b/src/test/spec/v2_runner/mod.rs new file mode 100644 index 000000000..f2a4fb593 --- /dev/null +++ b/src/test/spec/v2_runner/mod.rs @@ -0,0 +1,304 @@ +pub mod operation; +pub mod test_event; +pub mod test_file; + +use std::{ops::Deref, time::Duration}; + +use semver::VersionReq; + +use crate::{ + bson::doc, + coll::options::DropCollectionOptions, + concern::{Acknowledgment, WriteConcern}, + options::{CreateCollectionOptions, InsertManyOptions}, + test::{assert_matches, util::get_default_name, EventClient, TestClient}, + RUNTIME, +}; + +use operation::{OperationObject, OperationResult}; +use test_event::CommandStartedEvent; +use test_file::{TestData, TestFile}; + +const SKIPPED_OPERATIONS: &[&str] = &[ + "bulkWrite", + "count", + "download", + "download_by_name", + "listCollectionObjects", + "listDatabaseObjects", + "listIndexNames", + "listIndexes", + "mapReduce", + "watch", +]; + +pub async fn run_v2_test(test_file: TestFile) { + let client = TestClient::new().await; + + if let Some(requirements) = test_file.run_on { + let can_run_on = requirements.iter().any(|run_on| run_on.can_run_on(&client)); + if !can_run_on { + println!("Client topology not compatible with test"); + return; + } + } + + for test in test_file.tests { + if test + .operations + .iter() + .any(|operation| SKIPPED_OPERATIONS.contains(&operation.name.as_str())) + { + continue; + } + + if let Some(skip_reason) = test.skip_reason { + println!("skipping {}: {}", test.description, skip_reason); + continue; + } + + match client + .database("admin") + .run_command(doc! { "killAllSessions": [] }, None) + .await + { + Ok(_) => {} + Err(err) => match err.kind.code_and_message() { + Some((11601, _)) => {} + _ => panic!("{}: killAllSessions failed", test.description), + }, + } + + let db_name = test_file + .database_name + .clone() + .unwrap_or_else(|| get_default_name(&test.description)); + let coll_name = test_file + .collection_name + .clone() + .unwrap_or_else(|| get_default_name(&test.description)); + + let coll = client.database(&db_name).collection(&coll_name); + let options = DropCollectionOptions::builder() + .write_concern(majority_write_concern()) + .build(); + let req = VersionReq::parse(">=4.7").unwrap(); + if !(db_name.as_str() == "admin" + && client.is_sharded() + && req.matches(client.server_version.as_ref().unwrap())) + { + coll.drop(options).await.unwrap(); + } + + let options = CreateCollectionOptions::builder() + .write_concern(majority_write_concern()) + .build(); + client + .database(&db_name) + .create_collection(&coll_name, options) + .await + .unwrap(); + + if let Some(data) = &test_file.data { + match data { + TestData::Single(data) => { + if !data.is_empty() { + let options = InsertManyOptions::builder() + .write_concern(majority_write_concern()) + .build(); + coll.insert_many(data.clone(), options).await.unwrap(); + } + } + TestData::Many(_) => panic!("{}: invalid data format", &test.description), + } + } + + let client = EventClient::with_additional_options( + test.client_options.clone(), + None, + test.use_multiple_mongoses, + None, + false, + ) + .await; + + let _fp_guard = match test.fail_point { + Some(fail_point) => Some(fail_point.enable(client.deref(), None).await.unwrap()), + None => None, + }; + + let options = match test.session_options { + Some(ref options) => options.get("session0").cloned(), + None => None, + }; + let mut session0 = Some(client.start_session(options).await.unwrap()); + let session0_lsid = session0.as_ref().unwrap().id().clone(); + + let options = match test.session_options { + Some(ref options) => options.get("session1").cloned(), + None => None, + }; + let mut session1 = Some(client.start_session(options).await.unwrap()); + let session1_lsid = session1.as_ref().unwrap().id().clone(); + + for operation in test.operations { + let db = match &operation.database_options { + Some(options) => client.database_with_options(&db_name, options.clone()), + None => client.database(&db_name), + }; + let coll = match &operation.collection_options { + Some(options) => db.collection_with_options(&coll_name, options.clone()), + None => db.collection(&coll_name), + }; + + let session = match operation.session.as_deref() { + Some("session0") => session0.as_mut(), + Some("session1") => session1.as_mut(), + Some(other) => panic!("unknown session name: {}", other), + None => None, + }; + + let result = match operation.object { + Some(OperationObject::Collection) | None => { + let result = operation.execute_on_collection(&coll, session).await; + // This test (in src/test/spec/json/sessions/server-support.json) runs two + // operations with implicit sessions in sequence and then checks to see if they + // used the same lsid. We delay for one second to ensure that the + // implicit session used in the first operation is returned to the pool before + // the second operation is executed. + if test.description == "Server supports implicit sessions" { + RUNTIME.delay_for(Duration::from_secs(1)).await; + } + result + } + Some(OperationObject::Database) => { + operation.execute_on_database(&db, session).await + } + Some(OperationObject::Client) => operation.execute_on_client(&client).await, + Some(OperationObject::Session0) => { + if operation.name == "endSession" { + let session = session0.take(); + drop(session); + RUNTIME.delay_for(Duration::from_secs(1)).await; + } else { + operation + .execute_on_session(session0.as_ref().unwrap()) + .await; + } + continue; + } + Some(OperationObject::Session1) => { + if operation.name == "endSession" { + let session = session1.take(); + drop(session); + RUNTIME.delay_for(Duration::from_secs(1)).await; + } else { + operation + .execute_on_session(session1.as_ref().unwrap()) + .await; + } + continue; + } + Some(OperationObject::TestRunner) => { + match operation.name.as_str() { + "assertDifferentLsidOnLastTwoCommands" => { + assert_different_lsid_on_last_two_commands(&client) + } + "assertSameLsidOnLastTwoCommands" => { + assert_same_lsid_on_last_two_commands(&client) + } + "assertSessionDirty" => { + assert!(session.unwrap().is_dirty()) + } + "assertSessionNotDirty" => { + assert!(!session.unwrap().is_dirty()) + } + other => panic!("unknown operation: {}", other), + } + continue; + } + Some(OperationObject::GridfsBucket) => { + panic!("unsupported operation: {}", operation.name) + } + }; + + if let Some(error) = operation.error { + assert_eq!(error, result.is_err(), "{}", &test.description); + } + + if let Some(expected_result) = operation.result { + match expected_result { + OperationResult::Success(expected) => { + let result = result.unwrap().unwrap(); + assert_matches(&result, &expected, Some(&test.description)); + } + OperationResult::Error(operation_error) => { + let error = result.unwrap_err(); + if let Some(error_contains) = operation_error.error_contains { + let (_, message) = error.kind.code_and_message().unwrap(); + assert!(message.contains(&error_contains)); + } + if let Some(error_code_name) = operation_error.error_code_name { + let code_name = error.kind.code_name().unwrap(); + assert_eq!(error_code_name, code_name); + } + if let Some(error_labels_contain) = operation_error.error_labels_contain { + let labels = error.labels().to_vec(); + error_labels_contain + .iter() + .for_each(|label| assert!(labels.contains(label))); + } + if let Some(error_labels_omit) = operation_error.error_labels_omit { + let labels = error.labels().to_vec(); + error_labels_omit + .iter() + .for_each(|label| assert!(!labels.contains(label))); + } + } + } + } + } + + drop(session0); + drop(session1); + + if let Some(expectations) = test.expectations { + let events: Vec = client + .get_all_command_started_events() + .into_iter() + .map(Into::into) + .collect(); + + assert!(events.len() >= expectations.len(), "{}", test.description); + for (actual_event, expected_event) in events.iter().zip(expectations.iter()) { + assert!(actual_event.matches_expected( + expected_event, + &session0_lsid, + &session1_lsid + )); + } + } + + if let Some(outcome) = test.outcome { + assert!(outcome.matches_actual(db_name, coll_name, &client).await); + } + } +} + +fn majority_write_concern() -> WriteConcern { + WriteConcern::builder().w(Acknowledgment::Majority).build() +} + +fn assert_different_lsid_on_last_two_commands(client: &EventClient) { + let events = client.get_all_command_started_events(); + let lsid1 = events[events.len() - 1].command.get("lsid").unwrap(); + let lsid2 = events[events.len() - 2].command.get("lsid").unwrap(); + assert_ne!(lsid1, lsid2); +} + +fn assert_same_lsid_on_last_two_commands(client: &EventClient) { + let events = client.get_all_command_started_events(); + let lsid1 = events[events.len() - 1].command.get("lsid").unwrap(); + let lsid2 = events[events.len() - 2].command.get("lsid").unwrap(); + assert_eq!(lsid1, lsid2); +} diff --git a/src/test/spec/v2_runner/operation.rs b/src/test/spec/v2_runner/operation.rs new file mode 100644 index 000000000..5f88bc465 --- /dev/null +++ b/src/test/spec/v2_runner/operation.rs @@ -0,0 +1,1269 @@ +use std::{collections::HashMap, fmt::Debug, ops::Deref}; + +use async_trait::async_trait; +use futures::stream::TryStreamExt; +use serde::{de::Deserializer, Deserialize}; + +use crate::{ + bson::{doc, Bson, Deserializer as BsonDeserializer, Document}, + client::session::ClientSession, + coll::options::CollectionOptions, + db::options::DatabaseOptions, + error::Result, + options::{ + AggregateOptions, + CountOptions, + DeleteOptions, + DistinctOptions, + EstimatedDocumentCountOptions, + FindOneAndDeleteOptions, + FindOneAndReplaceOptions, + FindOneAndUpdateOptions, + FindOneOptions, + FindOptions, + InsertManyOptions, + InsertOneOptions, + ListCollectionsOptions, + ListDatabasesOptions, + ReplaceOptions, + UpdateModifications, + UpdateOptions, + }, + test::EventClient, + Collection, + Database, +}; + +// The linked issue causes a warning that cannot be suppressed when providing a default +// implementation for the async functions contained in this trait. +// +#[async_trait] +pub trait TestOperation: Debug { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result>; + + async fn execute_on_database( + &self, + database: &Database, + session: Option<&mut ClientSession>, + ) -> Result>; + + async fn execute_on_client(&self, client: &EventClient) -> Result>; + + async fn execute_on_session(&self, session: &ClientSession); +} + +#[derive(Debug)] +pub struct Operation { + operation: Box, + pub name: String, + pub object: Option, + pub collection_options: Option, + pub database_options: Option, + pub error: Option, + pub result: Option, + pub session: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub enum OperationObject { + Database, + Collection, + Client, + Session0, + Session1, + #[serde(rename = "gridfsbucket")] + GridfsBucket, + TestRunner, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum OperationResult { + Error(OperationError), + Success(Bson), +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct OperationError { + pub error_contains: Option, + pub error_code_name: Option, + pub error_labels_contain: Option>, + pub error_labels_omit: Option>, +} + +impl<'de> Deserialize<'de> for Operation { + fn deserialize>(deserializer: D) -> std::result::Result { + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase", deny_unknown_fields)] + struct OperationDefinition { + pub name: String, + pub object: Option, + pub collection_options: Option, + pub database_options: Option, + #[serde(default = "default_arguments")] + pub arguments: Document, + pub error: Option, + pub result: Option, + } + + fn default_arguments() -> Document { + doc! {} + } + + let mut definition = OperationDefinition::deserialize(deserializer)?; + let session = definition + .arguments + .remove("session") + .map(|session| session.as_str().unwrap().to_string()); + + let boxed_op = match definition.name.as_str() { + "insertOne" => { + InsertOne::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "insertMany" => { + InsertMany::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "updateOne" => { + UpdateOne::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "updateMany" => { + UpdateMany::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "deleteMany" => { + DeleteMany::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "deleteOne" => { + DeleteOne::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "find" => { + Find::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "aggregate" => { + Aggregate::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "distinct" => { + Distinct::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "countDocuments" => CountDocuments::deserialize(BsonDeserializer::new(Bson::Document( + definition.arguments, + ))) + .map(|op| Box::new(op) as Box), + "estimatedDocumentCount" => EstimatedDocumentCount::deserialize(BsonDeserializer::new( + Bson::Document(definition.arguments), + )) + .map(|op| Box::new(op) as Box), + "findOne" => { + FindOne::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "listCollections" => ListCollections::deserialize(BsonDeserializer::new( + Bson::Document(definition.arguments), + )) + .map(|op| Box::new(op) as Box), + "listCollectionNames" => ListCollectionNames::deserialize(BsonDeserializer::new( + Bson::Document(definition.arguments), + )) + .map(|op| Box::new(op) as Box), + "replaceOne" => { + ReplaceOne::deserialize(BsonDeserializer::new(Bson::Document(definition.arguments))) + .map(|op| Box::new(op) as Box) + } + "findOneAndUpdate" => FindOneAndUpdate::deserialize(BsonDeserializer::new( + Bson::Document(definition.arguments), + )) + .map(|op| Box::new(op) as Box), + "findOneAndReplace" => FindOneAndReplace::deserialize(BsonDeserializer::new( + Bson::Document(definition.arguments), + )) + .map(|op| Box::new(op) as Box), + "findOneAndDelete" => FindOneAndDelete::deserialize(BsonDeserializer::new( + Bson::Document(definition.arguments), + )) + .map(|op| Box::new(op) as Box), + "listDatabases" => ListDatabases::deserialize(BsonDeserializer::new(Bson::Document( + definition.arguments, + ))) + .map(|op| Box::new(op) as Box), + "listDatabaseNames" => ListDatabaseNames::deserialize(BsonDeserializer::new( + Bson::Document(definition.arguments), + )) + .map(|op| Box::new(op) as Box), + _ => Ok(Box::new(UnimplementedOperation) as Box), + } + .map_err(|e| serde::de::Error::custom(format!("{}", e)))?; + + Ok(Operation { + operation: boxed_op, + name: definition.name, + object: definition.object, + collection_options: definition.collection_options, + database_options: definition.database_options, + error: definition.error, + result: definition.result, + session, + }) + } +} + +impl Deref for Operation { + type Target = Box; + + fn deref(&self) -> &Box { + &self.operation + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct DeleteMany { + filter: Document, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for DeleteMany { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .delete_many_with_session(self.filter.clone(), self.options.clone(), session) + .await? + } + None => { + collection + .delete_many(self.filter.clone(), self.options.clone()) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct DeleteOne { + filter: Document, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for DeleteOne { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .delete_one_with_session(self.filter.clone(), self.options.clone(), session) + .await? + } + None => { + collection + .delete_one(self.filter.clone(), self.options.clone()) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Default, Deserialize)] +pub(super) struct Find { + filter: Option, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for Find { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + let mut cursor = collection + .find_with_session(self.filter.clone(), self.options.clone(), session) + .await?; + cursor + .with_session(session) + .try_collect::>() + .await? + } + None => { + let cursor = collection + .find(self.filter.clone(), self.options.clone()) + .await?; + cursor.try_collect::>().await? + } + }; + Ok(Some(Bson::from(result))) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct InsertMany { + documents: Vec, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for InsertMany { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .insert_many_with_session(self.documents.clone(), self.options.clone(), session) + .await? + } + None => { + collection + .insert_many(self.documents.clone(), self.options.clone()) + .await? + } + }; + let ids: HashMap = result + .inserted_ids + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + let ids = bson::to_bson(&ids)?; + Ok(Some(Bson::from(doc! { "insertedIds": ids }))) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct InsertOne { + document: Document, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for InsertOne { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .insert_one_with_session(self.document.clone(), self.options.clone(), session) + .await? + } + None => { + collection + .insert_one(self.document.clone(), self.options.clone()) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct UpdateMany { + filter: Document, + update: UpdateModifications, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for UpdateMany { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .update_many_with_session( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .update_many( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + ) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct UpdateOne { + filter: Document, + update: UpdateModifications, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for UpdateOne { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .update_one_with_session( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .update_one( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + ) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(super) struct Aggregate { + pipeline: Vec, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for Aggregate { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + let mut cursor = collection + .aggregate_with_session(self.pipeline.clone(), self.options.clone(), session) + .await?; + cursor + .with_session(session) + .try_collect::>() + .await? + } + None => { + let cursor = collection + .aggregate(self.pipeline.clone(), self.options.clone()) + .await?; + cursor.try_collect::>().await? + } + }; + Ok(Some(Bson::from(result))) + } + + async fn execute_on_database( + &self, + database: &Database, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + let mut cursor = database + .aggregate_with_session(self.pipeline.clone(), self.options.clone(), session) + .await?; + cursor + .with_session(session) + .try_collect::>() + .await? + } + None => { + let cursor = database + .aggregate(self.pipeline.clone(), self.options.clone()) + .await?; + cursor.try_collect::>().await? + } + }; + + Ok(Some(Bson::from(result))) + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(super) struct Distinct { + field_name: String, + filter: Option, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for Distinct { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .distinct_with_session( + &self.field_name, + self.filter.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .distinct(&self.field_name, self.filter.clone(), self.options.clone()) + .await? + } + }; + Ok(Some(Bson::Array(result))) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct CountDocuments { + filter: Document, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for CountDocuments { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .count_documents_with_session( + self.filter.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .count_documents(self.filter.clone(), self.options.clone()) + .await? + } + }; + Ok(Some(Bson::from(result))) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct EstimatedDocumentCount { + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for EstimatedDocumentCount { + async fn execute_on_collection( + &self, + collection: &Collection, + _session: Option<&mut ClientSession>, + ) -> Result> { + let result = collection + .estimated_document_count(self.options.clone()) + .await?; + Ok(Some(Bson::from(result))) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Default, Deserialize)] +pub(super) struct FindOne { + filter: Option, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for FindOne { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .find_one_with_session(self.filter.clone(), self.options.clone(), session) + .await? + } + None => { + collection + .find_one(self.filter.clone(), self.options.clone()) + .await? + } + }; + match result { + Some(result) => Ok(Some(Bson::from(result))), + None => Ok(None), + } + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct ListCollections { + filter: Option, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for ListCollections { + async fn execute_on_collection( + &self, + _collection: &Collection, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_database( + &self, + database: &Database, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + let mut cursor = database + .list_collections_with_session( + self.filter.clone(), + self.options.clone(), + session, + ) + .await?; + cursor + .with_session(session) + .try_collect::>() + .await? + } + None => { + let cursor = database + .list_collections(self.filter.clone(), self.options.clone()) + .await?; + cursor.try_collect::>().await? + } + }; + Ok(Some(Bson::from(result))) + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct ListCollectionNames { + filter: Option, +} + +#[async_trait] +impl TestOperation for ListCollectionNames { + async fn execute_on_collection( + &self, + _collection: &Collection, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_database( + &self, + database: &Database, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + database + .list_collection_names_with_session(self.filter.clone(), session) + .await? + } + None => database.list_collection_names(self.filter.clone()).await?, + }; + let result: Vec = result.iter().map(|s| Bson::String(s.to_string())).collect(); + Ok(Some(Bson::from(result))) + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct ReplaceOne { + filter: Document, + replacement: Document, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for ReplaceOne { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .replace_one_with_session( + self.filter.clone(), + self.replacement.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .replace_one( + self.filter.clone(), + self.replacement.clone(), + self.options.clone(), + ) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct FindOneAndUpdate { + filter: Document, + update: UpdateModifications, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for FindOneAndUpdate { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .find_one_and_update_with_session( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .find_one_and_update( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + ) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct FindOneAndReplace { + filter: Document, + replacement: Document, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for FindOneAndReplace { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .find_one_and_replace_with_session( + self.filter.clone(), + self.replacement.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .find_one_and_replace( + self.filter.clone(), + self.replacement.clone(), + self.options.clone(), + ) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct FindOneAndDelete { + filter: Document, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for FindOneAndDelete { + async fn execute_on_collection( + &self, + collection: &Collection, + session: Option<&mut ClientSession>, + ) -> Result> { + let result = match session { + Some(session) => { + collection + .find_one_and_delete_with_session( + self.filter.clone(), + self.options.clone(), + session, + ) + .await? + } + None => { + collection + .find_one_and_delete(self.filter.clone(), self.options.clone()) + .await? + } + }; + let result = bson::to_bson(&result)?; + Ok(Some(result)) + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct ListDatabases { + filter: Option, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for ListDatabases { + async fn execute_on_collection( + &self, + _collection: &Collection, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, client: &EventClient) -> Result> { + let result = client + .list_databases(self.filter.clone(), self.options.clone()) + .await?; + let result: Vec = result.iter().map(Bson::from).collect(); + Ok(Some(Bson::Array(result))) + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct ListDatabaseNames { + filter: Option, + #[serde(flatten)] + options: Option, +} + +#[async_trait] +impl TestOperation for ListDatabaseNames { + async fn execute_on_collection( + &self, + _collection: &Collection, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, client: &EventClient) -> Result> { + let result = client + .list_database_names(self.filter.clone(), self.options.clone()) + .await?; + let result: Vec = result.iter().map(|s| Bson::String(s.to_string())).collect(); + Ok(Some(Bson::Array(result))) + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} + +#[derive(Debug, Deserialize)] +pub(super) struct UnimplementedOperation; + +#[async_trait] +impl TestOperation for UnimplementedOperation { + async fn execute_on_collection( + &self, + _collection: &Collection, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_database( + &self, + _database: &Database, + _session: Option<&mut ClientSession>, + ) -> Result> { + unimplemented!() + } + + async fn execute_on_client(&self, _client: &EventClient) -> Result> { + unimplemented!() + } + + async fn execute_on_session(&self, _session: &ClientSession) { + unimplemented!() + } +} diff --git a/src/test/spec/v2_runner/test_event.rs b/src/test/spec/v2_runner/test_event.rs new file mode 100644 index 000000000..f410a7f78 --- /dev/null +++ b/src/test/spec/v2_runner/test_event.rs @@ -0,0 +1,50 @@ +use crate::{bson::Document, event, test::Matchable}; +use bson::Bson; +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct CommandStartedEvent { + command_name: Option, + database_name: Option, + command: Document, +} + +impl CommandStartedEvent { + pub fn matches_expected( + &self, + expected: &CommandStartedEvent, + session0_lsid: &Document, + session1_lsid: &Document, + ) -> bool { + if expected.command_name.is_some() && self.command_name != expected.command_name { + return false; + } + if expected.database_name.is_some() && self.database_name != expected.database_name { + return false; + } + let mut expected = expected.command.clone(); + if let Some(Bson::String(session)) = expected.remove("lsid") { + match session.as_str() { + "session0" => { + expected.insert("lsid", session0_lsid.clone()); + } + "session1" => { + expected.insert("lsid", session1_lsid.clone()); + } + other => panic!("unknown session name: {}", other), + } + } + self.command.content_matches(&expected) + } +} + +impl From for CommandStartedEvent { + fn from(event: event::command::CommandStartedEvent) -> Self { + CommandStartedEvent { + command_name: Some(event.command_name), + database_name: Some(event.db), + command: event.command, + } + } +} diff --git a/src/test/spec/runner/test_file.rs b/src/test/spec/v2_runner/test_file.rs similarity index 65% rename from src/test/spec/runner/test_file.rs rename to src/test/spec/v2_runner/test_file.rs index df54e1d16..9d01d4a6e 100644 --- a/src/test/spec/runner/test_file.rs +++ b/src/test/spec/v2_runner/test_file.rs @@ -1,16 +1,21 @@ use std::collections::HashMap; -use futures::stream::TryStreamExt; +use bson::{doc, from_document}; +use futures::TryStreamExt; use semver::VersionReq; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; use crate::{ - bson::{doc, Document}, - options::{ClientOptions, FindOptions}, - test::{util::EventClient, AnyTestOperation, TestEvent}, + bson::Document, + client::options::ClientOptions, + options::{FindOptions, SessionOptions}, + test::{EventClient, FailPoint, TestClient}, }; -#[derive(Debug, Deserialize)] +use super::{operation::Operation, test_event::CommandStartedEvent}; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] pub struct TestFile { #[serde(rename = "runOn")] pub run_on: Option>, @@ -18,14 +23,7 @@ pub struct TestFile { pub collection_name: Option, pub bucket_name: Option, pub data: Option, - pub tests: Vec, -} - -#[derive(Debug, Deserialize)] -#[serde(untagged)] -pub enum TestData { - Single(Vec), - Many(HashMap>), + pub tests: Vec, } #[derive(Debug, Deserialize)] @@ -37,7 +35,7 @@ pub struct RunOn { } impl RunOn { - pub fn can_run_on(&self, client: &EventClient) -> bool { + pub fn can_run_on(&self, client: &TestClient) -> bool { if let Some(ref min_version) = self.min_server_version { let req = VersionReq::parse(&format!(">= {}", &min_version)).unwrap(); if !req.matches(&client.server_version.as_ref().unwrap()) { @@ -51,7 +49,7 @@ impl RunOn { } } if let Some(ref topology) = self.topology { - if !topology.contains(&client.topology()) { + if !topology.contains(&client.topology_string()) { return false; } } @@ -60,16 +58,25 @@ impl RunOn { } #[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum TestData { + Single(Vec), + Many(HashMap>), +} + +#[derive(Deserialize)] #[serde(rename_all = "camelCase")] -pub struct TestCase { +pub struct Test { pub description: String, - pub client_options: Option, - pub use_multiple_mongoses: Option, pub skip_reason: Option, - pub fail_point: Option, - pub operations: Vec, + pub use_multiple_mongoses: Option, + pub client_options: Option, + pub fail_point: Option, + pub session_options: Option>, + pub operations: Vec, + #[serde(default, deserialize_with = "deserialize_command_started_events")] + pub expectations: Option>, pub outcome: Option, - pub expectations: Option>, } #[derive(Debug, Deserialize)] @@ -107,11 +114,19 @@ pub struct CollectionOutcome { pub data: Vec, } -#[derive(Debug, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum OperationObject { - Database, - Collection, - Client, - GridfsBucket, +fn deserialize_command_started_events<'de, D>( + deserializer: D, +) -> std::result::Result>, D::Error> +where + D: Deserializer<'de>, +{ + let docs = Vec::::deserialize(deserializer)?; + Ok(Some( + docs.iter() + .map(|doc| { + let event = doc.get_document("command_started_event").unwrap(); + from_document(event.clone()).unwrap() + }) + .collect(), + )) } diff --git a/src/test/util/event.rs b/src/test/util/event.rs index f768d5e2b..9c9068119 100644 --- a/src/test/util/event.rs +++ b/src/test/util/event.rs @@ -225,13 +225,13 @@ impl EventClient { } pub async fn with_additional_options( - options: Option, + options: impl Into>, heartbeat_freq: Option, use_multiple_mongoses: Option, event_handler: impl Into>, collect_server_info: bool, ) -> Self { - let mut options = match options { + let mut options = match options.into() { Some(mut options) => { options.merge(CLIENT_OPTIONS.clone()); options @@ -308,17 +308,7 @@ impl EventClient { panic!("could not find event for {} command", command_name); } - pub fn topology(&self) -> String { - if self.client.is_sharded() { - String::from("sharded") - } else if self.client.is_replica_set() { - String::from("replicaset") - } else { - String::from("single") - } - } - - /// Gets all of the command started events for a specified command names. + /// Gets all of the command started events for the specified command names. pub fn get_command_started_events(&self, command_names: &[&str]) -> Vec { let events = self.handler.command_events.read().unwrap(); events @@ -336,6 +326,22 @@ impl EventClient { .collect() } + /// Gets all command started events, excluding configureFailPoint events. + pub fn get_all_command_started_events(&self) -> Vec { + let events = self.handler.command_events.read().unwrap(); + events + .iter() + .filter_map(|event| match event { + CommandEvent::CommandStartedEvent(event) + if event.command_name != "configureFailPoint" => + { + Some(event.clone()) + } + _ => None, + }) + .collect() + } + /// Gets a list of all of the events of the requested event types that occurred on this client. /// Ignores any event with a name in the ignore list. Also ignores all configureFailPoint /// events. diff --git a/src/test/util/failpoint.rs b/src/test/util/failpoint.rs index df4acaf30..5aeccff23 100644 --- a/src/test/util/failpoint.rs +++ b/src/test/util/failpoint.rs @@ -43,7 +43,7 @@ impl FailPoint { FailPoint { command } } - pub(super) async fn enable( + pub async fn enable( self, client: &TestClient, criteria: impl Into>, diff --git a/src/test/util/matchable.rs b/src/test/util/matchable.rs index bd8dff7b9..f0c4b7aef 100644 --- a/src/test/util/matchable.rs +++ b/src/test/util/matchable.rs @@ -62,12 +62,17 @@ impl Matchable for Document { if k == "upsertedCount" { continue; } - if let Some(actual_v) = self.get(k) { - if !actual_v.matches(v) { - return false; + match self.get(k) { + Some(actual_v) => { + if !actual_v.matches(v) { + return false; + } + } + None => { + if v != &Bson::Null { + return false; + } } - } else { - return false; } } true diff --git a/src/test/util/mod.rs b/src/test/util/mod.rs index f865624a3..241c2519c 100644 --- a/src/test/util/mod.rs +++ b/src/test/util/mod.rs @@ -73,7 +73,7 @@ impl TestClient { // To avoid populating the session pool with leftover implicit sessions, we check out a // session here and immediately mark it as dirty, then use it with any operations we need. let mut session = client - .start_implicit_session_with_timeout(Duration::from_secs(60 * 60)) + .start_session_with_timeout(Duration::from_secs(60 * 60), None, true) .await; session.mark_dirty(); @@ -81,7 +81,7 @@ impl TestClient { let server_info = bson::from_bson(Bson::Document( client - .execute_operation_with_session(is_master, &mut session) + .execute_operation(is_master, &mut session) .await .unwrap(), )) @@ -95,7 +95,7 @@ impl TestClient { RunCommand::new("test".into(), doc! { "buildInfo": 1 }, None).unwrap(); let response = client - .execute_operation_with_session(build_info, &mut session) + .execute_operation(build_info, &mut session) .await .unwrap(); @@ -109,10 +109,7 @@ impl TestClient { // The command above may fail due to insufficient permissions. In that case, the unified // test runner will skip any tests with a serverParameters runOnRequirement as the check // will fail. - if let Ok(response) = client - .execute_operation_with_session(get_parameters, &mut session) - .await - { + if let Ok(response) = client.execute_operation(get_parameters, &mut session).await { server_parameters = Some(Bson::Document(response)); } } @@ -339,6 +336,16 @@ impl TestClient { Topology::Single } } + + pub fn topology_string(&self) -> String { + if self.is_sharded() { + "sharded".to_string() + } else if self.is_replica_set() { + "replicaset".to_string() + } else { + "single".to_string() + } + } } pub async fn drop_collection(coll: &Collection) @@ -385,7 +392,7 @@ pub struct IsMasterCommandResponse { pub primary: Option, } -pub fn get_db_name(description: &str) -> String { +pub fn get_default_name(description: &str) -> String { let mut db_name = description.replace('$', "%").replace(' ', "_"); // database names must have fewer than 64 characters db_name.truncate(63);