diff --git a/qubesadmin/tools/qvm_check.py b/qubesadmin/tools/qvm_check.py index bf7fc9da..3c56671e 100644 --- a/qubesadmin/tools/qvm_check.py +++ b/qubesadmin/tools/qvm_check.py @@ -27,26 +27,49 @@ import qubesadmin.tools import qubesadmin.vm -class QvmCheckArgumentParser(qubesadmin.tools.QubesArgumentParser): - """Collecting error message(s) on invalid domain(s) instead of aborting""" - def __init__(self, description, vmname_nargs): - super().__init__(description=description, vmname_nargs=vmname_nargs) - self._invalid_domains = [] - - def error(self, message): - if message.startswith('no such domain: '): - self._invalid_domains.append(message[17:-1]) +class QvmCheckVmNameAction(qubesadmin.tools.VmNameAction): + """ Action for parsing one or multiple valid/invalid VMNAMEs """ + + def __init__(self, option_strings, nargs=1, dest='vmnames', help=None, + **kwargs): + # pylint: disable=redefined-builtin + super().__init__(option_strings, dest=dest, help=help, + nargs=nargs, **kwargs) + + def parse_qubes_app(self, parent_parser, namespace): + # pylint: disable=arguments-renamed + assert hasattr(namespace, 'app') + setattr(namespace, 'domains', []) + setattr(namespace, 'invalid_domains', []) + app = namespace.app + if hasattr(namespace, 'all_domains') and namespace.all_domains: + namespace.domains = [ + vm + for vm in app.domains + if not vm.klass == 'AdminVM' and + vm.name not in namespace.exclude + ] else: - super().error(message) + if hasattr(namespace, 'exclude') and namespace.exclude: + parent_parser.error('--exclude can only be used with --all') + + for vm_name in getattr(namespace, self.dest): + try: + namespace.domains += [app.domains[vm_name]] + except KeyError: + namespace.invalid_domains += [vm_name] - def parse_args(self, *args, **kwargs): - parse_args = super().parse_args(*args, **kwargs) - self._invalid_domains.sort() - parse_args.invalid_domains = self._invalid_domains - return parse_args + +class QvmCheckArgumentParser(qubesadmin.tools.QubesArgumentParser): + """ Extended argument parser for qvm-check to collect invalid domains """ + def __init__(self, description): + super().__init__(description=description, vmname_nargs=None) + vm_name_group = qubesadmin.tools.VmNameGroup( + self, required='+', vm_action=QvmCheckVmNameAction) + self._mutually_exclusive_groups.append(vm_name_group) -parser = QvmCheckArgumentParser(description=__doc__, vmname_nargs='+') +parser = QvmCheckArgumentParser(description=__doc__) parser.add_argument("--running", action="store_true", dest="running", default=False, help="Determine if (any of given) VM is running") @@ -92,6 +115,7 @@ def main(args=None, app=None): """Main function of qvm-check tool""" args = parser.parse_args(args, app=app) domains = args.domains + invalid_domains = set(args.invalid_domains) return_code = 0 log = args.app.log @@ -121,9 +145,9 @@ def main(args=None, app=None): elif args.verbose: print_msg(log, domains, ["exists"]) - if args.invalid_domains: + if invalid_domains: if args.verbose: - for vm in args.invalid_domains: + for vm in invalid_domains: log.warning("{!s}: {!s}".format(vm, 'non-existent!')) return_code = 1