Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add task status callback #164

Merged
merged 13 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20230206-120426.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Add new task status callback functionality to the async dbt endpoint
time: 2023-02-06T12:04:26.954999-05:00
custom:
Author: jp-dbt
Issue: "165"
PR: "164"
21 changes: 4 additions & 17 deletions dbt_server/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,13 @@ def create_task(db: Session, task: schemas.Task):
return db_task


def set_task_running(db: Session, task: schemas.Task):
def set_task_state(db: Session, task: schemas.Task, state: models.TaskState, error: str):
db_task = get_task(db, task.task_id)
db_task.state = models.TaskState.RUNNING
db.commit()
db.refresh(db_task)
return db_task


def set_task_done(db: Session, task: schemas.Task):
db_task = get_task(db, task.task_id)
db_task.state = models.TaskState.FINISHED
db.commit()
db.refresh(db_task)
return db_task
db_task.state = state

if error:
db_task.error = error

def set_task_errored(db: Session, task: schemas.Task, error: str):
db_task = get_task(db, task.task_id)
db_task.state = models.TaskState.ERROR
db_task.error = error
db.commit()
db.refresh(db_task)
return db_task
24 changes: 20 additions & 4 deletions dbt_server/services/dbt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import dbt.tracking
import dbt.lib
import dbt.adapters.factory
import requests
from requests.adapters import HTTPAdapter

from sqlalchemy.orm import Session
from urllib3 import Retry

# These exceptions were removed in v1.4
try:
Expand Down Expand Up @@ -49,7 +52,7 @@
from dbt_server.services import filesystem_service
from dbt_server.logging import DBT_SERVER_LOGGER as logger
from dbt_server.helpers import get_profile_name
from dbt_server import crud, tracer
from dbt_server import crud, tracer, models
from dbt.lib import load_profile_project
from dbt.cli.main import dbtRunner

Expand Down Expand Up @@ -249,6 +252,7 @@ def execute_async_command(
manifest: Any,
db: Session,
state_id: Optional[str] = None,
callback_url: Optional[str] = None
) -> None:
db_task = crud.get_task(db, task_id)
# For commands, only the log file destination directory is sent to --log-path
Expand All @@ -270,20 +274,20 @@ def execute_async_command(
profile_name = get_profile_name()
profile, project = load_profile_project(root_path, profile_name)

crud.set_task_running(db, db_task)
update_task_status(db, db_task, callback_url, models.TaskState.RUNNING, None)

logger.info(f"Running dbt ({task_id}) - kicking off task")

try:
dbt = dbtRunner(project, profile, manifest)
_, _ = dbt.invoke(new_command)
except RuntimeException as e:
crud.set_task_errored(db, db_task, str(e))
update_task_status(db, db_task, callback_url, models.TaskState.ERROR, str(e))
raise e

logger.info(f"Running dbt ({task_id}) - done")

crud.set_task_done(db, db_task)
update_task_status(db, db_task, callback_url, models.TaskState.FINISHED, None)


@tracer.wrap
Expand All @@ -301,3 +305,15 @@ def execute_sync_command(command: List, root_path: str, manifest: Any):

dbt = dbtRunner(project, profile, manifest)
return dbt.invoke(command)


def update_task_status(db, db_task, callback_url, status, error):
crud.set_task_state(db, db_task, status, error)

if callback_url:
retries = Retry(total=5, allowed_methods=frozenset(['POST']))

session = requests.Session()
session.mount("http://", HTTPAdapter(max_retries=retries))
session.post(callback_url, json={"task_id": db_task.task_id, "status": status})

4 changes: 2 additions & 2 deletions dbt_server/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def execute_query(self, query):
return dbt_service.execute_sql(self.manifest, self.root_path, query)

@tracer.wrap
def execute_async_command(self, task_id, command, db) -> None:
def execute_async_command(self, task_id, command, db, callback_url) -> None:
return dbt_service.execute_async_command(
command, task_id, self.root_path, self.manifest, db, self.state_id
command, task_id, self.root_path, self.manifest, db, self.state_id, callback_url
)

@tracer.wrap
Expand Down
23 changes: 17 additions & 6 deletions dbt_server/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ class SQLConfig(BaseModel):
profile: Optional[str] = None


class dbtCommandArgs(BaseModel):
class DbtCommandArgs(BaseModel):
command: List[Any]
state_id: Optional[str]
# TODO: Need to handle this differently
profile: Optional[str]
callback_url: Optional[str]


@app.exception_handler(InvalidConfigurationException)
Expand Down Expand Up @@ -193,9 +194,9 @@ def parse_project(args: ParseArgs):
)


@app.post("/async/dbt", response_model=schemas.Task)
@app.post("/async/dbt")
async def dbt_entry_async(
args: dbtCommandArgs,
args: DbtCommandArgs,
background_tasks: BackgroundTasks,
db: Session = Depends(crud.get_db),
):
Expand All @@ -216,12 +217,22 @@ async def dbt_entry_async(
if db_task:
raise HTTPException(status_code=400, detail="Task already registered")

background_tasks.add_task(state.execute_async_command, task_id, args.command, db)
return crud.create_task(db, task)
background_tasks.add_task(state.execute_async_command, task_id, args.command, db, args.callback_url)
created_task = crud.create_task(db, task)
return JSONResponse(
status_code=200,
content={
"task_id": created_task.task_id,
"state_id": state.state_id,
"state": created_task.state,
"command": created_task.command,
"log_path": created_task.log_path
},
)


@app.post("/sync/dbt")
async def dbt_entry_sync(args: dbtCommandArgs):
async def dbt_entry_sync(args: DbtCommandArgs):
# example body: {"command":["list", "--output", "json"]}
state = StateController.load_state(args)
# TODO: See what if any useful info is returned when there's no success
Expand Down
3 changes: 1 addition & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ ipdb==0.13.9
pytest==7.1.2
pre-commit==2.20.0
pre-commit-hooks==4.3.0
reorder_python_imports==3.8.2
requests==2.26.0
reorder_python_imports==3.8.2
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ sse-starlette==0.9.0
gunicorn==20.1.0
uvicorn[standard]==0.18.3
websockets==10.0
psutil==5.9.2
psutil==5.9.2
requests==2.26.0
4 changes: 2 additions & 2 deletions tests/e2e/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unittest import TestCase

from dbt_server.state import StateController, LAST_PARSED
from dbt_server.views import dbtCommandArgs
from dbt_server.views import DbtCommandArgs
from .helpers import profiles_dir
from .fixtures import Profiles
import os
Expand All @@ -30,7 +30,7 @@ def tearDown(self):
def test_load_state(self):
# CURRENTLY USING SNOWFLAKE DUE TO DBT VERSION MISMATCH WITH POSTGRES
with profiles_dir(Profiles.Snowflake):
args = dbtCommandArgs(command=["run"], state_id=self.state_id)
args = DbtCommandArgs(command=["run"], state_id=self.state_id)
result = StateController.load_state(args)

assert result.state_id == self.state_id
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_dbt_entry_async_project_path(self, mock_tracking):
state.serialize_manifest()
state.update_cache()

args = views.dbtCommandArgs(command=["run", "--threads", 1])
args = views.DbtCommandArgs(command=["run", "--threads", 1])
response = self.client.post("/async/dbt", json=args.dict())

self.assertEqual(response.status_code, 200)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_dbt_entry_state_id(self, mock_tracking, mock_execute):
state.serialize_manifest()
state.update_cache()

args = views.dbtCommandArgs(command=["run", "--threads", 1])
args = views.DbtCommandArgs(command=["run", "--threads", 1])
response = self.client.post("/async/dbt", json=args.dict())

self.assertEqual(response.status_code, 200)
Expand All @@ -132,7 +132,7 @@ def test_dbt_entry_no_state_found(self):
Test that calling the async/dbt endpoint without first calling parse
results in a properly handled StateNotFoundException
"""
args = views.dbtCommandArgs(command=["run", "--threads", 1])
args = views.DbtCommandArgs(command=["run", "--threads", 1])
response = self.client.post("/async/dbt", json=args.dict())
self.assertEqual(response.status_code, 422)

Expand Down Expand Up @@ -164,7 +164,7 @@ def test_dbt_entry_sync_project_path(self, mock_tracking):
state.serialize_manifest()
state.update_cache()

args = views.dbtCommandArgs(command=["run", "--threads", 1])
args = views.DbtCommandArgs(command=["run", "--threads", 1])
response = self.client.post("/sync/dbt", json=args.dict())

self.assertEqual(response.status_code, 200)
Expand Down