From cec0320bb62df42e790d3f50177e9ecfa5e6a1ec Mon Sep 17 00:00:00 2001 From: Rachel <41338402+racheldaniel@users.noreply.github.com> Date: Fri, 27 Jan 2023 09:51:19 -0600 Subject: [PATCH] Fix profile for async endpoint (#157) * Updates async endpoint to use set_profile_name function * Adds checkfirst flag to avoid table exists error * Fixes profile name and potential fix for sqlalchemy error * Adds profile back to command args * Fixes whitespace * Adds status endpoint --- dbt_server/helpers.py | 22 ++++++++-------------- dbt_server/models.py | 1 - dbt_server/server.py | 8 +++++++- dbt_server/services/dbt_service.py | 11 +++++++++-- dbt_server/state.py | 5 ++++- dbt_server/views.py | 16 ++++++++++++++-- 6 files changed, 42 insertions(+), 21 deletions(-) diff --git a/dbt_server/helpers.py b/dbt_server/helpers.py index c281630..3449ece 100644 --- a/dbt_server/helpers.py +++ b/dbt_server/helpers.py @@ -1,10 +1,5 @@ import os from dbt_server.exceptions import InternalException -from pydantic import BaseModel - - -class Args(BaseModel): - profile: str = None def extract_compiled_code_from_node(result_node_dict): @@ -23,13 +18,12 @@ def extract_compiled_code_from_node(result_node_dict): return compiled_code -def set_profile_name(args=None): - # If no profile name is passed in args, we will attempt to set it from env vars - # If no profile is set, dbt will default to reading from dbt_project.yml +def get_profile_name(args=None): + # If no profile name is passed in args, we will attempt to get it from env vars + # If profile is None, dbt will default to reading from dbt_project.yml if args and hasattr(args, "profile") and args.profile: - return args - if os.getenv("DBT_PROFILE_NAME"): - if args is None: - args = Args() - args.profile = os.getenv("DBT_PROFILE_NAME") - return args + return args.profile + env_profile_name = os.getenv("DBT_PROFILE_NAME") + if env_profile_name: + return env_profile_name + return None diff --git a/dbt_server/models.py b/dbt_server/models.py index 050a25a..51bbe44 100644 --- a/dbt_server/models.py +++ b/dbt_server/models.py @@ -1,6 +1,5 @@ from sqlalchemy import Column, String from enum import Enum - from .database import Base diff --git a/dbt_server/server.py b/dbt_server/server.py index bee6271..2f48751 100644 --- a/dbt_server/server.py +++ b/dbt_server/server.py @@ -11,9 +11,15 @@ from dbt_server.logging import DBT_SERVER_LOGGER as logger, configure_uvicorn_access_log from dbt_server.state import LAST_PARSED from dbt_server.exceptions import StateNotFoundException +from sqlalchemy.exc import OperationalError +# The default checkfirst=True should handle this, however we still +# see a table exists error from time to time +try: + models.Base.metadata.create_all(bind=engine, checkfirst=True) +except OperationalError as err: + logger.debug(f"Handled error when creating database: {str(err)}") -models.Base.metadata.create_all(bind=engine) dbt_service.disable_tracking() diff --git a/dbt_server/services/dbt_service.py b/dbt_server/services/dbt_service.py index 8a641f4..65f06ca 100644 --- a/dbt_server/services/dbt_service.py +++ b/dbt_server/services/dbt_service.py @@ -44,7 +44,8 @@ dbtCoreCompilationException, UnsupportedQueryException, ) -from dbt_server.helpers import set_profile_name +from dbt_server.helpers import get_profile_name +from pydantic import BaseModel ALLOW_INTROSPECTION = str(os.environ.get("__DBT_ALLOW_INTROSPECTION", "1")).lower() in ( "true", @@ -54,6 +55,9 @@ CONFIG_GLOBAL_LOCK = threading.Lock() +class Args(BaseModel): + profile: str = None + def inject_dd_trace_into_core_lib(): @@ -115,7 +119,10 @@ def get_sql_parser(config, manifest): @tracer.wrap def create_dbt_config(project_path, args=None): try: - args = set_profile_name(args) + if not args: + args = Args() + if hasattr(args, "profile"): + args.profile = get_profile_name(args) # This needs a lock to prevent two threads from mutating an adapter concurrently with CONFIG_GLOBAL_LOCK: return dbt_get_dbt_config(project_path, args) diff --git a/dbt_server/state.py b/dbt_server/state.py index 71015af..3afffdf 100644 --- a/dbt_server/state.py +++ b/dbt_server/state.py @@ -2,6 +2,7 @@ from dbt_server.services import filesystem_service, dbt_service from dbt_server.exceptions import StateNotFoundException from dbt_server.logging import DBT_SERVER_LOGGER as logger +from dbt_server.helpers import get_profile_name from dbt_server import crud, tracer from dbt.lib import load_profile_project from dbt.cli.main import dbtRunner @@ -218,7 +219,9 @@ def execute_async_command(self, task_id, command, db): logger.info(f"Running dbt ({task_id}) - deserializing manifest {self.serialize_path}") - profile, project = load_profile_project(self.root_path, os.getenv("DBT_PROFILE_NAME", "user"),) + # TODO: If a command contains a --profile flag, how should we access/pass it? + profile_name = get_profile_name() + profile, project = load_profile_project(self.root_path, profile_name) crud.set_task_running(db, db_task) diff --git a/dbt_server/views.py b/dbt_server/views.py index 56af38c..c7d3e8f 100644 --- a/dbt_server/views.py +++ b/dbt_server/views.py @@ -79,6 +79,8 @@ class SQLConfig(BaseModel): class dbtCommandArgs(BaseModel): command: List[Any] state_id: Optional[str] + # TODO: Need to handle this differently + profile: Optional[str] @app.exception_handler(InvalidConfigurationException) @@ -270,5 +272,15 @@ def get_manifest_metadata(state): } -class Task(BaseModel): - task_id: str +@app.get("/status/{task_id}") +def get_task_status( + task_id: str, + db: Session = Depends(crud.get_db), +): + task = crud.get_task(db, task_id) + return JSONResponse( + status_code=200, + content={ + "status": task.state + } + ) \ No newline at end of file