diff --git a/src/azure-cli/azure/cli/command_modules/acr/_constants.py b/src/azure-cli/azure/cli/command_modules/acr/_constants.py index f2cd5c078f0..7f6d69c033c 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/_constants.py +++ b/src/azure-cli/azure/cli/command_modules/acr/_constants.py @@ -27,6 +27,9 @@ ACR_RUN_DEFAULT_TIMEOUT_IN_SEC = 60 * 60 # 60 minutes +ALLOWED_TASK_FILE_TYPES = ('.yaml', '.yml', '.toml', '.json', '.sh', '.bash', '.zsh', '.ps1', + '.ps', '.cmd', '.bat', '.ts', '.js', '.php', '.py', '.rb', '.lua') + def get_classic_sku(cmd): SkuName = cmd.get_models('SkuName') diff --git a/src/azure-cli/azure/cli/command_modules/acr/_utils.py b/src/azure-cli/azure/cli/command_modules/acr/_utils.py index 152065e4960..e78b48b7db1 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/_utils.py +++ b/src/azure-cli/azure/cli/command_modules/acr/_utils.py @@ -23,10 +23,13 @@ get_valid_os, get_valid_architecture, get_valid_variant, - ACR_NULL_CONTEXT + ACR_NULL_CONTEXT, + ALLOWED_TASK_FILE_TYPES ) from ._client_factory import cf_acr_registries +from ._validators import validate_docker_file_path + from ._archive_utils import upload_source_code, check_remote_source_code logger = get_logger(__name__) @@ -405,7 +408,12 @@ def get_task_id_from_task_name(cli_ctx, resource_group, registry_name, task_name ) -def prepare_source_location(cmd, source_location, client_registries, registry_name, resource_group_name): +def prepare_source_location(cmd, + source_location, + client_registries, + registry_name, + resource_group_name, + docker_file_path=None): if not source_location or source_location.lower() == ACR_NULL_CONTEXT: source_location = None elif os.path.exists(source_location): @@ -413,13 +421,34 @@ def prepare_source_location(cmd, source_location, client_registries, registry_na raise CLIError( "Source location should be a local directory path or remote URL.") + # NOTE: If docker_file_path is not specified, the default is Dockerfile in source_location. + # Otherwise, it's based on current working directory. + if docker_file_path: + if docker_file_path.endswith(ALLOWED_TASK_FILE_TYPES) or docker_file_path == "-": + docker_file_path = "" + else: + validate_docker_file_path(os.path.join(source_location, docker_file_path)) + else: + docker_file_path = os.path.join(source_location, "Dockerfile") + if os.path.isfile(docker_file_path): + logger.info("'--file or -f' is not provided. '%s' is used.", docker_file_path) + else: + docker_file_path = "" + tar_file_path = os.path.join(tempfile.gettempdir( ), 'cli_source_archive_{}.tar.gz'.format(uuid.uuid4().hex)) try: + if docker_file_path: + # NOTE: os.path.basename is unable to parse "\" in the file path + docker_file_name = os.path.basename( + docker_file_path.replace("\\", "/")) + else: + docker_file_name = "" + source_location = upload_source_code( cmd, client_registries, registry_name, resource_group_name, - source_location, tar_file_path, "", "") + source_location, tar_file_path, docker_file_path, docker_file_name) except Exception as err: raise CLIError(err) finally: diff --git a/src/azure-cli/azure/cli/command_modules/acr/_validators.py b/src/azure-cli/azure/cli/command_modules/acr/_validators.py index f2e26265f29..bb9f8fd3ab4 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/_validators.py +++ b/src/azure-cli/azure/cli/command_modules/acr/_validators.py @@ -6,7 +6,7 @@ import os from knack.util import CLIError from knack.log import get_logger -from azure.cli.core.azclierror import InvalidArgumentValueError +from azure.cli.core.azclierror import FileOperationError, InvalidArgumentValueError BAD_REPO_FQDN = "The positional parameter 'repo_id' must be a fully qualified repository specifier such"\ " as 'MyRegistry.azurecr.io/hello-world'." @@ -140,3 +140,8 @@ def validate_repository(namespace): if ':' in namespace.repository: raise InvalidArgumentValueError("Parameter 'name' refers to a repository and" " should not include a tag or digest.") + + +def validate_docker_file_path(docker_file_path): + if not os.path.isfile(docker_file_path): + raise FileOperationError("Unable to find '{}'.".format(docker_file_path)) diff --git a/src/azure-cli/azure/cli/command_modules/acr/task.py b/src/azure-cli/azure/cli/command_modules/acr/task.py index 88142c3a599..cc7ca39bdd2 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/task.py +++ b/src/azure-cli/azure/cli/command_modules/acr/task.py @@ -26,7 +26,8 @@ from ._constants import ( ACR_NULL_CONTEXT, ACR_TASK_QUICKTASK, - ACR_RUN_DEFAULT_TIMEOUT_IN_SEC + ACR_RUN_DEFAULT_TIMEOUT_IN_SEC, + ALLOWED_TASK_FILE_TYPES ) logger = get_logger(__name__) @@ -37,8 +38,6 @@ IDENTITY_LOCAL_ID = '[system]' IDENTITY_GLOBAL_REMOVE = '[all]' DEFAULT_CPU = 2 -ALLOWED_TASK_FILE_TYPES = ('.yaml', '.yml', '.toml', '.json', '.sh', '.bash', '.zsh', '.ps1', - '.ps', '.cmd', '.bat', '.ts', '.js', '.php', '.py', '.rb', '.lua') def acr_task_create(cmd, # pylint: disable=too-many-locals @@ -877,7 +876,14 @@ def acr_task_run(cmd, # pylint: disable=too-many-locals update_trigger_token = base64.b64encode(update_trigger_token.encode()).decode() task_id = get_task_id_from_task_name(cmd.cli_ctx, resource_group_name, registry_name, task_name) - context_path = prepare_source_location(cmd, context_path, client_registries, registry_name, resource_group_name) + context_path = prepare_source_location( + cmd, + context_path, + client_registries, + registry_name, + resource_group_name, + file + ) timeout = None task_details = get_task_details_by_name(cmd.cli_ctx, resource_group_name, registry_name, task_name)