From dc3a4dfafb1ee096b56c78d2506743e4012323f7 Mon Sep 17 00:00:00 2001 From: acdenisSK Date: Sat, 15 Jul 2017 23:01:12 +0200 Subject: [PATCH] Implement adding checks to buckets Don't ask about the horrendous code for this --- src/framework/buckets.rs | 15 ++--- src/framework/mod.rs | 117 ++++++++++++++++++++++++++++++++++----- 2 files changed, 108 insertions(+), 24 deletions(-) diff --git a/src/framework/buckets.rs b/src/framework/buckets.rs index 8fdaf3787c7..c4b81b64178 100644 --- a/src/framework/buckets.rs +++ b/src/framework/buckets.rs @@ -1,31 +1,26 @@ use chrono::Utc; use std::collections::HashMap; use std::default::Default; +use client::Context; +use model::{GuildId, ChannelId, UserId}; pub(crate) struct Ratelimit { pub delay: i64, pub limit: Option<(i64, i32)>, } +#[derive(Default)] pub(crate) struct MemberRatelimit { pub last_time: i64, pub set_time: i64, pub tickets: i32, } -impl Default for MemberRatelimit { - fn default() -> Self { - MemberRatelimit { - last_time: 0, - set_time: 0, - tickets: 0, - } - } -} - pub(crate) struct Bucket { pub ratelimit: Ratelimit, pub users: HashMap, + #[cfg(feature="cache")] + pub check: Option bool + 'static>>, } impl Bucket { diff --git a/src/framework/mod.rs b/src/framework/mod.rs index d68691edf7e..99d723633ce 100644 --- a/src/framework/mod.rs +++ b/src/framework/mod.rs @@ -74,7 +74,7 @@ use std::collections::HashMap; use std::default::Default; use std::sync::Arc; use ::client::Context; -use ::model::{Message, MessageId, UserId, ChannelId, ReactionType}; +use ::model::{Message, MessageId, UserId, GuildId, ChannelId, ReactionType}; use ::model::permissions::Permissions; use ::utils; use tokio_core::reactor::Handle; @@ -313,14 +313,65 @@ impl Framework { /// ``` pub fn bucket(mut self, s: S, delay: i64, time_span: i64, limit: i32) -> Self where S: Into { + feature_cache! {{ + self.buckets.insert(s.into(), Bucket { + ratelimit: Ratelimit { + delay: delay, + limit: Some((time_span, limit)), + }, + users: HashMap::new(), + check: None, + }); + } else { + self.buckets.insert(s.into(), Bucket { + ratelimit: Ratelimit { + delay: delay, + limit: Some((time_span, limit)), + }, + users: HashMap::new(), + }); + }} + + self + } + + /// Same as [`bucket`] but with a check added. + /// + /// # Examples + /// + /// ```rust + /// # use serenity::prelude::*; + /// # struct Handler; + /// # + /// # impl EventHandler for Handler {} + /// # let mut client = Client::new("token", Handler); + /// # + /// client.with_framework(|f| f + /// .complex_bucket("basic", 2, 10, 3, |_, guild_id, channel_id, user_id| { + /// // check if the guild is `123` and the channel where the command(s) was called: `456` + /// // and if the user who called the command(s) is `789` + /// // otherwise don't apply the bucket at all. + /// guild_id == 123 && channel_id == 456 && user_id == 789 + /// }) + /// .command("ping", |c| c + /// .bucket("basic") + /// .exec_str("pong!"))); + /// ``` + /// + /// [`bucket`]: #method.bucket + #[cfg(feature="cache")] + pub fn complex_bucket(mut self, s: S, delay: i64, time_span: i64, limit: i32, check: Check) -> Self + where Check: Fn(&mut Context, GuildId, ChannelId, UserId) -> bool + 'static, + S: Into { self.buckets.insert(s.into(), Bucket { ratelimit: Ratelimit { - delay: delay, + delay, limit: Some((time_span, limit)), }, users: HashMap::new(), + check: Some(Box::new(check)), }); - + self } @@ -345,13 +396,24 @@ impl Framework { /// ``` pub fn simple_bucket(mut self, s: S, delay: i64) -> Self where S: Into { - self.buckets.insert(s.into(), Bucket { - ratelimit: Ratelimit { - delay: delay, - limit: None, - }, - users: HashMap::new(), - }); + feature_cache! {{ + self.buckets.insert(s.into(), Bucket { + ratelimit: Ratelimit { + delay: delay, + limit: None, + }, + users: HashMap::new(), + check: None, + }); + } else { + self.buckets.insert(s.into(), Bucket { + ratelimit: Ratelimit { + delay: delay, + limit: None, + }, + users: HashMap::new(), + }); + }} self } @@ -452,11 +514,37 @@ impl Framework { } else if self.configuration.owners.contains(&message.author.id) { None } else { - if let Some(rate_limit) = command.bucket.clone().map(|x| self.ratelimit_time(x.as_str(), message.author.id.0)) { - if rate_limit > 0i64 { - return Some(DispatchError::RateLimited(rate_limit)); + feature_cache! {{ + if let Some(ref bucket) = command.bucket { + if let Some(ref mut bucket) = self.buckets.get_mut(bucket) { + let rate_limit = bucket.take(message.author.id.0); + match bucket.check { + Some(ref check) => { + if let Some(guild_id) = message.guild_id() { + if (check)(context, guild_id, message.channel_id, message.author.id) { + if rate_limit > 0i64 { + return Some(DispatchError::RateLimited(rate_limit)); + } + } else { + return None; + } + } + }, + None => { + if rate_limit > 0i64 { + return Some(DispatchError::RateLimited(rate_limit)); + } + }, + } + } } - } + } else { + if let Some(rate_limit) = command.bucket.clone().map(|x| self.ratelimit_time(x.as_str(), message.author.id.0)) { + if rate_limit > 0i64 { + return Some(DispatchError::RateLimited(rate_limit)); + } + } + }} if let Some(x) = command.min_args { if args < x as usize { @@ -884,6 +972,7 @@ impl Framework { self.user_info = (user_id.0, is_bot); } + #[allow(dead_code)] fn ratelimit_time(&mut self, bucket_name: &str, user_id: u64) -> i64 { self.buckets .get_mut(bucket_name)