From f80d7c764bbad964d2ae07b07d415de5c2ed3764 Mon Sep 17 00:00:00 2001 From: Clari Date: Wed, 24 Jan 2024 21:49:28 -0600 Subject: [PATCH] Initial fork commit; squashing previous commits for cleanliness jsk root: Add backticks around jishaku version jsk invite: Translate permissions input to lowercase and replace "server" with "guild" for ease of use jsk python: Add support for a list of files or embeds Change ReplResponseReactor(ctx.message) to ReplResponseReactor(ctx) so that errors and results are compatible with ContextEditor --- jishaku/__main__.py | 1 - jishaku/exception_handling.py | 36 ++++--- jishaku/features/filesystem.py | 2 +- jishaku/features/invocation.py | 33 +++--- jishaku/features/management.py | 2 + jishaku/features/python.py | 170 +++++++++++++++++++++---------- jishaku/features/root_command.py | 17 +++- jishaku/features/shell.py | 54 ++++++---- jishaku/features/sql.py | 8 +- jishaku/flags.py | 38 ++++--- 10 files changed, 231 insertions(+), 130 deletions(-) diff --git a/jishaku/__main__.py b/jishaku/__main__.py index 9bd555e4..cc32c200 100644 --- a/jishaku/__main__.py +++ b/jishaku/__main__.py @@ -170,6 +170,5 @@ def prefix(bot: commands.Bot, _: discord.Message) -> typing.List[str]: asyncio.run(entry(bot, token)) - if __name__ == '__main__': entrypoint() # pylint: disable=no-value-for-parameter diff --git a/jishaku/exception_handling.py b/jishaku/exception_handling.py index 83b5da2a..0fd8654a 100644 --- a/jishaku/exception_handling.py +++ b/jishaku/exception_handling.py @@ -21,6 +21,7 @@ from discord.ext import commands from jishaku.flags import Flags +from jishaku.types import ContextA async def send_traceback( @@ -49,9 +50,9 @@ async def send_traceback( message = None for page in paginator.pages: - if isinstance(destination, discord.Message): + if isinstance(destination, (discord.Message, commands.Context)): message = await destination.reply(page) - else: + elif isinstance(destination, discord.abc.Messageable): message = await destination.send(page) return message @@ -61,7 +62,9 @@ async def send_traceback( P = typing.ParamSpec('P') -async def do_after_sleep(delay: float, coro: typing.Callable[P, typing.Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: +async def do_after_sleep( + delay: float, coro: typing.Callable[P, typing.Awaitable[T]], *args: P.args, **kwargs: P.kwargs +) -> T: """ Performs an action after a set amount of time. @@ -99,11 +102,11 @@ class ReplResponseReactor: # pylint: disable=too-few-public-methods """ Extension of the ReactionProcedureTimer that absorbs errors, sending tracebacks. """ + __slots__ = ('ctx', 'message', 'loop', 'handle', 'raised') - __slots__ = ('message', 'loop', 'handle', 'raised') - - def __init__(self, message: discord.Message, loop: typing.Optional[asyncio.BaseEventLoop] = None): - self.message = message + def __init__(self, ctx: ContextA, loop: typing.Optional[asyncio.BaseEventLoop] = None): + self.ctx = ctx + self.message = ctx.message self.loop = loop or asyncio.get_event_loop() self.handle = None self.raised = False @@ -131,30 +134,33 @@ async def __aexit__( if isinstance(exc_val, (SyntaxError, asyncio.TimeoutError, subprocess.TimeoutExpired)): # short traceback, send to channel - destination = Flags.traceback_destination(self.message) or self.message.channel + destination = Flags.traceback_destination(self.ctx) or self.ctx - if destination != self.message.channel: + if destination != self.ctx: await attempt_add_reaction( self.message, # timed out is alarm clock # syntax error is single exclamation mark - "\N{HEAVY EXCLAMATION MARK SYMBOL}" if isinstance(exc_val, SyntaxError) else "\N{ALARM CLOCK}" + "\N{HEAVY EXCLAMATION MARK SYMBOL}" if isinstance(exc_val, SyntaxError) else "\N{ALARM CLOCK}", ) await send_traceback( - self.message if destination == self.message.channel else destination, - 0, exc_type, exc_val, exc_tb + destination, + 0, + exc_type, + exc_val, + exc_tb, ) else: - destination = Flags.traceback_destination(self.message) or self.message.author + destination = Flags.traceback_destination(self.ctx) or self.message.author - if destination != self.message.channel: + if destination != self.ctx: # other error, double exclamation mark await attempt_add_reaction(self.message, "\N{DOUBLE EXCLAMATION MARK}") # this traceback likely needs more info, so increase verbosity, and DM it instead. await send_traceback( - self.message if destination == self.message.channel else destination, + destination, 8, exc_type, exc_val, exc_tb ) diff --git a/jishaku/features/filesystem.py b/jishaku/features/filesystem.py index 669d631a..83ba19a9 100644 --- a/jishaku/features/filesystem.py +++ b/jishaku/features/filesystem.py @@ -103,7 +103,7 @@ async def jsk_curl(self, ctx: ContextA, url: str): # remove embed maskers if present url = url.lstrip("<").rstrip(">") - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): async with aiohttp.ClientSession() as session: async with session.get(url) as response: data = await response.read() diff --git a/jishaku/features/invocation.py b/jishaku/features/invocation.py index 8e772923..ea1f5e42 100644 --- a/jishaku/features/invocation.py +++ b/jishaku/features/invocation.py @@ -40,7 +40,7 @@ class SlimUserConverter(UserIDConverter): # pylint: disable=too-few-public-meth async def convert(self, ctx: ContextA, argument: str) -> typing.Union[discord.Member, discord.User]: """Converter method""" - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) # type: ignore + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) # type: ignore if match is not None: user_id = int(match.group(1)) @@ -63,7 +63,7 @@ class SlimChannelConverter(ChannelIDConverter): # pylint: disable=too-few-publi async def convert(self, ctx: ContextA, argument: str) -> discord.TextChannel: """Converter method""" - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) if match is not None: channel_id = int(match.group(1)) @@ -80,7 +80,7 @@ class SlimThreadConverter(ThreadIDConverter): # pylint: disable=too-few-public- async def convert(self, ctx: ContextA, argument: str) -> discord.Thread: """Converter method""" - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) if match is not None: thread_id = int(match.group(1)) @@ -98,13 +98,7 @@ class InvocationFeature(Feature): OVERRIDE_SIGNATURE = typing.Union[SlimUserConverter, SlimChannelConverter, SlimThreadConverter] @Feature.Command(parent="jsk", name="override", aliases=["execute", "exec", "override!", "execute!", "exec!"]) - async def jsk_override( - self, - ctx: ContextT, - overrides: commands.Greedy[OVERRIDE_SIGNATURE], - *, - command_string: str - ): + async def jsk_override(self, ctx: ContextT, overrides: commands.Greedy[OVERRIDE_SIGNATURE], *, command_string: str): """ Run a command with a different user, channel, or thread, optionally bypassing checks and cooldowns. @@ -114,7 +108,7 @@ async def jsk_override( kwargs: typing.Dict[str, typing.Any] = {} if ctx.prefix: - kwargs["content"] = ctx.prefix + command_string.lstrip('/') + kwargs["content"] = ctx.prefix + command_string.lstrip("/") else: await ctx.send("Reparsing requires a prefix") return @@ -142,12 +136,12 @@ async def jsk_override( if alt_ctx.command is None: if alt_ctx.invoked_with is None: - await ctx.send('This bot has been hard-configured to ignore this user.') + await ctx.send("This bot has been hard-configured to ignore this user.") return await ctx.send(f'Command "{alt_ctx.invoked_with}" is not found') return - if ctx.invoked_with and ctx.invoked_with.endswith('!'): + if ctx.invoked_with and ctx.invoked_with.endswith("!"): await alt_ctx.command.reinvoke(alt_ctx) return @@ -193,7 +187,7 @@ async def jsk_debug(self, ctx: ContextT, *, command_string: str): start = time.perf_counter() - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): await alt_ctx.command.invoke(alt_ctx) @@ -223,17 +217,14 @@ async def jsk_source(self, ctx: ContextA, *, command_name: str): pass # getsourcelines for some reason returns WITH line endings - source_text = ''.join(source_lines) + source_text = "".join(source_lines) if use_file_check(ctx, len(source_text)): # File "full content" preview limit - await ctx.send(file=discord.File( - filename=filename, - fp=io.BytesIO(source_text.encode('utf-8')) - )) + await ctx.send(file=discord.File(filename=filename, fp=io.BytesIO(source_text.encode("utf-8")))) else: - paginator = WrappedPaginator(prefix='```py', suffix='```', max_size=1980) + paginator = WrappedPaginator(prefix="```py", suffix="```", max_size=1980) - paginator.add_line(source_text.replace('```', '``\N{zero width space}`')) + paginator.add_line(source_text.replace("```", "``\N{zero width space}`")) interface = PaginatorInterface(ctx.bot, paginator, owner=ctx.author) await interface.send_to(ctx) diff --git a/jishaku/features/management.py b/jishaku/features/management.py index 0796426b..ac317580 100644 --- a/jishaku/features/management.py +++ b/jishaku/features/management.py @@ -128,6 +128,8 @@ async def jsk_invite(self, ctx: ContextA, *perms: str): permissions = discord.Permissions() for perm in perms: + perm = perm.lower() + perm.replace("server", "guild") if perm not in dict(permissions): raise commands.BadArgument(f"Invalid permission: {perm}") diff --git a/jishaku/features/python.py b/jishaku/features/python.py index b7a1e367..871f0305 100644 --- a/jishaku/features/python.py +++ b/jishaku/features/python.py @@ -63,7 +63,7 @@ def scope(self): return Scope() @Feature.Command(parent="jsk", name="retain") - async def jsk_retain(self, ctx: ContextA, *, toggle: bool = None): # type: ignore + async def jsk_retain(self, ctx: ContextA, *, toggle: bool | None = None): # type: ignore """ Turn variable retention for REPL on or off. @@ -82,15 +82,21 @@ async def jsk_retain(self, ctx: ContextA, *, toggle: bool = None): # type: igno self.retain = True self._scope = Scope() - return await ctx.send("Variable retention is ON. Future REPL sessions will retain their scope.") + return await ctx.send( + "Variable retention is ON. Future REPL sessions will retain their scope." + ) if not self.retain: return await ctx.send("Variable retention is already set to OFF.") self.retain = False - return await ctx.send("Variable retention is OFF. Future REPL sessions will dispose their scope when done.") + return await ctx.send( + "Variable retention is OFF. Future REPL sessions will dispose their scope when done." + ) - async def jsk_python_result_handling(self, ctx: ContextA, result: typing.Any): # pylint: disable=too-many-return-statements + async def jsk_python_result_handling( + self, ctx: ContextA, result: typing.Any + ): # pylint: disable=too-many-return-statements """ Determines what is done with a result when it comes out of jsk py. This allows you to override how this is done without having to rewrite the command itself. @@ -103,9 +109,21 @@ async def jsk_python_result_handling(self, ctx: ContextA, result: typing.Any): if isinstance(result, discord.File): return await ctx.send(file=result) + if isinstance(result, typing.Iterable) and all( + isinstance(obj, discord.File) for obj in result + ): + return await ctx.send(files=result) + if isinstance(result, discord.Embed): return await ctx.send(embed=result) + if ( + isinstance(result, typing.Iterable) + and all(isinstance(obj, discord.Embed) for obj in result) + and discord.__version__[0] == "2" + ): + return await ctx.send(embeds=result) + if isinstance(result, PaginatorInterface): return await result.send_to(ctx) @@ -115,15 +133,14 @@ async def jsk_python_result_handling(self, ctx: ContextA, result: typing.Any): # Eventually the below handling should probably be put somewhere else if len(result) <= 2000: - if result.strip() == '': + if result.strip() == "": result = "\u200b" if self.bot.http.token: result = result.replace(self.bot.http.token, "[token omitted]") return await ctx.send( - result, - allowed_mentions=discord.AllowedMentions.none() + result, allowed_mentions=discord.AllowedMentions.none() ) if use_file_check(ctx, len(result)): # File "full content" preview limit @@ -132,21 +149,24 @@ async def jsk_python_result_handling(self, ctx: ContextA, result: typing.Any): # Since this avoids escape issues and is more intuitive than pagination for # long results, it will now be prioritized over PaginatorInterface if the # resultant content is below the filesize threshold - return await ctx.send(file=discord.File( - filename="output.py", - fp=io.BytesIO(result.encode('utf-8')) - )) + return await ctx.send( + file=discord.File( + filename="output.py", fp=io.BytesIO(result.encode("utf-8")) + ) + ) # inconsistency here, results get wrapped in codeblocks when they are too large # but don't if they're not. probably not that bad, but noting for later review - paginator = WrappedPaginator(prefix='```py', suffix='```', max_size=1980) + paginator = WrappedPaginator(prefix="```py", suffix="```", max_size=1980) paginator.add_line(result) interface = PaginatorInterface(ctx.bot, paginator, owner=ctx.author) return await interface.send_to(ctx) - def jsk_python_get_convertables(self, ctx: ContextA) -> typing.Tuple[typing.Dict[str, typing.Any], typing.Dict[str, str]]: + def jsk_python_get_convertables( + self, ctx: ContextA + ) -> typing.Tuple[typing.Dict[str, typing.Any], typing.Dict[str, str]]: """ Gets the arg dict and convertables for this scope. @@ -158,7 +178,7 @@ def jsk_python_get_convertables(self, ctx: ContextA) -> typing.Tuple[typing.Dict arg_dict["_"] = self.last_result convertables: typing.Dict[str, str] = {} - if getattr(ctx, 'interaction', None) is None: + if getattr(ctx, "interaction", None) is None: for index, user in enumerate(ctx.message.mentions): arg_dict[f"__user_mention_{index}"] = user convertables[user.mention] = f"__user_mention_{index}" @@ -186,9 +206,14 @@ async def jsk_python(self, ctx: ContextA, *, argument: codeblock_converter): # scope = self.scope try: - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): - executor = AsyncCodeExecutor(argument.content, scope, arg_dict=arg_dict, convertables=convertables) + executor = AsyncCodeExecutor( + argument.content, + scope, + arg_dict=arg_dict, + convertables=convertables, + ) async for send, result in AsyncSender(executor): # type: ignore send: typing.Callable[..., None] result: typing.Any @@ -203,7 +228,11 @@ async def jsk_python(self, ctx: ContextA, *, argument: codeblock_converter): # finally: scope.clear_intersection(arg_dict) - @Feature.Command(parent="jsk", name="py_inspect", aliases=["pyi", "python_inspect", "pythoninspect"]) + @Feature.Command( + parent="jsk", + name="py_inspect", + aliases=["pyi", "python_inspect", "pythoninspect"], + ) async def jsk_python_inspect(self, ctx: ContextA, *, argument: codeblock_converter): # type: ignore """ Evaluation of Python code with inspect information. @@ -216,9 +245,14 @@ async def jsk_python_inspect(self, ctx: ContextA, *, argument: codeblock_convert scope = self.scope try: - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): - executor = AsyncCodeExecutor(argument.content, scope, arg_dict=arg_dict, convertables=convertables) + executor = AsyncCodeExecutor( + argument.content, + scope, + arg_dict=arg_dict, + convertables=convertables, + ) async for send, result in AsyncSender(executor): # type: ignore send: typing.Callable[..., None] result: typing.Any @@ -228,7 +262,9 @@ async def jsk_python_inspect(self, ctx: ContextA, *, argument: codeblock_convert header = repr(result).replace("``", "`\u200b`") if self.bot.http.token: - header = header.replace(self.bot.http.token, "[token omitted]") + header = header.replace( + self.bot.http.token, "[token omitted]" + ) if len(header) > 485: header = header[0:482] + "..." @@ -238,29 +274,40 @@ async def jsk_python_inspect(self, ctx: ContextA, *, argument: codeblock_convert for name, res in all_inspections(result): lines.append(f"{name:16.16} :: {res}") - docstring = (inspect.getdoc(result) or '').strip() + docstring = (inspect.getdoc(result) or "").strip() if docstring: lines.append(f"\n=== Help ===\n\n{docstring}") text = "\n".join(lines) - if use_file_check(ctx, len(text)): # File "full content" preview limit - send(await ctx.send(file=discord.File( - filename="inspection.prolog", - fp=io.BytesIO(text.encode('utf-8')) - ))) + if use_file_check( + ctx, len(text) + ): # File "full content" preview limit + send( + await ctx.send( + file=discord.File( + filename="inspection.prolog", + fp=io.BytesIO(text.encode("utf-8")), + ) + ) + ) else: - paginator = WrappedPaginator(prefix="```prolog", max_size=1980) + paginator = WrappedPaginator( + prefix="```prolog", max_size=1980 + ) paginator.add_line(text) - interface = PaginatorInterface(ctx.bot, paginator, owner=ctx.author) + interface = PaginatorInterface( + ctx.bot, paginator, owner=ctx.author + ) send(await interface.send_to(ctx)) finally: scope.clear_intersection(arg_dict) if line_profiler is not None: + @Feature.Command(parent="jsk", name="timeit") async def jsk_timeit(self, ctx: ContextA, *, argument: codeblock_converter): # type: ignore """ @@ -274,22 +321,28 @@ async def jsk_timeit(self, ctx: ContextA, *, argument: codeblock_converter): # scope = self.scope try: - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): executor = AsyncCodeExecutor( - argument.content, scope, + argument.content, + scope, arg_dict=arg_dict, convertables=convertables, - auto_return=False + auto_return=False, ) overall_start = time.perf_counter() count: int = 0 timings: typing.List[float] = [] ioless_timings: typing.List[float] = [] - line_timings: typing.Dict[int, typing.List[float]] = collections.defaultdict(list) - - while count < 10_000 and (time.perf_counter() - overall_start) < 30.0: + line_timings: typing.Dict[ + int, typing.List[float] + ] = collections.defaultdict(list) + + while ( + count < 10_000 + and (time.perf_counter() - overall_start) < 30.0 + ): profile = line_profiler.LineProfiler() # type: ignore profile.add_function(executor.function) # type: ignore @@ -305,9 +358,13 @@ async def jsk_timeit(self, ctx: ContextA, *, argument: codeblock_converter): # self.last_result = result - send(await self.jsk_python_result_handling(ctx, result)) - # Reduces likelihood of hardblocking - await asyncio.sleep(0.001) + send( + await self.jsk_python_result_handling( + ctx, result + ) + ) + # Reduces likelyhood of hardblocking + await asyncio.sleep(0.01) end = time.perf_counter() finally: @@ -331,7 +388,9 @@ async def jsk_timeit(self, ctx: ContextA, *, argument: codeblock_converter): # execution_time = format_stddev(timings) active_time = format_stddev(ioless_timings) - max_line_time = max(max(timing) for timing in line_timings.values()) + max_line_time = max( + max(timing) for timing in line_timings.values() + ) linecache = executor.create_linecache() lines: typing.List[str] = [] @@ -348,16 +407,18 @@ async def jsk_timeit(self, ctx: ContextA, *, argument: codeblock_converter): # lines.append('\u001b[0m' + color + line if Flags.use_ansi(ctx) else line) await ctx.send( - content="\n".join([ - f"Executed {count} times", - f"Actual execution time: {execution_time}", - f"Active (non-waiting) time: {active_time}", - "**Delay will be added by async setup, use only for relative measurements**", - ]), + content="\n".join( + [ + f"Executed {count} times", + f"Actual execution time: {execution_time}", + f"Active (non-waiting) time: {active_time}", + "**Delay will be added by async setup, use only for relative measurements**", + ] + ), file=discord.File( filename="lines.ansi", - fp=io.BytesIO(''.join(lines).encode('utf-8')) - ) + fp=io.BytesIO("".join(lines).encode("utf-8")), + ), ) finally: @@ -374,16 +435,17 @@ async def jsk_disassemble(self, ctx: ContextA, *, argument: codeblock_converter) arg_dict = get_var_dict_from_ctx(ctx, Flags.SCOPE_PREFIX) - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): text = "\n".join(disassemble(argument.content, arg_dict=arg_dict)) if use_file_check(ctx, len(text)): # File "full content" preview limit - await ctx.send(file=discord.File( - filename="dis.py", - fp=io.BytesIO(text.encode('utf-8')) - )) + await ctx.send( + file=discord.File( + filename="dis.py", fp=io.BytesIO(text.encode("utf-8")) + ) + ) else: - paginator = WrappedPaginator(prefix='```py', max_size=1980) + paginator = WrappedPaginator(prefix="```py", max_size=1980) paginator.add_line(text) @@ -399,7 +461,7 @@ async def jsk_ast(self, ctx: ContextA, *, argument: codeblock_converter): # typ if typing.TYPE_CHECKING: argument: Codeblock = argument # type: ignore - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): text = create_tree(argument.content, use_ansi=Flags.use_ansi(ctx)) await ctx.send(file=discord.File( @@ -421,7 +483,7 @@ async def jsk_specialist(self, ctx: ContextA, *, argument: codeblock_converter): scope = self.scope try: - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): executor = AsyncCodeExecutor(argument.content, scope, arg_dict=arg_dict, convertables=convertables) async for send, result in AsyncSender(executor): # type: ignore diff --git a/jishaku/features/root_command.py b/jishaku/features/root_command.py index e179e75a..48d89e0f 100644 --- a/jishaku/features/root_command.py +++ b/jishaku/features/root_command.py @@ -11,6 +11,7 @@ """ +import math import sys import typing from importlib.metadata import distribution, packages_distributions @@ -31,6 +32,20 @@ psutil = None +def natural_size(size_in_bytes: int): + """ + Converts a number of bytes to an appropriately-scaled unit + E.g.: + 1024 -> 1.00 KiB + 12345678 -> 11.77 MiB + """ + units = ('B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB') + + power = int(math.log(size_in_bytes, 1024)) + + return f"{size_in_bytes / (1024 ** power):.2f} {units[power]}" + + class RootCommand(Feature): """ Feature containing the root jsk command @@ -65,7 +80,7 @@ async def jsk(self, ctx: ContextA): dist_version = f'unknown `{discord.__version__}`' summary = [ - f"Jishaku v{package_version('jishaku')}, {dist_version}, " + f"Jishaku `v{package_version('jishaku')}`, {dist_version}, " f"`Python {sys.version}` on `{sys.platform}`".replace("\n", ""), f"Module was loaded , " f"cog was loaded .", diff --git a/jishaku/features/shell.py b/jishaku/features/shell.py index c5311294..cec0eb25 100644 --- a/jishaku/features/shell.py +++ b/jishaku/features/shell.py @@ -29,7 +29,7 @@ from jishaku.shell import ShellReader from jishaku.types import ContextA -SCAFFOLD_FOLDER = pathlib.Path(__file__).parent / 'scaffolds' +SCAFFOLD_FOLDER = pathlib.Path(__file__).parent / "scaffolds" @contextlib.contextmanager @@ -49,17 +49,17 @@ def scaffold(name: str, **kwargs: typing.Any): temp = pathlib.Path(temp) for item in source.glob("**/*"): - if '__pycache__' in str(item): + if "__pycache__" in str(item): continue if item.is_file(): - with open(item, 'r', encoding='utf-8') as fp: + with open(item, "r", encoding="utf-8") as fp: content = fp.read() target = temp / item.relative_to(source) target.parent.mkdir(parents=True, exist_ok=True) - with open(target, 'w', encoding='utf-8') as fp: + with open(target, "w", encoding="utf-8") as fp: fp.write(content.format(**kwargs)) yield temp @@ -70,7 +70,11 @@ class ShellFeature(Feature): Feature containing the shell-related commands """ - @Feature.Command(parent="jsk", name="shell", aliases=["bash", "sh", "powershell", "ps1", "ps", "cmd", "terminal"]) + @Feature.Command( + parent="jsk", + name="shell", + aliases=["bash", "sh", "powershell", "ps1", "ps", "cmd", "terminal"], + ) async def jsk_shell(self, ctx: ContextA, *, argument: codeblock_converter): # type: ignore """ Executes statements in the system shell. @@ -82,9 +86,11 @@ async def jsk_shell(self, ctx: ContextA, *, argument: codeblock_converter): # t if typing.TYPE_CHECKING: argument: Codeblock = argument # type: ignore - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): - with ShellReader(argument.content, escape_ansi=not Flags.use_ansi(ctx)) as reader: + with ShellReader( + argument.content, escape_ansi=not Flags.use_ansi(ctx) + ) as reader: prefix = "```" + reader.highlight paginator = WrappedPaginator(prefix=prefix, max_size=1975) @@ -124,10 +130,10 @@ async def jsk_pip(self, ctx: commands.Context, *, argument: codeblock_converter) location = pathlib.Path(sys.prefix) for test in ( - location / 'bin' / 'pip', - location / 'bin' / 'pip3', - location / 'Scripts' / 'pip.exe', - location / 'Scripts' / 'pip3.exe', + location / "bin" / "pip", + location / "bin" / "pip3", + location / "Scripts" / "pip.exe", + location / "Scripts" / "pip3.exe", ): if test.exists() and test.is_file(): executable = str(test) @@ -135,7 +141,8 @@ async def jsk_pip(self, ctx: commands.Context, *, argument: codeblock_converter) return await ctx.invoke(self.jsk_shell, argument=Codeblock(argument.language, f"{executable} {argument.content}")) # type: ignore - if shutil.which('node') and shutil.which('npm'): + if shutil.which("node") and shutil.which("npm"): + @Feature.Command(parent="jsk", name="node") async def jsk_node(self, ctx: commands.Context, *, argument: codeblock_converter): # type: ignore """ @@ -145,12 +152,16 @@ async def jsk_node(self, ctx: commands.Context, *, argument: codeblock_converter if typing.TYPE_CHECKING: argument: Codeblock = argument # type: ignore - requirements = ''.join(f"npm install {match} && " for match in re.findall('// jsk require: (.+)', argument.content)) + requirements = "".join( + f"npm install {match} && " + for match in re.findall("// jsk require: (.+)", argument.content) + ) - with scaffold('npm', content=argument.content) as directory: + with scaffold("npm", content=argument.content) as directory: return await ctx.invoke(self.jsk_shell, argument=Codeblock("js", f"cd {directory} && {requirements}npm run main")) # type: ignore - if shutil.which('pyright'): + if shutil.which("pyright"): + @Feature.Command(parent="jsk", name="pyright") async def jsk_pyright(self, ctx: commands.Context, *, argument: codeblock_converter): # type: ignore """ @@ -160,10 +171,11 @@ async def jsk_pyright(self, ctx: commands.Context, *, argument: codeblock_conver if typing.TYPE_CHECKING: argument: Codeblock = argument # type: ignore - with scaffold('pyright', content=argument.content) as directory: + with scaffold("pyright", content=argument.content) as directory: return await ctx.invoke(self.jsk_shell, argument=Codeblock("js", f"cd {directory} && pyright main.py")) # type: ignore - if shutil.which('rustc') and shutil.which('cargo'): + if shutil.which("rustc") and shutil.which("cargo"): + @Feature.Command(parent="jsk", name="rustc") async def jsk_rustc(self, ctx: commands.Context, *, argument: codeblock_converter): # type: ignore """ @@ -173,7 +185,11 @@ async def jsk_rustc(self, ctx: commands.Context, *, argument: codeblock_converte if typing.TYPE_CHECKING: argument: Codeblock = argument # type: ignore - requirements = '\n'.join(re.findall('// jsk require: (.+)', argument.content)) + requirements = "\n".join( + re.findall("// jsk require: (.+)", argument.content) + ) - with scaffold('cargo', content=argument.content, requirements=requirements) as directory: + with scaffold( + "cargo", content=argument.content, requirements=requirements + ) as directory: return await ctx.invoke(self.jsk_shell, argument=Codeblock("rust", f"cd {directory} && cargo run")) # type: ignore diff --git a/jishaku/features/sql.py b/jishaku/features/sql.py index 4470101b..f45c99ad 100644 --- a/jishaku/features/sql.py +++ b/jishaku/features/sql.py @@ -375,7 +375,7 @@ async def jsk_sql_fetchrow(self, ctx: ContextA, *, query: str): output = None async with adapter_shim.use(): - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): output = await adapter_shim.fetchrow(query) @@ -412,7 +412,7 @@ async def jsk_sql_fetch(self, ctx: ContextA, *, query: str): output = None async with adapter_shim.use(): - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): output = await adapter_shim.fetch(query) @@ -463,7 +463,7 @@ async def jsk_sql_execute(self, ctx: ContextA, *, query: str): output = None async with adapter_shim.use(): - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): output = await adapter_shim.execute(query) @@ -485,7 +485,7 @@ async def jsk_sql_schema(self, ctx: ContextA, *, query: typing.Optional[str] = N output = None async with adapter_shim.use(): - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): output = await adapter_shim.table_summary(query) diff --git a/jishaku/flags.py b/jishaku/flags.py index 92b0cc6c..e119ffd1 100644 --- a/jishaku/flags.py +++ b/jishaku/flags.py @@ -24,7 +24,7 @@ DISABLED_SYMBOLS = ("false", "f", "no", "n", "off", "0") -FlagHandler = typing.Optional[typing.Callable[['FlagMeta'], typing.Any]] +FlagHandler = typing.Optional[typing.Callable[["FlagMeta"], typing.Any]] @dataclasses.dataclass @@ -39,7 +39,9 @@ class Flag: handler: FlagHandler = None override: typing.Any = None - def resolve_raw(self, flags: 'FlagMeta'): # pylint: disable=too-many-return-statements + def resolve_raw( + self, flags: "FlagMeta" + ): # pylint: disable=too-many-return-statements """ Receive the intrinsic value for this flag, before optionally being processed by the handler. """ @@ -69,7 +71,7 @@ def resolve_raw(self, flags: 'FlagMeta'): # pylint: disable=too-many-return-sta return self.flag_type() - def resolve(self, flags: 'FlagMeta'): + def resolve(self, flags: "FlagMeta"): """ Resolve this flag. Only for internal use. Applies the handler when there is one. @@ -93,11 +95,11 @@ def __new__( cls, name: str, base: typing.Tuple[typing.Type[typing.Any]], - attrs: typing.Dict[str, typing.Any] + attrs: typing.Dict[str, typing.Any], ): - attrs['flag_map'] = {} + attrs["flag_map"] = {} - for flag_name, flag_type in attrs['__annotations__'].items(): + for flag_name, flag_type in attrs["__annotations__"].items(): default: typing.Union[ FlagHandler, typing.Tuple[ @@ -110,14 +112,14 @@ def __new__( if isinstance(default, tuple): default, handler = default - attrs['flag_map'][flag_name] = Flag(flag_name, flag_type, default, handler) + attrs["flag_map"][flag_name] = Flag(flag_name, flag_type, default, handler) return super(FlagMeta, cls).__new__(cls, name, base, attrs) def __getattr__(cls, name: str): cls.flag_map: typing.Dict[str, Flag] - if hasattr(cls, 'flag_map') and name in cls.flag_map: + if hasattr(cls, "flag_map") and name in cls.flag_map: return cls.flag_map[name].resolve(cls) return super().__getattribute__(name) @@ -127,7 +129,9 @@ def __setattr__(cls, name: str, value: typing.Any): flag = cls.flag_map[name] if not isinstance(value, flag.flag_type): - raise ValueError(f"Attempted to set flag {name} to type {type(value).__name__} (should be {flag.flag_type.__name__})") + raise ValueError( + f"Attempted to set flag {name} to type {type(value).__name__} (should be {flag.flag_type.__name__})" + ) flag.override = value else: @@ -155,7 +159,7 @@ class Flags(metaclass=FlagMeta): # pylint: disable=too-few-public-methods # The scope prefix, i.e. the prefix that appears before Jishaku's builtin variables in REPL sessions. # It is recommended that you set this programatically. - SCOPE_PREFIX: str = lambda flags: '' if flags.NO_UNDERSCORE else '_' # type: ignore + SCOPE_PREFIX: str = lambda flags: "" if flags.NO_UNDERSCORE else "_" # type: ignore # Flag to indicate whether to always use paginators over relying on Discord's file preview FORCE_PAGINATOR: bool @@ -168,17 +172,19 @@ class Flags(metaclass=FlagMeta): # pylint: disable=too-few-public-methods ALWAYS_DM_TRACEBACK: bool @classmethod - def traceback_destination(cls, message: discord.Message) -> typing.Optional[discord.abc.Messageable]: + def traceback_destination( + cls, ctx: ContextA + ) -> typing.Optional[discord.abc.Messageable]: """ Determine what 'default' location to send tracebacks to When None, the caller should decide """ if cls.ALWAYS_DM_TRACEBACK: - return message.author + return ctx.author if cls.NO_DM_TRACEBACK: - return message.channel + return ctx # Otherwise let the caller decide return None @@ -205,4 +211,8 @@ def use_ansi(cls, ctx: ContextA) -> bool: if cls.USE_ANSI_ALWAYS: return True - return not ctx.author.is_on_mobile() if isinstance(ctx.author, discord.Member) and ctx.bot.intents.presences else True + return ( + not ctx.author.is_on_mobile() + if isinstance(ctx.author, discord.Member) and ctx.bot.intents.presences + else True + )