forked from dbt-labs/dbt-spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support for pyspark connection method
- Loading branch information
cccs-jc
committed
Mar 30, 2022
1 parent
bbff5c7
commit 7eafc23
Showing
4 changed files
with
186 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
|
||
from __future__ import annotations | ||
|
||
import datetime as dt | ||
from types import TracebackType | ||
from typing import Any | ||
|
||
from dbt.events import AdapterLogger | ||
from dbt.utils import DECIMALS | ||
|
||
|
||
from pyspark.rdd import _load_from_socket | ||
import pyspark.sql.functions as F | ||
|
||
|
||
import importlib | ||
import sqlalchemy | ||
import re | ||
|
||
logger = AdapterLogger("Spark") | ||
NUMBERS = DECIMALS + (int, float) | ||
|
||
|
||
class PysparkConnectionWrapper(object): | ||
"""Wrap a Spark context""" | ||
|
||
def __init__(self, python_module): | ||
self.result = None | ||
if python_module: | ||
logger.debug(f"Loading spark context from python module {python_module}") | ||
module = importlib.import_module(python_module) | ||
create_spark_context = getattr(module, "create_spark_context") | ||
self.spark = create_spark_context() | ||
else: | ||
# Create a default pyspark context | ||
self.spark = SparkSession.builder.getOrCreate() | ||
|
||
def cursor(self): | ||
return self | ||
|
||
def rollback(self, *args, **kwargs): | ||
logger.debug("NotImplemented: rollback") | ||
|
||
def fetchall(self): | ||
try: | ||
rows = self.result.collect() | ||
logger.debug(rows) | ||
except Exception as e: | ||
logger.debug(f"raising error {e}") | ||
dbt.exceptions.raise_database_error(e) | ||
return rows | ||
|
||
def execute(self, sql, bindings=None): | ||
if sql.strip().endswith(";"): | ||
sql = sql.strip()[:-1] | ||
|
||
if bindings is not None: | ||
bindings = [self._fix_binding(binding) for binding in bindings] | ||
sql = sql % tuple(bindings) | ||
logger.debug(f"execute sql:{sql}") | ||
try: | ||
self.result = self.spark.sql(sql) | ||
logger.debug("Executed with no errors") | ||
if "show tables" in sql: | ||
self.result = self.result.withColumn("description", F.lit("")) | ||
except Exception as e: | ||
logger.debug(f"raising error {e}") | ||
dbt.exceptions.raise_database_error(e) | ||
|
||
@classmethod | ||
def _fix_binding(cls, value): | ||
"""Convert complex datatypes to primitives that can be loaded by | ||
the Spark driver""" | ||
if isinstance(value, NUMBERS): | ||
return float(value) | ||
elif isinstance(value, datetime): | ||
return "'" + value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + "'" | ||
elif isinstance(value, str): | ||
return "'" + value + "'" | ||
else: | ||
logger.debug(type(value)) | ||
return "'" + str(value) + "'" | ||
|
||
@property | ||
def description(self): | ||
logger.debug(f"Description called returning list of columns: {self.result.columns}") | ||
ret = [] | ||
# Not sure the type is ever used by specifying it anyways | ||
string_type = sqlalchemy.types.String | ||
for column_name in self.result.columns: | ||
ret.append((column_name, string_type)) | ||
return ret | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{% macro source(source_name, identifier, start_dt = None, end_dt = None) %} | ||
{%- set relation = builtins.source(source_name, identifier) -%} | ||
|
||
{%- if execute and (relation.source_meta.python_module or relation.meta.python_module) -%} | ||
{%- do relation.load_python_module(start_dt, end_dt) -%} | ||
{# Return the view name only. Spark view do not support schema and catalog names #} | ||
{%- do return(relation.identifier) -%} | ||
{% else -%} | ||
{%- do return(relation) -%} | ||
{% endif -%} | ||
{% endmacro %} |