diff --git a/sopel/bot.py b/sopel/bot.py index 8f7632b5b6..cd99cfbf1c 100644 --- a/sopel/bot.py +++ b/sopel/bot.py @@ -420,22 +420,46 @@ def reply(self, message, destination=None, reply_to=None, notice=False): def call(self, func, sopel, trigger): nick = trigger.nick + current_time = time.time() if nick not in self._times: self._times[nick] = dict() - - if not trigger.admin and \ - not func.unblockable and \ - func.rate > 0 and \ - func in self._times[nick]: - timediff = time.time() - self._times[nick][func] - if timediff < func.rate: - self._times[nick][func] = time.time() - LOGGER.info( - "%s prevented from using %s in %s: %d < %d", - trigger.nick, func.__name__, trigger.sender, timediff, - func.rate - ) - return + if self.nick not in self._times: + self._times[self.nick] = dict() + if not trigger.is_privmsg and trigger.sender not in self._times: + self._times[trigger.sender] = dict() + + if not trigger.admin and not func.unblockable: + if func in self._times[nick]: + usertimediff = current_time - self._times[nick][func] + if func.rate > 0 and usertimediff < func.rate: + #self._times[nick][func] = current_time + LOGGER.info( + "%s prevented from using %s in %s due to user limit: %d < %d", + trigger.nick, func.__name__, trigger.sender, usertimediff, + func.rate + ) + return + if func in self._times[self.nick]: + globaltimediff = current_time - self._times[self.nick][func] + if func.global_rate > 0 and globaltimediff < func.global_rate: + #self._times[self.nick][func] = current_time + LOGGER.info( + "%s prevented from using %s in %s due to global limit: %d < %d", + trigger.nick, func.__name__, trigger.sender, globaltimediff, + func.global_rate + ) + return + + if not trigger.is_privmsg and func in self._times[trigger.sender]: + chantimediff = current_time - self._times[trigger.sender][func] + if func.channel_rate > 0 and chantimediff < func.channel_rate: + #self._times[trigger.sender][func] = current_time + LOGGER.info( + "%s prevented from using %s in %s due to channel limit: %d < %d", + trigger.nick, func.__name__, trigger.sender, chantimediff, + func.channel_rate + ) + return try: exit_code = func(sopel, trigger) @@ -444,7 +468,10 @@ def call(self, func, sopel, trigger): self.error(trigger) if exit_code != NOLIMIT: - self._times[nick][func] = time.time() + self._times[nick][func] = current_time + self._times[self.nick][func] = current_time + if not trigger.is_privmsg: + self._times[trigger.sender][func] = current_time def dispatch(self, pretrigger): args = pretrigger.args diff --git a/sopel/loader.py b/sopel/loader.py index 2d15c84097..82d3d964d9 100644 --- a/sopel/loader.py +++ b/sopel/loader.py @@ -159,6 +159,8 @@ def clean_callable(func, config): func.priority = getattr(func, 'priority', 'medium') func.thread = getattr(func, 'thread', True) func.rate = getattr(func, 'rate', 0) + func.channel_rate = getattr(func, 'channel_rate', 0) + func.global_rate = getattr(func, 'global_rate', 0) if not hasattr(func, 'event'): func.event = ['PRIVMSG'] diff --git a/sopel/module.py b/sopel/module.py index 5aa4101ea9..bd5d748431 100644 --- a/sopel/module.py +++ b/sopel/module.py @@ -248,19 +248,21 @@ def add_attribute(function): return add_attribute -def rate(value): - """Decorate a function to limit how often a single user may trigger it. - - If a function is given a rate of 20, a single user may only use that - function once every 20 seconds. This limit applies to each user - individually. Users on the admin list in Sopel’s configuration are exempted - from rate limits. +def rate(user=0, channel=0, server=0): + """Decorate a function to limit how often it can be triggered on a per-user + basis, in a channel, or across the server (bot). A value of zero means no + limit. If a function is given a rate of 20, that function may only be used + once every 20 seconds in the scope corresponding to the parameter. + Users on the admin list in Sopel’s configuration are exempted from rate + limits. Rate-limited functions that use scheduled future commands should import threading.Timer() instead of sched, or rate limiting will not work properly. """ def add_attribute(function): - function.rate = value + function.rate = user + function.channel_rate = channel + function.global_rate = server return function return add_attribute