Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace password authentication with oauth #1069

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ Deploy the stack to your AWS account:
```shell
cdk deploy
```
Note: The AWS SSM parameters `/databricks/deploy/user`, `/databricks/deploy/password`, and `/databricks/account-id` are required for the deployment to succeed.
Note: The AWS SSM parameters `/databricks/deploy/client-id`, `/databricks/deploy/client-secret`, and `/databricks/account-id` are required for the deployment to succeed.

- `/databricks/deploy/client-id` is the client-id of service principal that is account admin and workspace admin
- `/databricks/deploy/client-secret` is the client-secret of service principal that is account admin and workspace admin
- `/databricks/account-id` is the id of your databricks account

See also the simple-workspace and multi-stack examples in [examples](examples)

Expand Down
10 changes: 6 additions & 4 deletions aws-lambda/src/databricks_cdk/resources/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import requests
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_auth, get_request, post_request
from databricks_cdk.utils import CnfResponse, get_authorization_headers, get_request, post_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,8 +86,11 @@ def get_cluster_by_name(cluster_name: str, workspace_url: str):
def get_cluster_by_id(cluster_id: str, workspace_url: str) -> Optional[dict]:
"""Getting cluster based on name"""
body = {"cluster_id": cluster_id}
auth = get_auth()
resp = requests.get(f"{get_cluster_url(workspace_url)}/get", json=body, headers={}, auth=auth)
resp = requests.get(
f"{get_cluster_url(workspace_url)}/get",
json=body,
headers=get_authorization_headers(),
)
if resp.status_code == 400 and "does not exist" in resp.text:
return None
resp.raise_for_status()
Expand All @@ -100,7 +103,6 @@ def create_or_update_cluster(properties: ClusterProperties, physical_resource_id
if physical_resource_id is not None:
current = get_cluster_by_id(physical_resource_id, properties.workspace_url)
if current is None:

# Json data
body = properties.cluster.dict()
response = post_request(f"{get_cluster_url(properties.workspace_url)}/create", body=body)
Expand Down
9 changes: 6 additions & 3 deletions aws-lambda/src/databricks_cdk/resources/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_auth, post_request
from databricks_cdk.utils import CnfResponse, get_authorization_headers, post_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -175,8 +175,11 @@ def get_job_url(workspace_url: str):

def get_job_by_id(job_id: str, workspace_url: str):
body = {"job_id": job_id}
auth = get_auth()
resp = requests.get(f"{get_job_url(workspace_url)}/get", json=body, headers={}, auth=auth)
resp = requests.get(
f"{get_job_url(workspace_url)}/get",
json=body,
headers=get_authorization_headers(),
)
if resp.status_code == 400 and "does not exist" in resp.text:
return None
resp.raise_for_status()
Expand Down
15 changes: 2 additions & 13 deletions aws-lambda/src/databricks_cdk/resources/scim/user.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from typing import Optional

import requests
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_auth, get_request, post_request
from databricks_cdk.utils import CnfResponse, delete_request, get_request, post_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,7 +40,6 @@ def create_or_update_user(properties: UserProperties) -> UserResponse:

current = get_user_by_user_name(properties.user_name, properties.workspace_url)
if current is None:

# Json data
body = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
Expand All @@ -58,16 +56,7 @@ def delete_user(properties: UserProperties, physical_resource_id: str) -> CnfRes
"""Deletes user at databricks"""
current = get_user_by_user_name(properties.user_name, properties.workspace_url)
if current is not None:
auth = get_auth()
if auth.username != properties.user_name:
resp = requests.delete(
f"{get_user_url(properties.workspace_url)}/{current['id']}",
headers={},
auth=auth,
)
resp.raise_for_status()
else:
logger.warning("Can't remove deploy user")
delete_request(f"{get_user_url(properties.workspace_url)}/{current['id']}")
else:
logger.warning("Already removed")
return CnfResponse(physical_resource_id=physical_resource_id)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
from typing import Dict, Optional

import requests.exceptions
from pydantic import BaseModel, Field

from databricks_cdk.utils import CnfResponse, delete_request, get_request, patch_request, post_request
Expand Down
72 changes: 54 additions & 18 deletions aws-lambda/src/databricks_cdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
import logging
import os
from functools import lru_cache
from typing import Any, Dict, Optional

import boto3
from databricks.sdk import AccountClient, WorkspaceClient
from databricks.sdk.core import Config
from pydantic import BaseModel
from requests import request
from requests.auth import HTTPBasicAuth
from requests.exceptions import HTTPError
from tenacity import retry, retry_if_exception, retry_if_exception_type, stop_after_attempt, wait_exponential

logger = logging.getLogger(__name__)


USER_PARAM = os.environ.get("USER_PARAM", "/databricks/deploy/user")
PASS_PARAM = os.environ.get("PASS_PARAM", "/databricks/deploy/password")
class UnsupportedAuthMethodError(Exception):
pass


ACCOUNT_PARAM = os.environ.get("ACCOUNT_PARAM", "/databricks/account-id")
CLIENT_SECRET_PARAM = os.environ.get("CLIENT_SECRET_PARAM", "/databricks/deploy/client-secret")

# Make sure password based authentication is not used anymore after 10th of July 2024
if CLIENT_SECRET_PARAM is None:
raise UnsupportedAuthMethodError(
"Password based authentication is not supported from 10th of July 2024.",
"Please set client_id and client_secret of a service principal instead.",
"Service principal can be created in account settings in Databricks."
"Needs account admin and admin permissions on workspace.",
)

CLIENT_ID_PARAM = os.environ.get("CLIENT_ID_PARAM", "/databricks/deploy/client-id")
ACCOUNTS_BASE_URL = os.environ.get("BASE_URL", "https://accounts.cloud.databricks.com")


Expand All @@ -33,6 +46,26 @@ def get_param(name: str, required: bool = False):
return result


@lru_cache(maxsize=1)
def get_authentication_config() -> Config:
"""
This config can be used to authenticate with databricks using
requests library. Not needed when using WorkspaceClient or AccountClient.
Config is cached to avoid multiple calls to get_param
"""
return Config(
host=ACCOUNTS_BASE_URL,
client_id=get_client_id(),
client_secret=get_client_secret(),
account_id=get_account_id(),
)


def get_authorization_headers() -> Dict[str, str]:
"""Get authorization headers"""
return get_authentication_config().authenticate()


class CnfResponse(BaseModel):
physical_resource_id: str

Expand All @@ -42,19 +75,12 @@ def get_account_id() -> str:
return get_param(ACCOUNT_PARAM, required=True)


def get_deploy_user() -> str:
return get_param(USER_PARAM, required=True)

def get_client_secret() -> str:
return get_param(CLIENT_SECRET_PARAM, required=True)

def get_password() -> str:
return get_param(PASS_PARAM, required=True)


def get_auth() -> HTTPBasicAuth:
"""Get auth from param store"""
user = get_deploy_user()
password = get_param(PASS_PARAM, required=True)
return HTTPBasicAuth(user, password)
def get_client_id() -> str:
return get_param(CLIENT_ID_PARAM, required=True)


@retry(
Expand All @@ -80,8 +106,13 @@ def _do_request(
:raises ValueError: If provided method is not supported
:return: Response data
"""
auth = get_auth()
resp = request(method=method, url=url, json=body, params=params, auth=auth)
resp = request(
method=method,
url=url,
json=body,
params=params,
headers=get_authorization_headers(),
)

# If the response was successful, no Exception will be raised
if resp.status_code >= 400:
Expand Down Expand Up @@ -142,7 +173,7 @@ def get_workspace_client(workspace_url: str, config: Optional[Config] = None) ->
if config:
return WorkspaceClient(config=config)

return WorkspaceClient(username=get_deploy_user(), password=get_password(), host=workspace_url)
return WorkspaceClient(client_id=get_client_id(), client_secret=get_client_secret(), host=workspace_url)


def get_account_client(
Expand All @@ -156,4 +187,9 @@ def get_account_client(
if config:
return AccountClient(config=config)

return AccountClient(username=get_deploy_user(), password=get_password(), host=host, account_id=get_account_id())
return AccountClient(
client_id=get_client_id(),
client_secret=get_client_secret(),
host=host,
account_id=get_account_id(),
)
23 changes: 19 additions & 4 deletions aws-lambda/tests/resources/tokens/test_token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from unittest.mock import patch

import src.databricks_cdk.resources.tokens.token
from src.databricks_cdk.resources.tokens.token import (
TokenInfo,
TokenProperties,
Expand Down Expand Up @@ -37,7 +36,12 @@ def test_create_token_not_exist(

patched__create_token.return_value = {
"token_value": "some_value",
"token_info": {"token_id": "some_id", "creation_time": 1234, "expiry_time": 1234, "comment": "some test token"},
"token_info": {
"token_id": "some_id",
"creation_time": 1234,
"expiry_time": 1234,
"comment": "some test token",
},
}
patched_get_existing_tokens.return_value = []

Expand Down Expand Up @@ -103,7 +107,14 @@ def test_create_token_already_exist(
@patch("src.databricks_cdk.resources.tokens.token.get_request")
def test_get_existing_token(patched_get_request):
patched_get_request.return_value = {
"token_infos": [{"token_id": "test", "creation_time": 1, "expiry_time": 2, "comment": "test_comment"}]
"token_infos": [
{
"token_id": "test",
"creation_time": 1,
"expiry_time": 2,
"comment": "test_comment",
}
]
}

token_list = get_existing_tokens(token_url="https://test.cloud.databricks.com/api/2.0/token")
Expand All @@ -118,7 +129,11 @@ def test_get_existing_token(patched_get_request):

@patch("src.databricks_cdk.resources.tokens.token.post_request")
def test__create_token(patched_post_request):
_create_token("https://test.cloud.databricks.com/api/2.0/token", comment="test_comment", lifetime_seconds=1)
_create_token(
"https://test.cloud.databricks.com/api/2.0/token",
comment="test_comment",
lifetime_seconds=1,
)

assert patched_post_request.call_args.kwargs == {"body": {"comment": "test_comment", "lifetime_seconds": 1}}

Expand Down
Loading
Loading