Skip to content

Commit

Permalink
Add global/command-specific configurable ratelimiting
Browse files Browse the repository at this point in the history
  • Loading branch information
zajrik committed Mar 15, 2017
1 parent ffd987e commit 472f247
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 27 deletions.
4 changes: 4 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ export { GuildStorageLoader } from './lib/storage/GuildStorageLoader';
export { GuildStorageRegistry } from './lib/storage/GuildStorageRegistry';
export { LocalStorage } from './lib/storage/LocalStorage';
export { Middleware } from './lib/command/middleware/Middleware';
export { RateLimiter } from './lib/command/RateLimiter';
export { RateLimit } from './lib/command/RateLimit';

export { Time } from './lib/Time';
export { Util } from './lib/Util';

export { ArgOpts } from './lib/types/ArgOpts';
Expand Down
72 changes: 72 additions & 0 deletions src/lib/Time.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { Difference } from './types/Difference';

/**
* Extend the Date class to provide helper methods
*/
export class Time extends Date
{
public constructor() { super(); }

/**
* Return a Difference object representing the time difference between a and b
*/
public static difference(a: number, b: number): Difference
{
let difference: Difference = {};
let ms: number = a - b;
difference.ms = ms;

let days: number = Math.floor(ms / 1000 / 60 / 60 / 24);
ms -= days * 1000 * 60 * 60 * 24;
let hours: number = Math.floor(ms / 1000 / 60 / 60);
ms -= hours * 1000 * 60 * 60;
let mins: number = Math.floor(ms / 1000 / 60);
ms -= mins * 1000 * 60;
let secs: number = Math.floor(ms / 1000);

let timeString: string = '';
if (days) { difference.days = days; timeString += `${days} days${hours ? ', ' : ' '}`; }
if (hours) { difference.hours = hours; timeString += `${hours} hours${mins ? ', ' : ' '}`; }
if (mins) { difference.mins = mins; timeString += `${mins} mins${secs ? ', ' : ' '}`; }
if (secs) { difference.secs = secs; timeString += `${secs} secs`; }

// Returns the time string as '# days, # hours, # mins, # secs'
difference.toString = () => timeString.trim() || `${(ms / 1000).toFixed(2)} seconds`;

// Returns the time string as '#d #h #m #s'
difference.toSimplifiedString = () =>
timeString.replace(/ays|ours|ins|ecs| /g, '').replace(/,/g, ' ').trim();

return difference;
}

/**
* Return a Difference object (for convenience) measuring the
* duration of the given MS
*/
public static duration(time: number): Difference
{
return this.difference(time * 2, time);
}

/**
* Parse a duration shorthand string and return the duration in ms
*
* Shorthand examples: 10m, 5h, 1d
*/
public static parseShorthand(shorthand: string): number
{
let duration: number, match: RegExpMatchArray;
if (/^(?:\d+(?:\.\d+)?)[s|m|h|d]$/.test(<string> shorthand))
{
match = shorthand.match(/(\d+(?:\.\d+)?)(s|m|h|d)$/);
duration = parseFloat(match[1]);
duration = match[2] === 's'
? duration * 1000 : match[2] === 'm'
? duration * 1000 * 60 : match[2] === 'h'
? duration * 1000 * 60 * 60 : match[2] === 'd'
? duration * 1000 * 60 * 60 * 24 : null;
}
return duration;
}
}
6 changes: 6 additions & 0 deletions src/lib/bot/Bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { Command } from '../command/Command';
import { CommandLoader } from '../command/CommandLoader';
import { CommandRegistry } from '../command/CommandRegistry';
import { CommandDispatcher } from '../command/CommandDispatcher';
import { RateLimiter } from '../command/RateLimiter';
import { MiddlewareFunction } from '../types/MiddlewareFunction';

/**
Expand All @@ -32,6 +33,7 @@ export class Bot extends Client
public disableBase: string[];
public config: any;
public _middleware: MiddlewareFunction[];
public _rateLimiter: RateLimiter;

public storage: LocalStorage;
public guildStorages: GuildStorageRegistry<string, GuildStorage>;
Expand Down Expand Up @@ -150,6 +152,10 @@ export class Bot extends Client
*/
this.disableBase = botOptions.disableBase || [];

// Create the global RateLimiter instance if a ratelimit is specified
if (botOptions.ratelimit)
this._rateLimiter = new RateLimiter(botOptions.ratelimit, true);

// Middleware function storage for the bot instance
this._middleware = [];

Expand Down
6 changes: 6 additions & 0 deletions src/lib/command/Command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { PermissionResolvable, Message } from 'discord.js';
import { Bot } from '../bot/Bot';
import { MiddlewareFunction } from '../types/MiddlewareFunction';
import { CommandInfo } from '../types/CommandInfo';
import { RateLimiter } from './RateLimiter';
import { ArgOpts } from '../types/ArgOpts';

/**
Expand All @@ -28,6 +29,7 @@ export class Command<T extends Bot>
public overloads: string;

public _classloc: string;
public _rateLimiter: RateLimiter;
public _middleware: MiddlewareFunction[];

public constructor(bot: T, info: CommandInfo = null)
Expand Down Expand Up @@ -195,6 +197,10 @@ export class Command<T extends Bot>
*/
this.overloads = info.overloads || null;

// Create the RateLimiter instance if a ratelimit is specified
if (info.ratelimit)
this._rateLimiter = new RateLimiter(info.ratelimit, false);

// Middleware function storage for the Command instance
this._middleware = [];

Expand Down
61 changes: 58 additions & 3 deletions src/lib/command/CommandDispatcher.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { RateLimiter } from './RateLimiter';
import { PermissionResolvable, TextChannel, User } from 'discord.js';
import { MiddlewareFunction } from '../types/MiddlewareFunction';
import { Message } from '../types/Message';
import { GuildStorage } from '../storage/GuildStorage';
import { Command } from '../command/Command';
import { Bot } from '../bot/Bot';
import { RateLimit } from './RateLimit';
import { Time } from '../Time';
import now = require('performance-now');

/**
Expand All @@ -30,14 +33,19 @@ export class CommandDispatcher<T extends Bot>
if (this._bot.selfbot && message.author !== this._bot.user) return;
if (message.author.bot) return;

const dm: boolean = ['dm', 'group'].includes(message.channel.type);
const dm: boolean = message.channel.type !== 'text';
if (!dm) message.guild.storage = this._bot.guildStorages.get(message.guild);

// Check blacklist
if (this.isBlacklisted(message.author, message, dm)) return;

const [commandCalled, command, prefix, name]: [boolean, Command<T>, string, string] = this.isCommandCalled(message);
if (!commandCalled) return;

// Check ratelimits
if (!this.checkRateLimits(message, command)) return;

// Remove bot from message.mentions if only mentioned one time as a prefix
if (!(!dm && prefix === message.guild.storage.getSetting('prefix')) && prefix !== ''
&& (message.content.match(new RegExp(`<@!?${this._bot.user.id}>`, 'g')) || []).length === 1)
message.mentions.users.delete(this._bot.user.id);
Expand Down Expand Up @@ -92,7 +100,7 @@ export class CommandDispatcher<T extends Bot>
*/
private isCommandCalled(message: Message): [boolean, Command<T>, string, string]
{
const dm: boolean = ['dm', 'group'].includes(message.channel.type);
const dm: boolean = message.channel.type !== 'text';
const prefixes: string[] = [
`<@${this._bot.user.id}>`,
`<@!${this._bot.user.id}>`
Expand Down Expand Up @@ -123,7 +131,7 @@ export class CommandDispatcher<T extends Bot>
private testCommand(command: Command<T>, message: Message): boolean
{
const config: any = this._bot.config;
const dm: boolean = ['dm', 'group'].includes(message.channel.type);
const dm: boolean = message.channel.type !== 'text';
const storage: GuildStorage = !dm ? this._bot.guildStorages.get(message.guild) : null;

if (!dm && storage.settingExists('disabledGroups')
Expand All @@ -139,6 +147,53 @@ export class CommandDispatcher<T extends Bot>
return true;
}

/**
* Check either global or command-specific rate limits for the given
* message author and also notify them if they exceed ratelimits
*/
private checkRateLimiter(message: Message, command?: Command<T>): boolean
{
const rateLimiter: RateLimiter = command ? command._rateLimiter : this._bot._rateLimiter;
if (!rateLimiter) return true;

const rateLimit: RateLimit = rateLimiter.get(message);
if (!rateLimit.isLimited) return true;

if (!rateLimit.wasNotified)
{
const globalLimiter: RateLimiter = this._bot._rateLimiter;
const globalLimit: RateLimit = globalLimiter ? globalLimiter.get(message) : null;
if (globalLimit && globalLimit.isLimited && globalLimit.wasNotified) return;

rateLimit.setNotified();
if (!command) message.channel.send(
`You have tried to use too many commands and may not use any more for **${
Time.difference(rateLimit.expires, Date.now()).toString()}**.`);
else message.channel.send(
`You have tried to use this command too many times and may not use it again for **${
Time.difference(rateLimit.expires, Date.now()).toString()}**.`);
}
return false;
}

/**
* Check global and command-specific ratelimits for the user
* for the given command
*/
private checkRateLimits(message: Message, command: Command<T>): boolean
{
let passedGlobal: boolean = true;
let passedCommand: boolean = true;
let passedRateLimiters: boolean = true;
if (!this.checkRateLimiter(message)) passedGlobal = false;
if (!this.checkRateLimiter(message, command)) passedCommand = false;
if (!passedGlobal || !passedCommand) passedRateLimiters = false;
if (passedRateLimiters)
if (!(command && command._rateLimiter && !command._rateLimiter.get(message).call()) && this._bot._rateLimiter)
this._bot._rateLimiter.get(message).call();
return passedRateLimiters;
}

/**
* Compare user permissions to the command's requisites
*/
Expand Down
68 changes: 68 additions & 0 deletions src/lib/command/RateLimit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* Maintains its own call count and expiry for making sure
* things only happen a certain number of times within
* a given timeframe
*/
export class RateLimit
{
private _limit: number;
private _duration: number;
private _count: number;
private _notified: boolean;
public expires: number;
public constructor(limit: [number, number])
{
[this._limit, this._duration] = limit;
this._reset();
}

/**
* Sets this RateLimit to default values
*/
private _reset(): void
{
this._count = 0;
this.expires = 0;
this._notified = false;
}

/**
* Returns whether or not this rate limit has been capped out
* for its current expiry window while incrementing calls
* towards the rate limit cap if not currently capped
*/
public call(): boolean
{
if (this.expires < Date.now()) this._reset();
if (this._count >= this._limit) return false;
this._count++;
if (this._count === 1) this.expires = Date.now() + this._duration;
return true;
}

/**
* Return whether or not this ratelimit is currently capped out
*/
public get isLimited(): boolean
{
return (this._count >= this._limit) && (Date.now() < this.expires);
}

/**
* Flag this RateLimit as having had the user the RateLimit
* is for notified of being rate limited
*/
public setNotified(): void
{
this._notified = true;
}

/**
* Return whether or not this RateLimit was flagged after
* notifying the user of being rate limited
*/
public get wasNotified(): boolean
{
return this._notified;
}
}
71 changes: 71 additions & 0 deletions src/lib/command/RateLimiter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { Collection } from 'discord.js';
import { RateLimit } from './RateLimit';
import { Time } from '../Time';
import { Message } from '../types/Message';

/**
* Handles assigning ratelimits to guildmembers and users
*/
export class RateLimiter
{
private _limit: [number, number];
private _global: boolean;
private _rateLimits: Collection<string, Collection<string, RateLimit>>;
private _globalLimits: Collection<string, RateLimit>;
public constructor(limit: string, global: boolean)
{
this._limit = this._parseLimit(limit);
this._global = global;

this._rateLimits = new Collection<string, Collection<string, RateLimit>>();
this._globalLimits = new Collection<string, RateLimit>();
}

/**
* Returns the RateLimit object for the message author if global
* or message member if message is in a guild
*/
public get(message: Message): RateLimit
{
if (this._isGlobal(message))
{
if (!this._globalLimits.has(message.author.id))
this._globalLimits.set(message.author.id, new RateLimit(this._limit));
return this._globalLimits.get(message.author.id);
}
else
{
if (!this._rateLimits.has(message.guild.id))
this._rateLimits.set(message.guild.id, new Collection<string, RateLimit>());

if (!this._rateLimits.get(message.guild.id).has(message.author.id))
this._rateLimits.get(message.guild.id).set(message.author.id, new RateLimit(this._limit));

return this._rateLimits.get(message.guild.id).get(message.author.id);
}
}

/**
* Parse the ratelimit from the given input string
*/
private _parseLimit(limitString: string): [number, number]
{
const limitRegex: RegExp = /^(\d+)\/(\d+)(s|m|h|d)?$/;
if (!limitRegex.test(limitString)) throw new Error(`Failed to parse a ratelimit from '${limitString}'`);
let [limit, duration, post]: (string | number)[] = limitRegex.exec(limitString).slice(1, 4);

if (post) duration = Time.parseShorthand(duration + post);
else duration = parseInt(duration);
limit = parseInt(limit);

return [limit, duration];
}

/**
* Determine whether or not to use the global rate limit collection
*/
private _isGlobal(message?: Message): boolean
{
return message ? message.channel.type !== 'text' || this._global : this._global;
}
}
Loading

0 comments on commit 472f247

Please sign in to comment.