Skip to content

Commit

Permalink
Fix adapter reset race condition in lib.py (#5921)
Browse files Browse the repository at this point in the history
* (#5919) Fix adapter reset race condition in lib.py

* run black

* changie

(cherry picked from commit 4e8aa00)
  • Loading branch information
drewbanin authored and github-actions[bot] committed Sep 26, 2022
1 parent 07415ca commit 8506e73
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Fixes-20220923-174504.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Fixes
body: Fix race condition when invoking dbt via lib.py concurrently
time: 2022-09-23T17:45:04.405026-04:00
custom:
Author: drewbanin
Issue: "5919"
PR: "5921"
61 changes: 34 additions & 27 deletions core/dbt/lib.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# TODO: this file is one big TODO
import os
from dbt.exceptions import RuntimeException
from dbt import flags
from collections import namedtuple
from dataclasses import dataclass

RuntimeArgs = namedtuple("RuntimeArgs", "project_dir profiles_dir single_threaded profile target")

@dataclass
class RuntimeArgs:
project_dir: str
profiles_dir: str
single_threaded: bool
profile: str
target: str


def get_dbt_config(project_dir, args=None, single_threaded=False):
Expand All @@ -17,27 +23,30 @@ def get_dbt_config(project_dir, args=None, single_threaded=False):
else:
profiles_dir = os.path.expanduser("~/.dbt")

profile = args.profile if hasattr(args, "profile") else None
target = args.target if hasattr(args, "target") else None

# Construct a phony config
config = RuntimeConfig.from_args(
RuntimeArgs(project_dir, profiles_dir, single_threaded, profile, target)
runtime_args = RuntimeArgs(
project_dir=project_dir,
profiles_dir=profiles_dir,
single_threaded=single_threaded,
profile=getattr(args, "profile", None),
target=getattr(args, "target", None),
)
# Clear previously registered adapters--
# this fixes cacheing behavior on the dbt-server

# Construct a RuntimeConfig from phony args
config = RuntimeConfig.from_args(runtime_args)

# Set global flags from arguments
flags.set_from_args(args, config)
dbt.adapters.factory.reset_adapters()
# Load the relevant adapter

# This is idempotent, so we can call it repeatedly
dbt.adapters.factory.register_adapter(config)
# Set invocation id

# Make sure we have a valid invocation_id
dbt.events.functions.set_invocation_id()

return config


def get_task_by_type(type):
# TODO: we need to tell dbt-server what tasks are available
from dbt.task.run import RunTask
from dbt.task.list import ListTask
from dbt.task.seed import SeedTask
Expand Down Expand Up @@ -70,16 +79,13 @@ def create_task(type, args, manifest, config):
def no_op(*args, **kwargs):
pass

# TODO: yuck, let's rethink tasks a little
task = task(args, config)

# Wow! We can monkeypatch taskCls.load_manifest to return _our_ manifest
task.load_manifest = no_op
task.manifest = manifest
return task


def _get_operation_node(manifest, project_path, sql):
def _get_operation_node(manifest, project_path, sql, node_name):
from dbt.parser.manifest import process_node
from dbt.parser.sql import SqlBlockParser
import dbt.adapters.factory
Expand All @@ -92,26 +98,28 @@ def _get_operation_node(manifest, project_path, sql):
)

adapter = dbt.adapters.factory.get_adapter(config)
# TODO : This needs a real name?
sql_node = block_parser.parse_remote(sql, "name")
sql_node = block_parser.parse_remote(sql, node_name)
process_node(config, manifest, sql_node)
return config, sql_node, adapter


def compile_sql(manifest, project_path, sql):
def compile_sql(manifest, project_path, sql, node_name="query"):
from dbt.task.sql import SqlCompileRunner

config, node, adapter = _get_operation_node(manifest, project_path, sql)
config, node, adapter = _get_operation_node(manifest, project_path, sql, node_name)

runner = SqlCompileRunner(config, adapter, node, 1, 1)

return runner.safe_run(manifest)


def execute_sql(manifest, project_path, sql):
def execute_sql(manifest, project_path, sql, node_name="query"):
from dbt.task.sql import SqlExecuteRunner

config, node, adapter = _get_operation_node(manifest, project_path, sql)
config, node, adapter = _get_operation_node(manifest, project_path, sql, node_name)

runner = SqlExecuteRunner(config, adapter, node, 1, 1)
# TODO: use same interface for runner

return runner.safe_run(manifest)


Expand All @@ -128,5 +136,4 @@ def deserialize_manifest(manifest_msgpack):


def serialize_manifest(manifest):
# TODO: what should this take as an arg?
return manifest.to_msgpack()

0 comments on commit 8506e73

Please sign in to comment.