Skip to content

Commit

Permalink
[k8s-extension] Update extension CLI to v1.6.0 (Azure#7204)
Browse files Browse the repository at this point in the history
  • Loading branch information
bavneetsingh16 authored Jan 24, 2024
1 parent f9e12b7 commit 426b159
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 66 deletions.
6 changes: 6 additions & 0 deletions src/k8s-extension/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Release History
===============

1.6.0
++++++++++++++++++
* AAD related changes in dataprotection aks ext CLI
* microsoft.azuremonitor.containers: Make containerlogv2 as default as true and remove region dependency for ARC
* microsoft.workloadiam: Refactor subcommand invocation

1.5.3
++++++++++++++++++
* Add WorkloadIAM extension support and tests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -673,37 +673,23 @@ def _ensure_container_insights_dcr_for_monitoring(cmd, subscription_id, cluster_
for region_data in json_response["value"]:
region_names_to_id[region_data["displayName"]] = region_data["name"]

# check if region supports DCR and DCR-A
for _ in range(3):
try:
feature_check_url = cmd.cli_ctx.cloud.endpoints.resource_manager + f"/subscriptions/{subscription_id}/providers/Microsoft.Insights?api-version=2020-10-01"
r = send_raw_request(cmd.cli_ctx, "GET", feature_check_url)
error = None
break
except AzCLIError as e:
error = e
else:
raise error

json_response = json.loads(r.text)
for resource in json_response["resourceTypes"]:
if (resource["resourceType"].lower() == "datacollectionrules"):
region_ids = map(lambda x: region_names_to_id[x], resource["locations"]) # dcr supported regions
if (workspace_region not in region_ids):
raise ClientRequestError(f"Data Collection Rules are not supported for LA workspace region {workspace_region}")
if (resource["resourceType"].lower() == "datacollectionruleassociations"):
region_ids = map(lambda x: region_names_to_id[x], resource["locations"]) # dcr-a supported regions
if (cluster_region not in region_ids):
raise ClientRequestError(f"Data Collection Rule Associations are not supported for cluster region {cluster_region}")

dcr_url = cmd.cli_ctx.cloud.endpoints.resource_manager + f"{dcr_resource_id}?api-version={DCR_API_VERSION}"
# get existing tags on the container insights extension DCR if the customer added any
existing_tags = get_existing_container_insights_extension_dcr_tags(cmd, dcr_url)
streams = ["Microsoft-ContainerInsights-Group-Default"]
if extensionSettings is not None and 'dataCollectionSettings' in extensionSettings.keys():
if extensionSettings is None:
extensionSettings = {}
if 'dataCollectionSettings' in extensionSettings.keys():
dataCollectionSettings = extensionSettings["dataCollectionSettings"]
dataCollectionSettings.setdefault("enableContainerLogV2", True)
if dataCollectionSettings is not None and 'streams' in dataCollectionSettings.keys():
streams = dataCollectionSettings["streams"]
else:
# If data_collection_settings is None, set default dataCollectionSettings
dataCollectionSettings = {
"enableContainerLogV2": True
}
extensionSettings["dataCollectionSettings"] = dataCollectionSettings

# create the DCR
dcr_creation_body = json.dumps(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,27 @@ def __init__(self):
self.BACKUP_STORAGE_ACCOUNT_SUBSCRIPTION = "configuration.backupStorageLocation.config.subscriptionId"
self.RESOURCE_LIMIT_CPU = "resources.limits.cpu"
self.RESOURCE_LIMIT_MEMORY = "resources.limits.memory"
self.BACKUP_STORAGE_ACCOUNT_USE_AAD = "configuration.backupStorageLocation.config.useAAD"
self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI = "configuration.backupStorageLocation.config.storageAccountURI"

self.blob_container = "blobContainer"
self.storage_account = "storageAccount"
self.storage_account_resource_group = "storageAccountResourceGroup"
self.storage_account_subsciption = "storageAccountSubscriptionId"
self.cpu_limit = "cpuLimit"
self.memory_limit = "memoryLimit"
self.use_aad = "useAAD"
self.storage_account_uri = "storageAccountURI"

self.configuration_mapping = {
self.blob_container.lower(): self.BACKUP_STORAGE_ACCOUNT_CONTAINER,
self.storage_account.lower(): self.BACKUP_STORAGE_ACCOUNT_NAME,
self.storage_account_resource_group.lower(): self.BACKUP_STORAGE_ACCOUNT_RESOURCE_GROUP,
self.storage_account_subsciption.lower(): self.BACKUP_STORAGE_ACCOUNT_SUBSCRIPTION,
self.cpu_limit.lower(): self.RESOURCE_LIMIT_CPU,
self.memory_limit.lower(): self.RESOURCE_LIMIT_MEMORY
self.memory_limit.lower(): self.RESOURCE_LIMIT_MEMORY,
self.use_aad.lower(): self.BACKUP_STORAGE_ACCOUNT_USE_AAD,
self.storage_account_uri.lower(): self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI
}

self.bsl_configuration_settings = [
Expand Down Expand Up @@ -99,6 +105,15 @@ def Create(

configuration_settings[self.TENANT_ID] = tenant_id

if configuration_settings.get(self.BACKUP_STORAGE_ACCOUNT_USE_AAD) is None:
logger.warning("useAAD flag is not specified. Setting it to 'true'. Please provide extension MSI Storage Blob Data Contributor role to the storage account.")
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_USE_AAD] = "true"

if configuration_settings.get(self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI) is None:
logger.warning("storageAccountURI is not populated. Setting it to the storage account URI of provided storage account")
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI] = self.__get_storage_account_uri(cmd.cli_ctx, configuration_settings)
logger.warning(f"storageAccountURI: {configuration_settings[self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI]}")

if release_train is None:
release_train = 'stable'

Expand Down Expand Up @@ -128,11 +143,25 @@ def Update(
if configuration_settings is None:
configuration_settings = {}

bsl_specified = False
if len(configuration_settings) > 0:
bsl_specified = self.__is_bsl_specified(configuration_settings)
self.__validate_and_map_config(configuration_settings, validate_bsl=bsl_specified)
if bsl_specified:
self.__validate_backup_storage_account(cmd.cli_ctx, resource_group_name, cluster_name, configuration_settings)
# this step is for brownfield migrating to AAD
if configuration_settings.get(self.BACKUP_STORAGE_ACCOUNT_USE_AAD) is not None and configuration_settings.get(self.BACKUP_STORAGE_ACCOUNT_USE_AAD).lower() == "true":
logger.warning("useAAD flag is set to true. Please provide extension MSI Storage Blob Data Contributor role to the storage account.")

if configuration_settings.get(self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI) is None:
# SA details provided in user inputs, but did not provide SA URI.
logger.warning("storageAccountURI is not populated. Setting it to the storage account URI of provided storage account")
if bsl_specified:
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI] = self.__get_storage_account_uri(cmd.cli_ctx, configuration_settings)
# SA details not provided in user input, SA Uri not provided in user input, and also not populated in the original extension, we populate it.
elif not bsl_specified and original_extension.configuration_settings.get(self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI) is None:
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI] = self.__get_storage_account_uri(cmd.cli_ctx, original_extension.configuration_settings)
logger.warning(f"storageAccountURI: {configuration_settings[self.BACKUP_STORAGE_ACCOUNT_STORAGE_ACCOUNT_URI]}")

return PatchExtension(
auto_upgrade_minor_version=True,
Expand Down Expand Up @@ -169,12 +198,7 @@ def __validate_backup_storage_account(self, cli_ctx, resource_group_name, cluste
- Existance of the storage account
- Cluster and storage account are in the same location
"""
sa_subscription_id = configuration_settings[self.BACKUP_STORAGE_ACCOUNT_SUBSCRIPTION]
storage_account_client = cf_storage(cli_ctx, sa_subscription_id).storage_accounts

storage_account = storage_account_client.get_properties(
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_RESOURCE_GROUP],
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_NAME])
storage_account = self.__get_storage_account(cli_ctx, configuration_settings)

cluster_subscription_id = get_subscription_id(cli_ctx)
managed_clusters_client = cf_managed_clusters(cli_ctx, cluster_subscription_id)
Expand All @@ -186,6 +210,23 @@ def __validate_backup_storage_account(self, cli_ctx, resource_group_name, cluste
error_message = f"The Kubernetes managed cluster '{cluster_name} ({managed_cluster.location})' and the backup storage account '{configuration_settings[self.BACKUP_STORAGE_ACCOUNT_NAME]} ({storage_account.location})' are not in the same location. Please make sure that the cluster and the storage account are in the same location."
raise SystemExit(logger.error(error_message))

def __get_storage_account(self, cli_ctx, configuration_settings):
"""Get the storage account properties"""
from azure.cli.core.commands.client_factory import get_mgmt_service_client
from azure.mgmt.storage import StorageManagementClient

sa_subscription_id = configuration_settings[self.BACKUP_STORAGE_ACCOUNT_SUBSCRIPTION]
storage_account_client = get_mgmt_service_client(cli_ctx, StorageManagementClient, subscription_id=sa_subscription_id)

return storage_account_client.storage_accounts.get_properties(
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_RESOURCE_GROUP],
configuration_settings[self.BACKUP_STORAGE_ACCOUNT_NAME])

def __get_storage_account_uri(self, cli_ctx, configuration_settings):
"""Get the storage account blob endpoint"""
storage_account = self.__get_storage_account(cli_ctx, configuration_settings)
return storage_account.primary_endpoints.blob

def __is_bsl_specified(self, configuration_settings):
"""Check if the backup storage account is specified in the input"""
input_configuration_keys = [key.lower() for key in configuration_settings]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import subprocess
import os

from knack.log import get_logger
from knack.util import CLIError

from azure.cli.core import get_default_cli
from azure.cli.core.azclierror import InvalidArgumentValueError

from ..vendored_sdks.models import (Extension, Scope, ScopeCluster)
Expand Down Expand Up @@ -44,6 +45,15 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t
# TODO - Set this to 'stable' when the extension is ready
release_train = 'preview'

# The name is used as a base to generate Kubernetes labels for config maps, pods, etc, and
# their names are limited to 63 characters (RFC-1123). Instead of calculating the exact
# number of characters that we can allow users to specify, it's better to restrict that even
# more so that we have extra space in case the name of a resource changes and it pushes the
# total string length over the limit.
if len(name) > 20:
raise InvalidArgumentValueError(
f"Name '{name}' is too long, it must be 20 characters long max.")

scope = scope.lower()
if scope is None:
scope = 'cluster'
Expand Down Expand Up @@ -117,36 +127,24 @@ def get_join_token(self, trust_domain, local_authority):
Invoke the az command to obtain a join token.
"""

logger.info("Getting a join token from the control plane")
logger.debug("Getting a join token from the control plane")

# Invoke az workload-iam command to obtain the join token
cmd = [
"az", "workload-iam", "local-authority", "attestation-method", "create",
"workload-iam", "local-authority", "attestation-method", "create",
"--td", trust_domain,
"--la", local_authority,
"--type", "joinTokenAttestationMethod",
"--query", "singleUseToken",
"--dn", "myJoinToken",
]
cmd_str = " ".join(cmd)

try:
# Note: We can't use get_default_cli() here because its invoke() method
# always prints the console output, which we want to avoid.
result = subprocess.run(cmd, capture_output=True, shell=True)
except Exception as e:
logger.error(f"Error while generating a join token: {cmd_str}")
raise e

if result.returncode != 0:
raise CLIError(f"Failed to generate a join token (exit code {result.returncode}): {cmd_str}")

try:
# Strip double quotes from the output
command_output = result.stdout.decode("utf-8")
token = command_output.strip("\r\n").strip("\"")
except Exception as e:
logger.error(f"Failed to parse output of join token command: {cmd_str}")
raise e

cli = get_default_cli()
cli.invoke(cmd, out_file=open(os.devnull, 'w')) # Don't print output
if cli.result.result:
token = cli.result.result
elif cli.result.error:
cmd_str = "az " + " ".join(cmd)
raise CLIError(f"Error while generating a join token. Command: {cmd_str}")

return token
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@

class TestWorkloadIAM(unittest.TestCase):

def test_workload_iam_create_with_instance_name_too_long(self):
"""
Test that the checks fail when the user provides an instance name that is too long.
"""

instance_name = "workload-iam-extra-long-instance-name"

with self.assertRaises(InvalidArgumentValueError) as context:
workload_iam = WorkloadIAM()
workload_iam.Create(cmd=None, client=None, resource_group_name=None,
cluster_name=None, name=instance_name, cluster_type=None, cluster_rp=None,
extension_type=None, scope='cluster', auto_upgrade_minor_version=None,
release_train='dev', version='0.1.0', target_namespace=None,
release_namespace=None, configuration_settings=None,
configuration_protected_settings=None, configuration_settings_file=None,
configuration_protected_settings_file=None, plan_name=None, plan_publisher=None,
plan_product=None)

self.assertEqual(str(context.exception),
f"Name '{instance_name}' is too long, it must be 20 characters long max.")

def test_workload_iam_create_without_join_token_success(self):
"""
Test that, when the user doesn't provide a join token, the Create() method calls
Expand Down Expand Up @@ -271,13 +292,20 @@ def test_workload_iam_get_join_token_with_valid_argument_success(self):
mock_local_authority_name = 'any_local_authority_name'
mock_join_token = 'any_join_token'

class MockResult():
class MockCLI():
def __init__(self):
self.returncode = 0
self.stdout = ('\"' + mock_join_token + '\"').encode('utf-8')
pass

def invoke(self, cmd, out_file):
class MockResult():
def __init__(self):
self.result = mock_join_token
self.error = None

self.result = MockResult()

with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.subprocess.run',
return_value=MockResult()):
with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.get_default_cli',
return_value=MockCLI()):
# Test & assert
workload_iam = WorkloadIAM()
join_token = workload_iam.get_join_token(mock_trust_domain_name, mock_local_authority_name)
Expand All @@ -294,7 +322,6 @@ def test_workload_iam_get_join_token_with_bad_exit_code(self):
mock_trust_domain_name = 'any_trust_domain_name.com'
mock_local_authority_name = 'any_local_authority_name'
mock_join_token = 'any_join_token'
mock_exit_code = 1

cmd = [
"az", "workload-iam", "local-authority", "attestation-method", "create",
Expand All @@ -305,15 +332,23 @@ def test_workload_iam_get_join_token_with_bad_exit_code(self):
"--dn", "myJoinToken",
]

class MockResult():
class MockCLI():
def __init__(self):
self.returncode = mock_exit_code
pass

def invoke(self, cmd, out_file):
class MockResult():
def __init__(self):
self.result = None
self.error = True

self.result = MockResult()

with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.subprocess.run',
return_value=MockResult()):
with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.get_default_cli',
return_value=MockCLI()):
# Test & assert
workload_iam = WorkloadIAM()
cmd_str = " ".join(cmd)
self.assertRaisesRegex(CLIError,
f"Failed to generate a join token \(exit code {mock_exit_code}\): {cmd_str}",
f"Error while generating a join token. Command: {cmd_str}",
workload_iam.get_join_token, mock_trust_domain_name, mock_local_authority_name)
2 changes: 1 addition & 1 deletion src/k8s-extension/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# TODO: Add any additional SDK dependencies here
DEPENDENCIES = []

VERSION = "1.5.3"
VERSION = "1.6.0"

with open("README.rst", "r", encoding="utf-8") as f:
README = f.read()
Expand Down

0 comments on commit 426b159

Please sign in to comment.