diff --git a/src/dbacademy/clients/dbrest/accounts_client/scim/__init__.py b/src/dbacademy/clients/dbrest/accounts_client/scim/__init__.py index 5179b31c..f4395379 100644 --- a/src/dbacademy/clients/dbrest/accounts_client/scim/__init__.py +++ b/src/dbacademy/clients/dbrest/accounts_client/scim/__init__.py @@ -1,14 +1,13 @@ __all__ = ["AccountScimApi"] -from dbacademy.clients.rest.common import ApiContainer +from dbacademy.common import validate +from dbacademy.clients.rest.common import ApiClient, ApiContainer from dbacademy.clients.dbrest.accounts_client.scim.users import AccountScimUsersApi -from dbacademy.clients.rest.common import ApiClient class AccountScimApi(ApiContainer): def __init__(self, client: ApiClient, account_id: str): - from dbacademy.common import validate self.__client = validate(client=client).required.as_type(ApiClient) self.__account_id = account_id diff --git a/src/dbacademy/clients/dbrest/clusters_api/__init__.py b/src/dbacademy/clients/dbrest/clusters_api/__init__.py index f1479dde..fcf1afa1 100644 --- a/src/dbacademy/clients/dbrest/clusters_api/__init__.py +++ b/src/dbacademy/clients/dbrest/clusters_api/__init__.py @@ -9,8 +9,6 @@ class ClustersApi(ApiContainer): def __init__(self, client: ApiClient): - from dbacademy.common import validate - self.__client = validate(client=client).required.as_type(ApiClient) self.base_uri = f"{self.__client.endpoint}/api/2.0/clusters" @@ -18,10 +16,6 @@ def create_from_config(self, config: ClusterConfig) -> str: return self.create_from_dict(config.params) def create_from_dict(self, params: Dict[str, Any]) -> str: - import json - print("-"*80) - print(json.dumps(params, indent=4)) - print("-"*80) cluster = self.__client.api("POST", f"{self.base_uri}/create", _data=params) return cluster.get("cluster_id") diff --git a/src/dbacademy/clients/dbrest/clusters_api/cluster_config.py b/src/dbacademy/clients/dbrest/clusters_api/cluster_config.py index 6f1eae55..6dd864b5 100644 --- a/src/dbacademy/clients/dbrest/clusters_api/cluster_config.py +++ b/src/dbacademy/clients/dbrest/clusters_api/cluster_config.py @@ -78,7 +78,7 @@ def __init__(self, *, instance_pool_id: Optional[str], policy_id: Optional[str], num_workers: int, - autotermination_minutes: int, + autotermination_minutes: Optional[int], single_user_name: Optional[str], availability: Optional[Union[str, Availability]], spark_conf: Optional[Dict[str, str]], @@ -91,7 +91,6 @@ def __init__(self, *, "cluster_name": validate(cluster_name=cluster_name).optional.str(), "spark_version": validate(spark_version=spark_version).required.str(), "num_workers": validate(num_workers=num_workers).required.int(), - "autotermination_minutes": validate(autotermination_minutes=autotermination_minutes).required.int(), } extra_params = validate(extra_params=extra_params).optional.dict(str, auto_create=True) @@ -99,6 +98,10 @@ def __init__(self, *, spark_env_vars = validate(spark_env_vars=spark_env_vars).optional.dict(str, auto_create=True) custom_tags = validate(custom_tags=custom_tags).optional.dict(str, auto_create=True) + if autotermination_minutes is not None: + # Not set for job clusters + self.__params["autotermination_minutes"] = validate(autotermination_minutes=autotermination_minutes).required.int() + if instance_pool_id is not None: extra_params["instance_pool_id"]: validate(instance_pool_id=instance_pool_id).optional.str() @@ -109,14 +112,17 @@ def __init__(self, *, extra_params["single_user_name"] = validate(single_user_name=single_user_name).required.str() extra_params["data_security_mode"] = "SINGLE_USER" - extra_params["node_type_id"] = validate(node_type_id=node_type_id).required.str() - extra_params["driver_node_type_id"] = validate(driver_node_type_id=driver_node_type_id or node_type_id).required.str() - if num_workers == 0: # Don't use "local[*, 4] because the node type might have more cores custom_tags["ResourceClass"] = "SingleNode" spark_conf["spark.master"] = "local[*]" spark_conf["spark.databricks.cluster.profile"] = "singleNode" + assert driver_node_type_id is None, f"""driver_node_type_id should be None when num_workers is zero.""" + else: + # More than one worker so define the driver_node_type_id and if necessary, default it to the node_type_id + extra_params["driver_node_type_id"] = validate(driver_node_type_id=driver_node_type_id or node_type_id).required.str() + + extra_params["node_type_id"] = validate(node_type_id=node_type_id).required.str() assert extra_params.get("custom_tags") is None, f"The parameter \"extra_params.custom_tags\" should not be specified directly, use \"custom_tags\" instead." assert extra_params.get("spark_conf") is None, f"The parameter \"extra_params.spark_conf\" should not be specified directly, use \"spark_conf\" instead." @@ -230,7 +236,7 @@ def __init__(self, *, driver_node_type_id: Optional[str] = None, instance_pool_id: Optional[str] = None, policy_id: Optional[str] = None, - autotermination_minutes: int = 120, + # autotermination_minutes: int = 120, single_user_name: Optional[str] = None, availability: Optional[Union[str, Availability]] = None, spark_conf: Optional[Dict[str, str]] = None, @@ -241,14 +247,14 @@ def __init__(self, *, # Parameters are validated in the call to CommonConfig super().__init__(cloud=cloud, - cluster_name=None, # Not allowed when uses as a job + cluster_name=None, # Not allowed when used as a job spark_version=spark_version, node_type_id=node_type_id, driver_node_type_id=driver_node_type_id, instance_pool_id=instance_pool_id, policy_id=policy_id, num_workers=num_workers, - autotermination_minutes=autotermination_minutes, + autotermination_minutes=None, # Not allowed when used as a job single_user_name=single_user_name, spark_conf=spark_conf, spark_env_vars=spark_env_vars, diff --git a/src/dbacademy/clients/dbrest/ml_api/__init__.py b/src/dbacademy/clients/dbrest/ml_api/__init__.py index fd013f46..11085fd5 100644 --- a/src/dbacademy/clients/dbrest/ml_api/__init__.py +++ b/src/dbacademy/clients/dbrest/ml_api/__init__.py @@ -1,5 +1,7 @@ __all__ = ["MlApi"] +# Code Review: JDP on 11-27-2023 +from dbacademy.common import validate from dbacademy.clients.rest.common import ApiClient, ApiContainer from dbacademy.clients.dbrest.ml_api.feature_store_api import FeatureStoreApi from dbacademy.clients.dbrest.ml_api.mlflow_endpoints_api import MLflowEndpointsApi @@ -9,8 +11,6 @@ class MlApi(ApiContainer): def __init__(self, client: ApiClient): - from dbacademy.common import validate - self.__client = validate(client=client).required.as_type(ApiClient) @property diff --git a/src/dbacademy/clients/dbrest/permissions_api/__init__.py b/src/dbacademy/clients/dbrest/permissions_api/__init__.py index 31c60f84..339afb1c 100644 --- a/src/dbacademy/clients/dbrest/permissions_api/__init__.py +++ b/src/dbacademy/clients/dbrest/permissions_api/__init__.py @@ -1,5 +1,7 @@ __all__ = ["PermissionsApi"] +# Code Review: JDP on 11-27-2023 +from dbacademy.common import validate from dbacademy.clients.rest.common import ApiContainer, ApiClient from dbacademy.clients.dbrest.permissions_api.clusters_permissions_api import ClustersPermissionsApi from dbacademy.clients.dbrest.permissions_api.directories_permissions_api import DirectoriesPermissionsApi @@ -14,8 +16,6 @@ class Authorization: def __init__(self, client: ApiClient): - from dbacademy.common import validate - self.__client = validate(client=client).required.as_type(ApiClient) @property @@ -26,8 +26,6 @@ def tokens(self) -> AuthTokensPermissionsApi: class PermissionsApi(ApiContainer): def __init__(self, client: ApiClient): - from dbacademy.common import validate - self.__client = validate(client=client).required.as_type(ApiClient) @property diff --git a/src/dbacademy/clients/dbrest/permissions_api/sql/__init__.py b/src/dbacademy/clients/dbrest/permissions_api/sql/__init__.py index 5465d947..9cfe13b4 100644 --- a/src/dbacademy/clients/dbrest/permissions_api/sql/__init__.py +++ b/src/dbacademy/clients/dbrest/permissions_api/sql/__init__.py @@ -1,5 +1,7 @@ __all__ = ["SqlPermissionsApi"] +# Code Review: JDP on 11-27-2023 +from dbacademy.common import validate from dbacademy.clients.rest.common import ApiContainer, ApiClient from dbacademy.clients.dbrest.permissions_api.sql.warehouses_permissions_api import SqlWarehousesPermissionsApi from dbacademy.clients.dbrest.permissions_api.sql.sql_crud_permissions_api import SqlCrudPermissions @@ -8,8 +10,6 @@ class SqlPermissionsApi(ApiContainer): def __init__(self, client: ApiClient): - from dbacademy.common import validate - self.__client = validate(client=client).required.as_type(ApiClient) @property diff --git a/src/dbacademy/clients/dbrest/scim_api/__init__.py b/src/dbacademy/clients/dbrest/scim_api/__init__.py index 58f6b4f8..c303e686 100644 --- a/src/dbacademy/clients/dbrest/scim_api/__init__.py +++ b/src/dbacademy/clients/dbrest/scim_api/__init__.py @@ -1,6 +1,8 @@ __all__ = ["ScimApi"] +# Code Review: JDP on 11-27-2023 from typing import Dict, Any +from dbacademy.common import validate from dbacademy.clients.rest.common import ApiContainer, ApiClient from dbacademy.clients.dbrest.scim_api.users_api import ScimUsersApi from dbacademy.clients.dbrest.scim_api.service_principals_api import ScimServicePrincipalsApi @@ -8,9 +10,8 @@ class ScimApi(ApiContainer): - def __init__(self, client: ApiClient): - from dbacademy.common import validate + def __init__(self, client: ApiClient): self.__client = validate(client=client).required.as_type(ApiClient) @property diff --git a/src/dbacademy/clients/dbrest/sql_api/__init__.py b/src/dbacademy/clients/dbrest/sql_api/__init__.py index d5d898a0..34fa9545 100644 --- a/src/dbacademy/clients/dbrest/sql_api/__init__.py +++ b/src/dbacademy/clients/dbrest/sql_api/__init__.py @@ -1,5 +1,7 @@ __all__ = ["SqlApi"] +# Code Review: JDP on 11-27-2023 +from dbacademy.common import validate from dbacademy.clients.rest.common import ApiContainer, ApiClient from dbacademy.clients.dbrest.sql_api.config_api import SqlConfigApi from dbacademy.clients.dbrest.sql_api.warehouses_api import SqlWarehousesApi @@ -10,8 +12,6 @@ class SqlApi(ApiContainer): def __init__(self, client: ApiClient): - from dbacademy.common import validate - self.__client = validate(client=client).required.as_type(ApiClient) @property @@ -29,7 +29,3 @@ def queries(self) -> SqlQueriesApi: @property def statements(self) -> StatementsApi: return StatementsApi(self.__client) - - # @property - # def permissions(self) -> SqlPermissionsApi: - # return SqlPermissionsApi(self.__client) diff --git a/src/dbacademy/common/validator.py b/src/dbacademy/common/validator.py index 4bcfd282..bc9dac35 100644 --- a/src/dbacademy/common/validator.py +++ b/src/dbacademy/common/validator.py @@ -336,8 +336,8 @@ def __validate_min_length(self, *, min_length: int = 0) -> None: message = f"""{E_INTERNAL} | Expected {self.__class__.__name__}.{inspect.stack()[0].function}(..)'s parameter 'min_length' to be of type int, found {type(min_length)}.""" do_validate(passed=isinstance(min_length, int), message=message) - if min_length > 0: - # We cannot test the length if the value is not of type Sized. + if self.parameter_value is not None and min_length > 0: + # We cannot test the length if the value is not of type Sized, and we shouldn't test it if it is None. message = f"""{E_TYPE} | Expected the parameter '{self.parameter_name}' to be of type Sized, found {type(self.parameter_value)}.""" do_validate(passed=isinstance(self.parameter_value, Sized), message=message) diff --git a/test/dbacademy_test/clients/dbrest/test_clusters.py b/test/dbacademy_test/clients/dbrest/test_clusters.py index d76c271c..1ca04d6f 100644 --- a/test/dbacademy_test/clients/dbrest/test_clusters.py +++ b/test/dbacademy_test/clients/dbrest/test_clusters.py @@ -238,7 +238,7 @@ def test_default_cluster(self): actual_value = cluster.get(key) expected_value = expected.get(key) if expected_value != actual_value: - self.fail(f"{key}: \"{expected_value}\" != \"{actual_value}\"") + self.fail(f"""{key}: "{expected_value}" != "{actual_value}".""") self.assertTrue(len(cluster.get("aws_attributes")) > 0) @@ -376,6 +376,7 @@ def test_create_driver_node_type(self): cluster_name="Driver Node Type A", spark_version="11.3.x-scala2.12", node_type_id="i3.xlarge", + # Expected to be None # driver_node_type_id="i3.2xlarge", num_workers=0, autotermination_minutes=10)) @@ -384,19 +385,6 @@ def test_create_driver_node_type(self): self.assertEqual("i3.xlarge", cluster.get("driver_node_type_id")) self.client.clusters.destroy_by_id(cluster_id) - cluster_id = self.client.clusters.create_from_config(ClusterConfig( - cloud=Cloud.AWS, - cluster_name="Driver Node Type B", - spark_version="11.3.x-scala2.12", - node_type_id="i3.xlarge", - driver_node_type_id="i3.2xlarge", - num_workers=0, - autotermination_minutes=10)) - cluster = self.client.clusters.get_by_id(cluster_id) - self.assertEqual("i3.xlarge", cluster.get("node_type_id")) - self.assertEqual("i3.2xlarge", cluster.get("driver_node_type_id")) - self.client.clusters.destroy_by_id(cluster_id) - def test_create_with_spark_conf(self): cluster_id = self.client.clusters.create_from_config(ClusterConfig(