Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds dbt bootstrap subcommand #1238

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def _catalog_filter_table(cls, table, manifest):
"""
return table.where(_catalog_filter_schemas(manifest))

def get_catalog(self, manifest):
def get_unfiltered_catalog(self, manifest):
"""Get the catalog for this manifest by running the get catalog macro.
Returns an agate.Table of catalog information.
"""
Expand All @@ -696,6 +696,14 @@ def get_catalog(self, manifest):
finally:
self.release_connection(GET_CATALOG_MACRO_NAME)

return table

def get_catalog(self, manifest):
"""Get the catalog for this manifest by running the get catalog macro.
Returns an agate.Table of catalog information filtered to schemas in the manifest
"""
table = self.get_unfiltered_catalog(manifest)

results = self._catalog_filter_table(table, manifest)
return results

Expand Down
30 changes: 30 additions & 0 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import dbt.task.debug as debug_task
import dbt.task.clean as clean_task
import dbt.task.deps as deps_task
import dbt.task.bootstrap as bootstrap_task
import dbt.task.init as init_task
import dbt.task.seed as seed_task
import dbt.task.test as test_task
Expand Down Expand Up @@ -440,6 +441,35 @@ def parse_args(args):
"listed in packages.yml")
sub.set_defaults(cls=deps_task.DepsTask, which='deps')

bootstrap_sub = subs.add_parser(
'bootstrap',
parents=[base_subparser],
help="Bootstrap schema.yml files from database catalog")
bootstrap_sub.set_defaults(cls=bootstrap_task.BootstrapTask, which='bootsrap')

bootstrap_sub.add_argument(
'--schemas',
required=True,
nargs='+',
help="""
Required. Specify the schemas to inspect when bootstrapping
schema.yml files.
"""
)
bootstrap_sub.add_argument(
'--single-file',
action='store_true',
dest='single_file',
help='Store all of the schema information in a single schema.yml file'
)
bootstrap_sub.add_argument(
'--print-only',
action='store_true',
dest='print_only',
help="Print generated yml to console. Don't attempt to create schema.yml files."
)


sub = subs.add_parser(
'archive',
parents=[base_subparser],
Expand Down
153 changes: 153 additions & 0 deletions core/dbt/task/bootstrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from __future__ import print_function
import os
import oyaml as yaml # NOTE: New dependency

from dbt.adapters.factory import get_adapter
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.node_types import NodeType
from dbt.utils import is_enabled as check_is_enabled
from dbt.task.generate import unflatten # NOTE: Should we move this somewhere else?

import dbt.ui.printer
from dbt.task.base_task import BaseTask


class BootstrapTask(BaseTask):
def _get_manifest(self):
compiler = dbt.compilation.Compiler(self.config)
compiler.initialize()

all_projects = compiler.get_all_projects()

manifest = dbt.loader.GraphLoader.load_all(self.config, all_projects)
return manifest

def _convert_single_relation_dict_to_yml(self, relation_dict):
to_yaml = {}
to_yaml["version"] = 2
to_yaml["models"] = relation_dict
# NOTE: Do we want to increase the indentation?
# https://stackoverflow.com/questions/25108581/python-yaml-dump-bad-indentation
return yaml.dump(to_yaml, default_flow_style=False)

def write_relation(self, design_file_path, relation_dict):
if os.path.isfile(design_file_path):
logger.info(
dbt.ui.printer.yellow(
"Warning: File {} already exists. Skipping".format(design_file_path)
)
)
return

logger.info("Creating design file: {}".format(design_file_path))

yml = self._convert_single_relation_dict_to_yml(relation_dict)
with open(design_file_path, "w") as f:
f.write(yml)

def print_relation(self, relation_dict):
yml = self._convert_single_relation_dict_to_yml(relation_dict)
logger.info(yml)

def prep_metadata(self, meta_dict):
columns = []
for colname in meta_dict["columns"]:
column = {}
column["name"] = colname
columns.append(column)

model = {}
model["name"] = meta_dict["metadata"]["name"]
if meta_dict["metadata"]["comment"]:
description = meta_dict["metadata"]["comment"]
else:
description = "TODO: Replace me"

model["description"] = description
model["columns"] = columns

return model

def run(self):
single_file = self.args.single_file
print_only = self.args.print_only
schemas = self.args.schemas

logger.info("Bootstrapping the following schemas:")
for schema in schemas:
logger.info("- {}".format(schema))

# Look up all of the relations in the DB
manifest = self._get_manifest()
adapter = get_adapter(self.config)
all_relations = adapter.get_unfiltered_catalog(manifest)

selected_relations = all_relations.where(
lambda row: row["table_schema"] in schemas
)

zipped_relations = [
dict(zip(selected_relations.column_names, row))
for row in selected_relations
]

relations_to_design = unflatten(zipped_relations)

if len(relations_to_design) == 0:
logger.info(
dbt.ui.printer.yellow(
"Warning: No relations found in selected schemas: {}."
"\nAborting.".format(schemas)
)
)
return {}

for schema, relations in relations_to_design.items():
schema_path = os.path.join("models", schema)
if print_only:
pass
elif os.path.isdir(schema_path):
logger.info(
dbt.ui.printer.yellow(
"Warning: Directory {} already exists. \n"
"Proceeding with caution.".format(schema_path)
)
)
else:
os.mkdir(schema_path)

all_models = []

for relation, meta_data in relations.items():

relation_dict = self.prep_metadata(meta_data)
all_models.append(relation_dict)

if not single_file:
if print_only:
logger.info("-" * 20)
logger.info(
"Design for relation: {}.{}".format(schema, relation)
)
logger.info("-" * 20)
self.print_relation([relation_dict])
else:
design_file_name = "{}.yml".format(relation)
design_file_path = os.path.join(schema_path, design_file_name)
self.write_relation(design_file_path, [relation_dict])

if single_file:
if print_only:
logger.info("-" * 20)
logger.info("Design for schmea: {}".format(schema))
logger.info("-" * 20)
self.print_relation(all_models)
else:
design_file_name = "{}.yml".format(schema)
design_file_path = os.path.join(schema_path, design_file_name)
self.write_relation(design_file_path, all_models)

return all_models

def interpret_results(self, results):
return len(results) != 0
1 change: 1 addition & 0 deletions test/integration/041_bootstrap_test/models/model_a.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT 1 AS alpha, 2 AS beta;
1 change: 1 addition & 0 deletions test/integration/041_bootstrap_test/models/model_b.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT *, 3 AS gamma FROM {{ref('model_a')}}
38 changes: 38 additions & 0 deletions test/integration/041_bootstrap_test/test_bootstrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from test.integration.base import DBTIntegrationTest, use_profile
import os


class TestBootstrap(DBTIntegrationTest):
@property
def schema(self):
return "config_041"

def unique_schema(self):
return super(TestBootstrap, self).unique_schema()

def tearDown(self):
files = os.listdir(self.models())
for f in files:
if f.endswith(".yml"):
os.remove(self.dir('models/'+f))

@staticmethod
def dir(path):
return "test/integration/010_bootstrap_test/" + path.lstrip("/")

@property
def models(self):
return self.dir("models")

def check_bootstrap_completeness(self):
self.run_dbt(["run"])
results = self.run_dbt(["bootstrap", '--schemas', self.schema])

self.assertTrue(os.path.isfile(self.path('models/model_a.yml')))
self.assertTrue(os.path.isfile(self.path('models/model_b.yml')))

import ipdb; ipdb.set_trace()


def test_late_binding_view(self):
pass