Skip to content

Commit

Permalink
Checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
SireInsectus committed Nov 27, 2023
1 parent 7ca58eb commit fd47352
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 49 deletions.
5 changes: 2 additions & 3 deletions src/dbacademy/clients/dbrest/accounts_client/scim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
__all__ = ["AccountScimApi"]

from dbacademy.clients.rest.common import ApiContainer
from dbacademy.common import validate
from dbacademy.clients.rest.common import ApiClient, ApiContainer
from dbacademy.clients.dbrest.accounts_client.scim.users import AccountScimUsersApi
from dbacademy.clients.rest.common import ApiClient


class AccountScimApi(ApiContainer):

def __init__(self, client: ApiClient, account_id: str):
from dbacademy.common import validate

self.__client = validate(client=client).required.as_type(ApiClient)
self.__account_id = account_id
Expand Down
6 changes: 0 additions & 6 deletions src/dbacademy/clients/dbrest/clusters_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,13 @@
class ClustersApi(ApiContainer):

def __init__(self, client: ApiClient):
from dbacademy.common import validate

self.__client = validate(client=client).required.as_type(ApiClient)
self.base_uri = f"{self.__client.endpoint}/api/2.0/clusters"

def create_from_config(self, config: ClusterConfig) -> str:
return self.create_from_dict(config.params)

def create_from_dict(self, params: Dict[str, Any]) -> str:
import json
print("-"*80)
print(json.dumps(params, indent=4))
print("-"*80)
cluster = self.__client.api("POST", f"{self.base_uri}/create", _data=params)
return cluster.get("cluster_id")

Expand Down
22 changes: 14 additions & 8 deletions src/dbacademy/clients/dbrest/clusters_api/cluster_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, *,
instance_pool_id: Optional[str],
policy_id: Optional[str],
num_workers: int,
autotermination_minutes: int,
autotermination_minutes: Optional[int],
single_user_name: Optional[str],
availability: Optional[Union[str, Availability]],
spark_conf: Optional[Dict[str, str]],
Expand All @@ -91,14 +91,17 @@ def __init__(self, *,
"cluster_name": validate(cluster_name=cluster_name).optional.str(),
"spark_version": validate(spark_version=spark_version).required.str(),
"num_workers": validate(num_workers=num_workers).required.int(),
"autotermination_minutes": validate(autotermination_minutes=autotermination_minutes).required.int(),
}

extra_params = validate(extra_params=extra_params).optional.dict(str, auto_create=True)
spark_conf = validate(spark_conf=spark_conf).optional.dict(str, auto_create=True)
spark_env_vars = validate(spark_env_vars=spark_env_vars).optional.dict(str, auto_create=True)
custom_tags = validate(custom_tags=custom_tags).optional.dict(str, auto_create=True)

if autotermination_minutes is not None:
# Not set for job clusters
self.__params["autotermination_minutes"] = validate(autotermination_minutes=autotermination_minutes).required.int()

if instance_pool_id is not None:
extra_params["instance_pool_id"]: validate(instance_pool_id=instance_pool_id).optional.str()

Expand All @@ -109,14 +112,17 @@ def __init__(self, *,
extra_params["single_user_name"] = validate(single_user_name=single_user_name).required.str()
extra_params["data_security_mode"] = "SINGLE_USER"

extra_params["node_type_id"] = validate(node_type_id=node_type_id).required.str()
extra_params["driver_node_type_id"] = validate(driver_node_type_id=driver_node_type_id or node_type_id).required.str()

if num_workers == 0:
# Don't use "local[*, 4] because the node type might have more cores
custom_tags["ResourceClass"] = "SingleNode"
spark_conf["spark.master"] = "local[*]"
spark_conf["spark.databricks.cluster.profile"] = "singleNode"
assert driver_node_type_id is None, f"""driver_node_type_id should be None when num_workers is zero."""
else:
# More than one worker so define the driver_node_type_id and if necessary, default it to the node_type_id
extra_params["driver_node_type_id"] = validate(driver_node_type_id=driver_node_type_id or node_type_id).required.str()

extra_params["node_type_id"] = validate(node_type_id=node_type_id).required.str()

assert extra_params.get("custom_tags") is None, f"The parameter \"extra_params.custom_tags\" should not be specified directly, use \"custom_tags\" instead."
assert extra_params.get("spark_conf") is None, f"The parameter \"extra_params.spark_conf\" should not be specified directly, use \"spark_conf\" instead."
Expand Down Expand Up @@ -230,7 +236,7 @@ def __init__(self, *,
driver_node_type_id: Optional[str] = None,
instance_pool_id: Optional[str] = None,
policy_id: Optional[str] = None,
autotermination_minutes: int = 120,
# autotermination_minutes: int = 120,
single_user_name: Optional[str] = None,
availability: Optional[Union[str, Availability]] = None,
spark_conf: Optional[Dict[str, str]] = None,
Expand All @@ -241,14 +247,14 @@ def __init__(self, *,

# Parameters are validated in the call to CommonConfig
super().__init__(cloud=cloud,
cluster_name=None, # Not allowed when uses as a job
cluster_name=None, # Not allowed when used as a job
spark_version=spark_version,
node_type_id=node_type_id,
driver_node_type_id=driver_node_type_id,
instance_pool_id=instance_pool_id,
policy_id=policy_id,
num_workers=num_workers,
autotermination_minutes=autotermination_minutes,
autotermination_minutes=None, # Not allowed when used as a job
single_user_name=single_user_name,
spark_conf=spark_conf,
spark_env_vars=spark_env_vars,
Expand Down
4 changes: 2 additions & 2 deletions src/dbacademy/clients/dbrest/ml_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__all__ = ["MlApi"]
# Code Review: JDP on 11-27-2023

from dbacademy.common import validate
from dbacademy.clients.rest.common import ApiClient, ApiContainer
from dbacademy.clients.dbrest.ml_api.feature_store_api import FeatureStoreApi
from dbacademy.clients.dbrest.ml_api.mlflow_endpoints_api import MLflowEndpointsApi
Expand All @@ -9,8 +11,6 @@

class MlApi(ApiContainer):
def __init__(self, client: ApiClient):
from dbacademy.common import validate

self.__client = validate(client=client).required.as_type(ApiClient)

@property
Expand Down
6 changes: 2 additions & 4 deletions src/dbacademy/clients/dbrest/permissions_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__all__ = ["PermissionsApi"]
# Code Review: JDP on 11-27-2023

from dbacademy.common import validate
from dbacademy.clients.rest.common import ApiContainer, ApiClient
from dbacademy.clients.dbrest.permissions_api.clusters_permissions_api import ClustersPermissionsApi
from dbacademy.clients.dbrest.permissions_api.directories_permissions_api import DirectoriesPermissionsApi
Expand All @@ -14,8 +16,6 @@
class Authorization:

def __init__(self, client: ApiClient):
from dbacademy.common import validate

self.__client = validate(client=client).required.as_type(ApiClient)

@property
Expand All @@ -26,8 +26,6 @@ def tokens(self) -> AuthTokensPermissionsApi:
class PermissionsApi(ApiContainer):

def __init__(self, client: ApiClient):
from dbacademy.common import validate

self.__client = validate(client=client).required.as_type(ApiClient)

@property
Expand Down
4 changes: 2 additions & 2 deletions src/dbacademy/clients/dbrest/permissions_api/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__all__ = ["SqlPermissionsApi"]
# Code Review: JDP on 11-27-2023

from dbacademy.common import validate
from dbacademy.clients.rest.common import ApiContainer, ApiClient
from dbacademy.clients.dbrest.permissions_api.sql.warehouses_permissions_api import SqlWarehousesPermissionsApi
from dbacademy.clients.dbrest.permissions_api.sql.sql_crud_permissions_api import SqlCrudPermissions
Expand All @@ -8,8 +10,6 @@
class SqlPermissionsApi(ApiContainer):

def __init__(self, client: ApiClient):
from dbacademy.common import validate

self.__client = validate(client=client).required.as_type(ApiClient)

@property
Expand Down
5 changes: 3 additions & 2 deletions src/dbacademy/clients/dbrest/scim_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
__all__ = ["ScimApi"]
# Code Review: JDP on 11-27-2023

from typing import Dict, Any
from dbacademy.common import validate
from dbacademy.clients.rest.common import ApiContainer, ApiClient
from dbacademy.clients.dbrest.scim_api.users_api import ScimUsersApi
from dbacademy.clients.dbrest.scim_api.service_principals_api import ScimServicePrincipalsApi
from dbacademy.clients.dbrest.scim_api.groups_api import ScimGroupsApi


class ScimApi(ApiContainer):
def __init__(self, client: ApiClient):
from dbacademy.common import validate

def __init__(self, client: ApiClient):
self.__client = validate(client=client).required.as_type(ApiClient)

@property
Expand Down
8 changes: 2 additions & 6 deletions src/dbacademy/clients/dbrest/sql_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__all__ = ["SqlApi"]
# Code Review: JDP on 11-27-2023

from dbacademy.common import validate
from dbacademy.clients.rest.common import ApiContainer, ApiClient
from dbacademy.clients.dbrest.sql_api.config_api import SqlConfigApi
from dbacademy.clients.dbrest.sql_api.warehouses_api import SqlWarehousesApi
Expand All @@ -10,8 +12,6 @@
class SqlApi(ApiContainer):

def __init__(self, client: ApiClient):
from dbacademy.common import validate

self.__client = validate(client=client).required.as_type(ApiClient)

@property
Expand All @@ -29,7 +29,3 @@ def queries(self) -> SqlQueriesApi:
@property
def statements(self) -> StatementsApi:
return StatementsApi(self.__client)

# @property
# def permissions(self) -> SqlPermissionsApi:
# return SqlPermissionsApi(self.__client)
4 changes: 2 additions & 2 deletions src/dbacademy/common/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ def __validate_min_length(self, *, min_length: int = 0) -> None:
message = f"""{E_INTERNAL} | Expected {self.__class__.__name__}.{inspect.stack()[0].function}(..)'s parameter 'min_length' to be of type int, found {type(min_length)}."""
do_validate(passed=isinstance(min_length, int), message=message)

if min_length > 0:
# We cannot test the length if the value is not of type Sized.
if self.parameter_value is not None and min_length > 0:
# We cannot test the length if the value is not of type Sized, and we shouldn't test it if it is None.
message = f"""{E_TYPE} | Expected the parameter '{self.parameter_name}' to be of type Sized, found {type(self.parameter_value)}."""
do_validate(passed=isinstance(self.parameter_value, Sized), message=message)

Expand Down
16 changes: 2 additions & 14 deletions test/dbacademy_test/clients/dbrest/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_default_cluster(self):
actual_value = cluster.get(key)
expected_value = expected.get(key)
if expected_value != actual_value:
self.fail(f"{key}: \"{expected_value}\" != \"{actual_value}\"")
self.fail(f"""{key}: "{expected_value}" != "{actual_value}".""")

self.assertTrue(len(cluster.get("aws_attributes")) > 0)

Expand Down Expand Up @@ -376,6 +376,7 @@ def test_create_driver_node_type(self):
cluster_name="Driver Node Type A",
spark_version="11.3.x-scala2.12",
node_type_id="i3.xlarge",
# Expected to be None
# driver_node_type_id="i3.2xlarge",
num_workers=0,
autotermination_minutes=10))
Expand All @@ -384,19 +385,6 @@ def test_create_driver_node_type(self):
self.assertEqual("i3.xlarge", cluster.get("driver_node_type_id"))
self.client.clusters.destroy_by_id(cluster_id)

cluster_id = self.client.clusters.create_from_config(ClusterConfig(
cloud=Cloud.AWS,
cluster_name="Driver Node Type B",
spark_version="11.3.x-scala2.12",
node_type_id="i3.xlarge",
driver_node_type_id="i3.2xlarge",
num_workers=0,
autotermination_minutes=10))
cluster = self.client.clusters.get_by_id(cluster_id)
self.assertEqual("i3.xlarge", cluster.get("node_type_id"))
self.assertEqual("i3.2xlarge", cluster.get("driver_node_type_id"))
self.client.clusters.destroy_by_id(cluster_id)

def test_create_with_spark_conf(self):

cluster_id = self.client.clusters.create_from_config(ClusterConfig(
Expand Down

0 comments on commit fd47352

Please sign in to comment.