Skip to content

Commit

Permalink
Fix profile for async endpoint (#157)
Browse files Browse the repository at this point in the history
* Updates async endpoint to use set_profile_name function

* Adds checkfirst flag to avoid table exists error

* Fixes profile name and potential fix for sqlalchemy error

* Adds profile back to command args

* Fixes whitespace

* Adds status endpoint
  • Loading branch information
racheldaniel authored Jan 27, 2023
1 parent fc3a38e commit cec0320
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 21 deletions.
22 changes: 8 additions & 14 deletions dbt_server/helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import os
from dbt_server.exceptions import InternalException
from pydantic import BaseModel


class Args(BaseModel):
profile: str = None


def extract_compiled_code_from_node(result_node_dict):
Expand All @@ -23,13 +18,12 @@ def extract_compiled_code_from_node(result_node_dict):
return compiled_code


def set_profile_name(args=None):
# If no profile name is passed in args, we will attempt to set it from env vars
# If no profile is set, dbt will default to reading from dbt_project.yml
def get_profile_name(args=None):
# If no profile name is passed in args, we will attempt to get it from env vars
# If profile is None, dbt will default to reading from dbt_project.yml
if args and hasattr(args, "profile") and args.profile:
return args
if os.getenv("DBT_PROFILE_NAME"):
if args is None:
args = Args()
args.profile = os.getenv("DBT_PROFILE_NAME")
return args
return args.profile
env_profile_name = os.getenv("DBT_PROFILE_NAME")
if env_profile_name:
return env_profile_name
return None
1 change: 0 additions & 1 deletion dbt_server/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from sqlalchemy import Column, String
from enum import Enum

from .database import Base


Expand Down
8 changes: 7 additions & 1 deletion dbt_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
from dbt_server.logging import DBT_SERVER_LOGGER as logger, configure_uvicorn_access_log
from dbt_server.state import LAST_PARSED
from dbt_server.exceptions import StateNotFoundException
from sqlalchemy.exc import OperationalError

# The default checkfirst=True should handle this, however we still
# see a table exists error from time to time
try:
models.Base.metadata.create_all(bind=engine, checkfirst=True)
except OperationalError as err:
logger.debug(f"Handled error when creating database: {str(err)}")

models.Base.metadata.create_all(bind=engine)
dbt_service.disable_tracking()


Expand Down
11 changes: 9 additions & 2 deletions dbt_server/services/dbt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
dbtCoreCompilationException,
UnsupportedQueryException,
)
from dbt_server.helpers import set_profile_name
from dbt_server.helpers import get_profile_name
from pydantic import BaseModel

ALLOW_INTROSPECTION = str(os.environ.get("__DBT_ALLOW_INTROSPECTION", "1")).lower() in (
"true",
Expand All @@ -54,6 +55,9 @@

CONFIG_GLOBAL_LOCK = threading.Lock()

class Args(BaseModel):
profile: str = None


def inject_dd_trace_into_core_lib():

Expand Down Expand Up @@ -115,7 +119,10 @@ def get_sql_parser(config, manifest):
@tracer.wrap
def create_dbt_config(project_path, args=None):
try:
args = set_profile_name(args)
if not args:
args = Args()
if hasattr(args, "profile"):
args.profile = get_profile_name(args)
# This needs a lock to prevent two threads from mutating an adapter concurrently
with CONFIG_GLOBAL_LOCK:
return dbt_get_dbt_config(project_path, args)
Expand Down
5 changes: 4 additions & 1 deletion dbt_server/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dbt_server.services import filesystem_service, dbt_service
from dbt_server.exceptions import StateNotFoundException
from dbt_server.logging import DBT_SERVER_LOGGER as logger
from dbt_server.helpers import get_profile_name
from dbt_server import crud, tracer
from dbt.lib import load_profile_project
from dbt.cli.main import dbtRunner
Expand Down Expand Up @@ -218,7 +219,9 @@ def execute_async_command(self, task_id, command, db):

logger.info(f"Running dbt ({task_id}) - deserializing manifest {self.serialize_path}")

profile, project = load_profile_project(self.root_path, os.getenv("DBT_PROFILE_NAME", "user"),)
# TODO: If a command contains a --profile flag, how should we access/pass it?
profile_name = get_profile_name()
profile, project = load_profile_project(self.root_path, profile_name)

crud.set_task_running(db, db_task)

Expand Down
16 changes: 14 additions & 2 deletions dbt_server/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class SQLConfig(BaseModel):
class dbtCommandArgs(BaseModel):
command: List[Any]
state_id: Optional[str]
# TODO: Need to handle this differently
profile: Optional[str]


@app.exception_handler(InvalidConfigurationException)
Expand Down Expand Up @@ -270,5 +272,15 @@ def get_manifest_metadata(state):
}


class Task(BaseModel):
task_id: str
@app.get("/status/{task_id}")
def get_task_status(
task_id: str,
db: Session = Depends(crud.get_db),
):
task = crud.get_task(db, task_id)
return JSONResponse(
status_code=200,
content={
"status": task.state
}
)

0 comments on commit cec0320

Please sign in to comment.