Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cancel python model job when dbt exit #690

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions dbt/adapters/databricks/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@

class token_auth(CredentialsProvider):
_token: str
_host: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why store this on the token? It's already on the DatabricksCredentials.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like I can't get DatabricksCredentials in DatabricksConnectionManager.

I can just use self.credentials_provider in DatabricksConnectionManager, but there are no host in credentials_provider, so I put a host in the token_auth class.

Could you give me a direction how to get DatabricksCredentials in DatabricksConnectionManager?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BaseDatabricksHelper has a copy of DatabricksCredentials.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, but that's an instance...let me think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to pull down a copy of this PR and see if i can figure it out.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes..think I can't get instance in DatabricksConnectionManager...

Thank you very much!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, did you already fix this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I can't found a way.. so I still use the way put host in token_auth. so that the host can be retrieve from DatabricksConnectionManager.credentials_provider.


def __init__(self, token: str) -> None:
def __init__(self, token: str, host: str) -> None:
self._token = token
self._host = host

def auth_type(self) -> str:
return "token"

def as_dict(self) -> dict:
return {"token": self._token}
return {"token": self._token, "host": self._host}

@staticmethod
def from_dict(raw: Optional[dict]) -> Optional[CredentialsProvider]:
if not raw:
return None
return token_auth(raw["token"])
return token_auth(raw["token"], raw["host"])

def __call__(self, _: Optional[Config] = None) -> HeaderFactory:
static_credentials = {"Authorization": f"Bearer {self._token}"}
Expand Down
10 changes: 10 additions & 0 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
from dbt_common.utils import cast_to_str
from requests import Session

from dbt.adapters.databricks.python_submissions import BaseDatabricksHelper

if TYPE_CHECKING:
from agate import Table

Expand Down Expand Up @@ -475,6 +477,14 @@ class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
credentials_provider: Optional[TCredentialProvider] = None

def cancel_open(self) -> List[str]:
for run_id in BaseDatabricksHelper.run_ids:
logger.debug(f"Cancel python model job: {run_id}")
BaseDatabricksHelper.cancel_run_id(run_id, self.credentials_provider.as_dict()['token'], self.credentials_provider.as_dict()['host'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, this is where you need to retrieve, and here you don't have an instance...we can use singleton pattern maybe?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give me an hour to take a crack at refactoring this; I have an idea :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Singleton maybe its a way, Let me try some code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah! thank you very much!

BaseDatabricksHelper.run_ids.clear()
return super().cancel_open()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below you raise an exception on a non-200, but that will interrupt cancelling the other operations. Better to log a warning on non-200 I think.



def compare_dbr_version(self, major: int, minor: int) -> int:
version = (major, minor)

Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia
self._lock.acquire()
try:
if self.token:
provider = token_auth(self.token)
provider = token_auth(self.token, self.host)
self._credentials_provider = provider.as_dict()
return provider

Expand Down
54 changes: 54 additions & 0 deletions dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@
from dbt.adapters.databricks.credentials import TCredentialProvider
from dbt.adapters.databricks.logging import logger

import threading


DEFAULT_POLLING_INTERVAL = 10
SUBMISSION_LANGUAGE = "python"
DEFAULT_TIMEOUT = 60 * 60 * 24


class BaseDatabricksHelper(PythonJobHelper):

run_ids = list()
_lock = threading.Lock() # to avoid concurrent issue

def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None:
self.credentials = credentials
self.identifier = parsed_model["alias"]
Expand All @@ -45,6 +51,7 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
self.extra_headers = {
"User-Agent": f"dbt-databricks/{version}",
}
host = credentials.host

@property
def cluster_id(self) -> str:
Expand Down Expand Up @@ -96,6 +103,7 @@ def _upload_notebook(self, path: str, compiled_code: str) -> None:
if response.status_code != 200:
raise DbtRuntimeError(f"Error creating python notebook.\n {response.content!r}")


def _submit_job(self, path: str, cluster_spec: dict) -> str:
job_spec = {
"run_name": f"{self.schema}-{self.identifier}-{uuid.uuid4()}",
Expand Down Expand Up @@ -146,6 +154,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No

# submit job
run_id = self._submit_job(whole_file_path, cluster_spec)
self._insert_run_ids(run_id)

self.polling(
status_func=self.session.get,
Expand Down Expand Up @@ -173,6 +182,51 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No
"match the line number in your code due to dbt templating)\n"
f"{utils.remove_ansi(json_run_output['error_trace'])}"
)
self._remove_run_ids(run_id)


def _remove_run_ids(self, run_id: str):
self._lock.acquire()
try:
self.run_ids.remove(run_id)
except Exception as e:
logger.warning(e)
finally:
self._lock.release()


def _insert_run_ids(self, run_id: str):
self._lock.acquire()
try:
self.run_ids.append(run_id)
except Exception as e:
logger.warning(e)
finally:
self._lock.release()

@staticmethod
def cancel_run_id(run_id: str, token: str, host: str) -> None:
retry_strategy = Retry(total=4, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
session = Session()
session.mount("https://", adapter)
extra_headers = {
"User-Agent": f"dbt-databricks/{version}",
"Authorization": f"Bearer {token}"
}

response = session.post(
f"https://{host}/api/2.0/jobs/runs/cancel",
headers=extra_headers,
json={
"run_id": run_id
},
)

print(response.status_code)
if response.status_code != 200:
logger.warning(f"Cancel run id failed.\n {response.content!r}")


def submit(self, compiled_code: str) -> None:
raise NotImplementedError(
Expand Down