diff --git a/Makefile b/Makefile index 7578bd058..98e76cd01 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ lint-style: flake8 sopel/ test/ lint-type: - mypy --check-untyped-defs sopel + mypy --check-untyped-defs --disallow-incomplete-defs sopel .PHONY: test test_norecord test_novcr vcr_rerecord test: diff --git a/NEWS b/NEWS index fe7647ed5..add20ec6e 100644 --- a/NEWS +++ b/NEWS @@ -2,6 +2,52 @@ This file is used to auto-generate the "Changelog" section of Sopel's website. When adding new entries, follow the style guide in NEWS.spec.md to avoid causing problems with the site build. +Changes between 8.0.0 and 8.0.1 +=============================== + +Plugin changes +-------------- + +* find: + * Fixed double-bold formatting [[#2589][]] + * Support escaping backslashes [[#2589][]] + +Core changes +------------ + +* Use distribution name to query version of entry-point plugins [[#2594][]] +* Added plugin version number in `sopel-plugins show` output [[#2638][]] +* Fixed loading folder-style plugins with relative imports [[#2633][]] +* Fixed rate-limiting behavior for rules without a rate limit [[#2629][]] +* `config.types.ChoiceAttribute` logs invalid values for debugging [[#2624][]] +* Also remove null (`\x00`) in `irc.utils.safe()` function [[#2620][]] + +Housekeeping changes +-------------------- + +* Document advanced tip about arbitrarily scheduling code [[#2617][]] +* Include `versionadded` notes for more methods in `irc.AbstractBot` [[#2642][]] +* Start moving from `typing.Optional` to the `| None` convention [[#2642][]] +* Minor updates to keep up with type-checking ecosystem [[#2614][], [#2628][]] +* Start checking for incomplete type annotations [[#2616][]] +* Added tests to `find` plugin [[#2589][]] +* Fixed slowdown in `@plugin.example` tests with `repeat` enabled [[#2630][]] + +[#2589]: https://github.com/sopel-irc/sopel/pull/2589 +[#2594]: https://github.com/sopel-irc/sopel/pull/2594 +[#2614]: https://github.com/sopel-irc/sopel/pull/2614 +[#2616]: https://github.com/sopel-irc/sopel/pull/2616 +[#2617]: https://github.com/sopel-irc/sopel/pull/2617 +[#2620]: https://github.com/sopel-irc/sopel/pull/2620 +[#2624]: https://github.com/sopel-irc/sopel/pull/2624 +[#2628]: https://github.com/sopel-irc/sopel/pull/2628 +[#2629]: https://github.com/sopel-irc/sopel/pull/2629 +[#2630]: https://github.com/sopel-irc/sopel/pull/2630 +[#2633]: https://github.com/sopel-irc/sopel/pull/2633 +[#2638]: https://github.com/sopel-irc/sopel/pull/2638 +[#2642]: https://github.com/sopel-irc/sopel/pull/2642 + + Changes between 7.1.9 and 8.0.0 =============================== diff --git a/docs/source/plugin/advanced.rst b/docs/source/plugin/advanced.rst index 0c2997e82..d07e77c2b 100644 --- a/docs/source/plugin/advanced.rst +++ b/docs/source/plugin/advanced.rst @@ -15,6 +15,45 @@ If something is not in here, feel free to ask about it on our IRC channel, or maybe open an issue with the solution if you devise one yourself. +Running a function on a schedule +================================ + +Sopel provides the :func:`@plugin.interval ` decorator +to run plugin callables periodically, but plugin developers semi-frequently ask +how to run a function at the same time every day/week. + +Integrating this kind of feature into Sopel's plugin API is trickier than one +might think, and it's actually simpler to have plugins just use a library like +`schedule`__ directly:: + + import schedule + + from sopel import plugin + + + def scheduled_message(bot): + bot.say("This is the scheduled message.", "#channelname") + + + def setup(bot): + # schedule the message at midnight every day + schedule.every().day.at('00:00').do(scheduled_message, bot=bot) + + + @plugin.interval(60) + def run_schedule(bot): + schedule.run_pending() + +As long as the ``bot`` is passed as an argument, the scheduled function can +access config settings or any other attributes/properties it needs. + +Multiple plugins all setting up their own checks with ``interval`` naturally +creates *some* overhead, but it shouldn't be significant compared to all the +other things happening inside a Sopel bot with numerous plugins. + +.. __: https://pypi.org/project/schedule/ + + Restricting commands to certain channels ======================================== diff --git a/pyproject.toml b/pyproject.toml index 2e1ccbce6..41793601e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ namespaces = false [project] name = "sopel" -version = "8.0.0" +version = "8.0.1" description = "Simple and extensible IRC bot" maintainers = [ { name="dgw" }, diff --git a/sopel/bot.py b/sopel/bot.py index cc908f247..d52e1707c 100644 --- a/sopel/bot.py +++ b/sopel/bot.py @@ -17,7 +17,9 @@ from types import MappingProxyType from typing import ( Any, + Callable, Optional, + Sequence, TYPE_CHECKING, TypeVar, Union, @@ -36,6 +38,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Mapping + + from sopel.plugins.handlers import AbstractPluginHandler from sopel.trigger import PreTrigger @@ -182,7 +186,6 @@ def hostmask(self) -> Optional[str]: :return: the bot's current hostmask if the bot is connected and in a least one channel; ``None`` otherwise - :rtype: Optional[str] """ if not self.users or self.nick not in self.users: # bot must be connected and in at least one channel @@ -198,11 +201,11 @@ def plugins(self) -> Mapping[str, plugins.handlers.AbstractPluginHandler]: """ return MappingProxyType(self._plugins) - def has_channel_privilege(self, channel, privilege) -> bool: + def has_channel_privilege(self, channel: str, privilege: int) -> bool: """Tell if the bot has a ``privilege`` level or above in a ``channel``. - :param str channel: a channel the bot is in - :param int privilege: privilege level to check + :param channel: a channel the bot is in + :param privilege: privilege level to check :raise ValueError: when the channel is unknown This method checks the bot's privilege level in a channel, i.e. if it @@ -339,10 +342,10 @@ def post_setup(self) -> None: # plugins management - def reload_plugin(self, name) -> None: + def reload_plugin(self, name: str) -> None: """Reload a plugin. - :param str name: name of the plugin to reload + :param name: name of the plugin to reload :raise plugins.exceptions.PluginNotRegistered: when there is no ``name`` plugin registered @@ -391,22 +394,24 @@ def reload_plugins(self) -> None: # TODO: deprecate both add_plugin and remove_plugin; see #2425 - def add_plugin(self, plugin, callables, jobs, shutdowns, urls) -> None: + def add_plugin( + self, + plugin: AbstractPluginHandler, + callables: Sequence[Callable], + jobs: Sequence[Callable], + shutdowns: Sequence[Callable], + urls: Sequence[Callable], + ) -> None: """Add a loaded plugin to the bot's registry. :param plugin: loaded plugin to add - :type plugin: :class:`sopel.plugins.handlers.AbstractPluginHandler` :param callables: an iterable of callables from the ``plugin`` - :type callables: :term:`iterable` :param jobs: an iterable of functions from the ``plugin`` that are periodically invoked - :type jobs: :term:`iterable` :param shutdowns: an iterable of functions from the ``plugin`` that should be called on shutdown - :type shutdowns: :term:`iterable` :param urls: an iterable of functions from the ``plugin`` to call when matched against a URL - :type urls: :term:`iterable` """ self._plugins[plugin.name] = plugin self.register_callables(callables) @@ -414,22 +419,24 @@ def add_plugin(self, plugin, callables, jobs, shutdowns, urls) -> None: self.register_shutdowns(shutdowns) self.register_urls(urls) - def remove_plugin(self, plugin, callables, jobs, shutdowns, urls) -> None: + def remove_plugin( + self, + plugin: AbstractPluginHandler, + callables: Sequence[Callable], + jobs: Sequence[Callable], + shutdowns: Sequence[Callable], + urls: Sequence[Callable], + ) -> None: """Remove a loaded plugin from the bot's registry. :param plugin: loaded plugin to remove - :type plugin: :class:`sopel.plugins.handlers.AbstractPluginHandler` :param callables: an iterable of callables from the ``plugin`` - :type callables: :term:`iterable` :param jobs: an iterable of functions from the ``plugin`` that are periodically invoked - :type jobs: :term:`iterable` :param shutdowns: an iterable of functions from the ``plugin`` that should be called on shutdown - :type shutdowns: :term:`iterable` :param urls: an iterable of functions from the ``plugin`` to call when matched against a URL - :type urls: :term:`iterable` """ name = plugin.name if not self.has_plugin(name): @@ -595,30 +602,26 @@ def rate_limit_info( if trigger.admin or rule.is_unblockable(): return False, None + nick = trigger.nick is_channel = trigger.sender and not trigger.sender.is_nick() channel = trigger.sender if is_channel else None at_time = trigger.time - - user_metrics = rule.get_user_metrics(trigger.nick) - channel_metrics = rule.get_channel_metrics(channel) - global_metrics = rule.get_global_metrics() - - if user_metrics.is_limited(at_time - rule.user_rate_limit): + if rule.is_user_rate_limited(nick, at_time): template = rule.user_rate_template rate_limit_type = "user" rate_limit = rule.user_rate_limit - metrics = user_metrics - elif is_channel and channel_metrics.is_limited(at_time - rule.channel_rate_limit): + metrics = rule.get_user_metrics(nick) + elif channel and rule.is_channel_rate_limited(channel, at_time): template = rule.channel_rate_template rate_limit_type = "channel" rate_limit = rule.channel_rate_limit - metrics = channel_metrics - elif global_metrics.is_limited(at_time - rule.global_rate_limit): + metrics = rule.get_channel_metrics(channel) + elif rule.is_global_rate_limited(at_time): template = rule.global_rate_template rate_limit_type = "global" rate_limit = rule.global_rate_limit - metrics = global_metrics + metrics = rule.get_global_metrics() else: return False, None @@ -993,12 +996,11 @@ def on_scheduler_error( self, scheduler: plugin_jobs.Scheduler, exc: BaseException, - ): + ) -> None: """Called when the Job Scheduler fails. :param scheduler: the job scheduler that errored - :type scheduler: :class:`sopel.plugins.jobs.Scheduler` - :param Exception exc: the raised exception + :param exc: the raised exception .. seealso:: @@ -1011,14 +1013,12 @@ def on_job_error( scheduler: plugin_jobs.Scheduler, job: tools_jobs.Job, exc: BaseException, - ): + ) -> None: """Called when a job from the Job Scheduler fails. :param scheduler: the job scheduler responsible for the errored ``job`` - :type scheduler: :class:`sopel.plugins.jobs.Scheduler` :param job: the Job that errored - :type job: :class:`sopel.tools.jobs.Job` - :param Exception exc: the raised exception + :param exc: the raised exception .. seealso:: @@ -1030,13 +1030,11 @@ def error( self, trigger: Optional[Trigger] = None, exception: Optional[BaseException] = None, - ): + ) -> None: """Called internally when a plugin causes an error. - :param trigger: the ``Trigger``\\ing line (if available) - :type trigger: :class:`sopel.trigger.Trigger` - :param Exception exception: the exception raised by the error (if - available) + :param trigger: the IRC line that caused the error (if available) + :param exception: the exception raised by the error (if available) """ message = 'Unexpected error' if exception: @@ -1056,7 +1054,7 @@ def error( def _host_blocked(self, host: str) -> bool: """Check if a hostname is blocked. - :param str host: the hostname to check + :param host: the hostname to check """ bad_masks = self.config.core.host_blocks for bad_mask in bad_masks: @@ -1071,7 +1069,7 @@ def _host_blocked(self, host: str) -> bool: def _nick_blocked(self, nick: str) -> bool: """Check if a nickname is blocked. - :param str nick: the nickname to check + :param nick: the nickname to check """ bad_nicks = self.config.core.nick_blocks for bad_nick in bad_nicks: diff --git a/sopel/builtins/calc.py b/sopel/builtins/calc.py index 50816e657..95791d5f7 100644 --- a/sopel/builtins/calc.py +++ b/sopel/builtins/calc.py @@ -42,8 +42,7 @@ def c(bot, trigger): # Account for the silly non-Anglophones and their silly radix point. eqn = trigger.group(2).replace(',', '.') try: - result = eval_equation(eqn) - result = "{:.10g}".format(result) + result = "{:.10g}".format(eval_equation(eqn)) except eval_equation.Error as err: bot.reply("Can't process expression: {}".format(str(err))) return diff --git a/sopel/builtins/dice.py b/sopel/builtins/dice.py index cc8f88bbd..f1eb00553 100644 --- a/sopel/builtins/dice.py +++ b/sopel/builtins/dice.py @@ -244,7 +244,7 @@ def _roll_dice(dice_match: re.Match[str]) -> DicePouch: @plugin.example(".roll 2d10+3", user_help=True) @plugin.example(".roll 1d6", user_help=True) @plugin.output_prefix('[dice] ') -def roll(bot: SopelWrapper, trigger: Trigger): +def roll(bot: SopelWrapper, trigger: Trigger) -> None: """Rolls dice and reports the result. The dice roll follows this format: XdY[vZ][+N][#COMMENT] diff --git a/sopel/builtins/find.py b/sopel/builtins/find.py index db6528252..67721d593 100644 --- a/sopel/builtins/find.py +++ b/sopel/builtins/find.py @@ -121,11 +121,11 @@ def kick_cleanup(bot, trigger): [:,]\s+)? # Followed by optional colon/comma and whitespace s(?P/) # The literal s and a separator / as group 2 (?P # Group 3 is the thing to find - (?:\\/|[^/])+ # One or more non-slashes or escaped slashes + (?:\\\\|\\/|[^/])+ # One or more non-slashes or escaped slashes ) / # The separator again (?P # Group 4 is what to replace with - (?:\\/|[^/])* # One or more non-slashes or escaped slashes + (?:\\\\|\\/|[^/])* # One or more non-slashes or escaped slashes ) (?:/ # Optional separator followed by group 5 (flags) (?P\S+) @@ -136,11 +136,11 @@ def kick_cleanup(bot, trigger): [:,]\s+)? # Followed by optional colon/comma and whitespace s(?P\|) # The literal s and a separator | as group 2 (?P # Group 3 is the thing to find - (?:\\\||[^|])+ # One or more non-pipe or escaped pipe + (?:\\\\|\\\||[^|])+ # One or more non-pipe or escaped pipe ) \| # The separator again (?P # Group 4 is what to replace with - (?:\\\||[^|])* # One or more non-pipe or escaped pipe + (?:\\\\|\\\||[^|])* # One or more non-pipe or escaped pipe ) (?:\| # Optional separator followed by group 5 (flags) (?P\S+) @@ -161,14 +161,16 @@ def findandreplace(bot, trigger): return sep = trigger.group('sep') - old = trigger.group('old').replace('\\%s' % sep, sep) + escape_sequence_pattern = re.compile(r'\\[\\%s]' % sep) + + old = escape_sequence_pattern.sub(decode_escape, trigger.group('old')) new = trigger.group('new') me = False # /me command flags = trigger.group('flags') or '' # only clean/format the new string if it's non-empty if new: - new = bold(new.replace('\\%s' % sep, sep)) + new = escape_sequence_pattern.sub(decode_escape, new) # If g flag is given, replace all. Otherwise, replace once. if 'g' in flags: @@ -181,39 +183,49 @@ def findandreplace(bot, trigger): if 'i' in flags: regex = re.compile(re.escape(old), re.U | re.I) - def repl(s): - return re.sub(regex, new, s, count == 1) + def repl(line, subst): + return re.sub(regex, subst, line, count == 1) else: - def repl(s): - return s.replace(old, new, count) + def repl(line, subst): + return line.replace(old, subst, count) # Look back through the user's lines in the channel until you find a line # where the replacement works - new_phrase = None + new_line = new_display = None for line in history: if line.startswith("\x01ACTION"): me = True # /me command line = line[8:] else: me = False - replaced = repl(line) + replaced = repl(line, new) if replaced != line: # we are done - new_phrase = replaced + new_line = replaced + new_display = repl(line, bold(new)) break - if not new_phrase: + if not new_line: return # Didn't find anything # Save the new "edited" message. action = (me and '\x01ACTION ') or '' # If /me message, prepend \x01ACTION - history.appendleft(action + new_phrase) # history is in most-recent-first order + history.appendleft(action + new_line) # history is in most-recent-first order # output if not me: - new_phrase = 'meant to say: %s' % new_phrase + new_display = 'meant to say: %s' % new_display if trigger.group(1): - phrase = '%s thinks %s %s' % (trigger.nick, rnick, new_phrase) + msg = '%s thinks %s %s' % (trigger.nick, rnick, new_display) else: - phrase = '%s %s' % (trigger.nick, new_phrase) + msg = '%s %s' % (trigger.nick, new_display) + + bot.say(msg) + - bot.say(phrase) +def decode_escape(match): + print("Substituting %s" % match.group(0)) + return { + r'\\': '\\', + r'\|': '|', + r'\/': '/', + }[match.group(0)] diff --git a/sopel/builtins/safety.py b/sopel/builtins/safety.py index 0522f54bf..bf47e7d6f 100644 --- a/sopel/builtins/safety.py +++ b/sopel/builtins/safety.py @@ -56,7 +56,7 @@ class SafetySection(types.StaticSection): """Optional hosts-file formatted domain blocklist to use instead of StevenBlack's.""" -def configure(settings: Config): +def configure(settings: Config) -> None: """ | name | example | purpose | | ---- | ------- | ------- | @@ -90,7 +90,7 @@ def configure(settings: Config): ) -def setup(bot: Sopel): +def setup(bot: Sopel) -> None: bot.settings.define_section("safety", SafetySection) if bot.settings.safety.default_mode is None: @@ -166,7 +166,7 @@ def download_domain_list(bot: Sopel, path: str) -> bool: return True -def update_local_cache(bot: Sopel, init: bool = False): +def update_local_cache(bot: Sopel, init: bool = False) -> None: """Download the current malware domain list and load it into memory. :param init: Load the file even if it's unchanged @@ -202,7 +202,7 @@ def update_local_cache(bot: Sopel, init: bool = False): bot.memory[SAFETY_CACHE_LOCAL_KEY] = unsafe_domains -def shutdown(bot: Sopel): +def shutdown(bot: Sopel) -> None: bot.memory.pop(SAFETY_CACHE_KEY, None) bot.memory.pop(SAFETY_CACHE_LOCAL_KEY, None) bot.memory.pop(SAFETY_CACHE_LOCK_KEY, None) @@ -211,7 +211,7 @@ def shutdown(bot: Sopel): @plugin.rule(r'(?u).*(https?://\S+).*') @plugin.priority('high') @plugin.output_prefix(PLUGIN_OUTPUT_PREFIX) -def url_handler(bot: SopelWrapper, trigger: Trigger): +def url_handler(bot: SopelWrapper, trigger: Trigger) -> None: """Checks for malicious URLs.""" mode = bot.db.get_channel_value( trigger.sender, @@ -365,7 +365,7 @@ def virustotal_lookup( @plugin.example(".virustotal https://malware.wicar.org/") @plugin.example(".virustotal hxxps://malware.wicar.org/") @plugin.output_prefix("[safety][VirusTotal] ") -def vt_command(bot: SopelWrapper, trigger: Trigger): +def vt_command(bot: SopelWrapper, trigger: Trigger) -> None: """Look up VT results on demand.""" if not bot.settings.safety.vt_api_key: bot.reply("Sorry, I don't have a VirusTotal API key configured.") @@ -421,7 +421,7 @@ def vt_command(bot: SopelWrapper, trigger: Trigger): @plugin.command('safety') @plugin.example(".safety on") @plugin.output_prefix(PLUGIN_OUTPUT_PREFIX) -def toggle_safety(bot: SopelWrapper, trigger: Trigger): +def toggle_safety(bot: SopelWrapper, trigger: Trigger) -> None: """Set safety setting for channel.""" if not trigger.admin and bot.channels[trigger.sender].privileges[trigger.nick] < plugin.OP: bot.reply('Only channel operators can change safety settings') @@ -455,7 +455,7 @@ def toggle_safety(bot: SopelWrapper, trigger: Trigger): # Clean the cache every day # Code above also calls this if there are too many cache entries @plugin.interval(24 * 60 * 60) -def _clean_cache(bot: Sopel): +def _clean_cache(bot: Sopel) -> None: """Cleans up old entries in URL safety cache.""" update_local_cache(bot) diff --git a/sopel/builtins/units.py b/sopel/builtins/units.py index 02316339a..31b61e634 100644 --- a/sopel/builtins/units.py +++ b/sopel/builtins/units.py @@ -9,10 +9,16 @@ from __future__ import annotations import re +from typing import Pattern, TYPE_CHECKING from sopel import plugin +if TYPE_CHECKING: + from sopel.bot import SopelWrapper + from sopel.trigger import Trigger + + PLUGIN_OUTPUT_PREFIX = '[units] ' find_temp = re.compile(r'(-?[0-9]*\.?[0-9]*)[ °]*(K|C|F)', re.IGNORECASE) @@ -20,23 +26,23 @@ find_mass = re.compile(r'([0-9]*\.?[0-9]*)[ ]*(lb|lbm|pound[s]?|ounce|oz|(?:kilo|)gram(?:me|)[s]?|[k]?g)', re.IGNORECASE) -def f_to_c(temp): +def f_to_c(temp: float) -> float: return (float(temp) - 32) * 5 / 9 -def c_to_k(temp): +def c_to_k(temp: float) -> float: return temp + 273.15 -def c_to_f(temp): +def c_to_f(temp: float) -> float: return (9.0 / 5.0 * temp + 32) -def k_to_c(temp): +def k_to_c(temp: float) -> float: return temp - 273.15 -def _extract_source(pattern, trigger) -> tuple[str, ...]: +def _extract_source(pattern: Pattern, trigger: Trigger) -> tuple[str, ...]: match = pattern.match(trigger.group(2)) if match: return match.groups() @@ -49,7 +55,7 @@ def _extract_source(pattern, trigger) -> tuple[str, ...]: @plugin.example('.temp 100C', '100.00°C = 212.00°F = 373.15K') @plugin.example('.temp 100K', '-173.15°C = -279.67°F = 100.00K') @plugin.output_prefix(PLUGIN_OUTPUT_PREFIX) -def temperature(bot, trigger): +def temperature(bot: SopelWrapper, trigger: Trigger) -> int | None: """Convert temperatures""" try: source = _extract_source(find_temp, trigger) @@ -71,7 +77,7 @@ def temperature(bot, trigger): if kelvin <= 0: bot.reply("Physically impossible temperature.") - return + return None bot.say("{:.2f}°C = {:.2f}°F = {:.2f}K".format( celsius, @@ -79,6 +85,8 @@ def temperature(bot, trigger): kelvin, )) + return None + @plugin.command('length', 'distance') @plugin.example('.distance 3m', '3.00m = 9 feet, 10.11 inches') @@ -92,7 +100,7 @@ def temperature(bot, trigger): @plugin.example('.length 3 au', '448793612.10km = 278867421.71 miles') @plugin.example('.length 3 parsec', '92570329129020.20km = 57520535754731.61 miles') @plugin.output_prefix(PLUGIN_OUTPUT_PREFIX) -def distance(bot, trigger): +def distance(bot: SopelWrapper, trigger: Trigger) -> int | None: """Convert distances""" try: source = _extract_source(find_length, trigger) @@ -160,10 +168,12 @@ def distance(bot, trigger): bot.say('{} = {}'.format(metric_part, stupid_part)) + return None + @plugin.command('weight', 'mass') @plugin.output_prefix(PLUGIN_OUTPUT_PREFIX) -def mass(bot, trigger): +def mass(bot: SopelWrapper, trigger: Trigger) -> int | None: """Convert mass""" try: source = _extract_source(find_mass, trigger) @@ -199,3 +209,5 @@ def mass(bot, trigger): stupid_part = '{:.2f} oz'.format(ounce) bot.say('{} = {}'.format(metric_part, stupid_part)) + + return None diff --git a/sopel/builtins/url.py b/sopel/builtins/url.py index 7383e0211..941dae7ab 100644 --- a/sopel/builtins/url.py +++ b/sopel/builtins/url.py @@ -76,7 +76,7 @@ class UrlSection(types.StaticSection): """Enable requests to private and local network IP addresses""" -def configure(config: Config): +def configure(config: Config) -> None: """ | name | example | purpose | | ---- | ------- | ------- | @@ -111,7 +111,7 @@ def configure(config: Config): ) -def setup(bot: Sopel): +def setup(bot: Sopel) -> None: bot.config.define_section('url', UrlSection) if bot.config.url.exclude: @@ -140,7 +140,7 @@ def setup(bot: Sopel): bot.memory['shortened_urls'] = tools.SopelMemory() -def shutdown(bot: Sopel): +def shutdown(bot: Sopel) -> None: # Unset `url_exclude` and `last_seen_url`, but not `shortened_urls`; # clearing `shortened_urls` will increase API calls. Leaving it in memory # should not lead to unexpected behavior. @@ -151,7 +151,7 @@ def shutdown(bot: Sopel): pass -def _user_can_change_excludes(bot: SopelWrapper, trigger: Trigger): +def _user_can_change_excludes(bot: SopelWrapper, trigger: Trigger) -> bool: if trigger.admin: return True @@ -169,7 +169,7 @@ def _user_can_change_excludes(bot: SopelWrapper, trigger: Trigger): @plugin.example('.urlpexclude example\\.com/\\w+', user_help=True) @plugin.example('.urlexclude example.com/path', user_help=True) @plugin.output_prefix('[url] ') -def url_ban(bot: SopelWrapper, trigger: Trigger): +def url_ban(bot: SopelWrapper, trigger: Trigger) -> None: """Exclude a URL from auto title. Use ``urlpexclude`` to exclude a pattern instead of a URL. @@ -220,7 +220,7 @@ def url_ban(bot: SopelWrapper, trigger: Trigger): @plugin.example('.urlpallow example\\.com/\\w+', user_help=True) @plugin.example('.urlallow example.com/path', user_help=True) @plugin.output_prefix('[url] ') -def url_unban(bot: SopelWrapper, trigger: Trigger): +def url_unban(bot: SopelWrapper, trigger: Trigger) -> None: """Allow a URL for auto title. Use ``urlpallow`` to allow a pattern instead of a URL. @@ -273,7 +273,7 @@ def url_unban(bot: SopelWrapper, trigger: Trigger): 'Google | www.google.com', online=True, vcr=True) @plugin.output_prefix('[url] ') -def title_command(bot: SopelWrapper, trigger: Trigger): +def title_command(bot: SopelWrapper, trigger: Trigger) -> None: """ Show the title or URL information for the given URL, or the last URL seen in this channel. @@ -313,7 +313,7 @@ def title_command(bot: SopelWrapper, trigger: Trigger): @plugin.rule(r'(?u).*(https?://\S+).*') @plugin.output_prefix('[url] ') -def title_auto(bot: SopelWrapper, trigger: Trigger): +def title_auto(bot: SopelWrapper, trigger: Trigger) -> None: """ Automatically show titles for URLs. For shortened URLs/redirects, find where the URL redirects to and show the title for that. @@ -437,7 +437,10 @@ def process_urls( except ValueError: # Extra try/except here in case the DNS resolution fails, see #2348 try: - ips = [ip_address(ip) for ip in dns.resolver.resolve(parsed_url.hostname)] + ips = [ + ip_address(ip.to_text()) + for ip in dns.resolver.resolve(parsed_url.hostname) + ] except Exception as exc: LOGGER.debug( "Cannot resolve hostname %s, ignoring URL %s" @@ -472,7 +475,11 @@ def process_urls( yield URLInfo(url, title, parsed_url.hostname, tinyurl, False) -def check_callbacks(bot: SopelWrapper, url: str, use_excludes: bool = True) -> bool: +def check_callbacks( + bot: SopelWrapper, + url: str, + use_excludes: bool = True, +) -> bool: """Check if ``url`` is excluded or matches any URL callback patterns. :param bot: Sopel instance diff --git a/sopel/cli/plugins.py b/sopel/cli/plugins.py index 01d09dbc5..78a07b3cc 100644 --- a/sopel/cli/plugins.py +++ b/sopel/cli/plugins.py @@ -283,6 +283,7 @@ def handle_show(options): }) print('Plugin:', description['name']) + print('Version:', description['version'] or 'unknown') print('Status:', description['status']) print('Type:', description['type']) print('Source:', description['source']) diff --git a/sopel/config/types.py b/sopel/config/types.py index 73f7a8b33..569995eca 100644 --- a/sopel/config/types.py +++ b/sopel/config/types.py @@ -649,7 +649,10 @@ def parse(self, value): if value in self.choices: return value else: - raise ValueError('Value must be in {}'.format(self.choices)) + raise ValueError( + '{!r} is not one of the valid choices: {}' + .format(value, ', '.join(self.choices)) + ) def serialize(self, value): """Make sure ``value`` is valid and safe to write in the config file. @@ -662,7 +665,10 @@ def serialize(self, value): if value in self.choices: return value else: - raise ValueError('Value must be in {}'.format(self.choices)) + raise ValueError( + '{!r} is not one of the valid choices: {}' + .format(value, ', '.join(self.choices)) + ) class FilenameAttribute(BaseValidated): diff --git a/sopel/coretasks.py b/sopel/coretasks.py index f10ebdbe2..93b11dbd2 100644 --- a/sopel/coretasks.py +++ b/sopel/coretasks.py @@ -161,7 +161,7 @@ def _handle_sasl_capability( CAP_SASL = plugin.capability('sasl', handler=_handle_sasl_capability) -def setup(bot: Sopel): +def setup(bot: Sopel) -> None: """Set up the coretasks plugin. The setup phase is used to activate the throttle feature to prevent a flood @@ -1261,7 +1261,7 @@ def _make_sasl_plain_token(account, password): @plugin.thread(False) @plugin.unblockable @plugin.priority('medium') -def sasl_success(bot: SopelWrapper, trigger: Trigger): +def sasl_success(bot: SopelWrapper, trigger: Trigger) -> None: """Resume capability negotiation on successful SASL auth.""" LOGGER.info("Successful SASL Auth.") bot.resume_capability_negotiation(CAP_SASL.cap_req, 'coretasks') @@ -1514,7 +1514,7 @@ def _record_who( away: Optional[bool] = None, is_bot: Optional[bool] = None, modes: Optional[str] = None, -): +) -> None: nick = bot.make_identifier(nick) channel = bot.make_identifier(channel) if nick not in bot.users: diff --git a/sopel/db.py b/sopel/db.py index 8deaa8035..101300351 100644 --- a/sopel/db.py +++ b/sopel/db.py @@ -34,6 +34,8 @@ if typing.TYPE_CHECKING: from collections.abc import Iterable + from sopel.config import Config + LOGGER = logging.getLogger(__name__) @@ -142,7 +144,7 @@ class SopelDB: def __init__( self, - config, + config: Config, identifier_factory: IdentifierFactory = Identifier, ) -> None: self.make_identifier: IdentifierFactory = identifier_factory @@ -628,7 +630,7 @@ def forget_nick_group(self, nick: str) -> None: def delete_nick_group(self, nick: str) -> None: # pragma: nocover self.forget_nick_group(nick) - def merge_nick_groups(self, first_nick: str, second_nick: str): + def merge_nick_groups(self, first_nick: str, second_nick: str) -> None: """Merge two nick groups. :param first_nick: one nick in the first group to merge @@ -788,7 +790,7 @@ def get_channel_value( channel: str, key: str, default: typing.Optional[typing.Any] = None, - ): + ) -> typing.Any: """Get a value from the key-value store for ``channel``. :param channel: the channel whose values to access @@ -980,8 +982,8 @@ def get_nick_or_channel_value( self, name: str, key: str, - default=None - ) -> typing.Optional[typing.Any]: + default: typing.Any | None = None + ) -> typing.Any | None: """Get a value from the key-value store for ``name``. :param name: nick or channel whose values to access diff --git a/sopel/irc/__init__.py b/sopel/irc/__init__.py index e2a3b1fa8..181940db9 100644 --- a/sopel/irc/__init__.py +++ b/sopel/irc/__init__.py @@ -39,7 +39,6 @@ import time from typing import ( Any, - Optional, TYPE_CHECKING, ) @@ -71,20 +70,23 @@ def __init__(self, settings: Config): self._name: str = settings.core.name self._isupport = ISupport() self._capabilities = Capabilities() - self._myinfo: Optional[MyInfo] = None + self._myinfo: MyInfo | None = None self._nick: identifiers.Identifier = self.make_identifier( settings.core.nick) self.backend: AbstractIRCBackend = UninitializedBackend(self) """IRC Connection Backend.""" self._connection_registered = threading.Event() - """Flag stating whether the IRC Connection is registered yet.""" + """Flag stating whether the IRC connection is registered yet.""" self.settings = settings - """Bot settings.""" + """The bot's settings. + + .. versionadded:: 7.0 + """ # internal machinery self.sending = threading.RLock() - self.last_error_timestamp: Optional[datetime] = None + self.last_error_timestamp: datetime | None = None self.error_count = 0 self.stack: dict[identifiers.Identifier, dict[str, Any]] = {} self.hasquit = False @@ -136,7 +138,10 @@ def config(self) -> Config: @property def capabilities(self) -> Capabilities: - """Capabilities negotiated with the server.""" + """Capabilities negotiated with the server. + + .. versionadded:: 8.0 + """ return self._capabilities @property @@ -174,7 +179,7 @@ def enabled_capabilities(self) -> set[str]: warning_in='8.1', removed_in='9.0', ) - def server_capabilities(self) -> dict[str, Optional[str]]: + def server_capabilities(self) -> dict[str, str | None]: """A dict mapping supported IRCv3 capabilities to their options. For example, if the server specifies the capability ``sasl=EXTERNAL``, @@ -203,7 +208,7 @@ def server_capabilities(self) -> dict[str, Optional[str]]: def isupport(self) -> ISupport: """Features advertised by the server. - :type: :class:`~.isupport.ISupport` instance + .. versionadded:: 7.0 """ return self._isupport @@ -211,8 +216,6 @@ def isupport(self) -> ISupport: def myinfo(self) -> MyInfo: """Server/network information. - :type: :class:`~.utils.MyInfo` instance - .. versionadded:: 7.0 """ if self._myinfo is None: @@ -221,7 +224,7 @@ def myinfo(self) -> MyInfo: @property @abc.abstractmethod - def hostmask(self) -> Optional[str]: + def hostmask(self) -> str | None: """The bot's hostmask.""" # Utility @@ -293,6 +296,8 @@ def safe_text_length(self, recipient: str) -> int: can be sent using ``PRIVMSG`` or ``NOTICE`` by subtracting the size required by the server to convey the bot's message. + .. versionadded:: 8.0 + .. seealso:: This method is useful when sending a message using :meth:`say`, @@ -337,13 +342,11 @@ def get_irc_backend( self, host: str, port: int, - source_address: Optional[tuple[str, int]], + source_address: tuple[str, int] | None, ) -> AbstractIRCBackend: """Set up the IRC backend based on the bot's settings. :return: the initialized IRC backend object - :rtype: an object implementing the interface of - :class:`~sopel.irc.abstract_backends.AbstractIRCBackend` """ timeout = int(self.settings.core.timeout) ping_interval = int(self.settings.core.timeout_ping_interval) @@ -369,8 +372,8 @@ def get_irc_backend( def run(self, host: str, port: int = 6667) -> None: """Connect to IRC server and run the bot forever. - :param str host: the IRC server hostname - :param int port: the IRC server port + :param host: the IRC server hostname + :param port: the IRC server port """ source_address = ((self.settings.core.bind_host, 0) if self.settings.core.bind_host else None) @@ -412,7 +415,7 @@ def on_connect(self) -> None: def on_message(self, message: str) -> None: """Handle an incoming IRC message. - :param str message: the received raw IRC message + :param message: the received raw IRC message """ if self.backend is None: raise RuntimeError(ERR_BACKEND_NOT_INITIALIZED) @@ -443,7 +446,7 @@ def on_message(self, message: str) -> None: def on_message_sent(self, raw: str) -> None: """Handle any message sent through the connection. - :param str raw: raw text message sent through the connection + :param raw: raw text message sent through the connection When a message is sent through the IRC connection, the bot will log the raw message. If necessary, it will also simulate the @@ -519,13 +522,17 @@ def rebuild_nick(self) -> None: This method exists to update the casemapping rules for the :class:`~sopel.tools.identifiers.Identifier` that represents the bot's nick, e.g. after ISUPPORT info is received. + + .. versionadded:: 8.0 """ self._nick = self.make_identifier(str(self._nick)) def change_current_nick(self, new_nick: str) -> None: """Change the current nick without configuration modification. - :param str new_nick: new nick to be used by the bot + :param new_nick: new nick to be used by the bot + + .. versionadded:: 7.1 """ if self.backend is None: raise RuntimeError(ERR_BACKEND_NOT_INITIALIZED) @@ -548,11 +555,10 @@ def _shutdown(self) -> None: # Features @abc.abstractmethod - def dispatch(self, pretrigger: trigger.PreTrigger): + def dispatch(self, pretrigger: trigger.PreTrigger) -> None: """Handle running the appropriate callables for an incoming message. :param pretrigger: Sopel PreTrigger object - :type pretrigger: :class:`sopel.trigger.PreTrigger` .. important:: This method **MUST** be implemented by concrete subclasses. @@ -561,8 +567,8 @@ def dispatch(self, pretrigger: trigger.PreTrigger): def log_raw(self, line: str, prefix: str) -> None: """Log raw line to the raw log. - :param str line: the raw line - :param str prefix: additional information to prepend to the log line + :param line: the raw line + :param prefix: additional information to prepend to the log line The ``prefix`` is usually either ``>>`` for an outgoing ``line`` or ``<<`` for a received one. @@ -572,7 +578,7 @@ def log_raw(self, line: str, prefix: str) -> None: logger = logging.getLogger('sopel.raw') logger.info("%s\t%r", prefix, line) - def write(self, args: Iterable[str], text: Optional[str] = None) -> None: + def write(self, args: Iterable[str], text: str | None = None) -> None: """Send a command to the server. :param args: an iterable of strings, which will be joined by spaces @@ -611,19 +617,19 @@ def write(self, args: Iterable[str], text: Optional[str] = None) -> None: def action(self, text: str, dest: str) -> None: """Send a CTCP ACTION PRIVMSG to a user or channel. - :param str text: the text to send in the CTCP ACTION - :param str dest: the destination of the CTCP ACTION + :param text: the text to send in the CTCP ACTION + :param dest: the destination of the CTCP ACTION The same loop detection and length restrictions apply as with :func:`say`, though automatic message splitting is not available. """ self.say('\001ACTION {}\001'.format(text), dest) - def join(self, channel: str, password: Optional[str] = None) -> None: + def join(self, channel: str, password: str | None = None) -> None: """Join a ``channel``. - :param str channel: the channel to join - :param str password: an optional channel password + :param channel: the channel to join + :param password: an optional channel password If ``channel`` contains a space, and no ``password`` is given, the space is assumed to split the argument into the channel to join and its @@ -639,7 +645,7 @@ def kick( self, nick: str, channel: str, - text: Optional[str] = None, + text: str | None = None, ) -> None: """Kick a ``nick`` from a ``channel``. @@ -667,7 +673,7 @@ def notice(self, text: str, dest: str) -> None: self.backend.send_notice(dest, text) - def part(self, channel: str, msg: Optional[str] = None) -> None: + def part(self, channel: str, msg: str | None = None) -> None: """Leave a channel. :param channel: the channel to leave @@ -678,7 +684,7 @@ def part(self, channel: str, msg: Optional[str] = None) -> None: self.backend.send_part(channel, reason=msg) - def quit(self, message: Optional[str] = None) -> None: + def quit(self, message: str | None = None) -> None: """Disconnect from IRC and close the bot. :param message: optional QUIT message to send (e.g. "Bye!") @@ -697,7 +703,7 @@ def quit(self, message: Optional[str] = None) -> None: # problematic because whomever called quit might still want to do # something before the main thread quits. - def restart(self, message: Optional[str] = None) -> None: + def restart(self, message: str | None = None) -> None: """Disconnect from IRC and restart the bot. :param message: optional QUIT message to send (e.g. "Be right back!") diff --git a/sopel/irc/backends.py b/sopel/irc/backends.py index 5c5bf9e9d..afa0a80b9 100644 --- a/sopel/irc/backends.py +++ b/sopel/irc/backends.py @@ -20,7 +20,7 @@ import socket import ssl import threading -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from .abstract_backends import AbstractIRCBackend @@ -158,7 +158,7 @@ def __init__( ca_certs: Optional[str] = None, ssl_ciphers: Optional[list[str]] = None, ssl_minimum_version: ssl.TLSVersion = ssl.TLSVersion.TLSv1_2, - **kwargs, + **kwargs: Any, ): super().__init__(bot) # connection parameters @@ -379,7 +379,7 @@ def get_connection_kwargs(self) -> dict: } async def _connect_to_server( - self, **connection_kwargs + self, **connection_kwargs: Any, ) -> tuple[ Optional[asyncio.StreamReader], Optional[asyncio.StreamWriter], diff --git a/sopel/irc/modes.py b/sopel/irc/modes.py index 694eb0d45..fb1ec74ac 100755 --- a/sopel/irc/modes.py +++ b/sopel/irc/modes.py @@ -100,7 +100,7 @@ class ModeException(Exception): class ModeTypeUnknown(ModeException): """Exception when a mode's type is unknown or cannot be determined.""" - def __init__(self, mode) -> None: + def __init__(self, mode: str) -> None: super().__init__('Unknown type for mode %s' % mode) diff --git a/sopel/irc/utils.py b/sopel/irc/utils.py index 12ba09461..a812f9687 100644 --- a/sopel/irc/utils.py +++ b/sopel/irc/utils.py @@ -17,27 +17,32 @@ from sopel.lifecycle import deprecated -def safe(string): - """Remove newlines from a string. +def safe(string: str) -> str: + """Remove disallowed bytes from a string, and ensure Unicode. - :param str string: input text to process - :return: the string without newlines - :rtype: str + :param string: input text to process + :return: the string as Unicode without characters prohibited in IRC messages :raises TypeError: when ``string`` is ``None`` - This function removes newlines from a string and always returns a unicode - string (``str``), but doesn't strip or alter it in any other way:: + This function removes newlines and null bytes from a string. It will always + return a Unicode ``str``, even if given non-Unicode input, but doesn't strip + or alter the string in any other way:: - >>> safe('some text\\r\\n') + >>> safe('some \\x00text\\r\\n') 'some text' - This is useful to ensure a string can be used in a IRC message. + This is useful to ensure a string can be used in a IRC message. Parameters + can **never** contain NUL, CR, or LF octets, per :rfc:`2812#section-2.3.1`. .. versionchanged:: 7.1 This function now raises a :exc:`TypeError` instead of an unpredictable behaviour when given ``None``. + .. versionchanged:: 8.0.1 + + Also remove NUL (``\\x00``) in addition to CR/LF. + """ if string is None: raise TypeError('safe function requires a string, not NoneType') @@ -45,6 +50,7 @@ def safe(string): string = string.decode("utf8") string = string.replace('\n', '') string = string.replace('\r', '') + string = string.replace('\x00', '') return string diff --git a/sopel/lifecycle.py b/sopel/lifecycle.py index fbd897bad..27f6565d9 100644 --- a/sopel/lifecycle.py +++ b/sopel/lifecycle.py @@ -13,7 +13,7 @@ import inspect import logging import traceback -from typing import Callable, Optional +from typing import Callable from packaging.version import parse as parse_version @@ -21,13 +21,13 @@ def deprecated( - reason: Optional[str] = None, - version: Optional[str] = None, - removed_in: Optional[str] = None, - warning_in: Optional[str] = None, + reason: str | Callable | None = None, + version: str | None = None, + removed_in: str | None = None, + warning_in: str | None = None, stack_frame: int = -1, - func: Optional[Callable] = None, -): + func: Callable | None = None, +) -> Callable: """Decorator to mark deprecated functions in Sopel's API :param reason: optional text added to the deprecation warning diff --git a/sopel/plugin.py b/sopel/plugin.py index ba4dfc72b..17d1f8cd4 100644 --- a/sopel/plugin.py +++ b/sopel/plugin.py @@ -1193,10 +1193,6 @@ def rate( @rate(10, 10, 2, message='Sorry {nick}, you hit the {rate_limit_type} rate limit!') - Rate-limited functions that use scheduled future commands should import - :class:`threading.Timer` instead of :mod:`sched`, or rate limiting will - not work properly. - .. versionchanged:: 8.0 Optional keyword argument ``message`` was added in Sopel 8. diff --git a/sopel/plugins/handlers.py b/sopel/plugins/handlers.py index 8a07cb3ad..eee1c056d 100644 --- a/sopel/plugins/handlers.py +++ b/sopel/plugins/handlers.py @@ -48,6 +48,7 @@ import importlib.util import inspect import itertools +import logging import os import sys from typing import Optional, TYPE_CHECKING, TypedDict @@ -61,6 +62,9 @@ from types import ModuleType +LOGGER = logging.getLogger(__name__) + + class PluginMetaDescription(TypedDict): """Meta description of a plugin, as a dictionary. @@ -97,6 +101,13 @@ class AbstractPluginHandler(abc.ABC): on shutdown (either upon exiting Sopel or unloading that plugin). """ + name: str + """Plugin identifier. + + The name of a plugin identifies this plugin: when Sopel loads a plugin, + it will store its information under that identifier. + """ + @abc.abstractmethod def load(self): """Load the plugin. @@ -471,7 +482,7 @@ def __init__(self, filename): spec = importlib.util.spec_from_file_location( name, os.path.join(filename, '__init__.py'), - submodule_search_locations=filename, + submodule_search_locations=[filename], ) else: raise exceptions.PluginError('Invalid Sopel plugin: %s' % filename) @@ -487,9 +498,9 @@ def __init__(self, filename): def _load(self): module = importlib.util.module_from_spec(self.module_spec) - sys.modules[self.name] = module if not self.module_spec.loader: raise exceptions.PluginError('Could not determine loader for plugin: %s' % self.filename) + sys.modules[self.name] = module self.module_spec.loader.exec_module(module) return module @@ -613,14 +624,20 @@ def get_version(self) -> Optional[str]: if ( version is None - and hasattr(self.module, "__package__") - and self.module.__package__ is not None + and hasattr(self.entry_point, "dist") + and hasattr(self.entry_point.dist, "name") ): + dist_name = self.entry_point.dist.name try: - version = importlib.metadata.version(self.module.__package__) - except ValueError: - # package name is probably empty-string; just give up - pass + version = importlib.metadata.version(dist_name) + except (ValueError, importlib.metadata.PackageNotFoundError): + LOGGER.warning("Cannot determine version of %r", dist_name) + except Exception: + LOGGER.warning( + "Unexpected error occurred while checking the version of %r", + dist_name, + exc_info=True, + ) return version diff --git a/sopel/plugins/rules.py b/sopel/plugins/rules.py index 76832b837..6276f27da 100644 --- a/sopel/plugins/rules.py +++ b/sopel/plugins/rules.py @@ -26,6 +26,7 @@ import threading from typing import ( Any, + Callable, Optional, Type, TYPE_CHECKING, @@ -39,7 +40,11 @@ if TYPE_CHECKING: from collections.abc import Generator, Iterable + + from sopel.bot import Sopel + from sopel.config import Config from sopel.tools.identifiers import Identifier + from sopel.trigger import PreTrigger __all__ = [ @@ -541,14 +546,16 @@ class AbstractRule(abc.ABC): """ @classmethod @abc.abstractmethod - def from_callable(cls: Type[TypedRule], settings, handler) -> TypedRule: + def from_callable( + cls: Type[TypedRule], + settings: Config, + handler: Callable, + ) -> TypedRule: """Instantiate a rule object from ``settings`` and ``handler``. :param settings: Sopel's settings - :type settings: :class:`sopel.config.Config` - :param callable handler: a function-based rule handler + :param handler: a function-based rule handler :return: an instance of this class created from the ``handler`` - :rtype: :class:`AbstractRule` Sopel's function-based rule handlers are simple callables, decorated with :mod:`sopel.plugin`'s decorators to add attributes, such as rate @@ -580,8 +587,6 @@ def priority_scale(self): def get_plugin_name(self) -> str: """Get the rule's plugin name. - :rtype: str - The rule's plugin name will be used in various places to select, register, unregister, and manipulate the rule based on its plugin, which is referenced by its name. @@ -591,8 +596,6 @@ def get_plugin_name(self) -> str: def get_rule_label(self) -> str: """Get the rule's label. - :rtype: str - A rule can have a label, which can identify the rule by string, the same way a plugin can be identified by its name. This label can be used to select, register, unregister, and manipulate the rule based on its @@ -603,8 +606,6 @@ def get_rule_label(self) -> str: def get_usages(self) -> tuple: """Get the rule's usage examples. - :rtype: tuple - A rule can have usage examples, i.e. a list of examples showing how the rule can be used, or in what context it can be triggered. """ @@ -613,8 +614,6 @@ def get_usages(self) -> tuple: def get_test_parameters(self) -> tuple: """Get parameters for automated tests. - :rtype: tuple - A rule can have automated tests attached to it, and this method must return the test parameters: @@ -633,8 +632,6 @@ def get_test_parameters(self) -> tuple: def get_doc(self) -> str: """Get the rule's documentation. - :rtype: str - A rule's documentation is a short text that can be displayed to a user on IRC upon asking for help about this rule. The equivalent of Python docstrings, but for IRC rules. @@ -644,8 +641,6 @@ def get_doc(self) -> str: def get_priority(self) -> str: """Get the rule's priority. - :rtype: str - A rule can have a priority, based on the three pre-defined priorities used by Sopel: ``PRIORITY_HIGH``, ``PRIORITY_MEDIUM``, and ``PRIORITY_LOW``. @@ -662,8 +657,6 @@ def get_priority(self) -> str: def get_output_prefix(self) -> str: """Get the rule's output prefix. - :rtype: str - .. seealso:: See the :class:`sopel.bot.SopelWrapper` class for more information @@ -671,13 +664,11 @@ def get_output_prefix(self) -> str: """ @abc.abstractmethod - def match(self, bot, pretrigger) -> Iterable: + def match(self, bot: Sopel, pretrigger: PreTrigger) -> Iterable: """Match a pretrigger according to the rule. :param bot: Sopel instance - :type bot: :class:`sopel.bot.Sopel` :param pretrigger: line to match - :type pretrigger: :class:`sopel.trigger.PreTrigger` This method must return a list of `match objects`__. @@ -685,12 +676,11 @@ def match(self, bot, pretrigger) -> Iterable: """ @abc.abstractmethod - def match_event(self, event) -> bool: + def match_event(self, event: str) -> bool: """Tell if the rule matches this ``event``. - :param str event: potential matching event + :param event: potential matching event :return: ``True`` when ``event`` matches the rule, ``False`` otherwise - :rtype: bool """ @abc.abstractmethod @@ -775,40 +765,49 @@ def global_rate_limit(self) -> datetime.timedelta: def is_user_rate_limited( self, nick: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: """Tell when the rule reached the ``nick``'s rate limit. :param nick: the nick associated with this check - :param at_time: optional aware datetime for the rate limit check; - if not given, ``utcnow`` will be used + :param at_time: aware datetime for the rate limit check :return: ``True`` when the rule reached the limit, ``False`` otherwise. + + .. versionchanged:: 8.0.1 + + Parameter ``at_time`` is now required. + """ @abc.abstractmethod def is_channel_rate_limited( self, channel: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: """Tell when the rule reached the ``channel``'s rate limit. :param channel: the channel associated with this check - :param at_time: optional aware datetime for the rate limit check; - if not given, ``utcnow`` will be used + :param at_time: aware datetime for the rate limit check :return: ``True`` when the rule reached the limit, ``False`` otherwise. + + .. versionchanged:: 8.0.1 + + Parameter ``at_time`` is now required. + """ @abc.abstractmethod - def is_global_rate_limited( - self, - at_time: Optional[datetime.datetime] = None, - ) -> bool: + def is_global_rate_limited(self, at_time: datetime.datetime) -> bool: """Tell when the rule reached the global rate limit. - :param at_time: optional aware datetime for the rate limit check; - if not given, ``utcnow`` will be used + :param at_time: aware datetime for the rate limit check :return: ``True`` when the rule reached the limit, ``False`` otherwise. + + .. versionchanged:: 8.0.1 + + Parameter ``at_time`` is now required. + """ @property @@ -845,7 +844,7 @@ def global_rate_template(self) -> Optional[str]: """ @abc.abstractmethod - def parse(self, text) -> Generator: + def parse(self, text: str) -> Generator: """Parse ``text`` and yield matches. :param str text: text to parse by the rule @@ -1046,7 +1045,7 @@ def __init__(self, self._handler = handler # filters - self._events = events or ['PRIVMSG'] + self._events: list[str] = events or ['PRIVMSG'] self._ctcp = ctcp or [] self._allow_bots = bool(allow_bots) self._allow_echo = bool(allow_echo) @@ -1171,10 +1170,10 @@ def parse(self, text): if result: yield result - def match_event(self, event) -> bool: + def match_event(self, event: str | None) -> bool: return bool(event and event in self._events) - def match_ctcp(self, command: Optional[str]) -> bool: + def match_ctcp(self, command: str | None) -> bool: if not self._ctcp: return True @@ -1219,29 +1218,29 @@ def global_rate_limit(self) -> datetime.timedelta: def is_user_rate_limited( self, nick: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: - if at_time is None: - at_time = datetime.datetime.now(datetime.timezone.utc) + if self._user_rate_limit <= 0: + return False + metrics = self.get_user_metrics(nick) return metrics.is_limited(at_time - self.user_rate_limit) def is_channel_rate_limited( self, channel: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: - if at_time is None: - at_time = datetime.datetime.now(datetime.timezone.utc) + if self._channel_rate_limit <= 0: + return False + metrics = self.get_channel_metrics(channel) return metrics.is_limited(at_time - self.channel_rate_limit) - def is_global_rate_limited( - self, - at_time: Optional[datetime.datetime] = None, - ) -> bool: - if at_time is None: - at_time = datetime.datetime.now(datetime.timezone.utc) + def is_global_rate_limited(self, at_time: datetime.datetime) -> bool: + if self._global_rate_limit <= 0: + return False + metrics = self.get_global_metrics() return metrics.is_limited(at_time - self.global_rate_limit) diff --git a/sopel/tests/pytest_plugin.py b/sopel/tests/pytest_plugin.py index 16a58c07a..2208a0d1c 100644 --- a/sopel/tests/pytest_plugin.py +++ b/sopel/tests/pytest_plugin.py @@ -18,6 +18,8 @@ nick = {name} owner = {owner} admin = {admin} +# avoid wasting cycles in time.sleep() during `repeat`ed tests +flood_max_wait = 0 """ diff --git a/sopel/tools/calculation.py b/sopel/tools/calculation.py index a1e85b653..b32068705 100644 --- a/sopel/tools/calculation.py +++ b/sopel/tools/calculation.py @@ -42,7 +42,11 @@ def __init__( self.binary_ops = bin_ops or {} self.unary_ops = unary_ops or {} - def __call__(self, expression_str: str, timeout: float = 5.0): + def __call__( + self, + expression_str: str, + timeout: float = 5.0, + ) -> int | float: """Evaluate a Python expression and return the result. :param expression_str: the expression to evaluate @@ -56,7 +60,7 @@ def __call__(self, expression_str: str, timeout: float = 5.0): ast_expression = ast.parse(expression_str, mode='eval') return self._eval_node(ast_expression.body, time.time() + timeout) - def _eval_node(self, node: ast.AST, timeout: float): + def _eval_node(self, node: ast.AST, timeout: float) -> float: """Recursively evaluate the given :class:`ast.Node `. :param node: the AST node to evaluate @@ -116,7 +120,7 @@ def _eval_node(self, node: ast.AST, timeout: float): ) -def guarded_mul(left: float, right: float): +def guarded_mul(left: float, right: float) -> float: """Multiply two values, guarding against overly large inputs. :param left: the left operand @@ -139,7 +143,7 @@ def guarded_mul(left: float, right: float): return operator.mul(left, right) -def pow_complexity(num: int, exp: int): +def pow_complexity(num: int, exp: int) -> float: """Estimate the worst case time :func:`pow` takes to calculate. :param num: base @@ -205,7 +209,7 @@ def pow_complexity(num: int, exp: int): return exp ** 1.590 * num.bit_length() ** 1.73 / 36864057619.3 -def guarded_pow(num: float, exp: float): +def guarded_pow(num: float, exp: float) -> float: """Raise a number to a power, guarding against overly large inputs. :param num: base @@ -255,7 +259,11 @@ def __init__(self): unary_ops=self.__unary_ops ) - def __call__(self, expression_str: str, timeout: float = 5.0): + def __call__( + self, + expression_str: str, + timeout: float = 5.0, + ) -> float: result = ExpressionEvaluator.__call__(self, expression_str, timeout) # This wrapper is here so additional sanity checks could be done diff --git a/sopel/tools/identifiers.py b/sopel/tools/identifiers.py index 6b0937a99..9297be111 100644 --- a/sopel/tools/identifiers.py +++ b/sopel/tools/identifiers.py @@ -97,9 +97,8 @@ def rfc1459_strict_lower(text: str) -> str: class Identifier(str): """A ``str`` subclass which acts appropriately for IRC identifiers. - :param str identifier: IRC identifier + :param identifier: IRC identifier :param casemapping: a casemapping function (optional keyword argument) - :type casemapping: Callable[[:class:`str`], :class:`str`] When used as normal ``str`` objects, case will be preserved. However, when comparing two Identifier objects, or comparing an Identifier @@ -162,12 +161,11 @@ def lower(self) -> str: return self.casemapping(self) @staticmethod - def _lower(identifier: str): + def _lower(identifier: str) -> str: """Convert an identifier to lowercase per :rfc:`2812`. - :param str identifier: the identifier (nickname or channel) to convert + :param identifier: the identifier (nickname or channel) to convert :return: RFC 2812-compliant lowercase version of ``identifier`` - :rtype: str :meta public: @@ -186,12 +184,11 @@ def _lower(identifier: str): return rfc1459_lower(identifier) @staticmethod - def _lower_swapped(identifier: str): + def _lower_swapped(identifier: str) -> str: """Backward-compatible version of :meth:`_lower`. :param identifier: the identifier (nickname or channel) to convert :return: RFC 2812-non-compliant lowercase version of ``identifier`` - :rtype: str This is what the old :meth:`_lower` function did before Sopel 7.0. It maps ``{}``, ``[]``, ``|``, ``\\``, ``^``, and ``~`` incorrectly. diff --git a/sopel/tools/memories.py b/sopel/tools/memories.py index 1fb2b0b2a..c23f53e6a 100644 --- a/sopel/tools/memories.py +++ b/sopel/tools/memories.py @@ -8,7 +8,7 @@ from collections import defaultdict import threading -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union from typing_extensions import override @@ -18,7 +18,10 @@ from collections.abc import Iterable, Mapping from typing import Tuple - MemoryConstructorInput = Union[Mapping[str, Any], Iterable[Tuple[str, Any]]] + MemoryConstructorInput = Union[ + Mapping[str, Any], + Iterable[Tuple[str, Any]], + ] class _NO_DEFAULT: @@ -176,7 +179,7 @@ def setup(bot): """ def __init__( self, - *args, + *args: MemoryConstructorInput, identifier_factory: IdentifierFactory = Identifier, ) -> None: if len(args) > 1: @@ -193,7 +196,7 @@ def __init__( else: super().__init__() - def _make_key(self, key: Optional[str]) -> Optional[Identifier]: + def _make_key(self, key: str | None) -> Identifier | None: if key is None: return None return self.make_identifier(key) @@ -221,19 +224,19 @@ def _convert_keys( # return converted input data return ((self.make_identifier(k), v) for k, v in data) - def __getitem__(self, key: Optional[str]): + def __getitem__(self, key: str | None) -> Any: return super().__getitem__(self._make_key(key)) - def __contains__(self, key): + def __contains__(self, key: Any) -> Any: return super().__contains__(self._make_key(key)) - def __setitem__(self, key: Optional[str], value): + def __setitem__(self, key: str | None, value: Any) -> None: super().__setitem__(self._make_key(key), value) - def setdefault(self, key: str, default=None): + def setdefault(self, key: str, default: Any = None) -> Any: return super().setdefault(self._make_key(key), default) - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: super().__delitem__(self._make_key(key)) def copy(self): @@ -243,7 +246,7 @@ def copy(self): """ return type(self)(self, identifier_factory=self.make_identifier) - def get(self, key: str, default=_NO_DEFAULT): + def get(self, key: str, default: Any = _NO_DEFAULT) -> Any: """Get the value of ``key`` from this ``SopelIdentifierMemory``. Takes an optional ``default`` value, just like :meth:`dict.get`. @@ -252,7 +255,7 @@ def get(self, key: str, default=_NO_DEFAULT): return super().get(self._make_key(key)) return super().get(self._make_key(key), default) - def pop(self, key: str, default=_NO_DEFAULT): + def pop(self, key: str, default: Any = _NO_DEFAULT) -> Any: """Pop the value of ``key`` from this ``SopelIdentifierMemory``. Takes an optional ``default`` value, just like :meth:`dict.pop`. diff --git a/test/builtins/test_builtins_find.py b/test/builtins/test_builtins_find.py new file mode 100644 index 000000000..f7daaa1f2 --- /dev/null +++ b/test/builtins/test_builtins_find.py @@ -0,0 +1,107 @@ +"""Tests for Sopel's ``find`` plugin""" +from __future__ import annotations + +import pytest + +from sopel.formatting import bold +from sopel.tests import rawlist + + +TMP_CONFIG = """ +[core] +owner = Admin +nick = Sopel +enable = + find +host = irc.libera.chat +""" + + +@pytest.fixture +def bot(botfactory, configfactory): + settings = configfactory('default.ini', TMP_CONFIG) + return botfactory.preloaded(settings, ['find']) + + +@pytest.fixture +def irc(bot, ircfactory): + return ircfactory(bot) + + +@pytest.fixture +def user(userfactory): + return userfactory('User') + + +@pytest.fixture +def other_user(userfactory): + return userfactory('other_user') + + +@pytest.fixture +def channel(): + return '#testing' + + +REPLACES_THAT_WORK = ( + ("A simple line.", r"s/line/message/", f"A simple {bold('message')}."), + ("An escaped / line.", r"s/\//slash/", f"An escaped {bold('slash')} line."), + ("A piped line.", r"s|line|replacement|", f"A piped {bold('replacement')}."), + ("An escaped | line.", r"s|\||pipe|", f"An escaped {bold('pipe')} line."), + ("An escaped \\ line.", r"s/\\/backslash/", f"An escaped {bold('backslash')} line."), + ("abABab", r"s/b/c/g", "abABab".replace('b', bold('c'))), # g (global) flag + ("ABabAB", r"s/b/c/i", f"A{bold('c')}abAB"), # i (case-insensitive) flag + ("ABabAB", r"s/b/c/ig", f"A{bold('c')}a{bold('c')}A{bold('c')}"), # both flags +) + + +@pytest.mark.parametrize('original, command, result', REPLACES_THAT_WORK) +def test_valid_replacements(bot, irc, user, channel, original, command, result): + """Verify that basic replacement functionality works.""" + irc.channel_joined(channel, [user.nick]) + + irc.say(user, channel, original) + irc.say(user, channel, command) + + assert len(bot.backend.message_sent) == 1, ( + "The bot should respond with exactly one line.") + assert bot.backend.message_sent == rawlist( + "PRIVMSG %s :%s meant to say: %s" % (channel, user.nick, result), + ) + + +def test_multiple_users(bot, irc, user, other_user, channel): + """Verify that correcting another user's line works.""" + irc.channel_joined(channel, [user.nick, other_user.nick]) + + irc.say(other_user, channel, 'Some weather we got yesterday') + irc.say(user, channel, '%s: s/yester/to/' % other_user.nick) + + assert len(bot.backend.message_sent) == 1, ( + "The bot should respond with exactly one line.") + assert bot.backend.message_sent == rawlist( + "PRIVMSG %s :%s thinks %s meant to say: %s" % ( + channel, user.nick, other_user.nick, + f"Some weather we got {bold('to')}day", + ), + ) + + +def test_replace_the_replacement(bot, irc, user, channel): + """Verify replacing text that was already replaced.""" + irc.channel_joined(channel, [user.nick]) + + irc.say(user, channel, 'spam') + irc.say(user, channel, 's/spam/eggs/') + irc.say(user, channel, 's/eggs/bacon/') + + assert len(bot.backend.message_sent) == 2, ( + "The bot should respond twice.") + assert bot.backend.message_sent == rawlist( + "PRIVMSG %s :%s meant to say: %s" % ( + channel, user.nick, bold('eggs'), + ), + "PRIVMSG %s :%s meant to say: %s" % ( + channel, user.nick, bold('bacon'), + ), + ) diff --git a/test/irc/test_irc_utils.py b/test/irc/test_irc_utils.py index 94eab6298..47f2e3d9a 100644 --- a/test/irc/test_irc_utils.py +++ b/test/irc/test_irc_utils.py @@ -1,22 +1,28 @@ """Tests for core ``sopel.irc.utils``""" from __future__ import annotations +from itertools import permutations + import pytest from sopel.irc import utils -def test_safe(): +@pytest.mark.parametrize('s1, s2, s3', permutations(('\n', '\r', '\x00'))) +def test_safe(s1, s2, s3): text = 'some text' - assert utils.safe(text + '\r\n') == text - assert utils.safe(text + '\n') == text - assert utils.safe(text + '\r') == text - assert utils.safe('\r\n' + text) == text - assert utils.safe('\n' + text) == text - assert utils.safe('\r' + text) == text - assert utils.safe('some \r\ntext') == text - assert utils.safe('some \ntext') == text - assert utils.safe('some \rtext') == text + seq = ''.join((s1, s2, s3)) + + assert utils.safe(text + seq) == text + assert utils.safe(seq + text) == text + assert utils.safe('some ' + seq + 'text') == text + assert utils.safe( + s1 + + 'some ' + + s2 + + 'text' + + s3 + ) == text def test_safe_empty(): @@ -24,20 +30,23 @@ def test_safe_empty(): assert utils.safe(text) == text -def test_safe_null(): +def test_safe_none(): with pytest.raises(TypeError): utils.safe(None) -def test_safe_bytes(): +@pytest.mark.parametrize('b1, b2, b3', permutations((b'\n', b'\r', b'\x00'))) +def test_safe_bytes(b1, b2, b3): text = b'some text' - assert utils.safe(text) == text.decode('utf-8') - assert utils.safe(text + b'\r\n') == text.decode('utf-8') - assert utils.safe(text + b'\n') == text.decode('utf-8') - assert utils.safe(text + b'\r') == text.decode('utf-8') - assert utils.safe(b'\r\n' + text) == text.decode('utf-8') - assert utils.safe(b'\n' + text) == text.decode('utf-8') - assert utils.safe(b'\r' + text) == text.decode('utf-8') - assert utils.safe(b'some \r\ntext') == text.decode('utf-8') - assert utils.safe(b'some \ntext') == text.decode('utf-8') - assert utils.safe(b'some \rtext') == text.decode('utf-8') + seq = b''.join((b1, b2, b3)) + + assert utils.safe(text + seq) == text.decode('utf-8') + assert utils.safe(seq + text) == text.decode('utf-8') + assert utils.safe(b'some ' + seq + b'text') == text.decode('utf-8') + assert utils.safe( + b1 + + b'some ' + + b2 + + b'text' + + b3 + ) == text.decode('utf-8') diff --git a/test/plugins/test_plugins_handlers.py b/test/plugins/test_plugins_handlers.py index f34035369..67a2d6034 100644 --- a/test/plugins/test_plugins_handlers.py +++ b/test/plugins/test_plugins_handlers.py @@ -100,3 +100,63 @@ def test_get_label_entrypoint(plugin_tmpfile): assert meta['label'] == 'plugin label' assert meta['type'] == handlers.EntryPointPlugin.PLUGIN_TYPE assert meta['source'] == 'test_plugin = file_mod' + + +MOCK_PARENT_MODULE = """ +from sopel import plugin + +from .sub import foo + + +@plugin.command('mock') +def mock(bot, trigger): + bot.say(foo) +""" + +MOCK_SUB_MODULE = """ +foo = 'bar baz' +""" + + +@pytest.fixture +def plugin_folder(tmp_path): + root = tmp_path / 'test_folder_plugin' + root.mkdir() + + parent = root / '__init__.py' + with open(parent, 'w') as f: + f.write(MOCK_PARENT_MODULE) + + submodule = root / 'sub.py' + with open(submodule, 'w') as f: + f.write(MOCK_SUB_MODULE) + + return str(root) + + +def test_folder_plugin_imports(plugin_folder): + """Ensure submodule imports work as expected in folder plugins. + + Regression test for https://github.com/sopel-irc/sopel/issues/2619 + """ + handler = handlers.PyFilePlugin(plugin_folder) + handler.load() + assert handler.module.foo == 'bar baz' + + +def test_get_version_entrypoint_package_does_not_match(plugin_tmpfile): + # See gh-2593, wherein an entrypoint plugin whose project/package names + # are not equal raised an exception that propagated too far + distrib_dir = os.path.dirname(plugin_tmpfile.strpath) + sys.path.append(distrib_dir) + + try: + entry_point = importlib.metadata.EntryPoint( + 'test_plugin', 'file_mod', 'sopel.plugins') + plugin = handlers.EntryPointPlugin(entry_point) + plugin.load() + plugin.module.__package__ = "FAKEFAKEFAKE" + # Under gh-2593, this call raises a PackageNotFound error + assert plugin.get_version() is None + finally: + sys.path.remove(distrib_dir) diff --git a/test/plugins/test_plugins_rules.py b/test/plugins/test_plugins_rules.py index d252439ca..1ff854dbb 100644 --- a/test/plugins/test_plugins_rules.py +++ b/test/plugins/test_plugins_rules.py @@ -1566,14 +1566,16 @@ def handler(bot, trigger): global_rate_limit=20, channel_rate_limit=20, ) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False rule.execute(mockbot, mocktrigger) - assert rule.is_user_rate_limited(mocktrigger.nick) is True - assert rule.is_channel_rate_limited(mocktrigger.sender) is True - assert rule.is_global_rate_limited() is True + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is True + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is True + assert rule.is_global_rate_limited(at_time) is True def test_rule_rate_limit_no_limit(mockbot, triggerfactory): @@ -1592,14 +1594,16 @@ def handler(bot, trigger): global_rate_limit=0, channel_rate_limit=0, ) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False rule.execute(mockbot, mocktrigger) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False def test_rule_rate_limit_ignore_rate_limit(mockbot, triggerfactory): @@ -1619,14 +1623,16 @@ def handler(bot, trigger): channel_rate_limit=20, threaded=False, # make sure there is no race-condition here ) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False rule.execute(mockbot, mocktrigger) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False def test_rule_rate_limit_messages(mockbot, triggerfactory): diff --git a/test/test_bot.py b/test/test_bot.py index 964878bee..73904e810 100644 --- a/test/test_bot.py +++ b/test/test_bot.py @@ -15,7 +15,9 @@ if typing.TYPE_CHECKING: from sopel.config import Config - from sopel.tests.factories import BotFactory, IRCFactory, UserFactory + from sopel.tests.factories import ( + BotFactory, ConfigFactory, IRCFactory, TriggerFactory, UserFactory, + ) from sopel.tests.mocks import MockIRCServer @@ -81,17 +83,17 @@ def ignored(): @pytest.fixture -def tmpconfig(configfactory): +def tmpconfig(configfactory: ConfigFactory) -> Config: return configfactory('test.cfg', TMP_CONFIG) @pytest.fixture -def mockbot(tmpconfig, botfactory): +def mockbot(tmpconfig: Config, botfactory: BotFactory) -> bot.Sopel: return botfactory(tmpconfig) @pytest.fixture -def mockplugin(tmpdir): +def mockplugin(tmpdir) -> plugins.handlers.PyFilePlugin: root = tmpdir.mkdir('loader_mods') mod_file = root.join('mockplugin.py') mod_file.write(MOCK_MODULE_CONTENT) @@ -676,7 +678,7 @@ def url_callback_http(bot, trigger, match): # call_rule @pytest.fixture -def match_hello_rule(mockbot, triggerfactory): +def match_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory): """Helper for generating matches to each `Rule` in the following tests""" def _factory(rule_hello): # trigger @@ -694,7 +696,25 @@ def _factory(rule_hello): return _factory -def test_call_rule(mockbot, match_hello_rule): +@pytest.fixture +def multimatch_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory): + def _factory(rule_hello): + # trigger + line = ':Test!test@example.com PRIVMSG #channel :hello hello hello' + + trigger = triggerfactory(mockbot, line) + pretrigger = trigger._pretrigger + + for match in rule_hello.match(mockbot, pretrigger): + wrapper = bot.SopelWrapper(mockbot, trigger) + yield match, trigger, wrapper + return _factory + + +def test_call_rule( + mockbot: bot.Sopel, + match_hello_rule: typing.Callable, +) -> None: # setup items = [] @@ -721,9 +741,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is not rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -738,6 +759,36 @@ def testrule(bot, trigger): assert items == [1, 1] +def test_call_rule_multiple_matches( + mockbot: bot.Sopel, + multimatch_hello_rule: typing.Callable, +) -> None: + # setup + items = [] + + def testrule(bot, trigger): + bot.say('hi') + items.append(1) + return "Return Value" + + find_hello = rules.FindRule( + [re.compile(r'(hi|hello|hey|sup)')], + plugin='testplugin', + label='testrule', + handler=testrule) + + for match, rule_trigger, wrapper in multimatch_hello_rule(find_hello): + mockbot.call_rule(find_hello, wrapper, rule_trigger) + + # assert the rule has been executed three times now + assert mockbot.backend.message_sent == rawlist( + 'PRIVMSG #channel :hi', + 'PRIVMSG #channel :hi', + 'PRIVMSG #channel :hi', + ) + assert items == [1, 1, 1] + + def test_call_rule_rate_limited_user(mockbot, match_hello_rule): items = [] @@ -767,9 +818,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -852,9 +904,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -897,9 +950,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -942,9 +996,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -987,9 +1042,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello)