From 48933a1bf9217c2664be18fb699feb0802537bde Mon Sep 17 00:00:00 2001 From: Brian Pugh Date: Wed, 21 Aug 2024 09:16:24 -0400 Subject: [PATCH 1/2] Use last command's help/version flags. --- cyclopts/core.py | 63 ++++++++++++++++++++------------- cyclopts/help.py | 2 +- tests/apps/test_burgery.py | 2 +- tests/test_version_parameter.py | 33 +++++++++++++++++ 4 files changed, 74 insertions(+), 26 deletions(-) create mode 100644 tests/test_version_parameter.py diff --git a/cyclopts/core.py b/cyclopts/core.py index 1899c50b..ec43387b 100644 --- a/cyclopts/core.py +++ b/cyclopts/core.py @@ -5,6 +5,7 @@ from contextlib import suppress from copy import copy from functools import partial +from itertools import chain from pathlib import Path from typing import ( TYPE_CHECKING, @@ -291,8 +292,10 @@ class App: def __attrs_post_init__(self): # Trigger the setters - self.help_flags = self._help_flags - self.version_flags = self._version_flags + func = getattr(self.default_command, "__func__", None) + if func != type(self).version_print and func != type(self).help_print: + self.help_flags = self._help_flags + self.version_flags = self._version_flags ########### # Methods # @@ -311,7 +314,7 @@ def _delete_commands(self, commands: Iterable[str], default=None): for command in commands: with suppress(KeyError): if default: - if self[command].default == self.version_print: + if self[command].default == default: del self[command] else: del self[command] @@ -324,14 +327,18 @@ def version_flags(self): def version_flags(self, value): self._version_flags = value self._delete_commands(self._version_flags, default=self.version_print) - if self._version_flags: + func = getattr(self.default_command, "__func__", None) + if self._version_flags and func != type(self).version_print: + assert isinstance(self._version_flags, tuple) self.command( - self.version_print, - name=self._version_flags, - help_flags=[], - version_flags=[], - version=self.version, - help="Display application version.", + App( + name=self._version_flags, + default_command=self.version_print, + help_flags=self.help_flags, + version_flags=self.version_flags, + version=self.version, + help="Display application version.", + ) ) @property @@ -342,14 +349,18 @@ def help_flags(self): def help_flags(self, value): self._help_flags = value self._delete_commands(self._help_flags, default=self.help_print) - if self._help_flags: + func = getattr(self.default_command, "__func__", None) + if self._help_flags and func != type(self).help_print: + assert isinstance(self._help_flags, tuple) self.command( - self.help_print, - name=self._help_flags, - help_flags=[], - version_flags=[], - version=self.version, - help="Display this message and exit.", + App( + name=self._help_flags, + default_command=self.help_print, + help_flags=self.help_flags, + version_flags=self.version_flags, + version=self.version, + help="Display this message and exit.", + ) ) @property @@ -460,8 +471,8 @@ def __iter__(self) -> Iterator[str]: def meta(self) -> "App": if self._meta is None: self._meta = type(self)( - help_flags=copy(self.help_flags), - version_flags=copy(self.version_flags), + help_flags=self.help_flags, + version_flags=self.version_flags, group_commands=copy(self.group_commands), group_arguments=copy(self.group_arguments), group_parameters=copy(self.group_parameters), @@ -571,8 +582,10 @@ def command( raise ValueError("Cannot supplied additional configuration when registering a sub-App.") else: validate_command(obj) - kwargs.setdefault("help_flags", []) - kwargs.setdefault("version_flags", []) + + kwargs.setdefault("help_flags", self.help_flags) + kwargs.setdefault("version_flags", self.version_flags) + if "group_commands" not in kwargs: kwargs["group_commands"] = copy(self.group_commands) if "group_parameters" not in kwargs: @@ -580,7 +593,9 @@ def command( if "group_arguments" not in kwargs: kwargs["group_arguments"] = copy(self.group_arguments) app = App(default_command=obj, **kwargs) # pyright: ignore - # app.name is handled below + + for flag in chain(kwargs["help_flags"], kwargs["version_flags"]): # pyright: ignore + app[flag].show = False if app._name_transform is None: app.name_transform = self.name_transform @@ -704,7 +719,7 @@ def parse_known_args( # Special flags (help/version) get intercepted by the root app. # Special flags are allows to be **anywhere** in the token stream. - for help_flag in self.help_flags: + for help_flag in command_app.help_flags: try: help_flag_index = tokens.index(help_flag) break @@ -720,7 +735,7 @@ def parse_known_args( command = meta_parent.help_print bound = cyclopts.utils.signature(command).bind(tokens, console=console) unused_tokens = [] - elif any(flag in tokens for flag in self.version_flags): + elif any(flag in tokens for flag in command_app.version_flags): # Version command = self.version_print while meta_parent := meta_parent._meta_parent: diff --git a/cyclopts/help.py b/cyclopts/help.py index b88198dc..eccdc44b 100644 --- a/cyclopts/help.py +++ b/cyclopts/help.py @@ -183,7 +183,7 @@ def format_usage( for command in command_chain: app = app[command] - if app._commands: + if any(x.show for x in app._commands.values()): usage.append("COMMAND") if app.default_command: diff --git a/tests/apps/test_burgery.py b/tests/apps/test_burgery.py index 5bf91516..083ceb40 100644 --- a/tests/apps/test_burgery.py +++ b/tests/apps/test_burgery.py @@ -143,4 +143,4 @@ def test_create_burger_3(): if __name__ == "__main__": - test_create_burger_1() + app() diff --git a/tests/test_version_parameter.py b/tests/test_version_parameter.py new file mode 100644 index 00000000..5aa4a558 --- /dev/null +++ b/tests/test_version_parameter.py @@ -0,0 +1,33 @@ +"""From issue #219.""" + +import pytest + + +@pytest.mark.parametrize( + "cmd", + [ + "foo --version 1.2.3", + "foo --version=1.2.3", + ], +) +def test_version_subapp_version_parameter(app, assert_parse_args, cmd): + @app.command(version_flags=[]) + def foo(version: str): + pass + + assert_parse_args(foo, cmd, version="1.2.3") + + +@pytest.mark.parametrize( + "cmd", + [ + "foo --help 1.2.3", + "foo --help=1.2.3", + ], +) +def test_version_subapp_help_parameter(app, assert_parse_args, cmd): + @app.command(help_flags=[]) + def foo(help: str): + pass + + assert_parse_args(foo, cmd, help="1.2.3") From 8035bccf4f850c4b408c86dadcd0036a0f554abe Mon Sep 17 00:00:00 2001 From: Brian Pugh Date: Tue, 27 Aug 2024 09:24:14 -0400 Subject: [PATCH 2/2] simpler --- cyclopts/core.py | 48 ++++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/cyclopts/core.py b/cyclopts/core.py index ec43387b..b74c9179 100644 --- a/cyclopts/core.py +++ b/cyclopts/core.py @@ -292,10 +292,8 @@ class App: def __attrs_post_init__(self): # Trigger the setters - func = getattr(self.default_command, "__func__", None) - if func != type(self).version_print and func != type(self).help_print: - self.help_flags = self._help_flags - self.version_flags = self._version_flags + self.help_flags = self._help_flags + self.version_flags = self._version_flags ########### # Methods # @@ -327,18 +325,14 @@ def version_flags(self): def version_flags(self, value): self._version_flags = value self._delete_commands(self._version_flags, default=self.version_print) - func = getattr(self.default_command, "__func__", None) - if self._version_flags and func != type(self).version_print: - assert isinstance(self._version_flags, tuple) + if self._version_flags: self.command( - App( - name=self._version_flags, - default_command=self.version_print, - help_flags=self.help_flags, - version_flags=self.version_flags, - version=self.version, - help="Display application version.", - ) + self.version_print, + name=self._version_flags, + help_flags=[], + version_flags=[], + version=self.version, + help="Display application version.", ) @property @@ -349,18 +343,14 @@ def help_flags(self): def help_flags(self, value): self._help_flags = value self._delete_commands(self._help_flags, default=self.help_print) - func = getattr(self.default_command, "__func__", None) - if self._help_flags and func != type(self).help_print: - assert isinstance(self._help_flags, tuple) + if self._help_flags: self.command( - App( - name=self._help_flags, - default_command=self.help_print, - help_flags=self.help_flags, - version_flags=self.version_flags, - version=self.version, - help="Display this message and exit.", - ) + self.help_print, + name=self._help_flags, + help_flags=[], + version_flags=[], + version=self.version, + help="Display this message and exit.", ) @property @@ -704,6 +694,12 @@ def parse_known_args( command_chain, apps, unused_tokens = self.parse_commands(tokens) command_app = apps[-1] + # We don't want the command_app to be the version/help handler. + with suppress(IndexError): + if set(command_app.name) & set(apps[-2].help_flags + apps[-2].version_flags): # pyright: ignore + apps = apps[:-1] + command_app = apps[-1] + try: parent_app = apps[-2] except IndexError: