diff --git a/.gitignore b/.gitignore index 89445832f..5742340f7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ target .idea .old Migrant.toml -node_modules/ +node_modules/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 96eae59aa..deadfa1f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3668,6 +3668,7 @@ dependencies = [ "serde_json", "serde_urlencoded 0.6.1", "slog", + "thiserror", "tokio 0.2.23", ] diff --git a/primitives/src/sentry.rs b/primitives/src/sentry.rs index 07117799b..f04c7cc3e 100644 --- a/primitives/src/sentry.rs +++ b/primitives/src/sentry.rs @@ -1,3 +1,4 @@ +use crate::targeting::Rules; use crate::validator::MessageTypes; use crate::{BigNum, Channel, ChannelId, ValidatorId}; use chrono::{DateTime, Utc}; @@ -52,19 +53,8 @@ pub enum Event { ad_slot: Option, referrer: Option, }, - ImpressionWithCommission { - earners: Vec, - }, - /// only the creator can send this event - UpdateImpressionPrice { - price: BigNum, - }, - /// only the creator can send this event - Pay { - outputs: HashMap, - }, /// only the creator can send this event - PauseChannel, + UpdateTargeting { targeting_rules: Rules }, /// only the creator can send this event Close, } @@ -84,10 +74,7 @@ impl fmt::Display for Event { match *self { Event::Impression { .. } => write!(f, "IMPRESSION"), Event::Click { .. } => write!(f, "CLICK"), - Event::ImpressionWithCommission { .. } => write!(f, "IMPRESSION_WITH_COMMMISION"), - Event::UpdateImpressionPrice { .. } => write!(f, "UPDATE_IMPRESSION_PRICE"), - Event::Pay { .. } => write!(f, "PAY"), - Event::PauseChannel => write!(f, "PAUSE_CHANNEL"), + Event::UpdateTargeting { .. } => write!(f, "UPDATE_TARGETING"), Event::Close => write!(f, "CLOSE"), } } diff --git a/sentry/Cargo.toml b/sentry/Cargo.toml index d958e43a0..48d09e941 100644 --- a/sentry/Cargo.toml +++ b/sentry/Cargo.toml @@ -35,3 +35,4 @@ serde_json = "^1.0" serde_urlencoded = "0.6.1" # Other lazy_static = "1.4.0" +thiserror = "^1.0" diff --git a/sentry/src/access.rs b/sentry/src/access.rs index 0c4cbb252..70568af03 100644 --- a/sentry/src/access.rs +++ b/sentry/src/access.rs @@ -7,34 +7,26 @@ use primitives::event_submission::{RateLimit, Rule}; use primitives::sentry::Event; use primitives::Channel; use std::cmp::PartialEq; -use std::error; -use std::fmt; +use thiserror::Error; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Error)] pub enum Error { + #[error("only creator can close channel")] OnlyCreatorCanCloseChannel, + #[error("only creator can update targeting rules")] + OnlyCreatorCanUpdateTargetingRules, + #[error("channel is expired")] ChannelIsExpired, + #[error("channel is in withdraw period")] ChannelIsInWithdrawPeriod, + #[error("event submission restricted")] ForbiddenReferrer, + #[error("{0}")] RulesError(String), + #[error("unauthenticated")] UnAuthenticated, } -impl error::Error for Error {} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Error::OnlyCreatorCanCloseChannel => write!(f, "only creator can create channel"), - Error::ChannelIsExpired => write!(f, "channel is expired"), - Error::ChannelIsInWithdrawPeriod => write!(f, "channel is in withdraw period"), - Error::ForbiddenReferrer => write!(f, "event submission restricted"), - Error::RulesError(error) => write!(f, "{}", error), - Error::UnAuthenticated => write!(f, "unauthenticated"), - } - } -} - // @TODO: Make pub(crate) pub async fn check_access( redis: &MultiplexedConnection, @@ -45,19 +37,21 @@ pub async fn check_access( events: &[Event], ) -> Result<(), Error> { let is_close_event = |e: &Event| matches!(e, Event::Close); + let is_update_targeting_event = |e: &Event| matches!(e, Event::UpdateTargeting { .. }); let has_close_event = events.iter().all(is_close_event); + let has_update_targeting_event = events.iter().all(is_update_targeting_event); let current_time = Utc::now(); let is_in_withdraw_period = current_time > channel.spec.withdraw_period_start; - if has_close_event && is_in_withdraw_period { - return Ok(()); - } - if current_time > channel.valid_until { return Err(Error::ChannelIsExpired); } + if has_close_event && is_in_withdraw_period { + return Ok(()); + } + let (is_creator, auth_uid) = match auth { Some(auth) => (auth.uid == channel.creator, auth.uid.to_string()), None => (false, Default::default()), @@ -68,11 +62,20 @@ pub async fn check_access( return Ok(()); } + if has_update_targeting_event && is_creator { + return Ok(()); + } + // Only the creator can send a CLOSE if !is_creator && events.iter().any(is_close_event) { return Err(Error::OnlyCreatorCanCloseChannel); } + // Only the creator can send a UPDATE_TARGETING + if !is_creator && events.iter().any(is_update_targeting_event) { + return Err(Error::OnlyCreatorCanUpdateTargetingRules); + } + if is_in_withdraw_period { return Err(Error::ChannelIsInWithdrawPeriod); } @@ -166,7 +169,6 @@ async fn apply_rule( } let seconds = rate_limit.time_frame.as_secs_f32().ceil(); - redis::cmd("SETEX") .arg(&key) .arg(seconds as i32) @@ -207,9 +209,11 @@ fn forbidden_country(session: &Session) -> bool { mod test { use std::time::Duration; + use chrono::TimeZone; use primitives::config::configuration; use primitives::event_submission::{RateLimit, Rule}; use primitives::sentry::Event; + use primitives::targeting::Rules; use primitives::util::tests::prep_db::{DUMMY_CHANNEL, IDS}; use primitives::{Channel, Config, EventSubmission}; @@ -251,6 +255,18 @@ mod test { .collect() } + fn get_close_events(count: i8) -> Vec { + (0..count).map(|_| Event::Close).collect() + } + + fn get_update_targeting_events(count: i8) -> Vec { + (0..count) + .map(|_| Event::UpdateTargeting { + targeting_rules: Rules::new(), + }) + .collect() + } + #[tokio::test] async fn session_uid_rate_limit() { let (config, redis) = setup().await; @@ -358,4 +374,472 @@ mod test { .await; assert_eq!(Ok(()), response); } + + #[tokio::test] + #[ignore] + async fn check_access_past_channel_valid_until() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let mut channel = get_channel(rule); + channel.valid_until = Utc.ymd(1970, 1, 1).and_hms(12, 00, 9); + + let err_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_impression_events(2), + ) + .await; + + assert_eq!(Err(Error::ChannelIsExpired), err_response); + } + + #[tokio::test] + #[ignore] + async fn check_access_close_event_in_withdraw_period() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let mut channel = get_channel(rule); + channel.spec.withdraw_period_start = Utc + .ymd(1970, 1, 1) + .and_hms(12, 0, 9); + + let ok_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_close_events(1), + ) + .await; + + assert_eq!(Ok(()), ok_response); + } + + #[tokio::test] + #[ignore] + async fn check_access_close_event_and_is_creator() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let mut channel = get_channel(rule); + channel.creator = IDS["follower"]; + + let ok_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_close_events(1), + ) + .await; + + assert_eq!(Ok(()), ok_response); + } + + #[tokio::test] + #[ignore] + async fn check_access_update_targeting_event_and_is_creator() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let mut channel = get_channel(rule); + channel.creator = IDS["follower"]; + + let ok_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_update_targeting_events(1), + ) + .await; + + assert_eq!(Ok(()), ok_response); + } + + #[tokio::test] + #[ignore] + async fn not_creator_and_there_are_close_events() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let mut channel = get_channel(rule); + channel.creator = IDS["leader"]; + let mixed_events = vec![ + Event::Impression { + publisher: IDS["publisher2"], + ad_unit: None, + ad_slot: None, + referrer: None, + }, + Event::Close, + Event::UpdateTargeting { + targeting_rules: Rules::new(), + }, + ]; + let err_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &mixed_events, + ) + .await; + + assert_eq!(Err(Error::OnlyCreatorCanCloseChannel), err_response); + } + + #[tokio::test] + #[ignore] + async fn not_creator_and_there_are_update_targeting_events() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let mut channel = get_channel(rule); + channel.creator = IDS["leader"]; + let mixed_events = vec![ + Event::Impression { + publisher: IDS["publisher2"], + ad_unit: None, + ad_slot: None, + referrer: None, + }, + Event::UpdateTargeting { + targeting_rules: Rules::new(), + }, + ]; + let err_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &mixed_events, + ) + .await; + + assert_eq!(Err(Error::OnlyCreatorCanUpdateTargetingRules), err_response); + } + + #[tokio::test] + #[ignore] + async fn in_withdraw_period_no_close_events() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let mut channel = get_channel(rule); + channel.spec.withdraw_period_start = Utc + .ymd(1970, 1, 1) + .and_hms(12, 0, 9); + + let err_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_impression_events(2), + ) + .await; + + assert_eq!(Err(Error::ChannelIsInWithdrawPeriod), err_response); + } + + #[tokio::test] + #[ignore] + async fn with_forbidden_country() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: Some("XX".into()), + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let channel = get_channel(rule); + + let err_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_impression_events(2), + ) + .await; + + assert_eq!(Err(Error::ForbiddenReferrer), err_response); + } + + #[tokio::test] + #[ignore] + async fn with_forbidden_referrer() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: Some("http://127.0.0.1".into()), + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(1), + }), + }; + let channel = get_channel(rule); + + let err_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_impression_events(2), + ) + .await; + + assert_eq!(Err(Error::ForbiddenReferrer), err_response); + } + + #[tokio::test] + #[ignore] + async fn no_rate_limit() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: None, + }; + let channel = get_channel(rule); + + let ok_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_impression_events(1), + ) + .await; + + assert_eq!(Ok(()), ok_response); + } + + #[tokio::test] + #[ignore] + async fn applied_rules() { + let (config, redis) = setup().await; + + let auth = Auth { + era: 0, + uid: IDS["follower"], + }; + + let session = Session { + ip: Default::default(), + referrer_header: None, + country: None, + os: None, + }; + + let rule = Rule { + uids: None, + rate_limit: Some(RateLimit { + limit_type: "ip".to_string(), + time_frame: Duration::from_millis(60_000), + }), + }; + let channel = get_channel(rule); + + let ok_response = check_access( + &redis, + &session, + Some(&auth), + &config.ip_rate_limit, + &channel, + &get_impression_events(1), + ) + .await; + + assert_eq!(Ok(()), ok_response); + let key = "adexRateLimit:061d5e2a67d0a9a10f1c732bca12a676d83f79663a396f7d87b3e30b9b411088:" + .to_string(); + let value = "1".to_string(); + + let value_in_redis = redis::cmd("GET") + .arg(&key) + .query_async::<_, String>(&mut redis.clone()) + .await + .expect("should exist in redis"); + assert_eq!(&value, &value_in_redis); + } } diff --git a/sentry/src/db/channel.rs b/sentry/src/db/channel.rs index fe0e7a7d5..0722df45a 100644 --- a/sentry/src/db/channel.rs +++ b/sentry/src/db/channel.rs @@ -2,7 +2,7 @@ use crate::db::DbPool; use bb8::RunError; use chrono::Utc; use primitives::validator::MessageTypes; -use primitives::{Channel, ChannelId, ValidatorId}; +use primitives::{targeting::Rules, Channel, ChannelId, ValidatorId}; use std::str::FromStr; pub use list_channels::list_channels; @@ -72,6 +72,32 @@ pub async fn insert_channel( .await } +pub async fn update_targeting_rules( + pool: &DbPool, + channel_id: &ChannelId, + targeting_rules: &Rules, +) -> Result> { + pool.run(move |connection| async move { + match connection + .prepare("UPDATE channels SET targeting_rules=$1 WHERE id=$2") + .await + { + Ok(stmt) => match connection + .execute(&stmt, &[&targeting_rules, &channel_id]) + .await + { + Ok(row) => { + let updated = row == 1; + Ok((updated, connection)) + } + Err(e) => Err((e, connection)), + }, + Err(e) => Err((e, connection)), + } + }) + .await +} + pub async fn insert_validator_messages( pool: &DbPool, channel: &Channel, diff --git a/sentry/src/event_aggregator.rs b/sentry/src/event_aggregator.rs index 42084cdbd..8079627cd 100644 --- a/sentry/src/event_aggregator.rs +++ b/sentry/src/event_aggregator.rs @@ -1,8 +1,8 @@ use crate::access::check_access; use crate::access::Error as AccessError; use crate::db::event_aggregate::insert_event_aggregate; -use crate::db::get_channel_by_id; use crate::db::DbPool; +use crate::db::{get_channel_by_id, update_targeting_rules}; use crate::event_reducer; use crate::Application; use crate::ResponseError; @@ -138,11 +138,23 @@ impl EventAggregator { AccessError::OnlyCreatorCanCloseChannel | AccessError::ForbiddenReferrer => { ResponseError::Forbidden(e.to_string()) } + AccessError::OnlyCreatorCanUpdateTargetingRules => { + ResponseError::Forbidden(e.to_string()) + } AccessError::RulesError(error) => ResponseError::TooManyRequests(error), AccessError::UnAuthenticated => ResponseError::Unauthorized, _ => ResponseError::BadRequest(e.to_string()), })?; + let new_targeting_rules = events.iter().find_map(|ev| match ev { + Event::UpdateTargeting { targeting_rules } => Some(targeting_rules), + _ => None, + }); + + if let Some(new_rules) = new_targeting_rules { + update_targeting_rules(&app.pool, &channel_id, &new_rules).await?; + } + events.iter().for_each(|ev| { match event_reducer::reduce( &app.logger,