diff --git a/src/conn/mod.rs b/src/conn/mod.rs index b4a29712..30ccc81f 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -30,20 +30,22 @@ use std::{ use crate::{ conn::{pool::Pool, stmt_cache::StmtCache}, - connection_like::{streamless::Streamless, ConnectionLike, StmtCacheResult}, + connection_like::{ConnectionLike, StmtCacheResult}, consts::{self, CapabilityFlags}, error::*, io::Stream, local_infile_handler::LocalInfileHandler, opts::Opts, - queryable::{query_result, BinaryProtocol, Queryable, TextProtocol}, + queryable::{ + query_result::QueryResult, transaction::TxStatus, BinaryProtocol, Queryable, TextProtocol, + }, Column, OptsBuilder, }; pub mod pool; pub mod stmt_cache; -/// Helper that asynchronously disconnects connection on the default tokio executor. +/// Helper that asynchronously disconnects the givent connection on the default tokio executor. fn disconnect(mut conn: Conn) { let disconnected = conn.inner.disconnected; @@ -87,7 +89,7 @@ struct ConnInner { last_ok_packet: Option>, pool: Option, has_result: Option, - in_transaction: bool, + tx_status: TxStatus, opts: Opts, last_io: Instant, wait_timeout: Duration, @@ -106,7 +108,7 @@ impl fmt::Debug for ConnInner { .field("server version", &self.version) .field("pool", &self.pool) .field("has result", &self.has_result.is_some()) - .field("in transaction", &self.in_transaction) + .field("tx_status", &self.tx_status) .field("stream", &self.stream) .field("options", &self.opts) .finish() @@ -126,7 +128,7 @@ impl ConnInner { id: 0, has_result: None, pool: None, - in_transaction: false, + tx_status: TxStatus::None, last_io: Instant::now(), wait_timeout: Duration::from_secs(0), stmt_cache: StmtCache::new(opts.get_stmt_cache_size()), @@ -146,6 +148,11 @@ pub struct Conn { } impl Conn { + /// Returns connection identifier. + pub fn connection_id(&self) -> u32 { + self.inner.id + } + /// Returns the ID generated by a query (usually `INSERT`) on a table with a column having the /// `AUTO_INCREMENT` attribute. Returns `None` if there was no previous query on the connection /// or if the query did not update an AUTO_INCREMENT value. @@ -159,8 +166,23 @@ impl Conn { self.get_affected_rows() } - async fn close(self) -> Result<()> { - self.cleanup().await?.disconnect().await + fn take_stream(&mut self) -> Stream { + self.inner.stream.take().unwrap() + } + + /// Disconnects this connection from server. + pub async fn disconnect(mut self) -> Result<()> { + self.on_disconnect(); + self.write_command_data(crate::consts::Command::COM_QUIT, &[]) + .await?; + let stream = self.take_stream(); + stream.close().await?; + Ok(()) + } + + async fn close(mut self) -> Result<()> { + self = self.cleanup().await?; + self.disconnect().await } fn is_secure(&self) -> bool { @@ -173,10 +195,7 @@ impl Conn { /// Hacky way to move connection through &mut. `self` becomes unusable. fn take(&mut self) -> Conn { - let inner = mem::replace(&mut *self.inner, ConnInner::empty(Default::default())); - Conn { - inner: Box::new(inner), - } + mem::replace(self, Conn::empty(Default::default())) } fn empty(opts: Opts) -> Self { @@ -185,31 +204,32 @@ impl Conn { } } - fn setup_stream(mut self) -> Result { - if let Some(stream) = self.inner.stream.take() { + /// Set `io::Stream` options as defined in the `Opts` of the connection. + /// + /// Requires that self.inner.stream is Some + fn setup_stream(&mut self) -> Result<()> { + debug_assert!(self.inner.stream.is_some()); + if let Some(stream) = self.inner.stream.as_mut() { stream.set_keepalive_ms(self.inner.opts.get_tcp_keepalive())?; stream.set_tcp_nodelay(self.inner.opts.get_tcp_nodelay())?; - self.inner.stream = Some(stream); - Ok(self) - } else { - unreachable!(); } + Ok(()) } - async fn handle_handshake(self) -> Result { - let (mut conn, packet) = self.read_packet().await?; + async fn handle_handshake(&mut self) -> Result<()> { + let packet = self.read_packet().await?; let handshake = parse_handshake_packet(&*packet)?; - conn.inner.nonce = { + self.inner.nonce = { let mut nonce = Vec::from(handshake.scramble_1_ref()); nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..])); nonce }; - conn.inner.capabilities = handshake.capabilities() & conn.inner.opts.get_capabilities(); - conn.inner.version = handshake.server_version_parsed().unwrap_or((0, 0, 0)); - conn.inner.id = handshake.connection_id(); - conn.inner.status = handshake.status_flags(); - conn.inner.auth_plugin = match handshake.auth_plugin() { + self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities(); + self.inner.version = handshake.server_version_parsed().unwrap_or((0, 0, 0)); + self.inner.id = handshake.connection_id(); + self.inner.status = handshake.status_flags(); + self.inner.auth_plugin = match handshake.auth_plugin() { Some(AuthPlugin::MysqlNativePassword) => AuthPlugin::MysqlNativePassword, Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password, Some(AuthPlugin::Other(ref name)) => { @@ -218,10 +238,10 @@ impl Conn { } None => AuthPlugin::MysqlNativePassword, }; - Ok(conn) + Ok(()) } - async fn switch_to_ssl_if_needed(self) -> Result { + async fn switch_to_ssl_if_needed(&mut self) -> Result<()> { if self .inner .opts @@ -229,22 +249,22 @@ impl Conn { .contains(CapabilityFlags::CLIENT_SSL) { let ssl_request = SslRequest::new(self.inner.capabilities); - let conn = self.write_packet(ssl_request.as_ref()).await?; + self.write_packet(ssl_request.as_ref()).await?; + let conn = self; let ssl_opts = conn .get_opts() .get_ssl_opts() .cloned() .expect("unreachable"); let domain = conn.get_opts().get_ip_or_hostname().into(); - let (streamless, stream) = conn.take_stream(); - let stream = stream.make_secure(domain, ssl_opts).await?; - Ok(streamless.return_stream(stream)) + conn.stream_mut().make_secure(domain, ssl_opts).await?; + Ok(()) } else { - Ok(self) + Ok(()) } } - async fn do_handshake_response(self) -> Result { + async fn do_handshake_response(&mut self) -> Result<()> { let auth_data = self .inner .auth_plugin @@ -260,13 +280,14 @@ impl Conn { &Default::default(), // TODO: Add support ); - self.write_packet(handshake_response.as_ref()).await + self.write_packet(handshake_response.as_ref()).await?; + Ok(()) } async fn perform_auth_switch( - mut self, + &mut self, auth_switch_request: AuthSwitchRequest<'_>, - ) -> Result { + ) -> Result<()> { if !self.inner.auth_switched { self.inner.auth_switched = true; self.inner.nonce = auth_switch_request.plugin_data().into(); @@ -276,19 +297,27 @@ impl Conn { .auth_plugin .gen_data(self.inner.opts.get_pass(), &*self.inner.nonce) .unwrap_or_else(Vec::new); - self.write_packet(plugin_data).await?.continue_auth().await + self.write_packet(plugin_data).await?; + self.continue_auth().await?; + Ok(()) } else { unreachable!("auth_switched flag should be checked by caller") } } - fn continue_auth(self) -> Pin> + Send>> { + fn continue_auth(&mut self) -> Pin> + Send + '_>> { // NOTE: we need to box this since it may recurse // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782 Box::pin(async move { match self.inner.auth_plugin { - AuthPlugin::MysqlNativePassword => self.continue_mysql_native_password_auth().await, - AuthPlugin::CachingSha2Password => self.continue_caching_sha2_password_auth().await, + AuthPlugin::MysqlNativePassword => { + self.continue_mysql_native_password_auth().await?; + Ok(()) + } + AuthPlugin::CachingSha2Password => { + self.continue_caching_sha2_password_auth().await?; + Ok(()) + } AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin { name: String::from_utf8_lossy(name.as_ref()).to_string(), })?, @@ -296,7 +325,7 @@ impl Conn { }) } - fn switch_to_compression(mut self) -> Result { + fn switch_to_compression(&mut self) -> Result<()> { if self .get_capabilities() .contains(CapabilityFlags::CLIENT_COMPRESS) @@ -307,23 +336,23 @@ impl Conn { } } } - Ok(self) + Ok(()) } - async fn continue_caching_sha2_password_auth(self) -> Result { - let (conn, packet) = self.read_packet().await?; + async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> { + let packet = self.read_packet().await?; match packet.get(0) { Some(0x00) => { // ok packet for empty password - Ok(conn) + Ok(()) } Some(0x01) => match packet.get(1) { Some(0x03) => { // auth ok - conn.drop_packet().await + self.drop_packet().await } Some(0x04) => { - let mut pass = conn + let mut pass = self .inner .opts .get_pass() @@ -331,28 +360,30 @@ impl Conn { .unwrap_or_default(); pass.push(0); - let conn = if conn.is_secure() { - conn.write_packet(&*pass).await? + if self.is_secure() { + self.write_packet(&*pass).await?; } else { - let conn = conn.write_packet(&[0x02][..]).await?; - let (conn, packet) = conn.read_packet().await?; + self.write_packet(&[0x02][..]).await?; + let packet = self.read_packet().await?; let key = &packet[1..]; for (i, byte) in pass.iter_mut().enumerate() { - *byte ^= conn.inner.nonce[i % conn.inner.nonce.len()]; + *byte ^= self.inner.nonce[i % self.inner.nonce.len()]; } let encrypted_pass = crypto::encrypt(&*pass, key); - conn.write_packet(&*encrypted_pass).await? + self.write_packet(&*encrypted_pass).await?; }; - conn.drop_packet().await + self.drop_packet().await?; + Ok(()) } _ => Err(DriverError::UnexpectedPacket { payload: packet.into(), } .into()), }, - Some(0xfe) if !conn.inner.auth_switched => { + Some(0xfe) if !self.inner.auth_switched => { let auth_switch_request = parse_auth_switch_request(&*packet)?.into_owned(); - conn.perform_auth_switch(auth_switch_request).await + self.perform_auth_switch(auth_switch_request).await?; + Ok(()) } _ => Err(DriverError::UnexpectedPacket { payload: packet.into(), @@ -361,130 +392,119 @@ impl Conn { } } - async fn continue_mysql_native_password_auth(self) -> Result { - let (this, packet) = self.read_packet().await?; + async fn continue_mysql_native_password_auth(&mut self) -> Result<()> { + let packet = self.read_packet().await?; match packet.get(0) { - Some(0x00) => Ok(this), - Some(0xfe) if !this.inner.auth_switched => { + Some(0x00) => Ok(()), + Some(0xfe) if !self.inner.auth_switched => { let auth_switch_request = parse_auth_switch_request(packet.as_ref())?.into_owned(); - this.perform_auth_switch(auth_switch_request).await + self.perform_auth_switch(auth_switch_request).await?; + Ok(()) } _ => Err(DriverError::UnexpectedPacket { payload: packet }.into()), } } - async fn drop_packet(self) -> Result { - Ok(self.read_packet().await?.0) + async fn drop_packet(&mut self) -> Result<()> { + self.read_packet().await?; + Ok(()) } - async fn run_init_commands(self) -> Result { + async fn run_init_commands(&mut self) -> Result<()> { let mut init: Vec<_> = self.inner.opts.get_init().iter().cloned().collect(); - let mut conn = self; while let Some(query) = init.pop() { - conn = conn.drop_query(query).await?; + self.drop_query(query).await?; } - Ok(conn) + + Ok(()) } - /// Returns future that resolves to [`Conn`]. - pub async fn new>(opts: T) -> Result { + /// Returns a future that resolves to [`Conn`]. + pub fn new>(opts: T) -> crate::BoxFuture<'static, Conn> { let opts = opts.into(); - let mut conn = Conn::empty(opts.clone()); + Box::pin(async move { + let mut conn = Conn::empty(opts.clone()); - let stream = if let Some(path) = opts.get_socket() { - Stream::connect_socket(path.to_owned()).await? - } else { - Stream::connect_tcp(opts.get_hostport_or_url()).await? - }; + let stream = if let Some(path) = opts.get_socket() { + Stream::connect_socket(path.to_owned()).await? + } else { + Stream::connect_tcp(opts.get_hostport_or_url()).await? + }; - conn.inner.stream = Some(stream); - conn.setup_stream()? - .handle_handshake() - .await? - .switch_to_ssl_if_needed() - .await? - .do_handshake_response() - .await? - .continue_auth() - .await? - .switch_to_compression()? - .read_socket() - .await? - .reconnect_via_socket_if_needed() - .await? - .read_max_allowed_packet() - .await? - .read_wait_timeout() - .await? - .run_init_commands() - .await + conn.inner.stream = Some(stream); + conn.setup_stream()?; + conn.handle_handshake().await?; + conn.switch_to_ssl_if_needed().await?; + conn.do_handshake_response().await?; + conn.continue_auth().await?; + conn.switch_to_compression()?; + conn.read_socket().await?; + conn.reconnect_via_socket_if_needed().await?; + conn.read_max_allowed_packet().await?; + conn.read_wait_timeout().await?; + conn.run_init_commands().await?; + Ok(conn) + }) } - /// Returns future that resolves to [`Conn`]. + /// Returns a future that resolves to [`Conn`]. pub async fn from_url>(url: T) -> Result { Conn::new(Opts::from_str(url.as_ref())?).await } - /// Will try to connect via socket using socket address in `self.inner.socket`. - /// - /// Returns new connection on success or self on error. + /// Will try to reconnect via socket using socket address in `self.inner.socket`. /// /// Won't try to reconnect if socket connection is already enforced in [`Opts`]. - fn reconnect_via_socket_if_needed(self) -> Pin> + Send>> { - // NOTE: we need to box this since it may recurse - // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782 - Box::pin(async move { - if let Some(socket) = self.inner.socket.as_ref() { - let opts = self.inner.opts.clone(); - if opts.get_socket().is_none() { - let mut builder = OptsBuilder::from_opts(opts); - builder.socket(Some(&**socket)); - match Conn::new(builder).await { - Ok(conn) => { - // tidy up the old connection - self.close().await?; - return Ok(conn); - } - Err(_) => return Ok(self), + async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> { + if let Some(socket) = self.inner.socket.as_ref() { + let opts = self.inner.opts.clone(); + if opts.get_socket().is_none() { + let mut builder = OptsBuilder::from_opts(opts); + builder.socket(Some(&**socket)); + match Conn::new(builder).await { + Ok(conn) => { + let old_conn = std::mem::replace(self, conn); + // tidy up the old connection + old_conn.close().await?; } + Err(_) => (), } } - Ok(self) - }) + } + Ok(()) } - /// Returns future that resolves to [`Conn`] with socket address stored in it. + /// Reads and stores socket address inside the connection. /// /// Do nothing if socket address is already in [`Opts`] or if `prefer_socket` is `false`. - async fn read_socket(self) -> Result { + async fn read_socket(&mut self) -> Result<()> { if self.inner.opts.get_prefer_socket() && self.inner.socket.is_none() { - let (mut this, row_opt) = self.first("SELECT @@socket").await?; - this.inner.socket = row_opt.unwrap_or((None,)).0; - Ok(this) - } else { - Ok(self) + let row_opt = self.first("SELECT @@socket").await?; + self.inner.socket = row_opt.unwrap_or((None,)).0; } + Ok(()) } - /// Returns future that resolves to [`Conn`] with `max_allowed_packet` stored in it. - async fn read_max_allowed_packet(self) -> Result { - let (mut this, row_opt): (Self, _) = self.first("SELECT @@max_allowed_packet").await?; - if let Some(stream) = this.inner.stream.as_mut() { + /// Reads and stores `max_allowed_packet` in the connection. + async fn read_max_allowed_packet(&mut self) -> Result<()> { + let row_opt = self.first("SELECT @@max_allowed_packet").await?; + if let Some(stream) = self.inner.stream.as_mut() { stream.set_max_allowed_packet(row_opt.unwrap_or((DEFAULT_MAX_ALLOWED_PACKET,)).0); } - Ok(this) + Ok(()) } - /// Returns future that resolves to [`Conn`] with `wait_timeout` stored in it. - async fn read_wait_timeout(self) -> Result { - let (mut this, row_opt) = self.first("SELECT @@wait_timeout").await?; + /// Reads and stores `wait_timeout` in the connection. + async fn read_wait_timeout(&mut self) -> Result<()> { + let row_opt = self.first("SELECT @@wait_timeout").await?; let wait_timeout_secs = row_opt.unwrap_or((28800,)).0; - this.inner.wait_timeout = Duration::from_secs(wait_timeout_secs); - Ok(this) + self.inner.wait_timeout = Duration::from_secs(wait_timeout_secs); + Ok(()) } - /// Returns true if time since last io exceeds wait_timeout (or conn_ttl if specified in opts). + /// Returns true if time since last IO exceeds `wait_timeout` + /// (or `conn_ttl` if specified in opts). fn expired(&self) -> bool { let ttl = self .inner @@ -494,82 +514,81 @@ impl Conn { self.idling() > ttl } - /// Returns duration since last io. + /// Returns duration since last IO. fn idling(&self) -> Duration { self.inner.last_io.elapsed() } - /// Returns future that resolves to a [`Conn`] with `COM_RESET_CONNECTION` executed on it. - pub async fn reset(self) -> Result { + /// Executes `COM_RESET_CONNECTION` on `self`. + /// + /// If server version is older than 5.7.2, then it'll reconnect. + pub async fn reset(&mut self) -> Result<()> { let pool = self.inner.pool.clone(); - let mut conn = if self.inner.version > (5, 7, 2) { + + if self.inner.version > (5, 7, 2) { self.write_command_data(consts::Command::COM_RESET_CONNECTION, &[]) - .await? - .read_packet() - .await? - .0 + .await?; + self.read_packet().await?; } else { let opts = self.inner.opts.clone(); + let old_conn = std::mem::replace(self, Conn::new(opts).await?); // tidy up the old connection - self.close().await?; - Conn::new(opts).await? + old_conn.close().await?; }; - conn.inner.stmt_cache.clear(); - conn.inner.pool = pool; - Ok(conn) + self.inner.stmt_cache.clear(); + self.inner.pool = pool; + Ok(()) } - async fn rollback_transaction(mut self) -> Result { - assert!(self.inner.in_transaction); - self.inner.in_transaction = false; + /// Requires that `self.inner.tx_status != TxStatus::None` + async fn rollback_transaction(&mut self) -> Result<()> { + debug_assert_ne!(self.inner.tx_status, TxStatus::None); + self.inner.tx_status = TxStatus::None; self.drop_query("ROLLBACK").await } - async fn drop_result(mut self) -> Result { + async fn drop_result(&mut self) -> Result<()> { match self.inner.has_result.take() { Some(PendingResult::Text(columns)) => { - query_result::assemble::<_, TextProtocol>(self, Some(columns), None) + QueryResult::<'_, _, TextProtocol>::new(self, Some(columns), None) .drop_result() - .await + .await?; + Ok(()) } Some(PendingResult::Binary(columns, cached)) => { - query_result::assemble::<_, BinaryProtocol>(self, Some(columns), Some(cached)) + QueryResult::<'_, _, BinaryProtocol>::new(self, Some(columns), Some(cached)) .drop_result() - .await + .await?; + Ok(()) } Some(PendingResult::Empty) => { - query_result::assemble::<_, TextProtocol>(self, None, None) + QueryResult::<'_, _, TextProtocol>::new(self, None, None) .drop_result() - .await + .await?; + Ok(()) } - None => Ok(self), + None => Ok(()), } } - fn cleanup(self) -> Pin> + Send>> { - // NOTE: we need to box this since it may recurse - // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782 - Box::pin(async move { + async fn cleanup(mut self) -> Result { + loop { if self.inner.has_result.is_some() { - self.drop_result().await?.cleanup().await - } else if self.inner.in_transaction { - self.rollback_transaction().await?.cleanup().await + self.drop_result().await?; + } else if self.inner.tx_status != TxStatus::None { + self.rollback_transaction().await?; } else { - Ok(self) + break; } - }) + } + Ok(self) } } impl ConnectionLike for Conn { - fn take_stream(mut self) -> (Streamless, Stream) { - let stream = self.inner.stream.take().expect("Logic error: stream taken"); - (Streamless::new(self), stream) - } - - fn return_stream(&mut self, stream: Stream) { - self.inner.stream = Some(stream); + fn stream_mut(&mut self) -> &mut Stream { + self.inner.stream.as_mut().expect("Logic error: stream") } fn stmt_cache_ref(&self) -> &StmtCache { @@ -592,8 +611,8 @@ impl ConnectionLike for Conn { self.inner.capabilities } - fn get_in_transaction(&self) -> bool { - self.inner.in_transaction + fn get_tx_status(&self) -> TxStatus { + self.inner.tx_status } fn get_last_insert_id(&self) -> Option { @@ -647,8 +666,8 @@ impl ConnectionLike for Conn { self.inner.last_ok_packet = ok_packet; } - fn set_in_transaction(&mut self, in_transaction: bool) { - self.inner.in_transaction = in_transaction; + fn set_tx_status(&mut self, tx_status: TxStatus) { + self.inner.tx_status = tx_status; } fn set_pending_result(&mut self, meta: Option) { @@ -699,12 +718,14 @@ mod test { // no database name opts.db_name(None::); - let conn: Conn = Conn::new(opts.clone()).await?.ping().await?; + let mut conn: Conn = Conn::new(opts.clone()).await?; + conn.ping().await?; conn.disconnect().await?; // empty database name opts.db_name(Some("")); - let conn: Conn = Conn::new(opts).await?.ping().await?; + let mut conn: Conn = Conn::new(opts).await?; + conn.ping().await?; conn.disconnect().await?; Ok(()) @@ -712,35 +733,34 @@ mod test { #[tokio::test] async fn should_connect() -> super::Result<()> { - let conn: Conn = Conn::new(get_opts()).await?.ping().await?; + let mut conn: Conn = Conn::new(get_opts()).await?; + conn.ping().await?; - let (mut conn, plugins): (Conn, _) = conn - .query("SHOW PLUGINS") - .await? + let result = conn.query("SHOW PLUGINS").await?; + let plugins = result .map_and_drop(|mut row| row.take::("Name").unwrap()) .await?; // Should connect with any combination of supported plugin and empty-nonempty password. let variants = vec![ - ("caching_sha2_password", 2, "non-empty"), - ("caching_sha2_password", 2, ""), - ("mysql_native_password", 0, "non-empty"), - ("mysql_native_password", 0, ""), + ("caching_sha2_password", 2_u8, "non-empty"), + ("caching_sha2_password", 2_u8, ""), + ("mysql_native_password", 0_u8, "non-empty"), + ("mysql_native_password", 0_u8, ""), ] .into_iter() .filter(|variant| plugins.iter().any(|p| p == variant.0)); for (plug, val, pass) in variants { let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug); - conn = conn.drop_query(query).await.unwrap(); + conn.drop_query(query).await.unwrap(); - conn = if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) { + if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) { conn.drop_query(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass)) .await - .unwrap() + .unwrap(); } else { - conn = conn - .drop_query(format!("SET old_passwords = {}", val)) + conn.drop_query(format!("SET old_passwords = {}", val)) .await .unwrap(); conn.drop_query(format!( @@ -748,7 +768,7 @@ mod test { pass )) .await - .unwrap() + .unwrap(); }; let mut opts = get_opts(); @@ -757,7 +777,7 @@ mod test { .db_name(None::); let result = Conn::new(opts).await; - conn = conn.drop_query("DROP USER 'test_user'@'%'").await.unwrap(); + conn.drop_query("DROP USER 'test_user'@'%'").await.unwrap(); result?.disconnect().await?; } @@ -788,8 +808,8 @@ mod test { async fn should_execute_init_queries_on_new_connection() -> super::Result<()> { let mut opts_builder = OptsBuilder::from_opts(get_opts()); opts_builder.init(vec!["SET @a = 42", "SET @b = 'foo'"]); - let (conn, result) = Conn::new(opts_builder) - .await? + let mut conn = Conn::new(opts_builder).await?; + let result = conn .query("SELECT @a, @b") .await? .collect_and_drop::<(u8, String)>() @@ -801,10 +821,10 @@ mod test { #[tokio::test] async fn should_reset_the_connection() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let conn = conn.drop_exec("SELECT ?", (1,)).await?; - let conn = conn.reset().await?; - let conn = conn.drop_exec("SELECT ?", (1,)).await?; + let mut conn = Conn::new(get_opts()).await?; + conn.drop_exec("SELECT ?", (1_u8,)).await?; + conn.reset().await?; + conn.drop_exec("SELECT ?", (1_u8,)).await?; conn.disconnect().await?; Ok(()) } @@ -813,16 +833,16 @@ mod test { async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> { let mut opts = OptsBuilder::from_opts(get_opts()); opts.stmt_cache_size(0); - let conn = Conn::new(opts).await?; - let conn = conn.drop_exec("DO ?", (1,)).await?; - let stmt = conn.prepare("DO 2").await?; - let (stmt, _) = stmt.first::<_, (crate::Value,)>(()).await?; - let (stmt, _) = stmt.first::<_, (crate::Value,)>(()).await?; - let conn = stmt.close().await?; - let conn = conn.prep_exec("DO 3", ()).await?.drop_result().await?; - let conn = conn.batch_exec("DO 4", vec![(), ()]).await?; - let (conn, _) = conn.first_exec::<_, _, (u8,)>("DO 5", ()).await?; - let (conn, row) = conn + let mut conn = Conn::new(opts).await?; + conn.drop_exec("DO ?", (1_u8,)).await?; + let mut stmt = conn.prepare("DO 2").await?; + stmt.first::<_, crate::Value>(()).await?; + stmt.first::<_, crate::Value>(()).await?; + stmt.close().await?; + conn.prep_exec("DO 3", ()).await?.drop_result().await?; + conn.batch_exec("DO 4", vec![(), ()]).await?; + conn.first_exec::<_, _, (u8,)>("DO 5", ()).await?; + let row = conn .first("SHOW SESSION STATUS LIKE 'Com_stmt_close';") .await?; assert_eq!(from_row::<(String, usize)>(row.unwrap()).1, 5); @@ -836,25 +856,16 @@ mod test { let mut opts = OptsBuilder::from_opts(get_opts()); opts.stmt_cache_size(3); - let conn = Conn::new(opts) - .await? - .drop_exec("DO 1", ()) - .await? - .drop_exec("DO 2", ()) - .await? - .drop_exec("DO 3", ()) - .await? - .drop_exec("DO 1", ()) - .await? - .drop_exec("DO 4", ()) - .await? - .drop_exec("DO 3", ()) - .await? - .drop_exec("DO 5", ()) - .await? - .drop_exec("DO 6", ()) - .await?; - let (conn, row_opt) = conn + let mut conn = Conn::new(opts).await?; + conn.drop_exec("DO 1", ()).await?; + conn.drop_exec("DO 2", ()).await?; + conn.drop_exec("DO 3", ()).await?; + conn.drop_exec("DO 1", ()).await?; + conn.drop_exec("DO 4", ()).await?; + conn.drop_exec("DO 3", ()).await?; + conn.drop_exec("DO 5", ()).await?; + conn.drop_exec("DO 6", ()).await?; + let row_opt = conn .first("SHOW SESSION STATUS LIKE 'Com_stmt_close';") .await?; let (_, count): (String, usize) = row_opt.unwrap(); @@ -874,40 +885,37 @@ mod test { let long_string = ::std::iter::repeat('A') .take(18 * 1024 * 1024) .collect::(); - let conn = Conn::new(get_opts()).await?; + let mut conn = Conn::new(get_opts()).await?; let result = conn .query(format!(r"SELECT '{}', 231", long_string)) .await?; - let (conn, result) = result + let result = result .reduce_and_drop(vec![], move |mut acc, row| { acc.push(from_row(row)); acc }) .await?; conn.disconnect().await?; - assert_eq!((long_string, 231), result[0]); + assert_eq!((long_string, 231_u8), result[0]); Ok(()) } #[tokio::test] async fn should_drop_query() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let (conn, result) = conn - .drop_query("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)") - .await? - .drop_query("INSERT INTO tmp VALUES (1, 'foo')") - .await? - .first::<_, (u8,)>("SELECT COUNT(*) FROM tmp") + let mut conn = Conn::new(get_opts()).await?; + conn.drop_query("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)") .await?; + conn.drop_query("INSERT INTO tmp VALUES (1, 'foo')").await?; + let result = conn.first::<_, (u8,)>("SELECT COUNT(*) FROM tmp").await?; conn.disconnect().await?; - assert_eq!(result, Some((1,))); + assert_eq!(result, Some((1_u8,))); Ok(()) } #[tokio::test] async fn should_try_collect() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let result = conn + let mut conn = Conn::new(get_opts()).await?; + let mut result = conn .query( r"SELECT 'hello', 123 UNION ALL @@ -917,19 +925,19 @@ mod test { ", ) .await?; - let (result, mut rows) = result.try_collect::<(String, u8)>().await?; + let mut rows = result.try_collect::<(String, u8)>().await?; assert!(rows.pop().unwrap().is_ok()); assert!(rows.pop().unwrap().is_err()); assert!(rows.pop().unwrap().is_ok()); - let conn = result.drop_result().await?; + result.drop_result().await?; conn.disconnect().await?; Ok(()) } #[tokio::test] async fn should_try_collect_and_drop() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let (conn, mut rows) = conn + let mut conn = Conn::new(get_opts()).await?; + let mut rows = conn .query( r"SELECT 'hello', 123 UNION ALL @@ -951,8 +959,8 @@ mod test { #[tokio::test] async fn should_handle_mutliresult_set() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let result = conn + let mut conn = Conn::new(get_opts()).await?; + let mut result = conn .query( r"SELECT 'hello', 123 UNION ALL @@ -961,8 +969,8 @@ mod test { ", ) .await?; - let (result, rows_1) = result.collect::<(String, u8)>().await?; - let (conn, rows_2) = result.collect_and_drop().await?; + let rows_1 = result.collect::<(String, u8)>().await?; + let rows_2 = result.collect_and_drop().await?; conn.disconnect().await?; assert_eq!((String::from("hello"), 123), rows_1[0]); @@ -973,8 +981,8 @@ mod test { #[tokio::test] async fn should_map_resultset() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let result = conn + let mut conn = Conn::new(get_opts()).await?; + let mut result = conn .query( r" SELECT 'hello', 123 @@ -985,8 +993,8 @@ mod test { ) .await?; - let (result, rows_1) = result.map(|row| from_row::<(String, u8)>(row)).await?; - let (conn, rows_2) = result.map_and_drop(from_row).await?; + let rows_1 = result.map(|row| from_row::<(String, u8)>(row)).await?; + let rows_2 = result.map_and_drop(from_row).await?; conn.disconnect().await?; assert_eq!((String::from("hello"), 123), rows_1[0]); @@ -997,8 +1005,8 @@ mod test { #[tokio::test] async fn should_reduce_resultset() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let result = conn + let mut conn = Conn::new(get_opts()).await?; + let mut result = conn .query( r"SELECT 5 UNION ALL @@ -1006,13 +1014,13 @@ mod test { SELECT 7;", ) .await?; - let (result, reduced) = result + let reduced = result .reduce(0, |mut acc, row| { acc += from_row::(row); acc }) .await?; - let (conn, rows_2) = result.collect_and_drop::().await?; + let rows_2 = result.collect_and_drop::().await?; conn.disconnect().await?; assert_eq!(11, reduced); assert_eq!(7, rows_2[0]); @@ -1030,31 +1038,29 @@ mod test { UPDATE time_zone SET Time_zone_id = 1 WHERE Time_zone_id = 1; SELECT 4;"; - let c = Conn::new(get_opts()).await?; - let c = c - .drop_query("CREATE TEMPORARY TABLE time_zone (Time_zone_id INT)") + let mut c = Conn::new(get_opts()).await?; + c.drop_query("CREATE TEMPORARY TABLE time_zone (Time_zone_id INT)") .await .unwrap(); - let t = c.start_transaction(TransactionOptions::new()).await?; - let t = t.drop_query(QUERY).await?; + let mut t = c.start_transaction(TransactionOptions::new()).await?; + t.drop_query(QUERY).await?; let r = t.query(QUERY).await?; - let (t, out) = r.collect_and_drop::().await?; + let out = r.collect_and_drop::().await?; assert_eq!(vec![1], out); let r = t.query(QUERY).await?; - let t = r - .for_each_and_drop(|x| assert_eq!(from_row::(x), 1)) + r.for_each_and_drop(|x| assert_eq!(from_row::(x), 1)) .await?; let r = t.query(QUERY).await?; - let (t, out) = r.map_and_drop(|row| from_row::(row)).await?; + let out = r.map_and_drop(|row| from_row::(row)).await?; assert_eq!(vec![1], out); let r = t.query(QUERY).await?; - let (t, out) = r + let out = r .reduce_and_drop(0u8, |acc, x| acc + from_row::(x)) .await?; assert_eq!(1, out); - let t = t.query(QUERY).await?.drop_result().await?; - let c = t.commit().await?; - let (c, result) = c.first_exec::<_, _, u8>("SELECT 1", ()).await?; + t.query(QUERY).await?.drop_result().await?; + t.commit().await?; + let result = c.first_exec::<_, _, u8>("SELECT 1", ()).await?; c.disconnect().await?; assert_eq!(result, Some(1)); Ok(()) @@ -1069,8 +1075,8 @@ mod test { let acc = Arc::new(AtomicUsize::new(0)); - let conn = Conn::new(get_opts()).await?; - let result = conn + let mut conn = Conn::new(get_opts()).await?; + let mut result = conn .query( r"SELECT 2 UNION ALL @@ -1078,7 +1084,7 @@ mod test { SELECT 5;", ) .await?; - let result = result + result .for_each({ let acc = acc.clone(); move |row| { @@ -1086,7 +1092,7 @@ mod test { } }) .await?; - let conn = result + result .for_each_and_drop({ let acc = acc.clone(); move |row| { @@ -1101,23 +1107,13 @@ mod test { #[tokio::test] async fn should_prepare_statement() -> super::Result<()> { - Conn::new(get_opts()) - .await? - .prepare(r"SELECT ?") - .await? - .close() - .await? - .disconnect() - .await?; + let mut conn = Conn::new(get_opts()).await?; + conn.prepare(r"SELECT ?").await?.close().await?; + conn.disconnect().await?; - Conn::new(get_opts()) - .await? - .prepare(r"SELECT :foo") - .await? - .close() - .await? - .disconnect() - .await?; + let mut conn = Conn::new(get_opts()).await?; + conn.prepare(r"SELECT :foo").await?.close().await?; + conn.disconnect().await?; Ok(()) } @@ -1126,33 +1122,34 @@ mod test { let long_string = ::std::iter::repeat('A') .take(18 * 1024 * 1024) .collect::(); - let conn = Conn::new(get_opts()).await?; - let stmt = conn.prepare(r"SELECT ?").await?; + let mut conn = Conn::new(get_opts()).await?; + let mut stmt = conn.prepare(r"SELECT ?").await?; let result = stmt.execute((&long_string,)).await?; - let (stmt, mut mapped) = result + let mut mapped = result .map_and_drop(|row| from_row::<(String,)>(row)) .await?; assert_eq!(mapped.len(), 1); assert_eq!(mapped.pop(), Some((long_string,))); - let result = stmt.execute((42,)).await?; - let (stmt, collected) = result.collect_and_drop::<(u8,)>().await?; + let result = stmt.execute((42_u8,)).await?; + let collected = result.collect_and_drop::<(u8,)>().await?; assert_eq!(collected, vec![(42u8,)]); - let result = stmt.execute((8,)).await?; - let (stmt, reduced) = result + let result = stmt.execute((8_u8,)).await?; + let reduced = result .reduce_and_drop(2, |mut acc, row| { acc += from_row::(row); acc }) .await?; - stmt.close().await?.disconnect().await?; + stmt.close().await?; + conn.disconnect().await?; assert_eq!(reduced, 10); - let conn = Conn::new(get_opts()).await?; - let stmt = conn.prepare(r"SELECT :foo, :bar, :foo, 3").await?; + let mut conn = Conn::new(get_opts()).await?; + let mut stmt = conn.prepare(r"SELECT :foo, :bar, :foo, 3").await?; let result = stmt .execute(params! { "foo" => "quux", "bar" => "baz" }) .await?; - let (stmt, mut mapped) = result + let mut mapped = result .map_and_drop(|row| from_row::<(String, String, String, u8)>(row)) .await?; assert_eq!(mapped.len(), 1); @@ -1161,27 +1158,28 @@ mod test { Some(("quux".into(), "baz".into(), "quux".into(), 3)) ); let result = stmt.execute(params! { "foo" => 2, "bar" => 3 }).await?; - let (stmt, collected) = result.collect_and_drop::<(u8, u8, u8, u8)>().await?; + let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?; assert_eq!(collected, vec![(2, 3, 2, 3)]); let result = stmt.execute(params! { "foo" => 2, "bar" => 3 }).await?; - let (stmt, reduced) = result + let reduced = result .reduce_and_drop(0, |acc, row| { let (a, b, c, d): (u8, u8, u8, u8) = from_row(row); acc + a + b + c + d }) .await?; - stmt.close().await?.disconnect().await?; + stmt.close().await?; + conn.disconnect().await?; assert_eq!(reduced, 10); Ok(()) } #[tokio::test] async fn should_prep_exec_statement() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; + let mut conn = Conn::new(get_opts()).await?; let result = conn .prep_exec(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 }) .await?; - let (conn, output) = result + let output = result .map_and_drop(|row| { let (a, b, c): (u8, u8, u8) = from_row(row); a * b * c @@ -1194,35 +1192,33 @@ mod test { #[tokio::test] async fn should_first_exec_statement() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let (conn, output): (_, Option<(u8,)>) = conn + let mut conn = Conn::new(get_opts()).await?; + let output = conn .first_exec( r"SELECT :a UNION ALL SELECT :b", params! { "a" => 2, "b" => 3 }, ) .await?; conn.disconnect().await?; - assert_eq!(output.unwrap(), (2u8,)); + assert_eq!(output, Some(2u8)); Ok(()) } #[tokio::test] async fn issue_107() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let conn = conn - .drop_query( - r"CREATE TEMPORARY TABLE mysql.issue ( + let mut conn = Conn::new(get_opts()).await?; + conn.drop_query( + r"CREATE TEMPORARY TABLE mysql.issue ( a BIGINT(20) UNSIGNED, b VARBINARY(16), c BINARY(32), d BIGINT(20) UNSIGNED, e BINARY(32) )", - ) - .await?; - let conn = conn - .drop_query( - r"INSERT INTO mysql.issue VALUES ( + ) + .await?; + conn.drop_query( + r"INSERT INTO mysql.issue VALUES ( 0, 0xC066F966B0860000, 0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44, @@ -1235,14 +1231,14 @@ mod test { 0, '' )", - ) - .await?; + ) + .await?; let q = "SELECT b, c, d, e FROM mysql.issue"; let result = conn.query(q).await?; - let (conn, loaded_structs) = result - .map_and_drop(|row| crate::from_row::<(Vec, Vec, u64, Vec)>(dbg!(row))) + let loaded_structs = result + .map_and_drop(|row| crate::from_row::<(Vec, Vec, u64, Vec)>(row)) .await?; conn.disconnect().await?; @@ -1254,29 +1250,36 @@ mod test { #[tokio::test] async fn should_run_transactions() -> super::Result<()> { - let conn = Conn::new(get_opts()).await?; - let conn = conn - .drop_query("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)") + let mut conn = Conn::new(get_opts()).await?; + conn.drop_query("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)") .await?; - let transaction = conn.start_transaction(Default::default()).await?; - let conn = transaction + let mut transaction = conn.start_transaction(Default::default()).await?; + transaction .drop_query("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')") - .await? - .commit() .await?; - let (conn, output_opt) = conn.first("SELECT COUNT(*) FROM tmp").await?; + transaction.commit().await?; + let output_opt = conn.first("SELECT COUNT(*) FROM tmp").await?; assert_eq!(output_opt, Some((2u8,))); - let transaction = conn.start_transaction(Default::default()).await?; - let transaction = transaction + let mut transaction = conn.start_transaction(Default::default()).await?; + transaction .drop_query("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')") .await?; - let (t, output_opt) = transaction + let output_opt = transaction .first_exec("SELECT COUNT(*) FROM tmp", ()) .await?; assert_eq!(output_opt, Some((4u8,))); - let conn = t.rollback().await?; - let (conn, output_opt) = conn.first("SELECT COUNT(*) FROM tmp").await?; + transaction.rollback().await?; + let output_opt = conn.first("SELECT COUNT(*) FROM tmp").await?; + assert_eq!(output_opt, Some((2u8,))); + + let mut transaction = conn.start_transaction(Default::default()).await?; + transaction + .drop_query("INSERT INTO tmp VALUES (3, 'baz')") + .await?; + drop(transaction); // implicit rollback + let output_opt = conn.first("SELECT COUNT(*) FROM tmp").await?; assert_eq!(output_opt, Some((2u8,))); + conn.disconnect().await?; Ok(()) } @@ -1292,9 +1295,8 @@ mod test { let mut opts = OptsBuilder::from_opts(get_opts()); opts.local_infile_handler(Some(WhiteListFsLocalInfileHandler::new(&[file_name][..]))); - let conn = Conn::new(opts).await.unwrap(); - let conn = conn - .drop_query("CREATE TEMPORARY TABLE tmp (a TEXT);") + let mut conn = Conn::new(opts).await.unwrap(); + conn.drop_query("CREATE TEMPORARY TABLE tmp (a TEXT);") .await .unwrap(); @@ -1302,14 +1304,14 @@ mod test { let _ = file.write(b"AAAAAA\n"); let _ = file.write(b"BBBBBB\n"); let _ = file.write(b"CCCCCC\n"); - let conn = match conn - .drop_query(dbg!(format!( + match conn + .drop_query(format!( r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#, file_name.display() - ))) + )) .await { - Ok(conn) => conn, + Ok(_) => (), Err(super::Error::Server(ref err)) if err.code == 1148 => { // The used command is not allowed with this MySQL version return Ok(()); @@ -1321,7 +1323,7 @@ mod test { } e @ Err(_) => e.unwrap(), }; - let (conn, result) = conn + let result = conn .prep_exec("SELECT * FROM tmp;", ()) .await .unwrap() diff --git a/src/conn/pool/futures/disconnect_pool.rs b/src/conn/pool/futures/disconnect_pool.rs index 67421242..8af5f7c6 100644 --- a/src/conn/pool/futures/disconnect_pool.rs +++ b/src/conn/pool/futures/disconnect_pool.rs @@ -19,17 +19,20 @@ use crate::{ use std::sync::{atomic, Arc}; -/// Future that disconnects this pool from server and resolves to `()`. +/// Future that disconnects this pool from a server and resolves to `()`. /// -/// Active connections taken from this pool should be disconnected manually. -/// Also all pending and new `GetConn`'s will resolve to error. +/// +/// **Note:** This Future won't resolve until all active connections, taken from it, +/// are dropped or disonnected. Also all pending and new `GetConn`'s will resolve to error. pub struct DisconnectPool { pool_inner: Arc, } -pub fn new(pool: Pool) -> DisconnectPool { - DisconnectPool { - pool_inner: pool.inner, +impl DisconnectPool { + pub(crate) fn new(pool: Pool) -> Self { + Self { + pool_inner: pool.inner, + } } } diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 9b668125..3b9fd771 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -23,7 +23,7 @@ pub(crate) enum GetConnInner { New, Done(Option), // TODO: one day this should be an existential - Connecting(BoxFuture), + Connecting(BoxFuture<'static, Conn>), } impl GetConnInner { diff --git a/src/conn/pool/futures/mod.rs b/src/conn/pool/futures/mod.rs index 050d4f31..24b36233 100644 --- a/src/conn/pool/futures/mod.rs +++ b/src/conn/pool/futures/mod.rs @@ -8,7 +8,7 @@ pub(super) use self::get_conn::GetConnInner; pub use self::{ - disconnect_pool::{new as new_disconnect_pool, DisconnectPool}, + disconnect_pool::DisconnectPool, get_conn::{new as new_get_conn, GetConn}, }; diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index d0bba8e3..7aa6294c 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -22,10 +22,7 @@ use crate::{ conn::{pool::futures::*, Conn}, error::*, opts::{Opts, PoolOptions}, - queryable::{ - transaction::{Transaction, TransactionOptions}, - Queryable, - }, + queryable::transaction::TxStatus, }; mod recycler; @@ -97,7 +94,7 @@ pub struct Inner { #[derive(Clone)] /// Asynchronous pool of MySql connections. /// -/// Note that you will probably want to await `Pool::disconnect` before dropping the runtime, as +/// Note that you will probably want to await [`Pool::disconnect`] before dropping the runtime, as /// otherwise you may end up with a number of connections that are not cleanly terminated. pub struct Pool { opts: Opts, @@ -112,7 +109,7 @@ impl fmt::Debug for Pool { } impl Pool { - /// Creates new pool of connections. + /// Creates a new pool of connections. pub fn new>(opts: O) -> Pool { let opts = opts.into(); let pool_options = opts.get_pool_options().clone(); @@ -133,29 +130,21 @@ impl Pool { } } - /// Creates new pool of connections. + /// Creates a new pool of connections. pub fn from_url>(url: T) -> Result { let opts = Opts::from_str(url.as_ref())?; Ok(Pool::new(opts)) } - /// Returns future that resolves to `Conn`. + /// Returns a future that resolves to [`Conn`]. pub fn get_conn(&self) -> GetConn { new_get_conn(self) } - /// Shortcut for `get_conn` followed by `start_transaction`. - pub async fn start_transaction( - &self, - options: TransactionOptions, - ) -> Result> { - Queryable::start_transaction(self.get_conn().await?, options).await - } - - /// Returns future that disconnects this pool from server and resolves to `()`. + /// Returns a future that disconnects this pool from the server and resolves to `()`. /// - /// Active connections taken from this pool should be disconnected manually. - /// Also all pending and new `GetConn`'s will resolve to error. + /// **Note:** This Future won't resolve until all active connections, taken from it, + /// are dropped or disonnected. Also all pending and new `GetConn`'s will resolve to error. pub fn disconnect(self) -> DisconnectPool { let was_closed = self.inner.close.swap(true, atomic::Ordering::AcqRel); if !was_closed { @@ -166,7 +155,7 @@ impl Pool { let _ = self.drop.send(None).is_ok(); } - new_disconnect_pool(self) + DisconnectPool::new(self) } /// A way to return connection taken from a pool. @@ -178,7 +167,7 @@ impl Pool { if conn.inner.stream.is_some() && !conn.inner.disconnected && !conn.expired() - && !conn.inner.in_transaction + && conn.inner.tx_status == TxStatus::None && conn.inner.has_result.is_none() && !self.inner.close.load(atomic::Ordering::Acquire) { @@ -277,6 +266,10 @@ impl Pool { impl Drop for Conn { fn drop(&mut self) { + if std::thread::panicking() { + return; + } + if let Some(mut pool) = self.inner.pool.take() { pool.return_conn(self.take()); } else if self.inner.stream.is_some() && !self.inner.disconnected { @@ -377,19 +370,20 @@ mod test { .await? .drop_query("CREATE TABLE IF NOT EXISTS tmp(id int)") .await?; - let _ = pool + let mut conn = pool.get_conn().await?; + let mut tx = conn .start_transaction(TransactionOptions::default()) - .await? - .batch_exec("INSERT INTO tmp (id) VALUES (?)", vec![(1,), (2,)]) - .await? - .prep_exec("SELECT * FROM tmp", ()) .await?; + tx.batch_exec("INSERT INTO tmp (id) VALUES (?)", vec![(1_u8,), (2_u8,)]) + .await?; + tx.prep_exec("SELECT * FROM tmp", ()).await?; + drop(tx); + drop(conn); let row_opt = pool .get_conn() .await? .first("SELECT COUNT(*) FROM tmp") - .await? - .1; + .await?; assert_eq!(row_opt, Some((0u8,))); pool.get_conn().await?.drop_query("DROP TABLE tmp").await?; pool.disconnect().await?; @@ -397,11 +391,10 @@ mod test { } #[tokio::test] - async fn aa_should_hold_bounds2() -> super::Result<()> { - use std::cmp::min; - + async fn should_check_inactive_connection_ttl() -> super::Result<()> { const POOL_MIN: usize = 5; const POOL_MAX: usize = 10; + const INACTIVE_CONNECTION_TTL: Duration = Duration::from_millis(500); const TTL_CHECK_INTERVAL: Duration = Duration::from_secs(1); @@ -417,6 +410,44 @@ mod test { let pool_clone = pool.clone(); let conns = (0..POOL_MAX).map(|_| pool.get_conn()).collect::>(); + let conns = try_join_all(conns).await?; + + assert_eq!(ex_field!(pool_clone, exist), POOL_MAX); + drop(conns); + + // wait for a bit to let the connections be reclaimed + tokio::time::delay_for(std::time::Duration::from_millis(100)).await; + + // check that connections are still in the pool because of inactive_connection_ttl + assert_eq!(ex_field!(pool_clone, available).len(), POOL_MAX); + + // then, wait for ttl_check_interval + tokio::time::delay_for(TTL_CHECK_INTERVAL).await; + + // check that we have the expected number of connections + assert_eq!(ex_field!(pool_clone, available).len(), POOL_MIN); + + Ok(()) + } + + #[tokio::test] + async fn aa_should_hold_bounds2() -> super::Result<()> { + use std::cmp::min; + + const POOL_MIN: usize = 5; + const POOL_MAX: usize = 10; + + let constraints = PoolConstraints::new(POOL_MIN, POOL_MAX).unwrap(); + let pool_options = PoolOptions::with_constraints(constraints); + + // Clean + let mut opts = get_opts(); + opts.pool_options(pool_options); + + let pool = Pool::new(opts); + let pool_clone = pool.clone(); + let conns = (0..POOL_MAX).map(|_| pool.get_conn()).collect::>(); + let mut conns = try_join_all(conns).await?; // we want to continuously drop connections @@ -444,15 +475,6 @@ mod test { let idle = min(dropped, POOL_MIN); let expected = conns.len() + idle; - if dropped > POOL_MIN { - // check that connection is still in the pool because of inactive_connection_ttl - let have = ex_field!(pool_clone, exist); - assert_eq!(have, expected + 1); - - // then, wait for ttl_check_interval - tokio::time::delay_for(TTL_CHECK_INTERVAL + Duration::from_millis(50)).await; - } - // check that we have the expected number of connections let have = ex_field!(pool_clone, exist); assert_eq!(have, expected); @@ -507,61 +529,34 @@ mod test { async fn zz_should_check_wait_timeout_on_get_conn() -> super::Result<()> { let pool = Pool::new(get_opts()); - let conn = pool.get_conn().await?; - let (conn, wait_timeout_orig) = conn.first::<_, usize>("SELECT @@wait_timeout").await?; - conn.drop_query("SET GLOBAL wait_timeout = 3") - .await? - .disconnect() - .await?; + let mut conn = pool.get_conn().await?; + let wait_timeout_orig = conn.first::<_, usize>("SELECT @@wait_timeout").await?; + conn.drop_query("SET GLOBAL wait_timeout = 3").await?; + conn.disconnect().await?; - let conn = pool.get_conn().await?; - let (conn, wait_timeout) = conn.first::<_, usize>("SELECT @@wait_timeout").await?; - let (_, id1) = conn.first::<_, usize>("SELECT CONNECTION_ID()").await?; + let mut conn = pool.get_conn().await?; + let wait_timeout = conn.first::<_, usize>("SELECT @@wait_timeout").await?; + let id1 = conn.first::<_, usize>("SELECT CONNECTION_ID()").await?; + drop(conn); assert_eq!(wait_timeout, Some(3)); assert_eq!(ex_field!(pool, exist), 1); tokio::time::delay_for(std::time::Duration::from_secs(6)).await; - let conn = pool.get_conn().await?; - let (conn, id2) = conn.first::<_, usize>("SELECT CONNECTION_ID()").await?; + let mut conn = pool.get_conn().await?; + let id2 = conn.first::<_, usize>("SELECT CONNECTION_ID()").await?; assert_eq!(ex_field!(pool, exist), 1); assert_ne!(id1, id2); conn.drop_exec("SET GLOBAL wait_timeout = ?", (wait_timeout_orig,)) .await?; + drop(conn); - pool.disconnect().await - } - - /* - #[test] - fn should_hold_bounds_on_get_conn_drop() { - let pool = Pool::new(format!("{}?pool_min=1&pool_max=2", get_opts())); - let mut runtime = tokio::runtime::Runtime::new().unwrap(); + pool.disconnect().await?; - // This test is a bit more intricate: we need to poll the connection future once to get the - // pool to set it up, then drop it and make sure that the `exist` count is updated. - // - // We wrap all of it in a lazy future to get us into the tokio context that deals with - // setting up tasks. There might be a better way to do this but I don't remember right - // now. Besides, std::future is just around the corner making this obsolete. - // - // It depends on implementation details of GetConn, but that should be fine. - runtime - .block_on(future::lazy(move || { - let mut conn = pool.get_conn(); - assert_eq!(pool.inner.exist.load(atomic::Ordering::SeqCst), 0); - let result = conn.poll().expect("successful first poll"); - assert!(result.is_not_ready(), "not ready after first poll"); - assert_eq!(pool.inner.exist.load(atomic::Ordering::SeqCst), 1); - drop(conn); - assert_eq!(pool.inner.exist.load(atomic::Ordering::SeqCst), 0); - Ok::<(), ()>(()) - })) - .unwrap(); + Ok(()) } - */ #[tokio::test] async fn droptest() -> super::Result<()> { @@ -607,7 +602,7 @@ mod test { let (tx, rx) = tokio::sync::oneshot::channel(); rt.block_on(async move { let pool = Pool::new(get_opts()); - let c = pool.get_conn().await.unwrap(); + let mut c = pool.get_conn().await.unwrap(); tokio::spawn(async move { let _ = rx.await; let _ = c.drop_query("SELECT 1").await; @@ -627,7 +622,7 @@ mod test { let (tx, rx) = tokio::sync::oneshot::channel(); let jh = rt.spawn(async move { let pool = Pool::new(get_opts()); - let c = pool.get_conn().await.unwrap(); + let mut c = pool.get_conn().await.unwrap(); tokio::spawn(async move { let _ = rx.await; let _ = c.drop_query("SELECT 1").await; @@ -639,26 +634,6 @@ mod test { } } - /* - #[test] - #[ignore] - fn should_not_panic_if_dropped_without_tokio_runtime() { - // NOTE: this test does not work anymore, since the runtime won't be idle until either - // - // - all Pools and Conns are dropped; OR - // - Pool::disconnect is called; OR - // - Runtime::shutdown_now is called - // - // none of these are true in this test, which is why it's been ignored - let pool = Pool::new(get_opts()); - run(collect( - (0..10).map(|_| pool.get_conn()).collect::>(), - )) - .unwrap(); - // pool will drop here - } - */ - #[cfg(feature = "nightly")] mod bench { use futures_util::{future::FutureExt, try_future::TryFutureExt}; diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index 439e628a..98887995 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -18,14 +18,14 @@ use std::{ }; use super::{IdlingConn, Inner}; -use crate::{BoxFuture, Conn, PoolOptions}; +use crate::{queryable::transaction::TxStatus, BoxFuture, Conn, PoolOptions}; use tokio::sync::mpsc::UnboundedReceiver; -pub struct Recycler { +pub(crate) struct Recycler { inner: Arc, - discard: FuturesUnordered>, + discard: FuturesUnordered>, discarded: usize, - cleaning: FuturesUnordered>, + cleaning: FuturesUnordered>, // Option so that we have a way to send a "I didn't make a Conn after all" signal dropped: mpsc::UnboundedReceiver>, @@ -63,7 +63,9 @@ impl Future for Recycler { if $conn.inner.stream.is_none() || $conn.inner.disconnected { // drop unestablished connection $self.discard.push(Box::pin(::futures_util::future::ok(()))); - } else if $conn.inner.in_transaction || $conn.inner.has_result.is_some() { + } else if $conn.inner.tx_status != TxStatus::None + || $conn.inner.has_result.is_some() + { $self.cleaning.push(Box::pin($conn.cleanup())); } else if $conn.expired() || close { $self.discard.push(Box::pin($conn.close())); @@ -73,7 +75,7 @@ impl Future for Recycler { drop(exchange); $self.discard.push(Box::pin($conn.close())); } else { - exchange.available.push_back($conn.into()); + exchange.available.push_back(dbg!($conn.into())); if let Some(w) = exchange.waiting.pop_front() { w.wake(); } diff --git a/src/conn/pool/ttl_check_inerval.rs b/src/conn/pool/ttl_check_inerval.rs index 1eebdb32..71e62e3e 100644 --- a/src/conn/pool/ttl_check_inerval.rs +++ b/src/conn/pool/ttl_check_inerval.rs @@ -19,7 +19,7 @@ use std::{ }; use super::Inner; -use crate::{prelude::Queryable, PoolOptions}; +use crate::PoolOptions; use futures_core::task::{Context, Poll}; use std::pin::Pin; @@ -29,7 +29,7 @@ use std::pin::Pin; /// * overflows min bound of the pool; /// * idles longer then `inactive_connection_ttl`. #[pin_project] -pub struct TtlCheckInterval { +pub(crate) struct TtlCheckInterval { inner: Arc, #[pin] interval: StreamFuture, diff --git a/src/conn/stmt_cache.rs b/src/conn/stmt_cache.rs index b608f517..3a1023b9 100644 --- a/src/conn/stmt_cache.rs +++ b/src/conn/stmt_cache.rs @@ -12,7 +12,7 @@ use twox_hash::XxHash; use std::collections::vec_deque::Iter; use std::{ borrow::Borrow, - collections::{hash_map::IntoIter, HashMap, VecDeque}, + collections::{HashMap, VecDeque}, hash::{BuildHasherDefault, Hash}, }; @@ -26,7 +26,7 @@ pub struct StmtCache { } impl StmtCache { - pub fn new(cap: usize) -> StmtCache { + pub(crate) fn new(cap: usize) -> StmtCache { StmtCache { cap, map: Default::default(), @@ -34,7 +34,7 @@ impl StmtCache { } } - pub fn get(&mut self, key: &T) -> Option<&InnerStmt> + pub(crate) fn get(&mut self, key: &T) -> Option<&InnerStmt> where String: Borrow, String: PartialEq, @@ -53,7 +53,7 @@ impl StmtCache { } } - pub fn put(&mut self, key: String, value: InnerStmt) -> Option { + pub(crate) fn put(&mut self, key: String, value: InnerStmt) -> Option { self.map.insert(key.clone(), value); self.order.push_back(key); if self.order.len() > self.cap { @@ -65,21 +65,13 @@ impl StmtCache { } } - pub fn clear(&mut self) { + pub(crate) fn clear(&mut self) { self.map.clear(); self.order.clear(); } #[cfg(test)] - pub fn iter<'a>(&'a self) -> Iter<'a, String> { + pub(crate) fn iter<'a>(&'a self) -> Iter<'a, String> { self.order.iter() } - - pub fn into_iter(self) -> IntoIter { - self.map.into_iter() - } - - pub fn get_cap(&self) -> usize { - self.cap - } } diff --git a/src/connection_like/mod.rs b/src/connection_like/mod.rs index 0cfff6d3..6923ef83 100644 --- a/src/connection_like/mod.rs +++ b/src/connection_like/mod.rs @@ -6,7 +6,6 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use crate::conn::PendingResult; use futures_util::future::ok; use mysql_common::{ io::ReadMysqlExt, @@ -17,40 +16,21 @@ use tokio::prelude::*; use std::{borrow::Cow, sync::Arc}; use crate::{ - conn::{named_params::parse_named_params, stmt_cache::StmtCache}, - connection_like::{read_packet::ReadPacket, streamless::Streamless, write_packet::WritePacket}, + conn::{named_params::parse_named_params, stmt_cache::StmtCache, PendingResult}, + connection_like::{ + read_packet::{ReadPacket, ReadPackets}, + write_packet2::WritePacket2, + }, consts::{CapabilityFlags, Command, StatusFlags}, error::*, io, local_infile_handler::LocalInfileHandler, - queryable::{ - query_result::{self, QueryResult}, - stmt::InnerStmt, - Protocol, - }, + queryable::{query_result::QueryResult, stmt::InnerStmt, transaction::TxStatus, Protocol}, BoxFuture, Opts, }; pub mod read_packet; -pub mod streamless { - use super::ConnectionLike; - use crate::io::Stream; - - #[derive(Debug)] - pub struct Streamless(T); - - impl Streamless { - pub fn new(x: T) -> Streamless { - Streamless(x) - } - - pub fn return_stream(mut self, stream: Stream) -> T { - self.0.return_stream(stream); - self.0 - } - } -} -pub mod write_packet; +pub mod write_packet2; #[derive(Debug, Clone, Copy)] pub enum StmtCacheResult { @@ -58,134 +38,13 @@ pub enum StmtCacheResult { NotCached(u32), } -pub trait ConnectionLikeWrapper { - type ConnLike: ConnectionLike; - - fn take_stream(self) -> (Streamless, io::Stream) - where - Self: Sized; - fn return_stream(&mut self, stream: io::Stream) -> (); - fn conn_like_ref(&self) -> &Self::ConnLike; - - fn conn_like_mut(&mut self) -> &mut Self::ConnLike; -} - -impl ConnectionLike for T -where - T: ConnectionLikeWrapper, - T: Send, - U: ConnectionLike + 'static, -{ - fn take_stream(self) -> (Streamless, io::Stream) - where - Self: Sized, - { - ::take_stream(self) - } - - fn return_stream(&mut self, stream: io::Stream) { - ::return_stream(self, stream) - } - - fn stmt_cache_ref(&self) -> &StmtCache { - self.conn_like_ref().stmt_cache_ref() - } - - fn stmt_cache_mut(&mut self) -> &mut StmtCache { - self.conn_like_mut().stmt_cache_mut() - } - - fn get_affected_rows(&self) -> u64 { - self.conn_like_ref().get_affected_rows() - } - - fn get_capabilities(&self) -> CapabilityFlags { - self.conn_like_ref().get_capabilities() - } - - fn get_in_transaction(&self) -> bool { - self.conn_like_ref().get_in_transaction() - } - - fn get_last_insert_id(&self) -> Option { - self.conn_like_ref().get_last_insert_id() - } - - fn get_info(&self) -> Cow<'_, str> { - self.conn_like_ref().get_info() - } - - fn get_warnings(&self) -> u16 { - self.conn_like_ref().get_warnings() - } - - fn get_local_infile_handler(&self) -> Option> { - self.conn_like_ref().get_local_infile_handler() - } - - fn get_max_allowed_packet(&self) -> usize { - self.conn_like_ref().get_max_allowed_packet() - } - - fn get_opts(&self) -> &Opts { - self.conn_like_ref().get_opts() - } - - fn get_pending_result(&self) -> Option<&PendingResult> { - self.conn_like_ref().get_pending_result() - } - - fn get_server_version(&self) -> (u16, u16, u16) { - self.conn_like_ref().get_server_version() - } - - fn get_status(&self) -> StatusFlags { - self.conn_like_ref().get_status() - } - - fn set_last_ok_packet(&mut self, ok_packet: Option>) { - self.conn_like_mut().set_last_ok_packet(ok_packet); - } - - fn set_in_transaction(&mut self, in_transaction: bool) { - self.conn_like_mut().set_in_transaction(in_transaction); - } - - fn set_pending_result(&mut self, meta: Option) { - self.conn_like_mut().set_pending_result(meta); - } - - fn set_status(&mut self, status: StatusFlags) { - self.conn_like_mut().set_status(status); - } - - fn reset_seq_id(&mut self) { - self.conn_like_mut().reset_seq_id(); - } - - fn sync_seq_id(&mut self) { - self.conn_like_mut().sync_seq_id(); - } - - fn touch(&mut self) { - self.conn_like_mut().touch(); - } - - fn on_disconnect(&mut self) { - self.conn_like_mut().on_disconnect(); - } -} - -pub trait ConnectionLike: Send { - fn take_stream(self) -> (Streamless, io::Stream) - where - Self: Sized; - fn return_stream(&mut self, stream: io::Stream) -> (); +pub trait ConnectionLike: Send + Sized { + fn stream_mut(&mut self) -> &mut io::Stream; fn stmt_cache_ref(&self) -> &StmtCache; fn stmt_cache_mut(&mut self) -> &mut StmtCache; fn get_affected_rows(&self) -> u64; fn get_capabilities(&self) -> CapabilityFlags; - fn get_in_transaction(&self) -> bool; + fn get_tx_status(&self) -> TxStatus; fn get_last_insert_id(&self) -> Option; fn get_info(&self) -> Cow<'_, str>; fn get_warnings(&self) -> u16; @@ -196,7 +55,7 @@ pub trait ConnectionLike: Send { fn get_server_version(&self) -> (u16, u16, u16); fn get_status(&self) -> StatusFlags; fn set_last_ok_packet(&mut self, ok_packet: Option>); - fn set_in_transaction(&mut self, in_transaction: bool); + fn set_tx_status(&mut self, tx_statux: TxStatus); fn set_pending_result(&mut self, meta: Option); fn set_status(&mut self, status: StatusFlags); fn reset_seq_id(&mut self); @@ -204,21 +63,26 @@ pub trait ConnectionLike: Send { fn touch(&mut self) -> (); fn on_disconnect(&mut self); - fn cache_stmt(mut self, query: String, stmt: &InnerStmt) -> BoxFuture<(Self, StmtCacheResult)> + fn cache_stmt<'a>( + &'a mut self, + query: String, + stmt: &InnerStmt, + ) -> BoxFuture<'a, StmtCacheResult> where - Self: Sized + 'static, + Self: Sized, { if self.get_opts().get_stmt_cache_size() > 0 { if let Some(old_stmt) = self.stmt_cache_mut().put(query, stmt.clone()) { - let f = self.close_stmt(old_stmt.statement_id); - Box::pin(async move { Ok((f.await?, StmtCacheResult::Cached)) }) + Box::pin(async move { + self.close_stmt(old_stmt.statement_id).await?; + Ok(StmtCacheResult::Cached) + }) } else { - Box::pin(futures_util::future::ok((self, StmtCacheResult::Cached))) + Box::pin(futures_util::future::ok(StmtCacheResult::Cached)) } } else { - Box::pin(futures_util::future::ok(( - self, - StmtCacheResult::NotCached(stmt.statement_id), + Box::pin(futures_util::future::ok(StmtCacheResult::NotCached( + stmt.statement_id, ))) } } @@ -227,55 +91,33 @@ pub trait ConnectionLike: Send { self.stmt_cache_mut().get(query) } - /// Returns future that reads packet from a server end resolves to `(Self, Packet)`. - fn read_packet(self) -> ReadPacket - where - Self: Sized + 'static, - { + fn read_packet<'a>(&'a mut self) -> ReadPacket<'a, Self> { ReadPacket::new(self) } - /// Returns future that reads packets from a server and resolves to `(Self, Vec)`. - fn read_packets(self, n: usize) -> BoxFuture<(Self, Vec>)> - where - Self: Sized + 'static, - { - if n == 0 { - return Box::pin(ok((self, Vec::new()))); - } - Box::pin(async move { - let mut acc = Vec::new(); - let mut conn_like = self; - for _ in 0..n { - let (cl, packet) = conn_like.read_packet().await?; - conn_like = cl; - acc.push(packet); - } - Ok((conn_like, acc)) - }) + /// Returns future that reads packets from a server. + fn read_packets<'a>(&'a mut self, n: usize) -> ReadPackets<'a, Self> { + ReadPackets::new(self, n) } - fn prepare_stmt(mut self, query: Q) -> BoxFuture<(Self, InnerStmt, StmtCacheResult)> + fn prepare_stmt<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, (InnerStmt, StmtCacheResult)> where Q: AsRef, - Self: Sized + 'static, + Self: Sized, { match parse_named_params(query.as_ref()) { Ok((named_params, query)) => { let query = query.into_owned(); if let Some(mut inner_stmt) = self.get_cached_stmt(&query).map(Clone::clone) { inner_stmt.named_params = named_params.clone(); - Box::pin(ok((self, inner_stmt, StmtCacheResult::Cached))) + Box::pin(ok((inner_stmt, StmtCacheResult::Cached))) } else { Box::pin(async move { - let (this, packet) = self - .write_command_data(Command::COM_STMT_PREPARE, &*query) - .await? - .read_packet() + self.write_command_data(Command::COM_STMT_PREPARE, &*query) .await?; + let packet = self.read_packet().await?; let mut inner_stmt = InnerStmt::new(&*packet, named_params)?; - let (mut this, packets) = - this.read_packets(inner_stmt.num_params as usize).await?; + let packets = self.read_packets(inner_stmt.num_params as usize).await?; if !packets.is_empty() { let params = packets .into_iter() @@ -285,16 +127,15 @@ pub trait ConnectionLike: Send { } if inner_stmt.num_params > 0 { - if !this + if !self .get_capabilities() .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) { - this = this.read_packet().await?.0; + self.read_packet().await?; } } - let (mut this, packets) = - this.read_packets(inner_stmt.num_columns as usize).await?; + let packets = self.read_packets(inner_stmt.num_columns as usize).await?; if !packets.is_empty() { let columns = packets .into_iter() @@ -304,16 +145,16 @@ pub trait ConnectionLike: Send { } if inner_stmt.num_columns > 0 { - if !this + if !self .get_capabilities() .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) { - this = this.read_packet().await?.0; + self.read_packet().await?; } } - let (this, stmt_cache_result) = this.cache_stmt(query, &inner_stmt).await?; - Ok((this, inner_stmt, stmt_cache_result)) + let stmt_cache_result = self.cache_stmt(query, &inner_stmt).await?; + Ok((inner_stmt, stmt_cache_result)) }) } } @@ -321,53 +162,46 @@ pub trait ConnectionLike: Send { } } - fn close_stmt(self, statement_id: u32) -> WritePacket - where - Self: Sized + 'static, - { + fn close_stmt<'a>(&'a mut self, statement_id: u32) -> WritePacket2<'a, Self> { self.write_command_raw(ComStmtClose::new(statement_id).into()) } - /// Returns future that reads result set from a server and resolves to `QueryResult`. - fn read_result_set

(self, cached: Option) -> BoxFuture> + /// Returns future that reads result set from a server. + fn read_result_set<'a, P>( + &'a mut self, + cached: Option, + ) -> BoxFuture<'a, QueryResult<'a, Self, P>> where - Self: Sized + 'static, + Self: Sized, P: Protocol, - P: Send + 'static, { Box::pin(async move { - let (this, packet) = self.read_packet().await?; + let packet = self.read_packet().await?; match packet.get(0) { - Some(0x00) => Ok(query_result::new(this, None, cached)), - Some(0xFB) => handle_local_infile(this, &*packet, cached).await, - _ => handle_result_set(this, &*packet, cached).await, + Some(0x00) => Ok(QueryResult::new(self, None, cached)), + Some(0xFB) => handle_local_infile(self, &*packet, cached).await, + _ => handle_result_set(self, &*packet, cached).await, } }) } - /// Returns future that writes packet to a server end resolves to `Self`. - fn write_packet(self, data: T) -> WritePacket + fn write_packet(&mut self, data: T) -> WritePacket2<'_, Self> where - T: Into>, // TODO: Switch to `AsRef + 'static`? - Self: Sized + 'static, + T: Into>, { - WritePacket::new(self, data) + WritePacket2::new(self, data.into()) } - /// Returns future that sends full command body to a server and resolves to `Self`. - fn write_command_raw(mut self, body: Vec) -> WritePacket - where - Self: Sized + 'static, - { + /// Returns future that sends full command body to a server. + fn write_command_raw<'a>(&'a mut self, body: Vec) -> WritePacket2<'a, Self> { assert!(body.len() > 0); self.reset_seq_id(); self.write_packet(body) } - /// Returns future that writes command to a server and resolves to `Self`. - fn write_command_data(self, cmd: Command, cmd_data: T) -> WritePacket + /// Returns future that writes command to a server. + fn write_command_data(&mut self, cmd: Command, cmd_data: T) -> WritePacket2<'_, Self> where - Self: Sized + 'static, T: AsRef<[u8]>, { let cmd_data = cmd_data.as_ref(); @@ -379,15 +213,14 @@ pub trait ConnectionLike: Send { } /// Will handle local infile packet. -async fn handle_local_infile( - mut this: T, +async fn handle_local_infile<'a, T: ?Sized, P>( + this: &'a mut T, packet: &[u8], cached: Option, -) -> Result> +) -> Result> where - P: Protocol + 'static, - T: ConnectionLike, - T: Send + Sized + 'static, + P: Protocol, + T: ConnectionLike + Sized, { let local_infile = parse_local_infile_packet(&*packet)?; let (local_infile, handler) = match this.get_local_infile_handler() { @@ -399,31 +232,29 @@ where let mut buf = [0; 4096]; loop { let read = reader.read(&mut buf[..]).await?; - this = this.write_packet(&buf[..read]).await?; + this.write_packet(&buf[..read]).await?; if read == 0 { break; } } - let (this, _) = this.read_packet().await?; - Ok(query_result::new(this, None, cached)) + this.read_packet().await?; + Ok(QueryResult::new(this, None, cached)) } /// Will handle result set packet. -async fn handle_result_set( - this: T, +async fn handle_result_set<'a, T: Sized, P>( + this: &'a mut T, mut packet: &[u8], cached: Option, -) -> Result> +) -> Result> where P: Protocol, - P: Send + 'static, T: ConnectionLike, - T: Send + Sized + 'static, { let column_count = packet.read_lenenc_int()?; - let (mut this, packets) = this.read_packets(column_count as usize).await?; + let packets = this.read_packets(column_count as usize).await?; let columns = packets .into_iter() .map(|packet| column_from_payload(packet).map_err(Error::from)) @@ -433,7 +264,7 @@ where .get_capabilities() .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) { - this = this.read_packet().await?.0; + this.read_packet().await?; } if column_count > 0 { @@ -444,9 +275,9 @@ where } None => this.set_pending_result(Some(PendingResult::Text(columns.clone()))), } - Ok(query_result::new(this, Some(columns), cached)) + Ok(QueryResult::new(this, Some(columns), cached)) } else { this.set_pending_result(Some(PendingResult::Empty)); - Ok(query_result::new(this, None, cached)) + Ok(QueryResult::new(this, None, cached)) } } diff --git a/src/connection_like/read_packet.rs b/src/connection_like/read_packet.rs index 58991eba..cdafbcbe 100644 --- a/src/connection_like/read_packet.rs +++ b/src/connection_like/read_packet.rs @@ -6,71 +6,100 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use futures_core::ready; -use futures_util::stream::{StreamExt, StreamFuture}; +use futures_core::{ready, stream::Stream}; use mysql_common::packets::{parse_err_packet, parse_ok_packet, OkPacketKind}; -use pin_project::pin_project; + use std::{ future::Future, + mem, pin::Pin, task::{Context, Poll}, }; -use crate::{ - connection_like::{streamless::Streamless, ConnectionLike}, - consts::StatusFlags, - error::*, - io, -}; +use crate::{connection_like::ConnectionLike, consts::StatusFlags, error::*}; -#[pin_project] -pub struct ReadPacket { - conn_like: Option>, - #[pin] - fut: StreamFuture, +/// Reads some number of packets. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ReadPackets<'a, T: ?Sized> { + conn_like: &'a mut T, + n: usize, + packets: Vec>, } -impl ReadPacket { - pub fn new(conn_like: T) -> Self { - let (incomplete_conn, stream) = conn_like.take_stream(); - ReadPacket { - conn_like: Some(incomplete_conn), - fut: stream.into_future(), +impl<'a, T: ?Sized> ReadPackets<'a, T> { + pub(crate) fn new(conn_like: &'a mut T, n: usize) -> Self { + Self { + conn_like, + n, + packets: Vec::with_capacity(n), } } } -impl Future for ReadPacket { - type Output = Result<(T, Vec)>; +impl<'a, T: ConnectionLike> Future for ReadPackets<'a, T> { + type Output = Result>>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let (packet_opt, stream) = ready!(this.fut.poll(cx)); - let packet_opt = packet_opt.transpose()?; - let mut conn_like = this.conn_like.take().unwrap().return_stream(stream); - match packet_opt { - Some(packet) => { - let kind = if conn_like.get_pending_result().is_some() { - OkPacketKind::ResultSetTerminator - } else { - OkPacketKind::Other - }; - if let Ok(ok_packet) = parse_ok_packet(&*packet, conn_like.get_capabilities(), kind) - { - conn_like.set_status(ok_packet.status_flags()); - conn_like.set_last_ok_packet(Some(ok_packet.into_owned())); - } else if let Ok(err_packet) = - parse_err_packet(&*packet, conn_like.get_capabilities()) - { - conn_like.set_status(StatusFlags::empty()); - conn_like.set_last_ok_packet(None); - return Err(err_packet.into()).into(); - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + if self.n > 0 { + let packet_opt = + ready!(Pin::new(self.conn_like.stream_mut()).poll_next(cx)).transpose()?; + match packet_opt { + Some(packet) => { + let kind = if self.conn_like.get_pending_result().is_some() { + OkPacketKind::ResultSetTerminator + } else { + OkPacketKind::Other + }; + + if let Ok(ok_packet) = + parse_ok_packet(&*packet, self.conn_like.get_capabilities(), kind) + { + self.conn_like.set_status(ok_packet.status_flags()); + self.conn_like + .set_last_ok_packet(Some(ok_packet.into_owned())); + } else if let Ok(err_packet) = + parse_err_packet(&*packet, self.conn_like.get_capabilities()) + { + self.conn_like.set_status(StatusFlags::empty()); + self.conn_like.set_last_ok_packet(None); + return Err(err_packet.into()).into(); + } - conn_like.touch(); - Poll::Ready(Ok((conn_like, packet))) + self.conn_like.touch(); + self.packets.push(packet); + self.n -= 1; + continue; + } + None => { + return Poll::Ready(Err(DriverError::ConnectionClosed.into())); + } + } + } else { + return Poll::Ready(Ok(mem::replace(&mut self.packets, Vec::new()))); } - None => Poll::Ready(Err(DriverError::ConnectionClosed.into())), } } } + +pub struct ReadPacket<'a, T: ?Sized> { + inner: ReadPackets<'a, T>, +} + +impl<'a, T: ?Sized> ReadPacket<'a, T> { + pub(crate) fn new(conn_like: &'a mut T) -> Self { + Self { + inner: ReadPackets::new(conn_like, 1), + } + } +} + +impl<'a, T: ConnectionLike> Future for ReadPacket<'a, T> { + type Output = Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut packets = ready!(Pin::new(&mut self.inner).poll(cx))?; + Poll::Ready(Ok(packets.pop().unwrap())) + } +} diff --git a/src/connection_like/write_packet.rs b/src/connection_like/write_packet.rs deleted file mode 100644 index 67e5befc..00000000 --- a/src/connection_like/write_packet.rs +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2017 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -use futures_core::ready; -use pin_project::pin_project; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use crate::{ - connection_like::{streamless::Streamless, ConnectionLike}, - error::*, - io, -}; - -#[pin_project] -pub struct WritePacket { - conn_like: Option>, - #[pin] - fut: io::futures::WritePacket, -} - -impl WritePacket { - pub fn new>>(conn_like: T, data: U) -> WritePacket { - let (incomplete_conn, stream) = conn_like.take_stream(); - WritePacket { - conn_like: Some(incomplete_conn), - fut: stream.write_packet(data.into()), - } - } -} - -impl Future for WritePacket { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let stream = ready!(this.fut.poll(cx))?; - let mut conn_like = this.conn_like.take().unwrap().return_stream(stream); - conn_like.touch(); - Poll::Ready(Ok(conn_like)) - } -} diff --git a/src/connection_like/write_packet2.rs b/src/connection_like/write_packet2.rs new file mode 100644 index 00000000..9e6d8e40 --- /dev/null +++ b/src/connection_like/write_packet2.rs @@ -0,0 +1,76 @@ +// Copyright (c) 2017 Anatoly Ikorsky +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +use futures_core::ready; +use futures_sink::Sink; + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{connection_like::ConnectionLike, error::*}; + +pub struct WritePacket2<'a, T: ?Sized> { + conn_like: &'a mut T, + data: Option>, +} + +impl<'a, T: ?Sized> WritePacket2<'a, T> { + pub(crate) fn new(conn_like: &'a mut T, data: Vec) -> WritePacket2<'a, T> { + Self { + conn_like, + data: Some(data), + } + } +} + +impl<'a, T> Future for WritePacket2<'a, T> +where + T: ConnectionLike, +{ + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.data.is_some() { + let codec = Pin::new( + self.conn_like + .stream_mut() + .codec + .as_mut() + .expect("must be here"), + ); + ready!(codec.poll_ready(cx))?; + } + + if let Some(data) = self.data.take() { + let codec = Pin::new( + self.conn_like + .stream_mut() + .codec + .as_mut() + .expect("must be here"), + ); + // to get here, stream must be ready + codec.start_send(data)?; + } + + let codec = Pin::new( + self.conn_like + .stream_mut() + .codec + .as_mut() + .expect("must be here"), + ); + + ready!(codec.poll_flush(cx)).map_err(Error::from)?; + + Poll::Ready(Ok(())) + } +} diff --git a/src/io/async_tls.rs b/src/io/async_tls.rs deleted file mode 100644 index def5f755..00000000 --- a/src/io/async_tls.rs +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2016 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -//! This module implements async TLS streams and source -//! of this module is mostly copyed from tokio_tls crate. - -use bytes::buf::BufMut; -use native_tls::{Error, TlsConnector}; -use pin_project::pin_project; -use std::{ - pin::Pin, - task::{Context, Poll}, -}; -use tokio::{io::Error as IoError, prelude::*}; -use tokio_tls::{self}; - -use std::mem::MaybeUninit; - -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -/// -/// A `TlsStream` represents a handshake that has been completed successfully -/// and both the server and the client are ready for receiving and sending -/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written -/// to a `TlsStream` are encrypted when passing through to `S`. -#[pin_project] -#[derive(Debug)] -pub struct TlsStream { - #[pin] - inner: tokio_tls::TlsStream, -} - -impl TlsStream { - /// Get access to the internal `tokio_tls::TlsStream` stream which also - /// transitively allows access to `S`. - pub fn get_ref(&self) -> &tokio_tls::TlsStream { - &self.inner - } -} - -impl AsyncRead for TlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { - self.project().inner.poll_read(cx, buf) - } - - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - - fn poll_read_buf( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut B, - ) -> Poll> - where - B: BufMut, - { - self.project().inner.poll_read_buf(cx, buf) - } -} - -impl AsyncWrite for TlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project().inner.poll_flush(cx) - } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.project().inner.poll_shutdown(cx) - } -} - -/// Connects the provided stream with this connector, assuming the provided -/// domain. -/// -/// This function will internally call `TlsConnector::connect` to connect -/// the stream and returns a future representing the resolution of the -/// connection operation. The returned future will resolve to either -/// `TlsStream` or `Error` depending if it's successful or not. -/// -/// This is typically used for clients who have already established, for -/// example, a TCP connection to a remote server. That stream is then -/// provided here to perform the client half of a connection to a -/// TLS-powered server. -pub async fn connect_async( - connector: &TlsConnector, - domain: &str, - stream: S, -) -> Result, Error> -where - S: AsyncRead + AsyncWrite + Unpin, -{ - let connector = tokio_tls::TlsConnector::from(connector.clone()); - let connect = connector.connect(domain, stream); - Ok(TlsStream { - inner: connect.await?, - }) -} diff --git a/src/io/futures/connecting_tcp_stream.rs b/src/io/futures/connecting_tcp_stream.rs deleted file mode 100644 index 49406efc..00000000 --- a/src/io/futures/connecting_tcp_stream.rs +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2016 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -use futures_util::stream::{FuturesUnordered, StreamExt}; -use tokio::net::TcpStream; -use tokio_util::codec::Framed; - -use std::{io, net::ToSocketAddrs}; - -use crate::{ - error::*, - io::{PacketCodec, Stream}, -}; - -pub async fn new(addr: S) -> Result -where - S: ToSocketAddrs, -{ - match addr.to_socket_addrs() { - Ok(addresses) => { - let mut streams = FuturesUnordered::new(); - - for address in addresses { - streams.push(TcpStream::connect(address)); - } - - let mut err = None; - while let Some(stream) = streams.next().await { - match stream { - Err(e) => { - err = Some(e); - } - Ok(stream) => { - return Ok(Stream { - closed: false, - codec: Box::new(Framed::new(stream.into(), PacketCodec::default())) - .into(), - }); - } - } - } - - if let Some(e) = err { - Err(e.into()) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve to any address", - ) - .into()) - } - } - Err(err) => Err(err.into()), - } -} diff --git a/src/io/futures/mod.rs b/src/io/futures/mod.rs deleted file mode 100644 index ae9bcd07..00000000 --- a/src/io/futures/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) 2016 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -pub use self::{ - connecting_tcp_stream::new as new_connecting_tcp_stream, - write_packet::{new as new_write_packet, WritePacket}, -}; - -mod connecting_tcp_stream; -mod write_packet; diff --git a/src/io/futures/write_packet.rs b/src/io/futures/write_packet.rs deleted file mode 100644 index b5b44ed5..00000000 --- a/src/io/futures/write_packet.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2016 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -use futures_core::ready; -use futures_sink::Sink; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use crate::{error::*, io::Stream}; - -/// Future that writes packet to a [`Stream`] and resolves to a pair of [`Stream`] and MySql's sequence -/// id. -pub struct WritePacket { - data: Option>, - stream: Option, -} - -pub fn new(stream: Stream, data: Vec) -> WritePacket { - WritePacket { - data: Some(data), - stream: Some(stream), - } -} - -impl Future for WritePacket { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.data.is_some() { - ready!(Pin::new(self.stream.as_mut().unwrap().codec.as_mut().unwrap()).poll_ready(cx))?; - } - - if let Some(data) = self.data.take() { - // to get here, stream must be ready - Pin::new(self.stream.as_mut().unwrap().codec.as_mut().unwrap()).start_send(data)?; - } - - ready!(Pin::new(self.stream.as_mut().unwrap().codec.as_mut().unwrap()).poll_flush(cx)) - .map_err(Error::from)?; - - Poll::Ready(Ok(self.stream.take().unwrap())) - } -} diff --git a/src/io/mod.rs b/src/io/mod.rs index b0d99e6b..45af959a 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -8,6 +8,7 @@ use bytes::{BufMut, BytesMut}; use futures_core::{ready, stream}; +use futures_util::stream::{FuturesUnordered, StreamExt}; use mysql_common::proto::codec::PacketCodec as PacketCodecInner; use native_tls::{Certificate, Identity, TlsConnector}; use pin_project::{pin_project, project}; @@ -27,17 +28,8 @@ use std::{ time::Duration, }; -use crate::{ - error::*, - io::{ - futures::{new_connecting_tcp_stream, new_write_packet, WritePacket}, - socket::Socket, - }, - opts::SslOpts, -}; +use crate::{error::*, io::socket::Socket, opts::SslOpts}; -mod async_tls; -pub mod futures; mod socket; #[derive(Debug, Default)] @@ -77,9 +69,9 @@ impl Encoder for PacketCodec { #[pin_project] #[derive(Debug)] -pub enum Endpoint { - Plain(#[pin] TcpStream), - Secure(#[pin] self::async_tls::TlsStream), +pub(crate) enum Endpoint { + Plain(Option), + Secure(#[pin] tokio_tls::TlsStream), Socket(#[pin] Socket), } @@ -95,8 +87,9 @@ impl Endpoint { pub fn set_keepalive_ms(&self, ms: Option) -> Result<()> { let ms = ms.map(|val| Duration::from_millis(u64::from(val))); match *self { - Endpoint::Plain(ref stream) => stream.set_keepalive(ms)?, - Endpoint::Secure(ref stream) => stream.get_ref().get_ref().set_keepalive(ms)?, + Endpoint::Plain(Some(ref stream)) => stream.set_keepalive(ms)?, + Endpoint::Plain(None) => unreachable!(), + Endpoint::Secure(ref stream) => stream.get_ref().set_keepalive(ms)?, Endpoint::Socket(_) => (/* inapplicable */), } Ok(()) @@ -104,17 +97,18 @@ impl Endpoint { pub fn set_tcp_nodelay(&self, val: bool) -> Result<()> { match *self { - Endpoint::Plain(ref stream) => stream.set_nodelay(val)?, - Endpoint::Secure(ref stream) => stream.get_ref().get_ref().set_nodelay(val)?, + Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?, + Endpoint::Plain(None) => unreachable!(), + Endpoint::Secure(ref stream) => stream.get_ref().set_nodelay(val)?, Endpoint::Socket(_) => (/* inapplicable */), } Ok(()) } - pub async fn make_secure(self, domain: String, ssl_opts: SslOpts) -> Result { + pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { if let Endpoint::Socket(_) = self { // inapplicable - return Ok(self); + return Ok(()); } let mut builder = TlsConnector::builder(); @@ -135,20 +129,24 @@ impl Endpoint { } builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs()); - let tls_connector = builder.build()?; - let tls_stream = match self { + let tls_connector: tokio_tls::TlsConnector = builder.build()?.into(); + + *self = match self { Endpoint::Plain(stream) => { - self::async_tls::connect_async(&tls_connector, &*domain, stream).await? + let stream = stream.take().unwrap(); + let tls_stream = tls_connector.connect(&*domain, stream).await?; + Endpoint::Secure(tls_stream) } Endpoint::Secure(_) | Endpoint::Socket(_) => unreachable!(), }; - Ok(Endpoint::Secure(tls_stream)) + + Ok(()) } } impl From for Endpoint { fn from(stream: TcpStream) -> Self { - Endpoint::Plain(stream) + Endpoint::Plain(Some(stream)) } } @@ -158,8 +156,8 @@ impl From for Endpoint { } } -impl From> for Endpoint { - fn from(stream: self::async_tls::TlsStream) -> Self { +impl From> for Endpoint { + fn from(stream: tokio_tls::TlsStream) -> Self { Endpoint::Secure(stream) } } @@ -173,7 +171,9 @@ impl AsyncRead for Endpoint { ) -> Poll> { #[project] match self.project() { - Endpoint::Plain(stream) => stream.poll_read(cx, buf), + Endpoint::Plain(ref mut stream) => { + Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf) + } Endpoint::Secure(stream) => stream.poll_read(cx, buf), Endpoint::Socket(stream) => stream.poll_read(cx, buf), } @@ -181,7 +181,8 @@ impl AsyncRead for Endpoint { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { match self { - Endpoint::Plain(stream) => stream.prepare_uninitialized_buffer(buf), + Endpoint::Plain(Some(stream)) => stream.prepare_uninitialized_buffer(buf), + Endpoint::Plain(None) => unreachable!(), Endpoint::Secure(stream) => stream.prepare_uninitialized_buffer(buf), Endpoint::Socket(stream) => stream.prepare_uninitialized_buffer(buf), } @@ -198,7 +199,9 @@ impl AsyncRead for Endpoint { { #[project] match self.project() { - Endpoint::Plain(stream) => stream.poll_read_buf(cx, buf), + Endpoint::Plain(ref mut stream) => { + Pin::new(stream.as_mut().unwrap()).poll_read_buf(cx, buf) + } Endpoint::Secure(stream) => stream.poll_read_buf(cx, buf), Endpoint::Socket(stream) => stream.poll_read_buf(cx, buf), } @@ -214,7 +217,9 @@ impl AsyncWrite for Endpoint { ) -> Poll> { #[project] match self.project() { - Endpoint::Plain(stream) => stream.poll_write(cx, buf), + Endpoint::Plain(ref mut stream) => { + Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf) + } Endpoint::Secure(stream) => stream.poll_write(cx, buf), Endpoint::Socket(stream) => stream.poll_write(cx, buf), } @@ -224,7 +229,7 @@ impl AsyncWrite for Endpoint { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { #[project] match self.project() { - Endpoint::Plain(stream) => stream.poll_flush(cx), + Endpoint::Plain(ref mut stream) => Pin::new(stream.as_mut().unwrap()).poll_flush(cx), Endpoint::Secure(stream) => stream.poll_flush(cx), Endpoint::Socket(stream) => stream.poll_flush(cx), } @@ -237,17 +242,17 @@ impl AsyncWrite for Endpoint { ) -> Poll> { #[project] match self.project() { - Endpoint::Plain(stream) => stream.poll_shutdown(cx), + Endpoint::Plain(ref mut stream) => Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx), Endpoint::Secure(stream) => stream.poll_shutdown(cx), Endpoint::Socket(stream) => stream.poll_shutdown(cx), } } } -/// Stream connected to MySql server. +/// A Stream, connected to MySql server. pub struct Stream { closed: bool, - codec: Option>>, + pub(crate) codec: Option>>, } impl fmt::Debug for Stream { @@ -270,67 +275,98 @@ impl Stream { } } - pub async fn connect_tcp(addr: S) -> Result + pub(crate) async fn connect_tcp(addr: S) -> Result where S: ToSocketAddrs, { - new_connecting_tcp_stream(addr).await + match addr.to_socket_addrs() { + Ok(addresses) => { + let mut streams = FuturesUnordered::new(); + + for address in addresses { + streams.push(TcpStream::connect(address)); + } + + let mut err = None; + while let Some(stream) = streams.next().await { + match stream { + Err(e) => { + err = Some(e); + } + Ok(stream) => { + return Ok(Stream { + closed: false, + codec: Box::new(Framed::new(stream.into(), PacketCodec::default())) + .into(), + }); + } + } + } + + if let Some(e) = err { + Err(e.into()) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + .into()) + } + } + Err(err) => Err(err.into()), + } } - pub async fn connect_socket>(path: P) -> Result { + pub(crate) async fn connect_socket>(path: P) -> Result { Ok(Stream::new(Socket::new(path).await?)) } - pub fn write_packet(self, data: Vec) -> WritePacket { - new_write_packet(self, data) - } - - pub fn set_keepalive_ms(&self, ms: Option) -> Result<()> { + pub(crate) fn set_keepalive_ms(&self, ms: Option) -> Result<()> { self.codec.as_ref().unwrap().get_ref().set_keepalive_ms(ms) } - pub fn set_tcp_nodelay(&self, val: bool) -> Result<()> { + pub(crate) fn set_tcp_nodelay(&self, val: bool) -> Result<()> { self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val) } - pub async fn make_secure(mut self, domain: String, ssl_opts: SslOpts) -> Result { + pub(crate) async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { let codec = self.codec.take().unwrap(); - let FramedParts { io, codec, .. } = codec.into_parts(); - let endpoint = io.make_secure(domain, ssl_opts).await?; - let codec = Framed::new(endpoint, codec); + let FramedParts { mut io, codec, .. } = codec.into_parts(); + io.make_secure(domain, ssl_opts).await?; + let codec = Framed::new(io, codec); self.codec = Some(Box::new(codec)); - Ok(self) + Ok(()) } - pub fn is_secure(&self) -> bool { + pub(crate) fn is_secure(&self) -> bool { self.codec.as_ref().unwrap().get_ref().is_secure() } - pub fn reset_seq_id(&mut self) { + pub(crate) fn reset_seq_id(&mut self) { if let Some(codec) = self.codec.as_mut() { codec.codec_mut().reset_seq_id(); } } - pub fn sync_seq_id(&mut self) { + pub(crate) fn sync_seq_id(&mut self) { if let Some(codec) = self.codec.as_mut() { codec.codec_mut().sync_seq_id(); } } - pub fn set_max_allowed_packet(&mut self, max_allowed_packet: usize) { + pub(crate) fn set_max_allowed_packet(&mut self, max_allowed_packet: usize) { if let Some(codec) = self.codec.as_mut() { codec.codec_mut().max_allowed_packet = max_allowed_packet; } } - pub fn compress(&mut self, level: crate::Compression) { + pub(crate) fn compress(&mut self, level: crate::Compression) { if let Some(codec) = self.codec.as_mut() { codec.codec_mut().compress(level); } } - pub async fn close(mut self) -> Result<()> { + pub(crate) async fn close(mut self) -> Result<()> { self.closed = true; if let Some(mut codec) = self.codec { use futures_sink::Sink; diff --git a/src/io/socket.rs b/src/io/socket.rs index bbaa281d..bd5efb4e 100644 --- a/src/io/socket.rs +++ b/src/io/socket.rs @@ -19,7 +19,7 @@ use std::{io, mem::MaybeUninit, path::Path}; /// Unix domain socket connection on unix, or named pipe connection on windows. #[pin_project] #[derive(Debug)] -pub struct Socket { +pub(crate) struct Socket { #[pin] #[cfg(unix)] inner: tokio::net::UnixStream, diff --git a/src/lib.rs b/src/lib.rs index 6e8c5a21..8fe08b80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ //! ### Example //! //! ```rust +//! # use mysql_async::test_misc::get_opts; //! use mysql_async::prelude::*; //! # use std::env; //! @@ -42,21 +43,13 @@ //! let payments_clone = payments.clone(); //! //! let database_url = /* ... */ -//! # if let Ok(url) = env::var("DATABASE_URL") { -//! # let opts = mysql_async::Opts::from_url(&url).expect("DATABASE_URL invalid"); -//! # if opts.get_db_name().expect("a database name is required").is_empty() { -//! # panic!("database name is empty"); -//! # } -//! # url -//! # } else { -//! # "mysql://root:password@127.0.0.1:3307/mysql".to_string() -//! # }; +//! # get_opts(); //! //! let pool = mysql_async::Pool::new(database_url); -//! let conn = pool.get_conn().await?; +//! let mut conn = pool.get_conn().await?; //! //! // Create temporary table -//! let conn = conn.drop_query( +//! conn.drop_query( //! r"CREATE TEMPORARY TABLE payment ( //! customer_id int not null, //! amount int not null, @@ -73,14 +66,17 @@ //! } //! }); //! -//! let conn = conn.batch_exec(r"INSERT INTO payment (customer_id, amount, account_name) -//! VALUES (:customer_id, :amount, :account_name)", params).await?; +//! conn.batch_exec( +//! r"INSERT INTO payment (customer_id, amount, account_name) +//! VALUES (:customer_id, :amount, :account_name)", +//! params, +//! ).await?; //! //! // Load payments from database. //! let result = conn.prep_exec("SELECT customer_id, amount, account_name FROM payment", ()).await?; //! //! // Collect payments -//! let (_ /* conn */, loaded_payments) = result.map_and_drop(|row| { +//! let loaded_payments = result.map_and_drop(|row| { //! let (customer_id, amount, account_name) = mysql_async::from_row(row); //! Payment { //! customer_id: customer_id, @@ -89,9 +85,11 @@ //! } //! }).await?; //! -//! // The destructor of a connection will return it to the pool, -//! // but pool should be disconnected explicitly because it's -//! // an asynchronous procedure. +//! // We must drop the connection before disconnecting the pool. +//! drop(conn); +//! +//! // Pool must be disconnected explicitly because it's +//! // an asynchronous operation. //! pool.disconnect().await?; //! //! assert_eq!(loaded_payments, payments); @@ -120,9 +118,8 @@ mod local_infile_handler; mod opts; mod queryable; -pub type BoxFuture = ::std::pin::Pin< - Box> + Send + 'static>, ->; +pub type BoxFuture<'a, T> = + std::pin::Pin> + Send + 'a>>; #[doc(inline)] pub use self::conn::Conn; diff --git a/src/local_infile_handler/builtin.rs b/src/local_infile_handler/builtin.rs index a660265b..2638f13e 100644 --- a/src/local_infile_handler/builtin.rs +++ b/src/local_infile_handler/builtin.rs @@ -12,7 +12,7 @@ use std::{collections::HashSet, path::PathBuf, str::from_utf8}; use crate::local_infile_handler::LocalInfileHandler; -/// Handles local infile requests from filesystem using explicit path white list. +/// Handles local infile requests from filesystem using explicit whitelist of paths. /// /// Example usage: /// diff --git a/src/local_infile_handler/mod.rs b/src/local_infile_handler/mod.rs index 5095eea9..27b93eb4 100644 --- a/src/local_infile_handler/mod.rs +++ b/src/local_infile_handler/mod.rs @@ -15,18 +15,19 @@ pub mod builtin; /// Trait used to handle local infile requests. /// -/// Be aware of security issues with [LOAD DATA LOCAL](https://dev.mysql.com/doc/refman/8.0/en/load-data-local.html). +/// Be aware of security issues with [LOAD DATA LOCAL][1]. /// Using [`crate::WhiteListFsLocalInfileHandler`] is advised. /// /// Simple handler example: /// /// ```rust -/// # use mysql_async::prelude::*; +/// # use mysql_async::{prelude::*, test_misc::get_opts}; /// # use tokio::prelude::*; /// # use std::env; /// # #[tokio::main] /// # async fn main() -> Result<(), mysql_async::error::Error> { /// # +/// /// This example hanlder will return contained bytes in response to a local infile request. /// struct ExampleHandler(&'static [u8]); /// /// impl LocalInfileHandler for ExampleHandler { @@ -36,25 +37,17 @@ pub mod builtin; /// } /// } /// -/// # let database_url: String = if let Ok(url) = env::var("DATABASE_URL") { -/// # let opts = mysql_async::Opts::from_url(&url).expect("DATABASE_URL invalid"); -/// # if opts.get_db_name().expect("a database name is required").is_empty() { -/// # panic!("database name is empty"); -/// # } -/// # url -/// # } else { -/// # "mysql://root:password@127.0.0.1:3307/mysql".into() -/// # }; +/// # let database_url = get_opts(); /// -/// let mut opts = mysql_async::OptsBuilder::from_opts(&*database_url); +/// let mut opts = mysql_async::OptsBuilder::from_opts(database_url); /// opts.local_infile_handler(Some(ExampleHandler(b"foobar"))); /// /// let pool = mysql_async::Pool::new(opts); /// -/// let conn = pool.get_conn().await?; -/// let conn = conn.drop_query("CREATE TEMPORARY TABLE tmp (a TEXT);").await?; -/// let conn = match conn.drop_query("LOAD DATA LOCAL INFILE 'baz' INTO TABLE tmp;").await { -/// Ok(conn) => conn, +/// let mut conn = pool.get_conn().await?; +/// conn.drop_query("CREATE TEMPORARY TABLE tmp (a TEXT);").await?; +/// match conn.drop_query("LOAD DATA LOCAL INFILE 'baz' INTO TABLE tmp;").await { +/// Ok(()) => (), /// Err(mysql_async::error::Error::Server(ref err)) if err.code == 1148 => { /// // The used command is not allowed with this MySQL version /// return Ok(()); @@ -67,17 +60,19 @@ pub mod builtin; /// e@Err(_) => e.unwrap(), /// }; /// let result = conn.prep_exec("SELECT * FROM tmp;", ()).await?; -/// let (_ /* conn */, result) = result.map_and_drop(|row| { +/// let result = result.map_and_drop(|row| { /// mysql_async::from_row::<(String,)>(row).0 /// }).await?; /// /// assert_eq!(result.len(), 1); /// assert_eq!(result[0], "foobar"); +/// drop(conn); /// pool.disconnect().await?; /// # Ok(()) /// # } /// ``` /// +/// [1]: https://dev.mysql.com/doc/refman/8.0/en/load-data-local.html pub trait LocalInfileHandler: Sync + Send { /// `file_name` is the file name in `LOAD DATA LOCAL INFILE '' INTO TABLE ...;` /// query. diff --git a/src/opts.rs b/src/opts.rs index 92bab723..070d377c 100644 --- a/src/opts.rs +++ b/src/opts.rs @@ -37,7 +37,7 @@ const DEFAULT_PORT: u16 = 3306; /// Represents information about a host and port combination that can be converted /// into socket addresses using to_socket_addrs. #[derive(Clone, Eq, PartialEq, Debug)] -pub enum HostPortOrUrl { +pub(crate) enum HostPortOrUrl { HostPort(String, u16), Url(Url), } @@ -190,7 +190,16 @@ pub struct PoolOptions { } impl PoolOptions { - /// Creates [`PoolOptions`]. + /// Creates the default [`PoolOptions`] with the given constraints. + pub const fn with_constraints(constraints: PoolConstraints) -> Self { + Self { + constraints, + inactive_connection_ttl: DEFAULT_INACTIVE_CONNECTION_TTL, + ttl_check_interval: DEFAULT_TTL_CHECK_INTERVAL, + } + } + + /// Creates a [`PoolOptions`]. pub const fn new( constraints: PoolConstraints, inactive_connection_ttl: Duration, @@ -203,34 +212,26 @@ impl PoolOptions { } } - /// Creates default [`PoolOptions`] with given constraints. - pub const fn with_constraints(constraints: PoolConstraints) -> Self { - Self { - constraints, - inactive_connection_ttl: DEFAULT_INACTIVE_CONNECTION_TTL, - ttl_check_interval: DEFAULT_TTL_CHECK_INTERVAL, - } - } - /// Sets pool constraints. pub fn set_constraints(&mut self, constraints: PoolConstraints) { self.constraints = constraints; } - /// Returns `constrains` value. + /// Returns pool constraints. pub fn constraints(&self) -> PoolConstraints { self.constraints } - /// Pool will recycle inactive connection if it outside of the lower bound of a pool - /// and if it is idling longer than this value (defaults to [`DEFAULT_INACTIVE_CONNECTION_TTL`]). + /// Pool will recycle inactive connection if it is outside of the lower bound of the pool + /// and if it is idling longer than this value (defaults to + /// [`DEFAULT_INACTIVE_CONNECTION_TTL`]). /// /// Note that it may, actually, idle longer because of [`PoolOptions::ttl_check_interval`]. pub fn set_inactive_connection_ttl(&mut self, ttl: Duration) { self.inactive_connection_ttl = ttl; } - /// Returns `inactive_connection_ttl` value. + /// Returns a `inactive_connection_ttl` value. pub fn inactive_connection_ttl(&self) -> Duration { self.inactive_connection_ttl } @@ -247,7 +248,7 @@ impl PoolOptions { } } - /// Returns `ttl_check_interval` value. + /// Returns a `ttl_check_interval` value. pub fn ttl_check_interval(&self) -> Duration { self.ttl_check_interval } @@ -259,7 +260,7 @@ impl PoolOptions { /// Active bound is either: /// * `min` bound of the pool constraints, if this [`PoolOptions`] defines /// `inactive_connection_ttl` to be `0`. This means, that pool will hold no more than `min` - /// number of idling connection and other connection will be immediately disconnected. + /// number of idling connections and other connections will be immediately disconnected. /// * `max` bound of the pool constraints, if this [`PoolOptions`] defines /// `inactive_connection_ttl` to be non-zero. This means, that pool will hold up to `max` /// number of idling connections and this number will be eventually reduced to `min` @@ -284,7 +285,7 @@ impl Default for PoolOptions { } #[derive(Clone, Eq, PartialEq, Default, Debug)] -pub struct InnerOpts { +pub(crate) struct InnerOpts { mysql_opts: MysqlOpts, address: HostPortOrUrl, } @@ -293,7 +294,7 @@ pub struct InnerOpts { /// /// Build one with [`OptsBuilder`]. #[derive(Clone, Eq, PartialEq, Debug)] -pub struct MysqlOpts { +pub(crate) struct MysqlOpts { /// User (defaults to `None`). user: Option, @@ -318,7 +319,7 @@ pub struct MysqlOpts { /// Connection pool options (defaults to [`PoolOptions::default`]). pool_options: PoolOptions, - /// Pool will close connection if time since last IO exceeds this number of seconds + /// Pool will close a connection if time since last IO exceeds this number of seconds /// (defaults to `wait_timeout`). conn_ttl: Option, @@ -402,7 +403,7 @@ impl Opts { self.inner.address.get_ip_or_hostname() } - pub fn get_hostport_or_url(&self) -> &HostPortOrUrl { + pub(crate) fn get_hostport_or_url(&self) -> &HostPortOrUrl { &self.inner.address } diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 27a67c91..236ee2b5 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use self::{ query_result::QueryResult, stmt::Stmt, - transaction::{Transaction, TransactionOptions}, + transaction::{Transaction, TransactionOptions, TxStatus}, }; use crate::{ connection_like::ConnectionLike, consts::Command, error::*, prelude::FromRow, BoxFuture, @@ -66,139 +66,171 @@ impl Protocol for BinaryProtocol { } } -/// Represents something queryable like connection or transaction. -pub trait Queryable: ConnectionLike +/// The only purpose of this function at the moment is to rollback a transaction in cases, +/// where `Transaction` is dropped without an explicit call to `commit` or `rollback`. +async fn cleanup(queryable: &mut T) -> Result<()> { + if queryable.get_tx_status() == TxStatus::RequiresRollback { + queryable.set_tx_status(TxStatus::None); + queryable.drop_query("ROLLBACK").await?; + } + Ok(()) +} + +/// Represents something queryable, e.g. connection or transaction. +pub trait Queryable: crate::prelude::ConnectionLike where - Self: Sized + 'static, + Self: Sized, { - /// Returns future that resolves to `Conn` if `COM_PING` executed successfully. - fn ping(self) -> BoxFuture { + /// Returns a future, that executes `COM_PING`. + fn ping(&mut self) -> BoxFuture<'_, ()> { Box::pin(async move { - Ok(self - .write_command_data(Command::COM_PING, &[]) - .await? - .read_packet() - .await? - .0) + cleanup(self).await?; + self.write_command_data(Command::COM_PING, &[]).await?; + self.read_packet().await?; + Ok(()) }) } - /// Returns future, that disconnects this connection from a server. - fn disconnect(mut self) -> BoxFuture<()> { - self.on_disconnect(); - let f = self.write_command_data(Command::COM_QUIT, &[]); + /// Returns a future that performs the given query. + fn query<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, QueryResult<'a, Self, TextProtocol>> + where + Q: AsRef + Sync + Send + 'static, + { Box::pin(async move { - let (_, stream) = f.await?.take_stream(); - stream.close().await?; - Ok(()) + cleanup(self).await?; + self.write_command_data(Command::COM_QUERY, query.as_ref().as_bytes()) + .await?; + self.read_result_set(None).await }) } - /// Returns future that performs `query`. - fn query>(self, query: Q) -> BoxFuture> { - let f = self.write_command_data(Command::COM_QUERY, query.as_ref().as_bytes()); - Box::pin(async move { f.await?.read_result_set(None).await }) - } - - /// Returns future that resolves to a first row of result of a `query` execution (if any). + /// Returns a future that executes the given query and returns the first row (if any). /// /// Returned future will call `R::from_row(row)` internally. - fn first(self, query: Q) -> BoxFuture<(Self, Option)> + fn first<'a, Q, R>(&'a mut self, query: Q) -> BoxFuture<'a, Option> where - Q: AsRef, + Q: AsRef + Sync + Send + 'static, R: FromRow, { - let f = self.query(query); Box::pin(async move { - let (this, mut rows) = f.await?.collect_and_drop::().await?; + let result = self.query(query).await?; + let mut rows = result.collect_and_drop::().await?; if rows.len() > 1 { - Ok((this, Some(FromRow::from_row(rows.swap_remove(0))))) + Ok(Some(FromRow::from_row(rows.swap_remove(0)))) } else { - Ok((this, rows.pop().map(FromRow::from_row))) + Ok(rows.pop().map(FromRow::from_row)) } }) } - /// Returns future that performs query. Result will be dropped. - fn drop_query>(self, query: Q) -> BoxFuture { - let f = self.query(query); - Box::pin(async move { f.await?.drop_result().await }) + /// Returns a future that performs the given query. Result will be dropped. + fn drop_query<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, ()> + where + Q: AsRef + Sync + Send + 'static, + { + Box::pin(async move { + let result = self.query(query).await?; + result.drop_result().await?; + Ok(()) + }) } - /// Returns future that prepares statement. - fn prepare>(self, query: Q) -> BoxFuture> { - let f = self.prepare_stmt(query); + /// Returns a future that prepares the given statement. + fn prepare<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Stmt<'a, Self>> + where + Q: AsRef + Send + 'static, + { Box::pin(async move { - let (this, inner_stmt, stmt_cache_result) = f.await?; - Ok(stmt::new(this, inner_stmt, stmt_cache_result)) + cleanup(self).await?; + let f = self.prepare_stmt(query); + let (inner_stmt, stmt_cache_result) = f.await?; + Ok(Stmt::new(self, inner_stmt, stmt_cache_result)) }) } - /// Returns future that prepares and executes statement in one pass. - fn prep_exec(self, query: Q, params: P) -> BoxFuture> + /// Returns a future that prepares and executes the given statement in one pass. + fn prep_exec<'a, Q, P>( + &'a mut self, + query: Q, + params: P, + ) -> BoxFuture<'a, QueryResult<'a, Self, BinaryProtocol>> where - Q: AsRef, + Q: AsRef + Send + 'static, P: Into, { let params: Params = params.into(); - let f = self.prepare(query); Box::pin(async move { - let result = f.await?.execute(params).await?; - let (stmt, columns, _) = query_result::disassemble(result); - let (conn_like, cached) = stmt.unwrap(); - Ok(query_result::assemble(conn_like, columns, cached)) + let mut stmt = self.prepare(query).await?; + let result = stmt.execute(params).await?; + let (stmt, columns, _) = result.disassemble(); + let cached = stmt.cached.clone(); + Ok(QueryResult::new(self, columns, cached)) }) } - /// Returns future that resolves to a first row of result of a statement execution (if any). + /// Returns a future that prepares and executes the given statement, + /// and resolves to the first row (if any). /// /// Returned future will call `R::from_row(row)` internally. - fn first_exec(self, query: Q, params: P) -> BoxFuture<(Self, Option)> + fn first_exec(&mut self, query: Q, params: P) -> BoxFuture<'_, Option> where - Q: AsRef, + Q: AsRef + Sync + Send + 'static, P: Into, R: FromRow, { - let f = self.prep_exec(query, params); + let params = params.into(); Box::pin(async move { - let (this, mut rows) = f.await?.collect_and_drop::().await?; + let mut rows = self + .prep_exec(query, params) + .await? + .collect_and_drop::() + .await?; if rows.len() > 1 { - Ok((this, Some(FromRow::from_row(rows.swap_remove(0))))) + Ok(Some(FromRow::from_row(rows.swap_remove(0)))) } else { - Ok((this, rows.pop().map(FromRow::from_row))) + Ok(rows.pop().map(FromRow::from_row)) } }) } - /// Returns future that prepares and executes statement. Result will be dropped. - fn drop_exec(self, query: Q, params: P) -> BoxFuture + /// Returns a future that prepares and executes the given statement. Result will be dropped. + fn drop_exec(&mut self, query: Q, params: P) -> BoxFuture<'_, ()> where - Q: AsRef, + Q: AsRef + Send + 'static, P: Into, { let f = self.prep_exec(query, params); Box::pin(async move { f.await?.drop_result().await }) } - /// Returns future that prepares statement and performs batch execution. - /// Results will be dropped. - fn batch_exec(self, query: Q, params_iter: I) -> BoxFuture + /// Returns a future that prepares the given statement and performs batch execution using + /// the given params. Results will be dropped. + fn batch_exec(&mut self, query: Q, params_iter: I) -> BoxFuture<'_, ()> where - Q: AsRef, - I: IntoIterator + Send + 'static, + Q: AsRef + Sync + Send + 'static, + I: IntoIterator, I::IntoIter: Send + 'static, Params: From

, - P: Send + 'static, { - let f = self.prepare(query); - Box::pin(async move { f.await?.batch(params_iter).await?.close().await }) + let params_iter = params_iter.into_iter(); + Box::pin(async move { + let mut stmt = self.prepare(query).await?; + stmt.batch(params_iter).await?; + stmt.close().await + }) } - /// Returns future that starts transaction. - fn start_transaction(self, options: TransactionOptions) -> BoxFuture> { - Box::pin(transaction::new(self, options)) + /// Returns a future that starts a transaction. + fn start_transaction<'a>( + &'a mut self, + options: TransactionOptions, + ) -> BoxFuture<'a, Transaction<'a, Self>> { + Box::pin(async move { + cleanup(self).await?; + Transaction::new(self, options).await + }) } } impl Queryable for Conn {} -impl Queryable for Transaction {} +impl<'a, T: Queryable + crate::prelude::ConnectionLike> Queryable for Transaction<'a, T> {} diff --git a/src/queryable/query_result/mod.rs b/src/queryable/query_result/mod.rs index 7716792f..c84b959c 100644 --- a/src/queryable/query_result/mod.rs +++ b/src/queryable/query_result/mod.rs @@ -6,216 +6,175 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use futures_util::future::Either; use mysql_common::row::convert::FromRowError; use std::{borrow::Cow, marker::PhantomData, result::Result as StdResult, sync::Arc}; use self::QueryResultInner::*; use crate::{ - connection_like::{ - streamless::Streamless, ConnectionLike, ConnectionLikeWrapper, StmtCacheResult, - }, + connection_like::StmtCacheResult, consts::StatusFlags, error::*, - io, - prelude::FromRow, - queryable::Protocol, + prelude::{ConnectionLike, FromRow, Protocol}, Column, Row, }; -pub fn new( - conn_like: T, - columns: Option>>, - cached: Option, -) -> QueryResult -where - T: ConnectionLike + Sized + 'static, - P: Protocol, - P: Send + 'static, -{ - QueryResult::new(conn_like, columns, cached) +enum QueryResultInner { + Empty(Option), + WithRows(Arc>, Option), } -pub fn disassemble( - query_result: QueryResult, -) -> (T, Option>>, Option) { - match query_result { - QueryResult(Empty(Some(Either::Left(conn_like)), cached, _)) => (conn_like, None, cached), - QueryResult(WithRows(Some(Either::Left(conn_like)), columns, cached, _)) => { - (conn_like, Some(columns), cached) +impl QueryResultInner { + fn new(columns: Option>>, cached: Option) -> Self { + match columns { + Some(columns) => WithRows(columns, cached), + None => Empty(cached), } - _ => unreachable!(), } -} -pub fn assemble( - conn_like: T, - columns: Option>>, - cached: Option, -) -> QueryResult -where - T: ConnectionLike + Sized + 'static, - P: Protocol + 'static, -{ - match columns { - Some(columns) => QueryResult(WithRows( - Some(Either::Left(conn_like)), - columns, - cached, - PhantomData, - )), - None => QueryResult(Empty(Some(Either::Left(conn_like)), cached, PhantomData)), + fn columns(&self) -> Option<&Arc>> { + match self { + WithRows(columns, _) => Some(columns), + Empty(_) => None, + } + } + + fn cached(&self) -> Option { + match *self { + WithRows(_, cached) | Empty(cached) => cached, + } } -} -enum QueryResultInner { - Empty( - Option>>, - Option, - PhantomData

, - ), - WithRows( - Option>>, - Arc>, - Option, - PhantomData

, - ), + fn make_empty(&mut self) { + *self = match *self { + WithRows(_, cached) | Empty(cached) => Empty(cached), + } + } } /// Result of a query or statement execution. -pub struct QueryResult(QueryResultInner); +pub struct QueryResult<'a, T: ?Sized, P> { + conn_like: &'a mut T, + inner: QueryResultInner, + __phantom: PhantomData

, +} -impl QueryResult +impl<'a, T: ?Sized, P> QueryResult<'a, T, P> where P: Protocol, - P: Send + 'static, T: ConnectionLike, - T: Sized + Send + 'static, { - fn into_empty(mut self) -> Self { - self.set_pending_result(None); - match self { - QueryResult(WithRows(conn_like, _, cached, _)) => { - QueryResult(Empty(conn_like, cached, PhantomData)) - } - x => x, + pub(crate) fn new( + conn_like: &'a mut T, + columns: Option>>, + cached: Option, + ) -> QueryResult<'a, T, P> { + QueryResult { + conn_like, + inner: QueryResultInner::new(columns, cached), + __phantom: PhantomData, } } - fn into_inner(self) -> (T, Option) { - match self { - QueryResult(Empty(conn_like, cached, _)) - | QueryResult(WithRows(conn_like, _, cached, _)) => match conn_like { - Some(Either::Left(conn_like)) => (conn_like, cached), - _ => unreachable!(), - }, + pub(crate) fn disassemble( + self, + ) -> (&'a mut T, Option>>, Option) { + match self.inner { + WithRows(columns, cached) => (self.conn_like, Some(columns), cached), + Empty(cached) => (self.conn_like, None, cached), } } - async fn get_row_raw(self) -> Result<(Self, Option>)> { + fn make_empty(&mut self) { + self.conn_like.set_pending_result(None); + self.inner.make_empty(); + } + + async fn get_row_raw(&mut self) -> Result>> { if self.is_empty() { - return Ok((self, None)); + return Ok(None); } - let (mut this, packet) = self.read_packet().await?; - if P::is_last_result_set_packet(&this, &packet) { - if this.more_results_exists() { - this.sync_seq_id(); - let (inner, cached) = this.into_inner(); - let this = inner.read_result_set(cached).await?; - Ok((this, None)) + let packet: Vec = self.conn_like.read_packet().await?; + + if P::is_last_result_set_packet(&*self.conn_like, &packet) { + if self.more_results_exists() { + self.conn_like.sync_seq_id(); + let cached = self.inner.cached(); + let next_set = self.conn_like.read_result_set::

(cached).await?; + self.inner = next_set.inner; + Ok(None) } else { - Ok((this.into_empty(), None)) + self.make_empty(); + Ok(None) } } else { - Ok((this, Some(packet))) + Ok(Some(packet)) } } - async fn get_row(self) -> Result<(Self, Option)> { - let (this, packet) = self.get_row_raw().await?; + /// Returns next row, if any. + /// + /// Requires that `self.inner` matches `WithRows(..)`. + async fn get_row(&mut self) -> Result> { + let packet = self.get_row_raw().await?; if let Some(packet) = packet { - if let QueryResult(WithRows(_, ref columns, ..)) = this { - let row = P::read_result_set_row(&packet, columns.clone())?; - Ok((this, Some(row))) - } else { - unreachable!() - } + let columns = self.inner.columns().expect("must be here"); + let row = P::read_result_set_row(&packet, columns.clone())?; + Ok(Some(row)) } else { - Ok((this, None)) - } - } - - fn new( - conn_like: T, - columns: Option>>, - cached: Option, - ) -> QueryResult { - match columns { - Some(columns) => QueryResult(WithRows( - Some(Either::Left(conn_like)), - columns, - cached, - PhantomData, - )), - None => QueryResult(Empty(Some(Either::Left(conn_like)), cached, PhantomData)), + Ok(None) } } /// Last insert id, if any. pub fn last_insert_id(&self) -> Option { - self.get_last_insert_id() + self.conn_like.get_last_insert_id() } - /// Number of affected rows, reported by the server, or `0`. + /// Number of affected rows, as reported by the server, or `0`. pub fn affected_rows(&self) -> u64 { - self.get_affected_rows() + self.conn_like.get_affected_rows() } - /// Text information, reported by the server, or an empty string + /// Text information, as reported by the server, or an empty string. pub fn info(&self) -> Cow<'_, str> { - self.get_info() + self.conn_like.get_info() } - /// Number of warnings, reported by the server, or `0`. + /// Number of warnings, as reported by the server, or `0`. pub fn warnings(&self) -> u16 { - self.get_warnings() + self.conn_like.get_warnings() } /// `true` if there is no more rows nor result sets in this query. /// /// One could use it to check if there is more than one result set in this query result. pub fn is_empty(&self) -> bool { - match *self { - QueryResult(Empty(..)) => !self.more_results_exists(), - _ => false, - } + !self.has_rows() } - /// Returns `true` if the SERVER_MORE_RESULTS_EXISTS flag is contained in status flags + /// Returns `true` if the `SERVER_MORE_RESULTS_EXISTS` flag is contained in status flags /// of the connection. fn more_results_exists(&self) -> bool { - self.get_status() + self.conn_like + .get_status() .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS) } - /// `true` if rows may exists for this query result. + /// Returns `true` if this query result may contain rows. /// /// If `false` then there is no rows possible (for example UPDATE query). fn has_rows(&self) -> bool { - match *self { - QueryResult(Empty(..)) => false, - _ => true, - } + matches!(self.inner, WithRows(..)) } - /// Returns future that collects result set of this query result. + /// Returns a future that collects result set of this query result. /// /// It is parametrized by `R` and internally calls `R::from_row(Row)` on each row. /// - /// It will stop collecting on result set boundary. This means that you should call `collect` - /// as many times as result sets in your query result. For example query + /// It will collect rows up to a neares result set boundary. This means that you should call + /// `collect` as many times as result sets in your query result. For example query /// `SELECT 'foo'; SELECT 'foo', 'bar';` will produce `QueryResult` with two result sets in it. /// One can use `QueryResult::is_empty` to make sure that there is no more result sets. /// @@ -224,7 +183,7 @@ where /// It'll panic if any row isn't convertible to `R` (i.e. programmer error or unknown schema). /// * In case of programmer error see [`FromRow`] docs; /// * In case of unknown schema use [`QueryResult::try_collect`]. - pub async fn collect(self) -> Result<(Self, Vec)> + pub async fn collect(&mut self) -> Result> where R: FromRow, R: Send + 'static, @@ -236,11 +195,11 @@ where .await } - /// Returns future that collects result set of this query. + /// Returns a future that collects result set of this query result. /// - /// It works the same way as [`QueryResult::collect`] but won't panic - /// if row isn't convertible to `R`. - pub async fn try_collect(self) -> Result<(Self, Vec>)> + /// It works the same way as [`QueryResult::collect`] but won't panic if row isn't convertible + /// to `R`. + pub async fn try_collect(&mut self) -> Result>> where R: FromRow, R: Send + 'static, @@ -252,254 +211,169 @@ where .await } - /// Returns future that collects result set of a query result and drops everything else. - /// It will resolve to a pair of wrapped [`crate::prelude::Queryable`] and collected result set. + /// Returns a future that collects the current result set of this query result and drops + /// everything else. /// /// # Panic /// /// It'll panic if any row isn't convertible to `R` (i.e. programmer error or unknown schema). /// * In case of programmer error see `FromRow` docs; /// * In case of unknown schema use [`QueryResult::try_collect`]. - pub async fn collect_and_drop(self) -> Result<(T, Vec)> + pub async fn collect_and_drop(mut self) -> Result> where R: FromRow, R: Send + 'static, { - let (this, output) = self.collect().await?; - let conn = this.drop_result().await?; - Ok((conn, output)) + let output = self.collect::().await?; + self.drop_result().await?; + Ok(output) } - /// Returns future that collects result set of a query result and drops everything else. - /// It will resolve to a pair of wrapped [`crate::prelude::Queryable`] and collected result set. + /// Returns a future that collects the current result set of this query result and drops + /// everything else. /// - /// It works the same way as [`QueryResult::collect_and_drop`] but won't panic - /// if row isn't convertible to `R`. - pub async fn try_collect_and_drop(self) -> Result<(T, Vec>)> + /// It works the same way as [`QueryResult::collect_and_drop`] but won't panic if row isn't + /// convertible to `R`. + pub async fn try_collect_and_drop(mut self) -> Result>> where R: FromRow, R: Send + 'static, { - let (this, output) = self.try_collect().await?; - let conn = this.drop_result().await?; - Ok((conn, output)) + let output = self.try_collect().await?; + self.drop_result().await?; + Ok(output) } - /// Returns future that will execute `fun` on every row of current result set. + /// Returns a future that will execute `fun` on every row of the current result set. /// - /// It will stop on result set boundary (see `QueryResult::collect` docs). - pub async fn for_each(self, mut fun: F) -> Result + /// It will stop on the nearest result set boundary (see `QueryResult::collect` docs). + pub async fn for_each(&mut self, mut fun: F) -> Result<()> where F: FnMut(Row), { if self.is_empty() { - Ok(self) + Ok(()) } else { - let mut qr = self; loop { - let (qr_, row) = qr.get_row().await?; - qr = qr_; + let row = self.get_row().await?; if let Some(row) = row { fun(row); } else { - break Ok(qr); + break Ok(()); } } } } - /// Returns future that will execute `fun` on every row of current result set and drop - /// everything else. It will resolve to a wrapped `Queryable`. - pub async fn for_each_and_drop(self, fun: F) -> Result + /// Returns a future that will execute `fun` on every row of the current result set and drop + /// everything else. + pub async fn for_each_and_drop(mut self, fun: F) -> Result<()> where F: FnMut(Row), { - self.for_each(fun).await?.drop_result().await + self.for_each(fun).await?; + self.drop_result().await?; + Ok(()) } - /// Returns future that will map every row of current result set to `U` using `fun`. + /// Returns a future that will map every row of the current result set to `U` using `fun`. /// - /// It will stop on result set boundary (see `QueryResult::collect` docs). - pub async fn map(self, mut fun: F) -> Result<(Self, Vec)> + /// It will stop on the nearest result set boundary (see `QueryResult::collect` docs). + pub async fn map(&mut self, mut fun: F) -> Result> where F: FnMut(Row) -> U, { if self.is_empty() { - Ok((self, Vec::new())) + Ok(Vec::new()) } else { - let mut qr = self; let mut rows = Vec::new(); loop { - let (qr_, row) = qr.get_row().await?; - qr = qr_; + let row = self.get_row().await?; if let Some(row) = row { rows.push(fun(row)); } else { - break Ok((qr, rows)); + break Ok(rows); } } } } - /// Returns future that will map every row of current result set to `U` using `fun` and drop - /// everything else. It will resolve to a pair of wrapped `Queryable` and mapped result set. - pub async fn map_and_drop(self, fun: F) -> Result<(T, Vec)> + /// Returns a future that will map every row of the current result set to `U` using `fun` + /// and drop everything else. + pub async fn map_and_drop(mut self, fun: F) -> Result> where F: FnMut(Row) -> U, { - let (this, rows) = self.map(fun).await?; - let this = this.drop_result().await?; - Ok((this, rows)) + let rows = self.map(fun).await?; + self.drop_result().await?; + Ok(rows) } - /// Returns future that will reduce rows of current result set to `U` using `fun`. + /// Returns a future that will reduce rows of the current result set to `U` using `fun`. /// - /// It will stop on result set boundary (see `QueryResult::collect` docs). - pub async fn reduce(self, init: U, mut fun: F) -> Result<(Self, U)> + /// It will stop on the nearest result set boundary (see `QueryResult::collect` docs). + pub async fn reduce(&mut self, init: U, mut fun: F) -> Result where F: FnMut(U, Row) -> U, { if self.is_empty() { - Ok((self, init)) + Ok(init) } else { - let mut qr = self; let mut acc = init; loop { - let (qr_, row) = qr.get_row().await?; - qr = qr_; + let row = self.get_row().await?; if let Some(row) = row { acc = fun(acc, row); } else { - break Ok((qr, acc)); + break Ok(acc); } } } } - /// Returns future that will reduce rows of current result set to `U` using `fun` and drop - /// everything else. It will resolve to a pair of wrapped `Queryable` and `U`. - pub async fn reduce_and_drop(self, init: U, fun: F) -> Result<(T, U)> + /// Returns a future that will reduce rows of the current result set to `U` using `fun` and drop + /// everything else. + pub async fn reduce_and_drop(mut self, init: U, fun: F) -> Result where F: FnMut(U, Row) -> U, { - let (this, acc) = self.reduce(init, fun).await?; - let this = this.drop_result().await?; - Ok((this, acc)) + let acc = self.reduce(init, fun).await?; + self.drop_result().await?; + Ok(acc) } - /// Returns future that will drop this query result end resolve to a wrapped `Queryable`. - pub async fn drop_result(self) -> Result { - let mut this = self; - let (conn_like, cached) = loop { - if !this.has_rows() { - if this.more_results_exists() { - let (inner, cached) = this.into_inner(); - this = inner.read_result_set(cached).await?; + /// Returns a future that will drop this query result. + pub async fn drop_result(mut self) -> Result<()> { + let cached = loop { + if !self.has_rows() { + if self.more_results_exists() { + let (inner, _, cached) = self.disassemble(); + self = inner.read_result_set(cached).await?; } else { - break this.into_inner(); + break self.inner.cached(); } } else { - let (this_, _) = this.get_row_raw().await?; - this = this_; + self.get_row_raw().await?; } }; if let Some(StmtCacheResult::NotCached(statement_id)) = cached { - conn_like.close_stmt(statement_id).await - } else { - Ok(conn_like) + self.conn_like.close_stmt(statement_id).await?; } + + Ok(()) } - /// Returns reference to columns in this query result. + /// Returns a reference to a columns list of this query result. pub fn columns_ref(&self) -> &[Column] { - match self.0 { - QueryResultInner::Empty(..) => { - static EMPTY: &'static [Column] = &[]; - EMPTY - } - QueryResultInner::WithRows(_, ref columns, ..) => &**columns, - } + self.inner + .columns() + .map(|columns| &***columns) + .unwrap_or_default() } - /// Returns copy of columns of this query result. + /// Returns a copy of a columns list of this query result. pub fn columns(&self) -> Option>> { - match self.0 { - QueryResultInner::Empty(..) => None, - QueryResultInner::WithRows(_, ref columns, ..) => Some(columns.clone()), - } - } -} - -impl ConnectionLikeWrapper for QueryResult { - type ConnLike = T; - - fn take_stream(self) -> (Streamless, io::Stream) - where - Self: Sized, - { - match self { - QueryResult(Empty(conn_like, cached, _)) => match conn_like { - Some(Either::Left(conn_like)) => { - let (streamless, stream) = conn_like.take_stream(); - let self_streamless = Streamless::new(QueryResult(Empty( - Some(Either::Right(streamless)), - cached, - PhantomData, - ))); - (self_streamless, stream) - } - Some(Either::Right(..)) => panic!("Logic error: stream taken"), - None => unreachable!(), - }, - QueryResult(WithRows(conn_like, columns, cached, _)) => match conn_like { - Some(Either::Left(conn_like)) => { - let (streamless, stream) = conn_like.take_stream(); - let self_streamless = Streamless::new(QueryResult(WithRows( - Some(Either::Right(streamless)), - columns, - cached, - PhantomData, - ))); - (self_streamless, stream) - } - Some(Either::Right(..)) => panic!("Logic error: stream taken"), - None => unreachable!(), - }, - } - } - - fn return_stream(&mut self, stream: io::Stream) { - match *self { - QueryResult(Empty(ref mut conn_like, ..)) - | QueryResult(WithRows(ref mut conn_like, ..)) => match conn_like.take() { - Some(Either::Left(..)) => panic!("Logic error: stream exists"), - Some(Either::Right(streamless)) => { - *conn_like = Some(Either::Left(streamless.return_stream(stream))); - } - None => unreachable!(), - }, - } - } - - fn conn_like_ref(&self) -> &Self::ConnLike { - match *self { - QueryResult(Empty(ref conn_like, ..)) | QueryResult(WithRows(ref conn_like, ..)) => { - match *conn_like { - Some(Either::Left(ref conn_like)) => conn_like, - _ => unreachable!(), - } - } - } - } - - fn conn_like_mut(&mut self) -> &mut Self::ConnLike { - match *self { - QueryResult(Empty(ref mut conn_like, ..)) - | QueryResult(WithRows(ref mut conn_like, ..)) => match *conn_like { - Some(Either::Left(ref mut conn_like)) => conn_like, - _ => unreachable!(), - }, - } + self.inner.columns().cloned() } } diff --git a/src/queryable/stmt.rs b/src/queryable/stmt.rs index ce433aff..7b660ae5 100644 --- a/src/queryable/stmt.rs +++ b/src/queryable/stmt.rs @@ -6,17 +6,12 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use futures_util::future::Either; - use crate::{ - connection_like::{ - streamless::Streamless, ConnectionLike, ConnectionLikeWrapper, StmtCacheResult, - }, + connection_like::{ConnectionLike, StmtCacheResult}, error::*, - io, prelude::FromRow, - queryable::{query_result::QueryResult, BinaryProtocol}, - Column, Params, Row, + queryable::{query_result::QueryResult, transaction::TxStatus, BinaryProtocol}, + Column, Params, Value::{self}, }; use mysql_common::{ @@ -38,8 +33,7 @@ pub struct InnerStmt { } impl InnerStmt { - // TODO: Consume payload? - pub fn new(payload: &[u8], named_params: Option>) -> Result { + pub(crate) fn new(payload: &[u8], named_params: Option>) -> Result { let packet = parse_stmt_packet(payload)?; Ok(InnerStmt { @@ -56,40 +50,37 @@ impl InnerStmt { /// Prepared statement. #[derive(Debug)] -pub struct Stmt { - conn_like: Option>>, +pub struct Stmt<'a, T> { + conn_like: &'a mut T, inner: InnerStmt, /// None => In use elsewhere /// Some(Cached) => Should not be closed /// Some(NotCached(_)) => Should be closed - cached: Option, + pub(crate) cached: Option, } -pub fn new(conn_like: T, inner: InnerStmt, cached: StmtCacheResult) -> Stmt +impl<'a, T> Stmt<'a, T> where - T: ConnectionLike + Sized + 'static, + T: crate::prelude::ConnectionLike, { - Stmt::new(conn_like, inner, cached) -} - -impl Stmt -where - T: ConnectionLike + Sized + 'static, -{ - fn new(conn_like: T, inner: InnerStmt, cached: StmtCacheResult) -> Stmt { + pub(crate) fn new( + conn_like: &'a mut T, + inner: InnerStmt, + cached: StmtCacheResult, + ) -> Stmt<'a, T> { Stmt { - conn_like: Some(Either::Left(conn_like)), + conn_like, inner, cached: Some(cached), } } - /// Returns statement identifier. + /// Returns an identifier of the statement. pub fn id(&self) -> u32 { self.inner.statement_id } - /// Returns statement columns. + /// Returns a list of statement columns. /// /// ```rust /// # use mysql_async::test_misc::get_opts; @@ -99,7 +90,7 @@ where /// #[tokio::main] /// async fn main() -> Result<(), mysql_async::error::Error> { /// let pool = Pool::new(get_opts()); - /// let conn = pool.get_conn().await?; + /// let mut conn = pool.get_conn().await?; /// /// let stmt = conn.prepare("SELECT 'foo', CAST(42 AS UNSIGNED)").await?; /// @@ -120,7 +111,7 @@ where .unwrap_or_default() } - /// Returns statement parameters. + /// Returns a list of statement parameters. /// /// ```rust /// # use mysql_async::test_misc::get_opts; @@ -130,7 +121,7 @@ where /// #[tokio::main] /// async fn main() -> Result<(), mysql_async::error::Error> { /// let pool = Pool::new(get_opts()); - /// let conn = pool.get_conn().await?; + /// let mut conn = pool.get_conn().await?; /// /// let stmt = conn.prepare("SELECT ?, ?").await?; /// @@ -147,9 +138,7 @@ where .unwrap_or_default() } - async fn send_long_data(self, params: Vec) -> Result { - let mut this = self; - + async fn send_long_data(&mut self, params: Vec) -> Result<()> { for (i, value) in params.into_iter().enumerate() { if let Value::Bytes(bytes) = value { let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6); @@ -159,16 +148,19 @@ where None }); for chunk in chunks { - let com = ComStmtSendLongData::new(this.inner.statement_id, i, chunk); - this = this.write_command_raw(com.into()).await?; + let com = ComStmtSendLongData::new(self.inner.statement_id, i, chunk); + self.write_command_raw(com.into()).await?; } } } - Ok(this) + Ok(()) } - async fn execute_positional(self, params: U) -> Result> + async fn execute_positional( + &mut self, + params: U, + ) -> Result, BinaryProtocol>> where U: ::std::ops::Deref, U: IntoIterator, @@ -186,19 +178,18 @@ where let (body, as_long_data) = ComStmtExecuteRequestBuilder::new(self.inner.statement_id).build(&*params); - let this = if as_long_data { + if as_long_data { self.send_long_data(params).await? - } else { - self - }; + } - this.write_command_raw(body) - .await? - .read_result_set(None) - .await + self.write_command_raw(body).await?; + self.read_result_set(None).await } - async fn execute_named(self, params: Params) -> Result> { + async fn execute_named( + &mut self, + params: Params, + ) -> Result, BinaryProtocol>> { if self.inner.named_params.is_none() { let error = DriverError::NamedParamsForPositionalQuery.into(); return Err(error); @@ -216,7 +207,7 @@ where } } - async fn execute_empty(self) -> Result> { + async fn execute_empty(&mut self) -> Result, BinaryProtocol>> { if self.inner.num_params > 0 { let error = DriverError::StmtParamsMismatch { required: self.inner.num_params, @@ -227,12 +218,15 @@ where } let (body, _) = ComStmtExecuteRequestBuilder::new(self.inner.statement_id).build(&[]); - let this = self.write_command_raw(body).await?; - this.read_result_set(None).await + self.write_command_raw(body).await?; + self.read_result_set(None).await } - /// See `Queryable::execute` - pub async fn execute

(self, params: P) -> Result> + /// See [`Queryable::execute`]. + pub async fn execute

( + &mut self, + params: P, + ) -> Result, BinaryProtocol>> where P: Into, { @@ -244,111 +238,119 @@ where } } - /// See `Queryable::first` - pub async fn first(self, params: P) -> Result<(Self, Option)> + /// See [`Queryable::first`]. + pub async fn first(&mut self, params: P) -> Result> where P: Into + 'static, R: FromRow, { let result = self.execute(params).await?; - let (this, mut rows) = result.collect_and_drop::().await?; - if rows.len() > 1 { - Ok((this, Some(FromRow::from_row(rows.swap_remove(0))))) + let mut rows = result.collect_and_drop::().await?; + if rows.len() > 0 { + Ok(Some(FromRow::from_row(rows.swap_remove(0)))) } else { - Ok((this, rows.pop().map(FromRow::from_row))) + Ok(None) } } - /// See `Queryable::batch` - pub async fn batch(self, params_iter: I) -> Result + /// See [`Queryable::batch`]. + pub async fn batch(&mut self, params_iter: I) -> Result<()> where I: IntoIterator, - I::IntoIter: Send + 'static, Params: From

, - P: 'static, { let mut params_iter = params_iter.into_iter().map(Params::from); - let mut this = self; loop { match params_iter.next() { Some(params) => { - this = this.execute(params).await?.drop_result().await?; + let result = self.execute(params).await?; + result.drop_result().await?; } - None => break Ok(this), + None => break Ok(()), } } } - /// This will close statement (if it's not in the cache) and resolve to a wrapped queryable. - pub async fn close(mut self) -> Result { + /// This will close statement (if it's not in the cache). + pub async fn close(mut self) -> Result<()> { let cached = self.cached.take(); - match self.conn_like { - Some(Either::Left(conn_like)) => { - if let Some(StmtCacheResult::NotCached(stmt_id)) = cached { - conn_like.close_stmt(stmt_id).await - } else { - Ok(conn_like) - } - } - _ => unreachable!(), - } - } - - pub(crate) fn unwrap(mut self) -> (T, Option) { - match self.conn_like { - Some(Either::Left(conn_like)) => (conn_like, self.cached.take()), - _ => unreachable!(), + if let Some(StmtCacheResult::NotCached(stmt_id)) = cached { + self.conn_like.close_stmt(stmt_id).await?; } + Ok(()) } } -impl ConnectionLikeWrapper for Stmt { - type ConnLike = T; - - fn take_stream(self) -> (Streamless, io::Stream) - where - Self: Sized, - { - let Stmt { - conn_like, - inner, - cached, - } = self; - match conn_like { - Some(Either::Left(conn_like)) => { - let (streamless, stream) = conn_like.take_stream(); - let this = Stmt { - conn_like: Some(Either::Right(streamless)), - inner, - cached, - }; - (Streamless::new(this), stream) - } - _ => unreachable!(), - } +impl<'a, T: ConnectionLike> ConnectionLike for Stmt<'a, T> { + fn stream_mut(&mut self) -> &mut crate::io::Stream { + self.conn_like.stream_mut() } - - fn return_stream(&mut self, stream: io::Stream) { - let conn_like = self.conn_like.take().unwrap(); - match conn_like { - Either::Right(streamless) => { - self.conn_like = Some(Either::Left(streamless.return_stream(stream))); - } - _ => unreachable!(), - } + fn stmt_cache_ref(&self) -> &crate::conn::stmt_cache::StmtCache { + self.conn_like.stmt_cache_ref() } - - fn conn_like_ref(&self) -> &Self::ConnLike { - match self.conn_like { - Some(Either::Left(ref conn_like)) => conn_like, - _ => unreachable!(), - } + fn stmt_cache_mut(&mut self) -> &mut crate::conn::stmt_cache::StmtCache { + self.conn_like.stmt_cache_mut() } - - fn conn_like_mut(&mut self) -> &mut Self::ConnLike { - match self.conn_like { - Some(Either::Left(ref mut conn_like)) => conn_like, - _ => unreachable!(), - } + fn get_affected_rows(&self) -> u64 { + self.conn_like.get_affected_rows() + } + fn get_capabilities(&self) -> crate::consts::CapabilityFlags { + self.conn_like.get_capabilities() + } + fn get_tx_status(&self) -> TxStatus { + self.conn_like.get_tx_status() + } + fn get_last_insert_id(&self) -> Option { + self.conn_like.get_last_insert_id() + } + fn get_info(&self) -> std::borrow::Cow<'_, str> { + self.conn_like.get_info() + } + fn get_warnings(&self) -> u16 { + self.conn_like.get_warnings() + } + fn get_local_infile_handler( + &self, + ) -> Option> { + self.conn_like.get_local_infile_handler() + } + fn get_max_allowed_packet(&self) -> usize { + self.conn_like.get_max_allowed_packet() + } + fn get_opts(&self) -> &crate::Opts { + self.conn_like.get_opts() + } + fn get_pending_result(&self) -> Option<&crate::conn::PendingResult> { + self.conn_like.get_pending_result() + } + fn get_server_version(&self) -> (u16, u16, u16) { + self.conn_like.get_server_version() + } + fn get_status(&self) -> crate::consts::StatusFlags { + self.conn_like.get_status() + } + fn set_last_ok_packet(&mut self, ok_packet: Option>) { + self.conn_like.set_last_ok_packet(ok_packet) + } + fn set_tx_status(&mut self, tx_status: TxStatus) { + self.conn_like.set_tx_status(tx_status) + } + fn set_pending_result(&mut self, meta: Option) { + self.conn_like.set_pending_result(meta) + } + fn set_status(&mut self, status: crate::consts::StatusFlags) { + self.conn_like.set_status(status) + } + fn reset_seq_id(&mut self) { + self.conn_like.reset_seq_id() + } + fn sync_seq_id(&mut self) { + self.conn_like.sync_seq_id() + } + fn touch(&mut self) -> () { + self.conn_like.touch() + } + fn on_disconnect(&mut self) { + self.conn_like.on_disconnect() } } diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index d09c5806..bc13bcca 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -6,18 +6,23 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use futures_util::future::Either; - use std::fmt; -use crate::{ - connection_like::{streamless::Streamless, ConnectionLike, ConnectionLikeWrapper}, - error::*, - io, - queryable::Queryable, -}; +use crate::{connection_like::ConnectionLike, error::*, queryable::Queryable}; + +/// Transaction status. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u8)] +pub enum TxStatus { + /// Connection is in transaction at the moment. + InTransaction, + /// `Transaction` was dropped without explicit call to `commit` or `rollback`. + RequiresRollback, + /// Connection is not in transaction at the moment. + None, +} -/// Options for transaction +/// Transaction options. #[derive(Eq, PartialEq, Debug, Hash, Clone, Default)] pub struct TransactionOptions { consistent_snapshot: bool, @@ -26,15 +31,18 @@ pub struct TransactionOptions { } impl TransactionOptions { + /// Creates a default instance. pub fn new() -> TransactionOptions { TransactionOptions::default() } + /// See [`TransactionOptions::consistent_snapshot`]. pub fn set_consistent_snapshot(&mut self, value: bool) -> &mut Self { self.consistent_snapshot = value; self } + /// See [`TransactionOptions::isolation_level`]. pub fn set_isolation_level(&mut self, value: T) -> &mut Self where T: Into>, @@ -43,6 +51,7 @@ impl TransactionOptions { self } + /// See [`TransactionOptions::readonly`]. pub fn set_readonly(&mut self, value: T) -> &mut Self where T: Into>, @@ -51,14 +60,20 @@ impl TransactionOptions { self } + /// If true, then `START TRANSACTION WITH CONSISTENT SNAPSHOT` will be performed. + /// Defaults to `false`. pub fn consistent_snapshot(&self) -> bool { self.consistent_snapshot } + /// If not `None`, then `SET TRANSACTION ISOLATION LEVEL ..` will be performed. + /// Defaults to `None`. pub fn isolation_level(&self) -> Option { self.isolation_level } + /// If not `None`, then `SET TRANSACTION READ ONLY|WRITE` will be performed. + /// Defaults to `None`. pub fn readonly(&self) -> Option { self.readonly } @@ -86,27 +101,23 @@ impl fmt::Display for IsolationLevel { /// This struct represents MySql transaction. /// -/// `Transaction` it's a sugar for `START TRANSACTION`, `ROLLBACK` and `COMMIT` queries, so one +/// `Transaction` is just a sugar for `START TRANSACTION`, `ROLLBACK` and `COMMIT` queries, so one /// should note that it is easy to mess things up calling this queries manually. Also you will get /// `NestedTransaction` error if you call `transaction.start_transaction(_)`. -pub struct Transaction(Option>>); +pub struct Transaction<'a, T: ConnectionLike>(&'a mut T); -pub async fn new(conn_like: T, options: TransactionOptions) -> Result> -where - T: Queryable + ConnectionLike, -{ - Transaction::new(conn_like, options).await -} - -impl Transaction { - async fn new(mut conn_like: T, options: TransactionOptions) -> Result> { +impl<'a, T: Queryable + ConnectionLike> Transaction<'a, T> { + pub(crate) async fn new( + conn_like: &'a mut T, + options: TransactionOptions, + ) -> Result> { let TransactionOptions { consistent_snapshot, isolation_level, readonly, } = options; - if conn_like.get_in_transaction() { + if conn_like.get_tx_status() != TxStatus::None { return Err(DriverError::NestedTransaction.into()); } @@ -116,18 +127,18 @@ impl Transaction { if let Some(isolation_level) = isolation_level { let query = format!("SET TRANSACTION ISOLATION LEVEL {}", isolation_level); - conn_like = conn_like.drop_query(query).await?; + conn_like.drop_query(query).await?; } if let Some(readonly) = readonly { if readonly { - conn_like = conn_like.drop_query("SET TRANSACTION READ ONLY").await?; + conn_like.drop_query("SET TRANSACTION READ ONLY").await?; } else { - conn_like = conn_like.drop_query("SET TRANSACTION READ WRITE").await?; + conn_like.drop_query("SET TRANSACTION READ WRITE").await?; } } - conn_like = if consistent_snapshot { + if consistent_snapshot { conn_like .drop_query("START TRANSACTION WITH CONSISTENT SNAPSHOT") .await? @@ -135,71 +146,105 @@ impl Transaction { conn_like.drop_query("START TRANSACTION").await? }; - conn_like.set_in_transaction(true); - Ok(Transaction(Some(Either::Left(conn_like)))) - } - - fn unwrap(self) -> T { - match self { - Transaction(Some(Either::Left(conn_like))) => conn_like, - _ => unreachable!(), - } + conn_like.set_tx_status(TxStatus::InTransaction); + Ok(Transaction(conn_like)) } - /// Returns future that will perform `COMMIT` query and resolve to a wrapped `Queryable`. - pub async fn commit(self) -> Result { - let mut this = self.drop_query("COMMIT").await?; - this.set_in_transaction(false); - Ok(this.unwrap()) + /// Performs `COMMIT` query. + pub async fn commit(mut self) -> Result<()> { + let result = self.0.query("COMMIT").await?; + result.drop_result().await?; + self.set_tx_status(TxStatus::None); + Ok(()) } - /// Returns future that will perform `ROLLBACK` query and resolve to a wrapped `Queryable`. - pub async fn rollback(self) -> Result { - let mut this = self.drop_query("ROLLBACK").await?; - this.set_in_transaction(false); - Ok(this.unwrap()) + /// Performs `ROLLBACK` query. + pub async fn rollback(mut self) -> Result<()> { + let result = self.0.query("ROLLBACK").await?; + result.drop_result().await?; + self.set_tx_status(TxStatus::None); + Ok(()) } } -impl ConnectionLikeWrapper for Transaction { - type ConnLike = T; - - fn take_stream(self) -> (Streamless, io::Stream) - where - Self: Sized, - { - let Transaction(conn_like) = self; - match conn_like { - Some(Either::Left(conn_like)) => { - let (streamless, stream) = conn_like.take_stream(); - let this = Transaction(Some(Either::Right(streamless))); - (Streamless::new(this), stream) - } - _ => unreachable!(), +impl Drop for Transaction<'_, T> { + fn drop(&mut self) { + if self.get_tx_status() == TxStatus::InTransaction { + self.set_tx_status(TxStatus::RequiresRollback); } } +} - fn return_stream(&mut self, stream: io::Stream) { - let conn_like = self.0.take().unwrap(); - match conn_like { - Either::Right(streamless) => { - self.0 = Some(Either::Left(streamless.return_stream(stream))); - } - _ => unreachable!(), - } +impl<'a, T: ConnectionLike> ConnectionLike for Transaction<'a, T> { + fn stream_mut(&mut self) -> &mut crate::io::Stream { + self.0.stream_mut() } - - fn conn_like_ref(&self) -> &Self::ConnLike { - match self.0 { - Some(Either::Left(ref conn_like)) => conn_like, - _ => unreachable!(), - } + fn stmt_cache_ref(&self) -> &crate::conn::stmt_cache::StmtCache { + self.0.stmt_cache_ref() } - - fn conn_like_mut(&mut self) -> &mut Self::ConnLike { - match self.0 { - Some(Either::Left(ref mut conn_like)) => conn_like, - _ => unreachable!(), - } + fn stmt_cache_mut(&mut self) -> &mut crate::conn::stmt_cache::StmtCache { + self.0.stmt_cache_mut() + } + fn get_affected_rows(&self) -> u64 { + self.0.get_affected_rows() + } + fn get_capabilities(&self) -> crate::consts::CapabilityFlags { + self.0.get_capabilities() + } + fn get_tx_status(&self) -> TxStatus { + self.0.get_tx_status() + } + fn get_last_insert_id(&self) -> Option { + self.0.get_last_insert_id() + } + fn get_info(&self) -> std::borrow::Cow<'_, str> { + self.0.get_info() + } + fn get_warnings(&self) -> u16 { + self.0.get_warnings() + } + fn get_local_infile_handler( + &self, + ) -> Option> { + self.0.get_local_infile_handler() + } + fn get_max_allowed_packet(&self) -> usize { + self.0.get_max_allowed_packet() + } + fn get_opts(&self) -> &crate::Opts { + self.0.get_opts() + } + fn get_pending_result(&self) -> Option<&crate::conn::PendingResult> { + self.0.get_pending_result() + } + fn get_server_version(&self) -> (u16, u16, u16) { + self.0.get_server_version() + } + fn get_status(&self) -> crate::consts::StatusFlags { + self.0.get_status() + } + fn set_last_ok_packet(&mut self, ok_packet: Option>) { + self.0.set_last_ok_packet(ok_packet) + } + fn set_tx_status(&mut self, tx_status: TxStatus) { + self.0.set_tx_status(tx_status) + } + fn set_pending_result(&mut self, meta: Option) { + self.0.set_pending_result(meta) + } + fn set_status(&mut self, status: crate::consts::StatusFlags) { + self.0.set_status(status) + } + fn reset_seq_id(&mut self) { + self.0.reset_seq_id() + } + fn sync_seq_id(&mut self) { + self.0.sync_seq_id() + } + fn touch(&mut self) -> () { + self.0.touch() + } + fn on_disconnect(&mut self) { + self.0.on_disconnect() } } diff --git a/tests/generic.rs b/tests/generic.rs index bee1fa5b..9a50ae5b 100644 --- a/tests/generic.rs +++ b/tests/generic.rs @@ -26,16 +26,20 @@ fn get_url() -> String { } } -pub async fn get_all_results(result: QueryResult) -> Result> +pub async fn get_all_results<'a, TupleType, T, P>( + mut result: QueryResult<'a, T, P>, +) -> Result> where TupleType: FromRow + Send + 'static, P: Protocol + Send + 'static, T: ConnectionLike + Sized + Send + 'static, { - Ok(result.collect().await?.1) + Ok(result.collect().await?) } -pub async fn get_single_result(result: QueryResult) -> Result +pub async fn get_single_result<'a, TupleType, T, P>( + result: QueryResult<'a, T, P>, +) -> Result where TupleType: FromRow + Send + 'static, P: Protocol + Send + 'static, @@ -52,11 +56,12 @@ where #[tokio::test] async fn use_generic_code() { let pool = Pool::new(Opts::from_url(&*get_url()).unwrap()); - let conn = pool.get_conn().await.unwrap(); + let mut conn = pool.get_conn().await.unwrap(); let result = conn.query("SELECT 1, 2, 3").await.unwrap(); let result = get_single_result::<(u8, u8, u8), _, _>(result) .await .unwrap(); + drop(conn); pool.disconnect().await.unwrap(); assert_eq!(result, (1, 2, 3)); }