Skip to content

Commit

Permalink
Implement adding checks to buckets
Browse files Browse the repository at this point in the history
Don't ask about the horrendous code for this
  • Loading branch information
arqunis committed Jul 15, 2017
1 parent 4ce2ddf commit dc3a4df
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 24 deletions.
15 changes: 5 additions & 10 deletions src/framework/buckets.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
use chrono::Utc;
use std::collections::HashMap;
use std::default::Default;
use client::Context;
use model::{GuildId, ChannelId, UserId};

pub(crate) struct Ratelimit {
pub delay: i64,
pub limit: Option<(i64, i32)>,
}

#[derive(Default)]
pub(crate) struct MemberRatelimit {
pub last_time: i64,
pub set_time: i64,
pub tickets: i32,
}

impl Default for MemberRatelimit {
fn default() -> Self {
MemberRatelimit {
last_time: 0,
set_time: 0,
tickets: 0,
}
}
}

pub(crate) struct Bucket {
pub ratelimit: Ratelimit,
pub users: HashMap<u64, MemberRatelimit>,
#[cfg(feature="cache")]
pub check: Option<Box<Fn(&mut Context, GuildId, ChannelId, UserId) -> bool + 'static>>,
}

impl Bucket {
Expand Down
117 changes: 103 additions & 14 deletions src/framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ use std::collections::HashMap;
use std::default::Default;
use std::sync::Arc;
use ::client::Context;
use ::model::{Message, MessageId, UserId, ChannelId, ReactionType};
use ::model::{Message, MessageId, UserId, GuildId, ChannelId, ReactionType};
use ::model::permissions::Permissions;
use ::utils;
use tokio_core::reactor::Handle;
Expand Down Expand Up @@ -313,14 +313,65 @@ impl Framework {
/// ```
pub fn bucket<S>(mut self, s: S, delay: i64, time_span: i64, limit: i32) -> Self
where S: Into<String> {
feature_cache! {{
self.buckets.insert(s.into(), Bucket {
ratelimit: Ratelimit {
delay: delay,
limit: Some((time_span, limit)),
},
users: HashMap::new(),
check: None,
});
} else {
self.buckets.insert(s.into(), Bucket {
ratelimit: Ratelimit {
delay: delay,
limit: Some((time_span, limit)),
},
users: HashMap::new(),
});
}}

self
}

/// Same as [`bucket`] but with a check added.
///
/// # Examples
///
/// ```rust
/// # use serenity::prelude::*;
/// # struct Handler;
/// #
/// # impl EventHandler for Handler {}
/// # let mut client = Client::new("token", Handler);
/// #
/// client.with_framework(|f| f
/// .complex_bucket("basic", 2, 10, 3, |_, guild_id, channel_id, user_id| {
/// // check if the guild is `123` and the channel where the command(s) was called: `456`
/// // and if the user who called the command(s) is `789`
/// // otherwise don't apply the bucket at all.
/// guild_id == 123 && channel_id == 456 && user_id == 789
/// })
/// .command("ping", |c| c
/// .bucket("basic")
/// .exec_str("pong!")));
/// ```
///
/// [`bucket`]: #method.bucket
#[cfg(feature="cache")]
pub fn complex_bucket<S, Check>(mut self, s: S, delay: i64, time_span: i64, limit: i32, check: Check) -> Self
where Check: Fn(&mut Context, GuildId, ChannelId, UserId) -> bool + 'static,
S: Into<String> {
self.buckets.insert(s.into(), Bucket {
ratelimit: Ratelimit {
delay: delay,
delay,
limit: Some((time_span, limit)),
},
users: HashMap::new(),
check: Some(Box::new(check)),
});

self
}

Expand All @@ -345,13 +396,24 @@ impl Framework {
/// ```
pub fn simple_bucket<S>(mut self, s: S, delay: i64) -> Self
where S: Into<String> {
self.buckets.insert(s.into(), Bucket {
ratelimit: Ratelimit {
delay: delay,
limit: None,
},
users: HashMap::new(),
});
feature_cache! {{
self.buckets.insert(s.into(), Bucket {
ratelimit: Ratelimit {
delay: delay,
limit: None,
},
users: HashMap::new(),
check: None,
});
} else {
self.buckets.insert(s.into(), Bucket {
ratelimit: Ratelimit {
delay: delay,
limit: None,
},
users: HashMap::new(),
});
}}

self
}
Expand Down Expand Up @@ -452,11 +514,37 @@ impl Framework {
} else if self.configuration.owners.contains(&message.author.id) {
None
} else {
if let Some(rate_limit) = command.bucket.clone().map(|x| self.ratelimit_time(x.as_str(), message.author.id.0)) {
if rate_limit > 0i64 {
return Some(DispatchError::RateLimited(rate_limit));
feature_cache! {{
if let Some(ref bucket) = command.bucket {
if let Some(ref mut bucket) = self.buckets.get_mut(bucket) {
let rate_limit = bucket.take(message.author.id.0);
match bucket.check {
Some(ref check) => {
if let Some(guild_id) = message.guild_id() {
if (check)(context, guild_id, message.channel_id, message.author.id) {
if rate_limit > 0i64 {
return Some(DispatchError::RateLimited(rate_limit));
}
} else {
return None;
}
}
},
None => {
if rate_limit > 0i64 {
return Some(DispatchError::RateLimited(rate_limit));
}
},
}
}
}
}
} else {
if let Some(rate_limit) = command.bucket.clone().map(|x| self.ratelimit_time(x.as_str(), message.author.id.0)) {
if rate_limit > 0i64 {
return Some(DispatchError::RateLimited(rate_limit));
}
}
}}

if let Some(x) = command.min_args {
if args < x as usize {
Expand Down Expand Up @@ -884,6 +972,7 @@ impl Framework {
self.user_info = (user_id.0, is_bot);
}

#[allow(dead_code)]
fn ratelimit_time(&mut self, bucket_name: &str, user_id: u64) -> i64 {
self.buckets
.get_mut(bucket_name)
Expand Down

0 comments on commit dc3a4df

Please sign in to comment.