Skip to content

Commit

Permalink
support for pyspark connection method
Browse files Browse the repository at this point in the history
  • Loading branch information
cccs-jc committed Mar 30, 2022
1 parent bbff5c7 commit 7eafc23
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 1 deletion.
31 changes: 30 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@
import base64
import time


try:
from pyspark.rdd import _load_from_socket
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
except ImportError:
SparkSession = None
_load_from_socket = None
F = None


logger = AdapterLogger("Spark")

NUMBERS = DECIMALS + (int, float)
Expand All @@ -56,7 +67,7 @@ class SparkConnectionMethod(StrEnum):
HTTP = 'http'
ODBC = 'odbc'
SESSION = 'session'

PYSPARK = 'pyspark'

@dataclass
class SparkCredentials(Credentials):
Expand All @@ -77,6 +88,7 @@ class SparkCredentials(Credentials):
use_ssl: bool = False
server_side_parameters: Dict[str, Any] = field(default_factory=dict)
retry_all: bool = False
python_module: Optional[str] = None

@classmethod
def __pre_deserialize__(cls, data):
Expand All @@ -99,6 +111,18 @@ def __post_init__(self):
)
self.database = None

if (
self.method == SparkConnectionMethod.PYSPARK
) and not (
_load_from_socket and SparkSession and F
):
raise dbt.exceptions.RuntimeException(
f"{self.method} connection method requires "
"additional dependencies. \n"
"Install the additional required dependencies with "
"`pip install pyspark`"
)

if self.method == SparkConnectionMethod.ODBC:
try:
import pyodbc # noqa: F401
Expand Down Expand Up @@ -462,6 +486,11 @@ def open(cls, connection):
SessionConnectionWrapper,
)
handle = SessionConnectionWrapper(Connection())
elif creds.method == SparkConnectionMethod.PYSPARK:
from .pysparkcon import ( # noqa: F401
PysparkConnectionWrapper,
)
handle = PysparkConnectionWrapper(self.python_module)
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down
93 changes: 93 additions & 0 deletions dbt/adapters/spark/pysparkcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

from __future__ import annotations

import datetime as dt
from types import TracebackType
from typing import Any

from dbt.events import AdapterLogger
from dbt.utils import DECIMALS


from pyspark.rdd import _load_from_socket
import pyspark.sql.functions as F


import importlib
import sqlalchemy
import re

logger = AdapterLogger("Spark")
NUMBERS = DECIMALS + (int, float)


class PysparkConnectionWrapper(object):
"""Wrap a Spark context"""

def __init__(self, python_module):
self.result = None
if python_module:
logger.debug(f"Loading spark context from python module {python_module}")
module = importlib.import_module(python_module)
create_spark_context = getattr(module, "create_spark_context")
self.spark = create_spark_context()
else:
# Create a default pyspark context
self.spark = SparkSession.builder.getOrCreate()

def cursor(self):
return self

def rollback(self, *args, **kwargs):
logger.debug("NotImplemented: rollback")

def fetchall(self):
try:
rows = self.result.collect()
logger.debug(rows)
except Exception as e:
logger.debug(f"raising error {e}")
dbt.exceptions.raise_database_error(e)
return rows

def execute(self, sql, bindings=None):
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]

if bindings is not None:
bindings = [self._fix_binding(binding) for binding in bindings]
sql = sql % tuple(bindings)
logger.debug(f"execute sql:{sql}")
try:
self.result = self.spark.sql(sql)
logger.debug("Executed with no errors")
if "show tables" in sql:
self.result = self.result.withColumn("description", F.lit(""))
except Exception as e:
logger.debug(f"raising error {e}")
dbt.exceptions.raise_database_error(e)

@classmethod
def _fix_binding(cls, value):
"""Convert complex datatypes to primitives that can be loaded by
the Spark driver"""
if isinstance(value, NUMBERS):
return float(value)
elif isinstance(value, datetime):
return "'" + value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + "'"
elif isinstance(value, str):
return "'" + value + "'"
else:
logger.debug(type(value))
return "'" + str(value) + "'"

@property
def description(self):
logger.debug(f"Description called returning list of columns: {self.result.columns}")
ret = []
# Not sure the type is ever used by specifying it anyways
string_type = sqlalchemy.types.String
for column_name in self.result.columns:
ret.append((column_name, string_type))
return ret

52 changes: 52 additions & 0 deletions dbt/adapters/spark/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from dbt.adapters.base.relation import BaseRelation, Policy
from dbt.exceptions import RuntimeException

from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set

Self = TypeVar("Self", bound="BaseRelation")
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.utils import filter_null_values, deep_merge, classproperty

import importlib


from datetime import timezone, datetime


@dataclass
class SparkQuotePolicy(Policy):
Expand All @@ -28,6 +39,8 @@ class SparkRelation(BaseRelation):
is_delta: Optional[bool] = None
is_hudi: Optional[bool] = None
information: str = None
source_meta: Dict[str, Any] = None
meta: Dict[str, Any] = None

def __post_init__(self):
if self.database != self.schema and self.database:
Expand All @@ -40,3 +53,42 @@ def render(self):
'include, but only one can be set'
)
return super().render()

@classmethod
def create_from_source(cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
quote_policy = deep_merge(
cls.get_default_quote_policy().to_dict(omit_none=True),
source_quoting,
kwargs.get("quote_policy", {}),
)

return cls.create(
database=source.database,
schema=source.schema,
identifier=source.identifier,
quote_policy=quote_policy,
source_meta=source.source_meta,
meta=source.meta,
**kwargs,
)

def load_python_module(self, start_time, end_time):
logger.debug(f"Creating pyspark view for {self.identifier}")
from pyspark.sql import SparkSession
spark = SparkSession._instantiatedSession
if self.meta and self.meta.get('python_module'):
path = self.meta.get('python_module')
logger.debug(f"Loading python module {path}")
module = importlib.import_module(path)
create_dataframe = getattr(module, "create_dataframe")
df = create_dataframe(spark, start_time, end_time)
df.createOrReplaceTempView(self.identifier)
elif self.source_meta and self.source_meta.get('python_module'):
path = self.source_meta.get('python_module')
logger.debug(f"Loading python module {path}")
module = importlib.import_module(path)
create_dataframe_for = getattr(module, "create_dataframe_for")
df = create_dataframe_for(spark, self.identifier, start_time, end_time)
df.createOrReplaceTempView(self.identifier)
11 changes: 11 additions & 0 deletions dbt/include/spark/macros/source.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{% macro source(source_name, identifier, start_dt = None, end_dt = None) %}
{%- set relation = builtins.source(source_name, identifier) -%}

{%- if execute and (relation.source_meta.python_module or relation.meta.python_module) -%}
{%- do relation.load_python_module(start_dt, end_dt) -%}
{# Return the view name only. Spark view do not support schema and catalog names #}
{%- do return(relation.identifier) -%}
{% else -%}
{%- do return(relation) -%}
{% endif -%}
{% endmacro %}

0 comments on commit 7eafc23

Please sign in to comment.