diff --git a/src/azure-cli-core/azure/cli/core/__init__.py b/src/azure-cli-core/azure/cli/core/__init__.py index 2d750bab8fd..0cdb575ed53 100644 --- a/src/azure-cli-core/azure/cli/core/__init__.py +++ b/src/azure-cli-core/azure/cli/core/__init__.py @@ -30,6 +30,8 @@ EXCLUDED_PARAMS = ['self', 'raw', 'polling', 'custom_headers', 'operation_config', 'content_version', 'kwargs', 'client', 'no_wait'] EVENT_FAILED_EXTENSION_LOAD = 'MainLoader.OnFailedExtensionLoad' +# Extensions that will always be loaded if installed. These extensions don't expose commands but hook into CLI core. +ALWAYS_LOADED_EXTENSION_MODNAMES = ['azext_ai_examples', 'azext_ai_did_you_mean_this'] class AzCli(CLI): @@ -42,7 +44,7 @@ def __init__(self, **kwargs): register_ids_argument, register_global_subscription_argument) from azure.cli.core.cloud import get_active_cloud from azure.cli.core.commands.transform import register_global_transforms - from azure.cli.core._session import ACCOUNT, CONFIG, SESSION + from azure.cli.core._session import ACCOUNT, CONFIG, SESSION, INDEX from knack.util import ensure_dir @@ -57,6 +59,8 @@ def __init__(self, **kwargs): ACCOUNT.load(os.path.join(azure_folder, 'azureProfile.json')) CONFIG.load(os.path.join(azure_folder, 'az.json')) SESSION.load(os.path.join(azure_folder, 'az.sess'), max_age=3600) + INDEX.load(os.path.join(azure_folder, 'commandIndex.json')) + self.cloud = get_active_cloud(self) logger.debug('Current cloud config:\n%s', str(self.cloud.name)) self.local_context = AzCLILocalContext(self) @@ -148,6 +152,12 @@ def save_local_context(self, parsed_args, argument_definitions, specified_argume class MainCommandsLoader(CLICommandsLoader): + # Format string for pretty-print the command module table + header_mod = "%-20s %10s %9s %9s" % ("Extension", "Load Time", "Groups", "Commands") + item_format_string = "%-20s %10.3f %9d %9d" + header_ext = header_mod + " Directory" + item_ext_format_string = item_format_string + " %s" + def __init__(self, cli_ctx=None): super(MainCommandsLoader, self).__init__(cli_ctx) self.cmd_to_loader_map = {} @@ -160,33 +170,41 @@ def _update_command_definitions(self): loader.command_table = self.command_table loader._update_command_definitions() # pylint: disable=protected-access - # pylint: disable=too-many-statements + # pylint: disable=too-many-statements, too-many-locals def load_command_table(self, args): from importlib import import_module import pkgutil import traceback from azure.cli.core.commands import ( - _load_module_command_loader, _load_extension_command_loader, BLACKLISTED_MODS, ExtensionCommandSource) + _load_module_command_loader, _load_extension_command_loader, BLOCKED_MODS, ExtensionCommandSource) from azure.cli.core.extension import ( get_extensions, get_extension_path, get_extension_modname) - def _update_command_table_from_modules(args): + def _update_command_table_from_modules(args, command_modules=None): '''Loads command table(s) When `module_name` is specified, only commands from that module will be loaded. If the module is not found, all commands are loaded. ''' - installed_command_modules = [] - try: - mods_ns_pkg = import_module('azure.cli.command_modules') - installed_command_modules = [modname for _, modname, _ in - pkgutil.iter_modules(mods_ns_pkg.__path__) - if modname not in BLACKLISTED_MODS] - except ImportError as e: - logger.warning(e) - - logger.debug('Installed command modules %s', installed_command_modules) + + if not command_modules: + # Perform module discovery + command_modules = [] + try: + mods_ns_pkg = import_module('azure.cli.command_modules') + command_modules = [modname for _, modname, _ in + pkgutil.iter_modules(mods_ns_pkg.__path__)] + logger.debug('Discovered command modules: %s', command_modules) + except ImportError as e: + logger.warning(e) + + count = 0 cumulative_elapsed_time = 0 - for mod in [m for m in installed_command_modules if m not in BLACKLISTED_MODS]: + cumulative_group_count = 0 + cumulative_command_count = 0 + logger.debug("Loading command modules:") + logger.debug(self.header_mod) + + for mod in [m for m in command_modules if m not in BLOCKED_MODS]: try: start_time = timeit.default_timer() module_command_table, module_group_table = _load_module_command_loader(self, args, mod) @@ -194,9 +212,14 @@ def _update_command_table_from_modules(args): cmd.command_source = mod self.command_table.update(module_command_table) self.command_group_table.update(module_group_table) + elapsed_time = timeit.default_timer() - start_time - logger.debug("Loaded module '%s' in %.3f seconds.", mod, elapsed_time) + logger.debug(self.item_format_string, mod, elapsed_time, + len(module_group_table), len(module_command_table)) + count += 1 cumulative_elapsed_time += elapsed_time + cumulative_group_count += len(module_group_table) + cumulative_command_count += len(module_command_table) except Exception as ex: # pylint: disable=broad-except # Changing this error message requires updating CI script that checks for failed # module loading. @@ -205,11 +228,12 @@ def _update_command_table_from_modules(args): telemetry.set_exception(exception=ex, fault_type='module-load-error-' + mod, summary='Error loading module: {}'.format(mod)) logger.debug(traceback.format_exc()) - logger.debug("Loaded all modules in %.3f seconds. " - "(note: there's always an overhead with the first module loaded)", - cumulative_elapsed_time) + # Summary line + logger.debug(self.item_format_string, + "Total ({})".format(count), cumulative_elapsed_time, + cumulative_group_count, cumulative_command_count) - def _update_command_table_from_extensions(ext_suppressions): + def _update_command_table_from_extensions(ext_suppressions, extension_modname=None): from azure.cli.core.extension.operations import check_version_compatibility @@ -224,11 +248,33 @@ def _handle_extension_suppressions(extensions): filtered_extensions.append(ext) return filtered_extensions + def _filter_modname(extensions): + # Extension's name may not be the same as its modname. eg. name: virtual-wan, modname: azext_vwan + filtered_extensions = [] + extension_modname.extend(ALWAYS_LOADED_EXTENSION_MODNAMES) + for ext in extensions: + ext_name = ext.name + ext_dir = ext.path or get_extension_path(ext.name) + ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir) + # Filter the extensions according to the index + if ext_mod in extension_modname: + filtered_extensions.append(ext) + return filtered_extensions + extensions = get_extensions() if extensions: - logger.debug("Found %s extensions: %s", len(extensions), [e.name for e in extensions]) + if extension_modname: + extensions = _filter_modname(extensions) allowed_extensions = _handle_extension_suppressions(extensions) module_commands = set(self.command_table.keys()) + + count = 0 + cumulative_elapsed_time = 0 + cumulative_group_count = 0 + cumulative_command_count = 0 + logger.debug("Loading extensions:") + logger.debug(self.header_ext) + for ext in allowed_extensions: try: check_version_compatibility(ext.get_metadata()) @@ -238,7 +284,6 @@ def _handle_extension_suppressions(extensions): continue ext_name = ext.name ext_dir = ext.path or get_extension_path(ext_name) - logger.debug("Extensions directory: '%s'", ext_dir) sys.path.append(ext_dir) try: ext_mod = get_extension_modname(ext_name, ext_dir=ext_dir) @@ -258,13 +303,24 @@ def _handle_extension_suppressions(extensions): self.command_table.update(extension_command_table) self.command_group_table.update(extension_group_table) + elapsed_time = timeit.default_timer() - start_time - logger.debug("Loaded extension '%s' in %.3f seconds.", ext_name, elapsed_time) + logger.debug(self.item_ext_format_string, ext_name, elapsed_time, + len(extension_group_table), len(extension_command_table), + ext_dir) + count += 1 + cumulative_elapsed_time += elapsed_time + cumulative_group_count += len(extension_group_table) + cumulative_command_count += len(extension_command_table) except Exception as ex: # pylint: disable=broad-except self.cli_ctx.raise_event(EVENT_FAILED_EXTENSION_LOAD, extension_name=ext_name) logger.warning("Unable to load extension '%s: %s'. Use --debug for more information.", ext_name, ex) logger.debug(traceback.format_exc()) + # Summary line + logger.debug(self.item_ext_format_string, + "Total ({})".format(count), cumulative_elapsed_time, + cumulative_group_count, cumulative_command_count, "") def _wrap_suppress_extension_func(func, ext): """ Wrapper method to handle centralization of log messages for extension filters """ @@ -295,15 +351,60 @@ def _get_extension_suppressions(mod_loaders): res.append(sup) return res + def _roughly_parse_command(args): + # Roughly parse the command part: --name vm1 + # Similar to knack.invocation.CommandInvoker._rudimentary_get_command, but we don't need to bother with + # positional args + nouns = [] + for arg in args: + if arg and arg[0] != '-': + nouns.append(arg) + else: + break + return ' '.join(nouns).lower() + + # Clear the tables to make this method idempotent + self.command_group_table.clear() + self.command_table.clear() + + command_index = None + # Set fallback=False to turn off command index in case of regression + use_command_index = self.cli_ctx.config.getboolean('core', 'use_command_index', fallback=True) + if use_command_index: + command_index = CommandIndex(self.cli_ctx) + index_result = command_index.get(args) + if index_result: + index_modules, index_extensions = index_result + if index_modules: + _update_command_table_from_modules(args, index_modules) + if index_extensions: + # The index won't contain suppressed extensions + _update_command_table_from_extensions([], index_extensions) + + logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table)) + # The index may be outdated. Make sure the command appears in the loaded command table + command_str = _roughly_parse_command(args) + if command_str in self.command_table or command_str in self.command_group_table: + logger.debug("Found a match in the command table for '%s'", command_str) + return self.command_table + + logger.debug("Could not find a match in the command table for '%s'. The index may be outdated", + command_str) + else: + logger.debug("No module found from index for '%s'", args) + + # No module found from the index. Load all command modules and extensions + logger.debug("Loading all modules and extensions") _update_command_table_from_modules(args) - try: - ext_suppressions = _get_extension_suppressions(self.loaders) - # We always load extensions even if the appropriate module has been loaded - # as an extension could override the commands already loaded. - _update_command_table_from_extensions(ext_suppressions) - except Exception: # pylint: disable=broad-except - logger.warning("Unable to load extensions. Use --debug for more information.") - logger.debug(traceback.format_exc()) + + ext_suppressions = _get_extension_suppressions(self.loaders) + # We always load extensions even if the appropriate module has been loaded + # as an extension could override the commands already loaded. + _update_command_table_from_extensions(ext_suppressions) + logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table)) + + if use_command_index: + command_index.update(self.command_table) return self.command_table @@ -348,6 +449,111 @@ def load_arguments(self, command=None): loader._update_command_definitions() # pylint: disable=protected-access +class CommandIndex: + + _COMMAND_INDEX = 'commandIndex' + _COMMAND_INDEX_VERSION = 'version' + _COMMAND_INDEX_CLOUD_PROFILE = 'cloudProfile' + + def __init__(self, cli_ctx=None): + """Class to manage command index. + + :param cli_ctx: Only needed when `get` or `update` is called. + """ + from azure.cli.core._session import INDEX + self.INDEX = INDEX + if cli_ctx: + self.version = __version__ + self.cloud_profile = cli_ctx.cloud.profile + + def get(self, args): + """Get the corresponding module and extension list of a command. + + :param args: command arguments, like ['network', 'vnet', 'create', '-h'] + :return: a tuple containing a list of modules and a list of extensions. + """ + # If the command index version or cloud profile doesn't match those of the current command, + # invalidate the command index. + index_version = self.INDEX[self._COMMAND_INDEX_VERSION] + cloud_profile = self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE] + if not (index_version and index_version == self.version and + cloud_profile and cloud_profile == self.cloud_profile): + logger.debug("Command index version or cloud profile is invalid or doesn't match the current command.") + self.invalidate() + return None + + # Make sure the top-level command is provided, like `az version`. + # Skip command index for `az` or `az --help`. + if not args or args[0].startswith('-'): + return None + + # Get the top-level command, like `network` in `network vnet create -h` + top_command = args[0] + index = self.INDEX[self._COMMAND_INDEX] + # Check the command index for (command: [module]) mapping, like + # "network": ["azure.cli.command_modules.natgateway", "azure.cli.command_modules.network", "azext_firewall"] + index_modules_extensions = index.get(top_command) + + if index_modules_extensions: + # This list contains both built-in modules and extensions + index_builtin_modules = [] + index_extensions = [] + # Found modules from index + logger.debug("Modules found from index for '%s': %s", top_command, index_modules_extensions) + command_module_prefix = 'azure.cli.command_modules.' + for m in index_modules_extensions: + if m.startswith(command_module_prefix): + # The top-level command is from a command module + index_builtin_modules.append(m[len(command_module_prefix):]) + elif m.startswith('azext_'): + # The top-level command is from an extension + index_extensions.append(m) + else: + logger.warning("Unrecognized module: %s", m) + return index_builtin_modules, index_extensions + + return None + + def update(self, command_table): + """Update the command index according to the given command table. + + :param command_table: The command table built by azure.cli.core.MainCommandsLoader.load_command_table + """ + start_time = timeit.default_timer() + self.INDEX[self._COMMAND_INDEX_VERSION] = __version__ + self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE] = self.cloud_profile + from collections import defaultdict + index = defaultdict(list) + + # self.cli_ctx.invocation.commands_loader.command_table doesn't exist in DummyCli due to the lack of invocation + for command_name, command in command_table.items(): + # Get the top-level name: create + top_command = command_name.split()[0] + # Get module name, like azure.cli.command_modules.vm, azext_webapp + module_name = command.loader.__module__ + if module_name not in index[top_command]: + index[top_command].append(module_name) + elapsed_time = timeit.default_timer() - start_time + self.INDEX[self._COMMAND_INDEX] = index + logger.debug("Updated command index in %.3f seconds.", elapsed_time) + + def invalidate(self): + """Invalidate the command index. + + This function MUST be called when installing or updating extensions. Otherwise, when an extension + 1. overrides a built-in command, or + 2. extends an existing command group, + the command or command group will only be loaded from the command modules as per the stale command index, + making the newly installed extension be ignored. + + This function can be called when removing extensions. + """ + self.INDEX[self._COMMAND_INDEX_VERSION] = "" + self.INDEX[self._COMMAND_INDEX_CLOUD_PROFILE] = "" + self.INDEX[self._COMMAND_INDEX] = {} + logger.debug("Command index has been invalidated.") + + class ModExtensionSuppress(object): # pylint: disable=too-few-public-methods def __init__(self, mod_name, suppress_extension_name, suppress_up_to_version, reason=None, recommend_remove=False, diff --git a/src/azure-cli-core/azure/cli/core/_session.py b/src/azure-cli-core/azure/cli/core/_session.py index afd406cbd9d..0df29ceb768 100644 --- a/src/azure-cli-core/azure/cli/core/_session.py +++ b/src/azure-cli-core/azure/cli/core/_session.py @@ -105,6 +105,9 @@ def __len__(self): # SESSION provides read-write session variables SESSION = Session() +# INDEX contains {top-level command: [command_modules and extensions]} mapping index +INDEX = Session() + # VERSIONS provides local versions and pypi versions. # DO NOT USE it to get the current version of azure-cli, # it could be lagged behind and can be used to check whether diff --git a/src/azure-cli-core/azure/cli/core/commands/__init__.py b/src/azure-cli-core/azure/cli/core/commands/__init__.py index e7def95839a..9286637258f 100644 --- a/src/azure-cli-core/azure/cli/core/commands/__init__.py +++ b/src/azure-cli-core/azure/cli/core/commands/__init__.py @@ -21,7 +21,7 @@ # pylint: disable=unused-import from azure.cli.core.commands.constants import ( - BLACKLISTED_MODS, DEFAULT_QUERY_TIME_RANGE, CLI_COMMON_KWARGS, CLI_COMMAND_KWARGS, CLI_PARAM_KWARGS, + BLOCKED_MODS, DEFAULT_QUERY_TIME_RANGE, CLI_COMMON_KWARGS, CLI_COMMAND_KWARGS, CLI_PARAM_KWARGS, CLI_POSITIONAL_PARAM_KWARGS, CONFIRM_PARAM_NAME) from azure.cli.core.commands.parameters import ( AzArgumentContext, patch_arg_make_required, patch_arg_make_optional) diff --git a/src/azure-cli-core/azure/cli/core/commands/constants.py b/src/azure-cli-core/azure/cli/core/commands/constants.py index 7284e793524..1886a0c884b 100644 --- a/src/azure-cli-core/azure/cli/core/commands/constants.py +++ b/src/azure-cli-core/azure/cli/core/commands/constants.py @@ -30,7 +30,7 @@ # 1 hour in milliseconds DEFAULT_QUERY_TIME_RANGE = 3600000 -BLACKLISTED_MODS = ['context', 'shell', 'documentdb', 'component'] +BLOCKED_MODS = ['context', 'shell', 'documentdb', 'component'] SURVEY_PROMPT = 'Please let us know how we are doing: https://aka.ms/azureclihats' SURVEY_PROMPT_COLOR = Fore.YELLOW + Style.BRIGHT + 'Please let us know how we are doing: ' + Fore.BLUE + \ diff --git a/src/azure-cli-core/azure/cli/core/extension/operations.py b/src/azure-cli-core/azure/cli/core/extension/operations.py index 53612cdd632..8a929e24d63 100644 --- a/src/azure-cli-core/azure/cli/core/extension/operations.py +++ b/src/azure-cli-core/azure/cli/core/extension/operations.py @@ -18,6 +18,7 @@ import requests from pkg_resources import parse_version +from azure.cli.core import CommandIndex from azure.cli.core.util import CLIError, reload_module from azure.cli.core.extension import (extension_exists, build_extension_path, get_extensions, get_extension_modname, get_extension, ext_compat_with_cli, @@ -188,8 +189,8 @@ def _augment_telemetry_with_ext_info(extension_name, ext=None): def check_version_compatibility(azext_metadata): is_compatible, cli_core_version, min_required, max_required = ext_compat_with_cli(azext_metadata) - logger.debug("Extension compatibility result: is_compatible=%s cli_core_version=%s min_required=%s " - "max_required=%s", is_compatible, cli_core_version, min_required, max_required) + # logger.debug("Extension compatibility result: is_compatible=%s cli_core_version=%s min_required=%s " + # "max_required=%s", is_compatible, cli_core_version, min_required, max_required) if not is_compatible: min_max_msg_fmt = "The '{}' extension is not compatible with this version of the CLI.\n" \ "You have CLI core version {} and this extension " \ @@ -244,6 +245,7 @@ def add_extension(cmd, source=None, extension_name=None, index_url=None, yes=Non "Please use with discretion.", extension_name) elif extension_name and ext.preview: logger.warning("The installed extension '%s' is in preview.", extension_name) + CommandIndex().invalidate() except ExtensionNotInstalledException: pass @@ -263,6 +265,7 @@ def log_err(func, path, exc_info): # We call this just before we remove the extension so we can get the metadata before it is gone _augment_telemetry_with_ext_info(extension_name, ext) shutil.rmtree(ext.path, onerror=log_err) + CommandIndex().invalidate() except ExtensionNotInstalledException as e: raise CLIError(e) @@ -315,6 +318,7 @@ def update_extension(cmd, extension_name, index_url=None, pip_extra_index_urls=N logger.debug('Copying %s to %s', backup_dir, extension_path) shutil.copytree(backup_dir, extension_path) raise CLIError('Failed to update. Rolled {} back to {}.'.format(extension_name, cur_version)) + CommandIndex().invalidate() except ExtensionNotInstalledException as e: raise CLIError(e) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py b/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py index bc254591cb4..b7cb1a8a650 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py @@ -146,10 +146,14 @@ def _mock_import_lib(_): return mock_obj def _mock_iter_modules(_): - return [(None, __name__, None)] + return [(None, "hello", None), + (None, "extra", None)] def _mock_extension_modname(ext_name, ext_dir): - return ext_name + if ext_name.endswith('.ExtCommandsLoader'): + return "azext_hello1" + if ext_name.endswith('.Ext2CommandsLoader'): + return "azext_hello2" def _mock_get_extensions(): MockExtension = namedtuple('Extension', ['name', 'preview', 'experimental', 'path', 'get_metadata']) @@ -163,7 +167,18 @@ class TestCommandsLoader(AzCommandsLoader): def load_command_table(self, args): super(TestCommandsLoader, self).load_command_table(args) with self.command_group('hello', operations_tmpl='{}#TestCommandRegistration.{{}}'.format(__name__)) as g: - g.command('world', 'sample_vm_get') + g.command('mod-only', 'sample_vm_get') + g.command('overridden', 'sample_vm_get') + self.__module__ = "azure.cli.command_modules.hello" + return self.command_table + + class Test2CommandsLoader(AzCommandsLoader): + # An extra group that is not loaded if a module is found from the index + def load_command_table(self, args): + super(Test2CommandsLoader, self).load_command_table(args) + with self.command_group('extra', operations_tmpl='{}#TestCommandRegistration.{{}}'.format(__name__)) as g: + g.command('unused', 'sample_vm_get') + self.__module__ = "azure.cli.command_modules.extra" return self.command_table # A command from an extension @@ -172,7 +187,8 @@ class ExtCommandsLoader(AzCommandsLoader): def load_command_table(self, args): super(ExtCommandsLoader, self).load_command_table(args) with self.command_group('hello', operations_tmpl='{}#TestCommandRegistration.{{}}'.format(__name__)) as g: - g.command('noodle', 'sample_vm_get') + g.command('ext-only', 'sample_vm_get') + self.__module__ = "azext_hello1" return self.command_table # A command from an extension that overrides the original command @@ -181,22 +197,26 @@ class Ext2CommandsLoader(AzCommandsLoader): def load_command_table(self, args): super(Ext2CommandsLoader, self).load_command_table(args) with self.command_group('hello', operations_tmpl='{}#TestCommandRegistration.{{}}'.format(__name__)) as g: - g.command('world', 'sample_vm_get') + g.command('overridden', 'sample_vm_get') + self.__module__ = "azext_hello2" return self.command_table if prefix == 'azure.cli.command_modules.': - command_loaders = {'TestCommandsLoader': TestCommandsLoader} + command_loaders = {'hello': TestCommandsLoader, 'extra': Test2CommandsLoader} else: - command_loaders = {'ExtCommandsLoader': ExtCommandsLoader, 'Ext2CommandsLoader': Ext2CommandsLoader} + command_loaders = {'azext_hello1': ExtCommandsLoader, 'azext_hello2': Ext2CommandsLoader} module_command_table = {} - for _, loader_cls in command_loaders.items(): + for mod_name, loader_cls in command_loaders.items(): + # If name is provided, only load the named module + if name and name != mod_name: + continue command_loader = loader_cls(cli_ctx=loader.cli_ctx) command_table = command_loader.load_command_table(args) if command_table: module_command_table.update(command_table) loader.loaders.append(command_loader) # this will be used later by the load_arguments method - return module_command_table, {} + return module_command_table, command_loader.command_group_table @mock.patch('importlib.import_module', _mock_import_lib) @mock.patch('pkgutil.iter_modules', _mock_iter_modules) @@ -205,20 +225,157 @@ def load_command_table(self, args): @mock.patch('azure.cli.core.extension.get_extensions', _mock_get_extensions) def test_register_command_from_extension(self): - from azure.cli.core.commands import _load_command_loader cli = DummyCli() - main_loader = MainCommandsLoader(cli) - cli.loader = main_loader + loader = cli.commands_loader + + cmd_tbl = loader.load_command_table(None) + hello_mod_only_cmd = cmd_tbl['hello mod-only'] + hello_ext_only_cmd = cmd_tbl['hello ext-only'] + hello_overridden_cmd = cmd_tbl['hello overridden'] - cmd_tbl = cli.loader.load_command_table(None) - ext1 = cmd_tbl['hello noodle'] - ext2 = cmd_tbl['hello world'] + self.assertEqual(hello_mod_only_cmd.command_source, 'hello') + self.assertEqual(hello_mod_only_cmd.loader.__module__, 'azure.cli.command_modules.hello') - self.assertTrue(isinstance(ext1.command_source, ExtensionCommandSource)) - self.assertFalse(ext1.command_source.overrides_command) + self.assertTrue(isinstance(hello_ext_only_cmd.command_source, ExtensionCommandSource)) + self.assertFalse(hello_ext_only_cmd.command_source.overrides_command) - self.assertTrue(isinstance(ext2.command_source, ExtensionCommandSource)) - self.assertTrue(ext2.command_source.overrides_command) + self.assertTrue(isinstance(hello_overridden_cmd.command_source, ExtensionCommandSource)) + self.assertTrue(hello_overridden_cmd.command_source.overrides_command) + + @mock.patch.dict("os.environ", {"AZURE_CORE_USE_COMMAND_INDEX": "True"}) + @mock.patch('importlib.import_module', _mock_import_lib) + @mock.patch('pkgutil.iter_modules', _mock_iter_modules) + @mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader) + @mock.patch('azure.cli.core.extension.get_extension_modname', _mock_extension_modname) + @mock.patch('azure.cli.core.extension.get_extensions', _mock_get_extensions) + def test_command_index(self): + + from azure.cli.core._session import INDEX + from azure.cli.core import CommandIndex, __version__ + + cli = DummyCli() + loader = cli.commands_loader + command_index = CommandIndex(cli) + + expected_command_index = {'hello': ['azure.cli.command_modules.hello', 'azext_hello2', 'azext_hello1'], + 'extra': ['azure.cli.command_modules.extra']} + expected_command_table = ['hello mod-only', 'hello overridden', 'extra unused', 'hello ext-only'] + + def _set_index(dict_): + INDEX[CommandIndex._COMMAND_INDEX] = dict_ + + def _check_index(): + self.assertEqual(INDEX[CommandIndex._COMMAND_INDEX_VERSION], __version__) + self.assertEqual(INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE], cli.cloud.profile) + self.assertDictEqual(INDEX[CommandIndex._COMMAND_INDEX], expected_command_index) + + # Clear the command index + _set_index({}) + self.assertFalse(INDEX[CommandIndex._COMMAND_INDEX]) + loader.load_command_table(None) + # Test command index is built for None args + _check_index() + + # Test command index is built when `args` is provided + _set_index({}) + loader.load_command_table(["hello", "mod-only"]) + _check_index() + + # Test rebuild command index if no module found + _set_index({"network": ["azure.cli.command_modules.network"]}) + loader.load_command_table(["hello", "mod-only"]) + _check_index() + + with mock.patch("azure.cli.core.__version__", "2.7.0"), mock.patch.object(cli.cloud, "profile", "2019-03-01-hybrid"): + def update_and_check_index(): + loader.load_command_table(["hello", "mod-only"]) + self.assertEqual(INDEX[CommandIndex._COMMAND_INDEX_VERSION], "2.7.0") + self.assertEqual(INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE], "2019-03-01-hybrid") + self.assertDictEqual(INDEX[CommandIndex._COMMAND_INDEX], expected_command_index) + + # Test rebuild command index if version is not present + del INDEX[CommandIndex._COMMAND_INDEX_VERSION] + del INDEX[CommandIndex._COMMAND_INDEX] + update_and_check_index() + + # Test rebuild command index if version is not valid + INDEX[CommandIndex._COMMAND_INDEX_VERSION] = "" + _set_index({}) + update_and_check_index() + + # Test rebuild command index if version is outdated + INDEX[CommandIndex._COMMAND_INDEX_VERSION] = "2.6.0" + _set_index({}) + update_and_check_index() + + # Test rebuild command index if profile is outdated + INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE] = "2017-03-09-profile" + _set_index({}) + update_and_check_index() + + # Test rebuild command index if modules are found but outdated + # This only happens in dev environment. For users, the version check logic prevents it + _set_index({"hello": ["azure.cli.command_modules.extra"]}) + loader.load_command_table(["hello", "mod-only"]) + _check_index() + + # Test irrelevant commands are not loaded + _set_index(expected_command_index) + cmd_tbl = loader.load_command_table(["hello", "mod-only"]) + self.assertEqual(['hello mod-only', 'hello overridden', 'hello ext-only'], list(cmd_tbl.keys())) + + # Full scenario test 1: Installing an extension 'azext_hello1' that extends 'hello' group + outdated_command_index = {'hello': ['azure.cli.command_modules.hello'], + 'extra': ['azure.cli.command_modules.extra']} + _set_index(outdated_command_index) + + # Command for an outdated group + cmd_tbl = loader.load_command_table(["hello", "-h"]) + # Index is not updated, and only built-in commands are loaded + _set_index(outdated_command_index) + self.assertListEqual(list(cmd_tbl), ['hello mod-only', 'hello overridden']) + + # Command index is explicitly invalidated by azure.cli.core.extension.operations.add_extension + command_index.invalidate() + + cmd_tbl = loader.load_command_table(["hello", "-h"]) + # Index is updated, and new commands are loaded + _check_index() + self.assertListEqual(list(cmd_tbl), expected_command_table) + + # Full scenario test 2: Installing extension 'azext_hello2' that overrides existing command 'hello overridden' + outdated_command_index = {'hello': ['azure.cli.command_modules.hello', 'azext_hello1'], + 'extra': ['azure.cli.command_modules.extra']} + _set_index(outdated_command_index) + # Command for an overridden command + cmd_tbl = loader.load_command_table(["hello", "overridden"]) + # Index is not updated + self.assertEqual(INDEX[CommandIndex._COMMAND_INDEX], outdated_command_index) + # With the old command index, 'hello overridden' is loaded from the build-in module + hello_overridden_cmd = cmd_tbl['hello overridden'] + self.assertEqual(hello_overridden_cmd.command_source, 'hello') + self.assertListEqual(list(cmd_tbl), ['hello mod-only', 'hello overridden', 'hello ext-only']) + + # Command index is explicitly invalidated by azure.cli.core.extension.operations.add_extension + command_index.invalidate() + + # Command index is updated, and 'hello overridden' is loaded from the new extension + cmd_tbl = loader.load_command_table(["hello", "overridden"]) + hello_overridden_cmd = cmd_tbl['hello overridden'] + self.assertTrue(isinstance(hello_overridden_cmd.command_source, ExtensionCommandSource)) + _check_index() + self.assertListEqual(list(cmd_tbl), expected_command_table) + + # Call again with the new command index. Irrelevant commands are not loaded + cmd_tbl = loader.load_command_table(["hello", "overridden"]) + hello_overridden_cmd = cmd_tbl['hello overridden'] + self.assertTrue(isinstance(hello_overridden_cmd.command_source, ExtensionCommandSource)) + _check_index() + self.assertListEqual(list(cmd_tbl), ['hello mod-only', 'hello overridden', 'hello ext-only']) + + del INDEX[CommandIndex._COMMAND_INDEX_VERSION] + del INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE] + del INDEX[CommandIndex._COMMAND_INDEX] def test_argument_with_overrides(self):