diff --git a/Cargo.lock b/Cargo.lock index 7c6ca445b..dbfeea364 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -369,31 +369,6 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" -[[package]] -name = "bb8" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dae93eccab998c4b8703e3a6bbaa1714c38e445ebacb4bede25d0408521e293c" -dependencies = [ - "async-trait", - "futures-channel", - "futures-util", - "parking_lot 0.11.1", - "tokio 1.0.2", -] - -[[package]] -name = "bb8-postgres" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fdf56d52b2cca401d2380407e5c35d3d25d3560224ecf74d6e4ca13e51239b" -dependencies = [ - "async-trait", - "bb8", - "tokio 1.0.2", - "tokio-postgres", -] - [[package]] name = "bitflags" version = "1.2.1" @@ -3512,8 +3487,6 @@ dependencies = [ "adapter", "async-std", "async-trait", - "bb8", - "bb8-postgres", "chrono", "clap", "dashmap", @@ -4132,7 +4105,6 @@ dependencies = [ "memchr", "mio 0.7.7", "num_cpus", - "parking_lot 0.11.1", "pin-project-lite 0.2.4", "tokio-macros 1.0.0", ] diff --git a/primitives/src/util/tests/prep_db.rs b/primitives/src/util/tests/prep_db.rs index 68b18687d..48cbd5124 100644 --- a/primitives/src/util/tests/prep_db.rs +++ b/primitives/src/util/tests/prep_db.rs @@ -253,6 +253,10 @@ pub mod postgres { .host(&POSTGRES_HOST) .port(*POSTGRES_PORT); + if let Some(db) = POSTGRES_DB.as_ref() { + config.dbname(db); + } + let mgr_config = ManagerConfig { recycling_method: RecyclingMethod::Fast, }; diff --git a/sentry/Cargo.toml b/sentry/Cargo.toml index 968daf394..46eae8f46 100644 --- a/sentry/Cargo.toml +++ b/sentry/Cargo.toml @@ -23,8 +23,9 @@ hyper = { version = "0.14", features = ["stream", "runtime", "http1", "server"] regex = "1" # Database redis = { version = "0.19", features = ["aio", "tokio-comp"] } -bb8 = "0.7" -bb8-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1"] } +deadpool = "0.7.0" +deadpool-postgres = "0.7.0" +tokio-postgres = { version = "0.7.0", features = ["with-chrono-0_4", "with-serde_json-1"] } # Migrations migrant_lib = { version = "^0.32", features = ["d-postgres"] } @@ -37,10 +38,4 @@ serde_urlencoded = "^0.7" # Other lazy_static = "1.4.0" thiserror = "^1.0" -tokio-postgres = { version = "0.7.0", features = ["with-chrono-0_4", "with-serde_json-1"] } - -[dev-dependencies] -# todo: Replace `bb8` once we update all places. -deadpool = "0.7.0" -deadpool-postgres = "0.7.0" once_cell = "1.5.2" diff --git a/sentry/src/db.rs b/sentry/src/db.rs index e0a7dc628..81a336788 100644 --- a/sentry/src/db.rs +++ b/sentry/src/db.rs @@ -1,7 +1,7 @@ -use bb8::Pool; -use bb8_postgres::{tokio_postgres::NoTls, PostgresConnectionManager}; -use redis::{aio::MultiplexedConnection, RedisError}; +use deadpool_postgres::{Manager, ManagerConfig, RecyclingMethod}; +use redis::aio::MultiplexedConnection; use std::env; +use tokio_postgres::NoTls; use lazy_static::lazy_static; @@ -15,7 +15,12 @@ pub use self::channel::*; pub use self::event_aggregate::*; pub use self::validator_message::*; -pub type DbPool = Pool>; +// Re-export the Postgres PoolError for easier usages +pub use deadpool_postgres::PoolError; +// Re-export the redis RedisError for easier usage +pub use redis::RedisError; + +pub type DbPool = deadpool_postgres::Pool; lazy_static! { static ref POSTGRES_USER: String = @@ -29,6 +34,20 @@ lazy_static! { .parse() .unwrap(); static ref POSTGRES_DB: Option = env::var("POSTGRES_DB").ok(); + static ref POSTGRES_CONFIG: tokio_postgres::Config = { + let mut config = tokio_postgres::Config::new(); + + config + .user(POSTGRES_USER.as_str()) + .password(POSTGRES_PASSWORD.as_str()) + .host(POSTGRES_HOST.as_str()) + .port(*POSTGRES_PORT); + if let Some(db) = POSTGRES_DB.as_ref() { + config.dbname(db); + } + + config + }; } pub async fn redis_connection(url: &str) -> Result { @@ -37,20 +56,14 @@ pub async fn redis_connection(url: &str) -> Result Result { - let mut config = bb8_postgres::tokio_postgres::Config::new(); +pub async fn postgres_connection(max_size: usize) -> DbPool { + let mgr_config = ManagerConfig { + recycling_method: RecyclingMethod::Verified, + }; - config - .user(POSTGRES_USER.as_str()) - .password(POSTGRES_PASSWORD.as_str()) - .host(POSTGRES_HOST.as_str()) - .port(*POSTGRES_PORT); - if let Some(db) = POSTGRES_DB.as_ref() { - config.dbname(db); - } - let pg_mgr = PostgresConnectionManager::new(config, NoTls); + let manager = Manager::from_config(POSTGRES_CONFIG.clone(), NoTls, mgr_config); - Pool::builder().build(pg_mgr).await + DbPool::new(manager, max_size) } pub async fn setup_migrations(environment: &str) { @@ -81,7 +94,7 @@ pub async fn setup_migrations(environment: &str) { } // NOTE: Make sure to update list of migrations for the tests as well! - // `postgres_pool::MIGRATIONS` + // `tests_postgres::MIGRATIONS` let mut migrations = vec![make_migration!("20190806011140_initial-tables")]; if environment == "development" { @@ -125,140 +138,154 @@ pub async fn setup_migrations(environment: &str) { } #[cfg(test)] -pub mod postgres_pool { +pub mod tests_postgres { use std::{ ops::{Deref, DerefMut}, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, + sync::atomic::{AtomicUsize, Ordering}, }; use deadpool::managed::{Manager as ManagerTrait, RecycleResult}; - use deadpool_postgres::ClientWrapper; - use once_cell::sync::Lazy; - use tokio_postgres::{ - tls::{MakeTlsConnect, TlsConnect}, - Client, Error, SimpleQueryMessage, Socket, - }; + use deadpool_postgres::ManagerConfig; + use tokio_postgres::{NoTls, SimpleQueryMessage}; use async_trait::async_trait; - use super::{POSTGRES_DB, POSTGRES_HOST, POSTGRES_PASSWORD, POSTGRES_PORT, POSTGRES_USER}; + use super::{DbPool, PoolError}; - pub type Pool = deadpool::managed::Pool; + pub type Pool = deadpool::managed::Pool; /// we must have a duplication of the migration because of how migrant is handling migratoins /// we need to separately setup test migrations pub static MIGRATIONS: &[&str] = &["20190806011140_initial-tables"]; - pub static TESTS_POOL: Lazy = Lazy::new(|| { - use deadpool_postgres::{ManagerConfig, RecyclingMethod}; - use tokio_postgres::tls::NoTls; - let mut config = bb8_postgres::tokio_postgres::Config::new(); + pub fn test_postgres_connection(base_config: tokio_postgres::Config) -> Pool { + let manager_config = ManagerConfig { + recycling_method: deadpool_postgres::RecyclingMethod::Fast, + }; + let manager = Manager::new(base_config, manager_config); - config - .user(POSTGRES_USER.as_str()) - .password(POSTGRES_PASSWORD.as_str()) - .host(POSTGRES_HOST.as_str()) - .port(*POSTGRES_PORT); - if let Some(db) = POSTGRES_DB.as_ref() { - config.dbname(db); - } + Pool::new(manager, 15) + } - let deadpool_manager = deadpool_postgres::Manager::from_config( - config, - NoTls, - ManagerConfig { - recycling_method: RecyclingMethod::Verified, - }, - ); - - Pool::new( - Manager { - postgres_manager: Arc::new(deadpool_manager), - index: AtomicUsize::new(0), - }, - 15, - ) - }); - - /// A Scheme is used to isolate test runs from each other - /// we need to know the name of the schema we've created. - /// This will allow us the drop the schema when we are recycling the connection - pub struct Schema { - /// The schema name that will be created by the pool `CREATE SCHEMA` - /// This schema will be set as the connection `search_path` (`SET SCHEMA` for short) + /// A Database is used to isolate test runs from each other + /// we need to know the name of the database we've created. + /// This will allow us the drop the database when we are recycling the connection + pub struct Database { + /// The database name that will be created by the pool `CREATE DATABASE` + /// This database will be set on configuration level of the underlying connection Pool for tests pub name: String, - pub client: ClientWrapper, + pub pool: deadpool_postgres::Pool, } - impl Deref for Schema { - type Target = tokio_postgres::Client; - fn deref(&self) -> &tokio_postgres::Client { - &self.client + impl Deref for Database { + type Target = deadpool_postgres::Pool; + fn deref(&self) -> &deadpool_postgres::Pool { + &self.pool } } - impl DerefMut for Schema { - fn deref_mut(&mut self) -> &mut tokio_postgres::Client { - &mut self.client + impl DerefMut for Database { + fn deref_mut(&mut self) -> &mut deadpool_postgres::Pool { + &mut self.pool } } - struct Manager + Send + Sync> { - postgres_manager: Arc>, + /// Base Pool and Config are used to create a new SCHEMA and later on + /// create the actual with default options set for each connection to that SCHEMA + /// Otherwise we cannot create/ + pub struct Manager { + base_config: tokio_postgres::Config, + base_pool: deadpool_postgres::Pool, + manager_config: ManagerConfig, index: AtomicUsize, } - #[async_trait] - impl ManagerTrait for Manager - where - T: MakeTlsConnect + Clone + Sync + Send + 'static, - T::Stream: Sync + Send, - T::TlsConnect: Sync + Send, - >::Future: Send, - { - async fn create(&self) -> Result { - let client = self.postgres_manager.create().await?; - - let conn_index = self.index.fetch_add(1, Ordering::SeqCst); - let schema_name = format!("test_{}", conn_index); - - // 1. Drop the schema if it exists - if a test failed before, the schema wouldn't have been removed - // 2. Create schema - // 3. Set the `search_path` (SET SCHEMA) - this way we don't have to define schema on queries or table creation - - let queries = format!( - "DROP SCHEMA IF EXISTS {0} CASCADE; CREATE SCHEMA {0}; SET SESSION SCHEMA '{0}';", - schema_name + impl Manager { + pub fn new(base_config: tokio_postgres::Config, manager_config: ManagerConfig) -> Self { + // We need to create the schema with a temporary connection, in order to use it for the real Test Pool + let base_manager = deadpool_postgres::Manager::from_config( + base_config.clone(), + NoTls, + manager_config.clone(), ); + let base_pool = deadpool_postgres::Pool::new(base_manager, 15); - let result = client.simple_query(&queries).await?; + Self::new_with_pool(base_pool, base_config, manager_config) + } - assert_eq!(3, result.len()); - assert!(matches!(result[0], SimpleQueryMessage::CommandComplete(..))); - assert!(matches!(result[1], SimpleQueryMessage::CommandComplete(..))); - assert!(matches!(result[2], SimpleQueryMessage::CommandComplete(..))); + pub fn new_with_pool( + base_pool: deadpool_postgres::Pool, + base_config: tokio_postgres::Config, + manager_config: ManagerConfig, + ) -> Self { + Self { + base_config, + base_pool, + manager_config, + index: AtomicUsize::new(0), + } + } + } - Ok(Schema { - name: schema_name, - client, + #[async_trait] + impl ManagerTrait for Manager { + async fn create(&self) -> Result { + let pool_index = self.index.fetch_add(1, Ordering::SeqCst); + let db_name = format!("test_{}", pool_index); + + // 1. Drop the database if it exists - if a test failed before, the database wouldn't have been removed + // 2. Create database + let drop_db = format!("DROP DATABASE IF EXISTS {0} WITH (FORCE);", db_name); + let created_db = format!("CREATE DATABASE {0};", db_name); + let temp_client = self.base_pool.get().await?; + + let drop_db_result = temp_client.simple_query(drop_db.as_str()).await?; + assert_eq!(1, drop_db_result.len()); + assert!(matches!( + drop_db_result[0], + SimpleQueryMessage::CommandComplete(..) + )); + + let create_db_result = temp_client.simple_query(created_db.as_str()).await?; + assert_eq!(1, create_db_result.len()); + assert!(matches!( + create_db_result[0], + SimpleQueryMessage::CommandComplete(..) + )); + + let mut config = self.base_config.clone(); + // set the database in the configuration of the inside Pool (used for tests) + config.dbname(&db_name); + + let manager = + deadpool_postgres::Manager::from_config(config, NoTls, self.manager_config.clone()); + let pool = deadpool_postgres::Pool::new(manager, 15); + + Ok(Database { + name: db_name, + pool, }) } - async fn recycle(&self, schema: &mut Schema) -> RecycleResult { - let queries = format!("DROP SCHEMA {0} CASCADE;", schema.name); - let result = schema.simple_query(&queries).await?; - assert_eq!(2, result.len()); + async fn recycle(&self, database: &mut Database) -> RecycleResult { + let queries = format!("DROP DATABASE {0} WITH (FORCE);", database.name); + let result = self + .base_pool + .get() + .await? + .simple_query(&queries) + .await + .map_err(|err| PoolError::Backend(err))?; + assert_eq!(1, result.len()); assert!(matches!(result[0], SimpleQueryMessage::CommandComplete(..))); - assert!(matches!(result[1], SimpleQueryMessage::CommandComplete(..))); - self.postgres_manager.recycle(&mut schema.client).await + Ok(()) } } - pub async fn setup_test_migrations(client: &Client) -> Result<(), Error> { + pub async fn setup_test_migrations(pool: DbPool) -> Result<(), PoolError> { + let client = pool.get().await?; + let full_query: String = MIGRATIONS .iter() .map(|migration| { @@ -278,7 +305,7 @@ pub mod postgres_pool { }) .collect(); - client.batch_execute(&full_query).await + Ok(client.batch_execute(&full_query).await?) } } diff --git a/sentry/src/db/analytics.rs b/sentry/src/db/analytics.rs index 2d6abdba9..206e282dd 100644 --- a/sentry/src/db/analytics.rs +++ b/sentry/src/db/analytics.rs @@ -1,16 +1,15 @@ -use crate::db::DbPool; -use crate::epoch; -use crate::Auth; -use bb8::RunError; -use bb8_postgres::tokio_postgres::types::ToSql; +use crate::{epoch, Auth}; use chrono::Utc; -use primitives::analytics::{AnalyticsData, AnalyticsQuery, ANALYTICS_QUERY_LIMIT}; -use primitives::sentry::{AdvancedAnalyticsResponse, ChannelReport, PublisherReport}; -use primitives::{ChannelId, ValidatorId}; -use redis::aio::MultiplexedConnection; -use redis::cmd; +use primitives::{ + analytics::{AnalyticsData, AnalyticsQuery, ANALYTICS_QUERY_LIMIT}, + sentry::{AdvancedAnalyticsResponse, ChannelReport, PublisherReport}, + ChannelId, ValidatorId, +}; +use redis::{aio::MultiplexedConnection, cmd}; use std::collections::HashMap; -use std::error::Error; +use tokio_postgres::types::ToSql; + +use super::{DbPool, PoolError}; pub enum AnalyticsType { Advertiser { auth: Auth }, @@ -21,13 +20,13 @@ pub enum AnalyticsType { pub async fn advertiser_channel_ids( pool: &DbPool, creator: &ValidatorId, -) -> Result, RunError> { - let connection = pool.get().await?; +) -> Result, PoolError> { + let client = pool.get().await?; - let stmt = connection + let stmt = client .prepare("SELECT id FROM channels WHERE creator = $1") .await?; - let rows = connection.query(&stmt, &[creator]).await?; + let rows = client.query(&stmt, &[creator]).await?; let channel_ids: Vec = rows.iter().map(ChannelId::from).collect(); Ok(channel_ids) @@ -47,7 +46,9 @@ pub async fn get_analytics( analytics_type: AnalyticsType, segment_by_channel: bool, channel_id: Option<&ChannelId>, -) -> Result, RunError> { +) -> Result, PoolError> { + let client = pool.get().await?; + // converts metric to column let metric = metric_to_column(&query.metric); @@ -115,11 +116,9 @@ pub async fn get_analytics( applied_limit, ); - let connection = pool.get().await?; - // execute query - let stmt = connection.prepare(&sql_query).await?; - let rows = connection.query(&stmt, ¶ms).await?; + let stmt = client.prepare(&sql_query).await?; + let rows = client.query(&stmt, ¶ms).await?; let analytics: Vec = rows.iter().map(AnalyticsData::from).collect(); @@ -144,11 +143,11 @@ fn get_time_frame(timeframe: &str) -> (i64, i64) { async fn stat_pair( mut conn: MultiplexedConnection, key: &str, -) -> Result, Box> { +) -> Result, Box> { let data = cmd("ZRANGE") .arg(key) - .arg(0 as u64) - .arg(-1 as i64) + .arg(0_u64) + .arg(-1_i64) .arg("WITHSCORES") .query_async::<_, Vec>(&mut conn) .await?; @@ -169,7 +168,7 @@ pub async fn get_advanced_reports( event_type: &str, publisher: &ValidatorId, channel_ids: &[ChannelId], -) -> Result> { +) -> Result> { let publisher_reports = [ PublisherReport::AdUnit, PublisherReport::AdSlot, diff --git a/sentry/src/db/channel.rs b/sentry/src/db/channel.rs index cf19db04b..6ec72eadf 100644 --- a/sentry/src/db/channel.rs +++ b/sentry/src/db/channel.rs @@ -1,21 +1,20 @@ -use crate::db::DbPool; -use bb8::RunError; use chrono::Utc; -use primitives::validator::MessageTypes; -use primitives::{targeting::Rules, Channel, ChannelId, ValidatorId}; +use primitives::{targeting::Rules, validator::MessageTypes, Channel, ChannelId, ValidatorId}; use std::str::FromStr; pub use list_channels::list_channels; +use super::{DbPool, PoolError}; + pub async fn get_channel_by_id( pool: &DbPool, id: &ChannelId, -) -> Result, RunError> { - let connection = pool.get().await?; +) -> Result, PoolError> { + let client = pool.get().await?; - let select = connection.prepare("SELECT id, creator, deposit_asset, deposit_amount, valid_until, targeting_rules, spec, exhausted FROM channels WHERE id = $1 LIMIT 1").await?; + let select = client.prepare("SELECT id, creator, deposit_asset, deposit_amount, valid_until, targeting_rules, spec, exhausted FROM channels WHERE id = $1 LIMIT 1").await?; - let results = connection.query(&select, &[&id]).await?; + let results = client.query(&select, &[&id]).await?; Ok(results.get(0).map(Channel::from)) } @@ -24,28 +23,25 @@ pub async fn get_channel_by_id_and_validator( pool: &DbPool, id: &ChannelId, validator_id: &ValidatorId, -) -> Result, RunError> { - let connection = pool.get().await?; +) -> Result, PoolError> { + let client = pool.get().await?; let validator = serde_json::Value::from_str(&format!(r#"[{{"id": "{}"}}]"#, validator_id)) .expect("Not a valid json"); let query = "SELECT id, creator, deposit_asset, deposit_amount, valid_until, targeting_rules, spec, exhausted FROM channels WHERE id = $1 AND spec->'validators' @> $2 LIMIT 1"; - let select = connection.prepare(query).await?; + let select = client.prepare(query).await?; - let results = connection.query(&select, &[&id, &validator]).await?; + let results = client.query(&select, &[&id, &validator]).await?; Ok(results.get(0).map(Channel::from)) } -pub async fn insert_channel( - pool: &DbPool, - channel: &Channel, -) -> Result> { - let connection = pool.get().await?; +pub async fn insert_channel(pool: &DbPool, channel: &Channel) -> Result { + let client = pool.get().await?; - let stmt = connection.prepare("INSERT INTO channels (id, creator, deposit_asset, deposit_amount, valid_until, targeting_rules, spec, exhausted) values ($1, $2, $3, $4, $5, $6, $7, $8)").await?; + let stmt = client.prepare("INSERT INTO channels (id, creator, deposit_asset, deposit_amount, valid_until, targeting_rules, spec, exhausted) values ($1, $2, $3, $4, $5, $6, $7, $8)").await?; - let row = connection + let row = client .execute( &stmt, &[ @@ -69,13 +65,13 @@ pub async fn update_targeting_rules( pool: &DbPool, channel_id: &ChannelId, targeting_rules: &Rules, -) -> Result> { - let connection = pool.get().await?; +) -> Result { + let client = pool.get().await?; - let stmt = connection + let stmt = client .prepare("UPDATE channels SET targeting_rules=$1 WHERE id=$2") .await?; - let row = connection + let row = client .execute(&stmt, &[&targeting_rules, &channel_id]) .await?; @@ -88,12 +84,12 @@ pub async fn insert_validator_messages( channel: &Channel, from: &ValidatorId, validator_message: &MessageTypes, -) -> Result> { - let connection = pool.get().await?; +) -> Result { + let client = pool.get().await?; - let stmt = connection.prepare("INSERT INTO validator_messages (channel_id, \"from\", msg, received) values ($1, $2, $3, $4)").await?; + let stmt = client.prepare("INSERT INTO validator_messages (channel_id, \"from\", msg, received) values ($1, $2, $3, $4)").await?; - let row = connection + let row = client .execute( &stmt, &[&channel.id, &from, &validator_message, &Utc::now()], @@ -108,35 +104,34 @@ pub async fn update_exhausted_channel( pool: &DbPool, channel: &Channel, index: u32, -) -> Result> { - let connection = pool.get().await?; +) -> Result { + let client = pool.get().await?; - let stmt = connection + let stmt = client .prepare("UPDATE channels SET exhausted[$1] = true WHERE id = $2") .await?; // WARNING: By default PostgreSQL uses a one-based numbering convention for arrays, that is, an array of n elements starts with array[1] and ends with array[n]. // this is why we add +1 to the index - let row = connection - .execute(&stmt, &[&(index + 1), &channel.id]) - .await?; + let row = client.execute(&stmt, &[&(index + 1), &channel.id]).await?; let updated = row == 1; Ok(updated) } mod list_channels { - use crate::db::DbPool; - use bb8::RunError; - use bb8_postgres::tokio_postgres::types::{accepts, FromSql, ToSql, Type}; use chrono::{DateTime, Utc}; - use primitives::sentry::ChannelListResponse; - use primitives::{Channel, ValidatorId}; - use std::error::Error; + use primitives::{sentry::ChannelListResponse, Channel, ValidatorId}; use std::str::FromStr; + use tokio_postgres::types::{accepts, FromSql, ToSql, Type}; + + use crate::db::{DbPool, PoolError}; struct TotalCount(pub u64); impl<'a> FromSql<'a> for TotalCount { - fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + fn from_sql( + ty: &Type, + raw: &'a [u8], + ) -> Result> { let str_slice = <&str as FromSql>::from_sql(ty, raw)?; Ok(Self(u64::from_str(str_slice)?)) @@ -153,7 +148,9 @@ mod list_channels { creator: &Option, validator: &Option, valid_until_ge: &DateTime, - ) -> Result> { + ) -> Result { + let client = pool.get().await?; + let validator = validator.as_ref().map(|validator_id| { serde_json::Value::from_str(&format!(r#"[{{"id": "{}"}}]"#, validator_id)) .expect("Not a valid json") @@ -162,17 +159,15 @@ mod list_channels { channel_list_query_params(creator, validator.as_ref(), valid_until_ge); let total_count_params = (where_clauses.clone(), params.clone()); - let connection = pool.get().await?; - // To understand why we use Order by, see Postgres Documentation: https://www.postgresql.org/docs/8.1/queries-limit.html let statement = format!("SELECT id, creator, deposit_asset, deposit_amount, valid_until, targeting_rules, spec, exhausted FROM channels WHERE {} ORDER BY spec->>'created' DESC LIMIT {} OFFSET {}", where_clauses.join(" AND "), limit, skip); - let stmt = connection.prepare(&statement).await?; + let stmt = client.prepare(&statement).await?; - let rows = connection.query(&stmt, params.as_slice()).await?; + let rows = client.query(&stmt, params.as_slice()).await?; let channels = rows.iter().map(Channel::from).collect(); let total_count = - list_channels_total_count(&pool, (&total_count_params.0, total_count_params.1)).await?; + list_channels_total_count(pool, (&total_count_params.0, total_count_params.1)).await?; // fast ceil for total_pages let total_pages = if total_count == 0 { @@ -192,15 +187,15 @@ mod list_channels { async fn list_channels_total_count<'a>( pool: &DbPool, (where_clauses, params): (&'a [String], Vec<&'a (dyn ToSql + Sync)>), - ) -> Result> { - let connection = pool.get().await?; + ) -> Result { + let client = pool.get().await?; let statement = format!( "SELECT COUNT(id)::varchar FROM channels WHERE {}", where_clauses.join(" AND ") ); - let stmt = connection.prepare(&statement).await?; - let row = connection.query_one(&stmt, params.as_slice()).await?; + let stmt = client.prepare(&statement).await?; + let row = client.query_one(&stmt, params.as_slice()).await?; Ok(row.get::<_, TotalCount>(0).0) } diff --git a/sentry/src/db/event_aggregate.rs b/sentry/src/db/event_aggregate.rs index 4cb2d92fa..73b73d138 100644 --- a/sentry/src/db/event_aggregate.rs +++ b/sentry/src/db/event_aggregate.rs @@ -1,10 +1,3 @@ -use crate::db::DbPool; -use bb8::RunError; -use bb8_postgres::tokio_postgres::{ - binary_copy::BinaryCopyInWriter, - types::{ToSql, Type}, - Error, -}; use chrono::{DateTime, Utc}; use futures::pin_mut; use primitives::{ @@ -13,15 +6,21 @@ use primitives::{ Address, BigNum, Channel, ChannelId, ValidatorId, }; use std::{convert::TryFrom, ops::Add}; +use tokio_postgres::{ + binary_copy::BinaryCopyInWriter, + types::{ToSql, Type}, +}; + +use super::{DbPool, PoolError}; pub async fn latest_approve_state( pool: &DbPool, channel: &Channel, -) -> Result>, RunError> { - let connection = pool.get().await?; +) -> Result>, PoolError> { + let client = pool.get().await?; - let select = connection.prepare("SELECT \"from\", msg, received FROM validator_messages WHERE channel_id = $1 AND \"from\" = $2 AND msg ->> 'type' = 'ApproveState' ORDER BY received DESC LIMIT 1").await?; - let rows = connection + let select = client.prepare("SELECT \"from\", msg, received FROM validator_messages WHERE channel_id = $1 AND \"from\" = $2 AND msg ->> 'type' = 'ApproveState' ORDER BY received DESC LIMIT 1").await?; + let rows = client .query( &select, &[&channel.id, &channel.spec.validators.follower().id], @@ -31,18 +30,18 @@ pub async fn latest_approve_state( rows.get(0) .map(MessageResponse::::try_from) .transpose() - .map_err(RunError::User) + .map_err(PoolError::Backend) } pub async fn latest_new_state( pool: &DbPool, channel: &Channel, state_root: &str, -) -> Result>, RunError> { - let connection = pool.get().await?; +) -> Result>, PoolError> { + let client = pool.get().await?; - let select = connection.prepare("SELECT \"from\", msg, received FROM validator_messages WHERE channel_id = $1 AND \"from\" = $2 AND msg ->> 'type' = 'NewState' AND msg->> 'stateRoot' = $3 ORDER BY received DESC LIMIT 1").await?; - let rows = connection + let select = client.prepare("SELECT \"from\", msg, received FROM validator_messages WHERE channel_id = $1 AND \"from\" = $2 AND msg ->> 'type' = 'NewState' AND msg->> 'stateRoot' = $3 ORDER BY received DESC LIMIT 1").await?; + let rows = client .query( &select, &[ @@ -56,25 +55,23 @@ pub async fn latest_new_state( rows.get(0) .map(MessageResponse::::try_from) .transpose() - .map_err(RunError::User) + .map_err(PoolError::Backend) } pub async fn latest_heartbeats( pool: &DbPool, channel_id: &ChannelId, validator_id: &ValidatorId, -) -> Result>, RunError> { - let connection = pool.get().await?; +) -> Result>, PoolError> { + let client = pool.get().await?; - let select = connection.prepare("SELECT \"from\", msg, received FROM validator_messages WHERE channel_id = $1 AND \"from\" = $2 AND msg ->> 'type' = 'Heartbeat' ORDER BY received DESC LIMIT 2").await?; - let rows = connection - .query(&select, &[&channel_id, &validator_id]) - .await?; + let select = client.prepare("SELECT \"from\", msg, received FROM validator_messages WHERE channel_id = $1 AND \"from\" = $2 AND msg ->> 'type' = 'Heartbeat' ORDER BY received DESC LIMIT 2").await?; + let rows = client.query(&select, &[&channel_id, &validator_id]).await?; rows.iter() .map(MessageResponse::::try_from) .collect::>() - .map_err(RunError::User) + .map_err(PoolError::Backend) } pub async fn list_event_aggregates( @@ -83,7 +80,9 @@ pub async fn list_event_aggregates( limit: u32, from: &Option, after: &Option>, -) -> Result, RunError> { +) -> Result, PoolError> { + let client = pool.get().await?; + let (mut where_clauses, mut params) = (vec![], Vec::<&(dyn ToSql + Sync)>::new()); let id = channel_id.to_string(); params.push(&id); @@ -102,8 +101,6 @@ pub async fn list_event_aggregates( where_clauses.push(format!("created > ${}", params.len())); } - let connection = pool.get().await?; - let where_clause = if !where_clauses.is_empty() { where_clauses.join(" AND ").to_string() } else { @@ -135,8 +132,8 @@ pub async fn list_event_aggregates( ) SELECT channel_id, created, jsonb_object_agg(event_type , data) as events FROM aggregates GROUP BY channel_id, created ", where_clause, limit); - let stmt = connection.prepare(&statement).await?; - let rows = connection.query(&stmt, params.as_slice()).await?; + let stmt = client.prepare(&statement).await?; + let rows = client.query(&stmt, params.as_slice()).await?; let event_aggregates = rows.iter().map(EventAggregate::from).collect(); @@ -156,7 +153,7 @@ pub async fn insert_event_aggregate( pool: &DbPool, channel_id: &ChannelId, event: &EventAggregate, -) -> Result> { +) -> Result { let mut data: Vec = Vec::new(); for (event_type, aggr) in &event.events { @@ -189,10 +186,10 @@ pub async fn insert_event_aggregate( } } - let connection = pool.get().await?; + let client = pool.get().await?; - let mut err: Option = None; - let sink = connection.copy_in("COPY event_aggregates(channel_id, created, event_type, count, payout, earner) FROM STDIN BINARY").await?; + let mut err: Option = None; + let sink = client.copy_in("COPY event_aggregates(channel_id, created, event_type, count, payout, earner) FROM STDIN BINARY").await?; let created = Utc::now(); // time discrepancy @@ -227,7 +224,7 @@ pub async fn insert_event_aggregate( } match err { - Some(e) => Err(bb8::RunError::from(e)), + Some(e) => Err(PoolError::Backend(e)), None => { writer.finish().await?; Ok(true) diff --git a/sentry/src/db/spendable.rs b/sentry/src/db/spendable.rs index dd3b83f45..293358fb1 100644 --- a/sentry/src/db/spendable.rs +++ b/sentry/src/db/spendable.rs @@ -1,13 +1,15 @@ use std::convert::TryFrom; use primitives::{spender::Spendable, Address, ChannelId}; -use tokio_postgres::{Client, Error}; + +use super::{DbPool, PoolError}; /// ```text /// INSERT INTO spendable (spender, channel_id, channel, total, still_on_create2) /// values ('0xce07CbB7e054514D590a0262C93070D838bFBA2e', '0x061d5e2a67d0a9a10f1c732bca12a676d83f79663a396f7d87b3e30b9b411088', '{}', 10.00000000, 2.00000000); /// ``` -pub async fn insert_spendable(client: &Client, spendable: &Spendable) -> Result { +pub async fn insert_spendable(pool: DbPool, spendable: &Spendable) -> Result { + let client = pool.get().await?; let stmt = client.prepare("INSERT INTO spendable (spender, channel_id, channel, total, still_on_create2) values ($1, $2, $3, $4, $5)").await?; let row = client @@ -32,15 +34,16 @@ pub async fn insert_spendable(client: &Client, spendable: &Spendable) -> Result< /// WHERE spender = $1 AND channel_id = $2 /// ``` pub async fn fetch_spendable( - client: &Client, + pool: DbPool, spender: &Address, channel_id: &ChannelId, -) -> Result { +) -> Result { + let client = pool.get().await?; let statement = client.prepare("SELECT spender, channel_id, channel, total, still_on_create2 FROM spendable WHERE spender = $1 AND channel_id = $2").await?; let row = client.query_one(&statement, &[spender, channel_id]).await?; - Spendable::try_from(row) + Ok(Spendable::try_from(row)?) } #[cfg(test)] @@ -51,15 +54,22 @@ mod test { UnifiedNum, }; - use crate::db::postgres_pool::{setup_test_migrations, TESTS_POOL}; + use crate::db::{ + tests_postgres::{setup_test_migrations, test_postgres_connection}, + POSTGRES_CONFIG, + }; use super::*; #[tokio::test] async fn it_inserts_and_fetches_spendable() { - let test_client = TESTS_POOL.get().await.unwrap(); + let test_pool = test_postgres_connection(POSTGRES_CONFIG.clone()) + .get() + .await + .unwrap(); + // let pool = test_pool.get().await.expect("Should get a DB pool"); - setup_test_migrations(&test_client) + setup_test_migrations(test_pool.clone()) .await .expect("Migrations should succeed"); @@ -71,16 +81,19 @@ mod test { still_on_create2: UnifiedNum::from(500_000), }, }; - let is_inserted = insert_spendable(&test_client, &spendable) + let is_inserted = insert_spendable(test_pool.clone(), &spendable) .await .expect("Should succeed"); assert!(is_inserted); - let fetched_spendable = - fetch_spendable(&test_client, &spendable.spender, &spendable.channel.id()) - .await - .expect("Should fetch successfully"); + let fetched_spendable = fetch_spendable( + test_pool.clone(), + &spendable.spender, + &spendable.channel.id(), + ) + .await + .expect("Should fetch successfully"); assert_eq!(spendable, fetched_spendable); } diff --git a/sentry/src/db/validator_message.rs b/sentry/src/db/validator_message.rs index ecc4c467a..af2742091 100644 --- a/sentry/src/db/validator_message.rs +++ b/sentry/src/db/validator_message.rs @@ -1,7 +1,7 @@ -use crate::db::DbPool; -use bb8::RunError; -use bb8_postgres::tokio_postgres::types::ToSql; use primitives::{sentry::ValidatorMessage, ChannelId, ValidatorId}; +use tokio_postgres::types::ToSql; + +use super::{DbPool, PoolError}; pub async fn get_validator_messages( pool: &DbPool, @@ -9,7 +9,9 @@ pub async fn get_validator_messages( validator_id: &Option, message_types: &[String], limit: u64, -) -> Result, RunError> { +) -> Result, PoolError> { + let client = pool.get().await?; + let mut where_clauses: Vec = vec!["channel_id = $1".to_string()]; let mut params: Vec<&(dyn ToSql + Sync)> = vec![&channel_id]; @@ -20,15 +22,13 @@ pub async fn get_validator_messages( add_message_types_params(&mut where_clauses, &mut params, message_types); - let connection = pool.get().await?; - let statement = format!( r#"SELECT "from", msg, received FROM validator_messages WHERE {} ORDER BY received DESC LIMIT {}"#, where_clauses.join(" AND "), limit ); - let select = connection.prepare(&statement).await?; - let results = connection.query(&select, params.as_slice()).await?; + let select = client.prepare(&statement).await?; + let results = client.query(&select, params.as_slice()).await?; let messages = results.iter().map(ValidatorMessage::from).collect(); Ok(messages) diff --git a/sentry/src/event_aggregator.rs b/sentry/src/event_aggregator.rs index f2a9e4306..da5c8a339 100644 --- a/sentry/src/event_aggregator.rs +++ b/sentry/src/event_aggregator.rs @@ -1,17 +1,14 @@ -use crate::access::check_access; -use crate::access::Error as AccessError; -use crate::db::event_aggregate::insert_event_aggregate; -use crate::db::DbPool; -use crate::db::{get_channel_by_id, update_targeting_rules}; // // TODO: AIP#61 Event Aggregator should be replaced with the Spender aggregator & Event Analytics // // use crate::event_reducer; // use crate::payout::get_payout; -use crate::Application; -use crate::ResponseError; -use crate::Session; -use crate::{analytics_recorder, Auth}; +use crate::{ + access::{check_access, Error as AccessError}, + analytics_recorder, + db::{event_aggregate::insert_event_aggregate, get_channel_by_id, update_targeting_rules}, + Application, Auth, DbPool, ResponseError, Session, +}; use async_std::sync::RwLock; use chrono::Utc; use lazy_static::lazy_static; @@ -50,11 +47,11 @@ pub fn new_aggr(channel_id: &ChannelId) -> EventAggregate { } } -async fn store(db: &DbPool, channel_id: &ChannelId, logger: &Logger, recorder: Recorder) { +async fn store(pool: &DbPool, channel_id: &ChannelId, logger: &Logger, recorder: Recorder) { let mut channel_recorder = recorder.write().await; let record: Option<&Record> = channel_recorder.get(channel_id); if let Some(data) = record { - if let Err(e) = insert_event_aggregate(&db, &channel_id, &data.aggregate).await { + if let Err(e) = insert_event_aggregate(&pool, &channel_id, &data.aggregate).await { error!(&logger, "{}", e; "module" => "event_aggregator", "in" => "store"); } else { // reset aggr record @@ -105,6 +102,8 @@ impl EventAggregator { // the channel events to database if aggr_throttle > 0 { let recorder = recorder.clone(); + let dbpool = dbpool.clone(); + tokio::spawn(async move { loop { // break loop if the @@ -155,7 +154,7 @@ impl EventAggregator { }); if let Some(new_rules) = new_targeting_rules { - update_targeting_rules(&app.pool, &channel_id, &new_rules).await?; + update_targeting_rules(&dbpool.clone(), &channel_id, &new_rules).await?; } // @@ -202,7 +201,7 @@ impl EventAggregator { drop(channel_recorder); if aggr_throttle == 0 { - store(&app.pool, &channel_id, &app.logger, recorder.clone()).await; + store(&dbpool, &channel_id, &app.logger, recorder.clone()).await; } Ok(()) diff --git a/sentry/src/event_reducer.rs b/sentry/src/event_reducer.rs index cc7b5a8ee..4a1ba5c68 100644 --- a/sentry/src/event_reducer.rs +++ b/sentry/src/event_reducer.rs @@ -6,7 +6,7 @@ use primitives::{ // // TODO: AIP#61 remove `allow(dead_code)` and see what should be changed for Spender Aggregate // -#[allow(dead_code)] +#[allow(dead_code, clippy::unnecessary_wraps)] pub(crate) fn reduce( channel: &Channel, initial_aggr: &mut EventAggregate, diff --git a/sentry/src/lib.rs b/sentry/src/lib.rs index 7bf1378cd..403df405a 100644 --- a/sentry/src/lib.rs +++ b/sentry/src/lib.rs @@ -329,10 +329,9 @@ where ResponseError::BadRequest("Bad Request: try again later".into()) } } - -impl Into> for ResponseError { - fn into(self) -> Response { - map_response_error(self) +impl From for Response { + fn from(response_error: ResponseError) -> Self { + map_response_error(response_error) } } diff --git a/sentry/src/main.rs b/sentry/src/main.rs index 9f0e1f228..f8963f562 100644 --- a/sentry/src/main.rs +++ b/sentry/src/main.rs @@ -114,7 +114,7 @@ async fn main() -> Result<(), Box> { info!(&logger, "Checking connection and applying migrations..."); // Check connection and setup migrations before setting up Postgres setup_migrations(&environment).await; - let postgres = postgres_connection().await?; + let postgres = postgres_connection(42).await; match adapter { AdapterTypes::EthereumAdapter(adapter) => { diff --git a/sentry/src/middleware.rs b/sentry/src/middleware.rs index cd8c4a382..f0a652d54 100644 --- a/sentry/src/middleware.rs +++ b/sentry/src/middleware.rs @@ -36,10 +36,10 @@ impl Chain { } /// Applies chained middlewares in the order they were chained - pub async fn apply<'a>( + pub async fn apply( &self, mut request: Request, - application: &'a Application, + application: &Application, ) -> Result, ResponseError> { for middleware in self.0.iter() { request = middleware.call(request, application).await?; diff --git a/sentry/src/routes/channel.rs b/sentry/src/routes/channel.rs index f57cd9a80..9c485d78f 100644 --- a/sentry/src/routes/channel.rs +++ b/sentry/src/routes/channel.rs @@ -1,11 +1,9 @@ -use crate::db::event_aggregate::{latest_approve_state, latest_heartbeats, latest_new_state}; use crate::db::{ + event_aggregate::{latest_approve_state, latest_heartbeats, latest_new_state}, get_channel_by_id, insert_channel, insert_validator_messages, list_channels, - update_exhausted_channel, + update_exhausted_channel, PoolError, }; use crate::{success_response, Application, Auth, ResponseError, RouteParams, Session}; -use bb8::RunError; -use bb8_postgres::tokio_postgres::error; use futures::future::try_join_all; use hex::FromHex; use hyper::{Body, Request, Response}; @@ -20,6 +18,7 @@ use primitives::{ }; use slog::error; use std::collections::HashMap; +use tokio_postgres::error::SqlState; pub async fn channel_status( req: Request, @@ -59,10 +58,13 @@ pub async fn create_channel( match insert_channel(&app.pool, &channel).await { Err(error) => { error!(&app.logger, "{}", &error; "module" => "create_channel"); + match error { - RunError::User(e) if e.code() == Some(&error::SqlState::UNIQUE_VIOLATION) => Err( - ResponseError::Conflict("channel already exists".to_string()), - ), + PoolError::Backend(error) if error.code() == Some(&SqlState::UNIQUE_VIOLATION) => { + Err(ResponseError::Conflict( + "channel already exists".to_string(), + )) + } _ => Err(error_response), } } diff --git a/sentry/src/spender.rs b/sentry/src/spender.rs index b1dab92ce..2e7f3d20c 100644 --- a/sentry/src/spender.rs +++ b/sentry/src/spender.rs @@ -41,9 +41,9 @@ pub mod fee { // should never overflow let fee_payout = payout .checked_mul(&validator.fee) - .ok_or(DomainError::InvalidArgument( - "payout calculation overflow".to_string(), - ))? + .ok_or_else(|| { + DomainError::InvalidArgument("payout calculation overflow".to_string()) + })? .div_floor(&PRO_MILLE); Some(fee_payout)