Skip to content

Commit

Permalink
replace password authentication with oauth
Browse files Browse the repository at this point in the history
  • Loading branch information
DaanRademaker committed Jul 2, 2024
1 parent eb92b11 commit acf24cc
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 59 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ Deploy the stack to your AWS account:
```shell
cdk deploy
```
Note: The AWS SSM parameters `/databricks/deploy/user`, `/databricks/deploy/password`, and `/databricks/account-id` are required for the deployment to succeed.
Note: The AWS SSM parameters `/databricks/deploy/client-id`, `/databricks/deploy/client-secret`, and `/databricks/account-id` are required for the deployment to succeed.

- `databricks/deploy/client-id` is the client-id of service principal that is account admin and workspace admin
- `databricks/deploy/client-secret` is the client-service of service principal that is account admin and workspace admin
- `/databricks/account-id` is the id of your databricks account

See also the simple-workspace and multi-stack examples in [examples](examples)

Expand Down
10 changes: 6 additions & 4 deletions aws-lambda/src/databricks_cdk/resources/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import requests
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_auth, get_request, post_request
from databricks_cdk.utils import CnfResponse, get_authorization_headers, get_request, post_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,8 +86,11 @@ def get_cluster_by_name(cluster_name: str, workspace_url: str):
def get_cluster_by_id(cluster_id: str, workspace_url: str) -> Optional[dict]:
"""Getting cluster based on name"""
body = {"cluster_id": cluster_id}
auth = get_auth()
resp = requests.get(f"{get_cluster_url(workspace_url)}/get", json=body, headers={}, auth=auth)
resp = requests.get(
f"{get_cluster_url(workspace_url)}/get",
json=body,
headers=get_authorization_headers(),
)
if resp.status_code == 400 and "does not exist" in resp.text:
return None
resp.raise_for_status()
Expand All @@ -100,7 +103,6 @@ def create_or_update_cluster(properties: ClusterProperties, physical_resource_id
if physical_resource_id is not None:
current = get_cluster_by_id(physical_resource_id, properties.workspace_url)
if current is None:

# Json data
body = properties.cluster.dict()
response = post_request(f"{get_cluster_url(properties.workspace_url)}/create", body=body)
Expand Down
9 changes: 6 additions & 3 deletions aws-lambda/src/databricks_cdk/resources/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_auth, post_request
from databricks_cdk.utils import CnfResponse, get_authorization_headers, post_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -175,8 +175,11 @@ def get_job_url(workspace_url: str):

def get_job_by_id(job_id: str, workspace_url: str):
body = {"job_id": job_id}
auth = get_auth()
resp = requests.get(f"{get_job_url(workspace_url)}/get", json=body, headers={}, auth=auth)
resp = requests.get(
f"{get_job_url(workspace_url)}/get",
json=body,
headers=get_authorization_headers(),
)
if resp.status_code == 400 and "does not exist" in resp.text:
return None
resp.raise_for_status()
Expand Down
15 changes: 2 additions & 13 deletions aws-lambda/src/databricks_cdk/resources/scim/user.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from typing import Optional

import requests
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_auth, get_request, post_request
from databricks_cdk.utils import CnfResponse, delete_request, get_request, post_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,7 +40,6 @@ def create_or_update_user(properties: UserProperties) -> UserResponse:

current = get_user_by_user_name(properties.user_name, properties.workspace_url)
if current is None:

# Json data
body = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
Expand All @@ -58,16 +56,7 @@ def delete_user(properties: UserProperties, physical_resource_id: str) -> CnfRes
"""Deletes user at databricks"""
current = get_user_by_user_name(properties.user_name, properties.workspace_url)
if current is not None:
auth = get_auth()
if auth.username != properties.user_name:
resp = requests.delete(
f"{get_user_url(properties.workspace_url)}/{current['id']}",
headers={},
auth=auth,
)
resp.raise_for_status()
else:
logger.warning("Can't remove deploy user")
delete_request(f"{get_user_url(properties.workspace_url)}/{current['id']}")
else:
logger.warning("Already removed")
return CnfResponse(physical_resource_id=physical_resource_id)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
from typing import Dict, Optional

import requests.exceptions
from pydantic import BaseModel, Field

from databricks_cdk.utils import CnfResponse, delete_request, get_request, patch_request, post_request
Expand Down
74 changes: 56 additions & 18 deletions aws-lambda/src/databricks_cdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,38 @@
import logging
import os
from functools import cache
from typing import Any, Dict, Optional

import boto3
from databricks.sdk import AccountClient, WorkspaceClient
from databricks.sdk.core import Config
from pydantic import BaseModel
from requests import request
from requests.auth import HTTPBasicAuth
from requests.exceptions import HTTPError
from tenacity import retry, retry_if_exception, retry_if_exception_type, stop_after_attempt, wait_exponential

logger = logging.getLogger(__name__)


USER_PARAM = os.environ.get("USER_PARAM", "/databricks/deploy/user")
PASS_PARAM = os.environ.get("PASS_PARAM", "/databricks/deploy/password")
class UnsupportedAuthMethodError(Exception):
pass


ACCOUNT_PARAM = os.environ.get("ACCOUNT_PARAM", "/databricks/account-id")
CLIENT_SECRET_PARAM = os.environ.get("CLIENT_SECRET_PARAM", "/databricks/deploy/client-secret")

PASS_PARAM = os.environ.get("PASS_PARAM")

# Make sure password based authentication is not used anymore after 10th of July 2024
if PASS_PARAM is not None:
raise UnsupportedAuthMethodError(
"Password based authentication is not supported from 10th of July 2024.",
"Please set client_id and client_secret of a service principal instead.",
"Service principal can be created in account settings in Databricks."
"Needs account admin and admin permissions on workspace.",
)

CLIENT_ID_PARAM = os.environ.get("CLIENT_ID_PARAM", "/databricks/deploy/client-id")
ACCOUNTS_BASE_URL = os.environ.get("BASE_URL", "https://accounts.cloud.databricks.com")


Expand All @@ -33,6 +48,26 @@ def get_param(name: str, required: bool = False):
return result


@cache
def get_authentication_config() -> Config:
"""
This config can be used to authenticate with databricks using
requests library. Not needed when using WorkspaceClient or AccountClient.
Config is cached to avoid multiple calls to get_param
"""
return Config(
host=ACCOUNTS_BASE_URL,
client_id=get_client_id(),
client_secret=get_client_secret(),
account_id=get_account_id(),
)


def get_authorization_headers() -> Dict[str, str]:
"""Get authorization headers"""
return get_authentication_config().authenticate()


class CnfResponse(BaseModel):
physical_resource_id: str

Expand All @@ -42,19 +77,12 @@ def get_account_id() -> str:
return get_param(ACCOUNT_PARAM, required=True)


def get_deploy_user() -> str:
return get_param(USER_PARAM, required=True)
def get_client_secret() -> str:
return get_param(CLIENT_SECRET_PARAM, required=True)


def get_password() -> str:
return get_param(PASS_PARAM, required=True)


def get_auth() -> HTTPBasicAuth:
"""Get auth from param store"""
user = get_deploy_user()
password = get_param(PASS_PARAM, required=True)
return HTTPBasicAuth(user, password)
def get_client_id() -> str:
return get_param(CLIENT_ID_PARAM, required=True)


@retry(
Expand All @@ -80,8 +108,13 @@ def _do_request(
:raises ValueError: If provided method is not supported
:return: Response data
"""
auth = get_auth()
resp = request(method=method, url=url, json=body, params=params, auth=auth)
resp = request(
method=method,
url=url,
json=body,
params=params,
headers=get_authorization_headers(),
)

# If the response was successful, no Exception will be raised
if resp.status_code >= 400:
Expand Down Expand Up @@ -142,7 +175,7 @@ def get_workspace_client(workspace_url: str, config: Optional[Config] = None) ->
if config:
return WorkspaceClient(config=config)

return WorkspaceClient(username=get_deploy_user(), password=get_password(), host=workspace_url)
return WorkspaceClient(client_id=get_client_id(), client_secret=get_client_secret(), host=workspace_url)


def get_account_client(
Expand All @@ -156,4 +189,9 @@ def get_account_client(
if config:
return AccountClient(config=config)

return AccountClient(username=get_deploy_user(), password=get_password(), host=host, account_id=get_account_id())
return AccountClient(
client_id=get_client_id(),
client_secret=get_client_secret(),
host=host,
account_id=get_account_id(),
)
23 changes: 19 additions & 4 deletions aws-lambda/tests/resources/tokens/test_token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from unittest.mock import patch

import src.databricks_cdk.resources.tokens.token
from src.databricks_cdk.resources.tokens.token import (
TokenInfo,
TokenProperties,
Expand Down Expand Up @@ -37,7 +36,12 @@ def test_create_token_not_exist(

patched__create_token.return_value = {
"token_value": "some_value",
"token_info": {"token_id": "some_id", "creation_time": 1234, "expiry_time": 1234, "comment": "some test token"},
"token_info": {
"token_id": "some_id",
"creation_time": 1234,
"expiry_time": 1234,
"comment": "some test token",
},
}
patched_get_existing_tokens.return_value = []

Expand Down Expand Up @@ -103,7 +107,14 @@ def test_create_token_already_exist(
@patch("src.databricks_cdk.resources.tokens.token.get_request")
def test_get_existing_token(patched_get_request):
patched_get_request.return_value = {
"token_infos": [{"token_id": "test", "creation_time": 1, "expiry_time": 2, "comment": "test_comment"}]
"token_infos": [
{
"token_id": "test",
"creation_time": 1,
"expiry_time": 2,
"comment": "test_comment",
}
]
}

token_list = get_existing_tokens(token_url="https://test.cloud.databricks.com/api/2.0/token")
Expand All @@ -118,7 +129,11 @@ def test_get_existing_token(patched_get_request):

@patch("src.databricks_cdk.resources.tokens.token.post_request")
def test__create_token(patched_post_request):
_create_token("https://test.cloud.databricks.com/api/2.0/token", comment="test_comment", lifetime_seconds=1)
_create_token(
"https://test.cloud.databricks.com/api/2.0/token",
comment="test_comment",
lifetime_seconds=1,
)

assert patched_post_request.call_args.kwargs == {"body": {"comment": "test_comment", "lifetime_seconds": 1}}

Expand Down
Loading

0 comments on commit acf24cc

Please sign in to comment.