Skip to content

Commit

Permalink
Sqlalchemy Task (flyteorg#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-hoffman authored Apr 22, 2021
1 parent 986cb42 commit 0edc7dc
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 77 deletions.
8 changes: 4 additions & 4 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ flake8-black==0.2.1
# via -r dev-requirements.in
flake8-isort==4.0.0
# via -r dev-requirements.in
flake8==3.9.0
flake8==3.9.1
# via
# -r dev-requirements.in
# flake8-black
Expand Down Expand Up @@ -71,9 +71,9 @@ pyparsing==2.4.7
# via
# -c requirements.txt
# packaging
pytest==6.2.2
pytest==6.2.3
# via -r dev-requirements.in
regex==2021.3.17
regex==2021.4.4
# via
# -c requirements.txt
# black
Expand All @@ -85,7 +85,7 @@ toml==0.10.2
# black
# coverage
# pytest
typed-ast==1.4.2
typed-ast==1.4.3
# via
# -c requirements.txt
# black
Expand Down
46 changes: 23 additions & 23 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ appnope==0.1.2
# via
# ipykernel
# ipython
astroid==2.5.2
astroid==2.5.3
# via sphinx-autoapi
async-generator==1.10
# via nbclient
Expand All @@ -39,9 +39,9 @@ black==20.8b1
# via papermill
bleach==3.3.0
# via nbconvert
boto3==1.17.40
boto3==1.17.55
# via sagemaker-training
botocore==1.20.40
botocore==1.20.55
# via
# boto3
# s3transfer
Expand All @@ -60,7 +60,7 @@ click==7.1.2
# flytekit
# hmsclient
# papermill
croniter==1.0.10
croniter==1.0.12
# via flytekit
cryptography==3.4.7
# via
Expand All @@ -70,7 +70,7 @@ css-html-js-minify==2.5.5
# via sphinx-material
dataclasses-json==0.5.2
# via flytekit
decorator==4.4.2
decorator==5.0.7
# via
# ipython
# retry
Expand All @@ -88,15 +88,15 @@ entrypoints==0.3
# via
# nbconvert
# papermill
flyteidl==0.18.26
flyteidl==0.18.37
# via flytekit
furo==2021.3.20b30
furo==2021.4.11b34
# via -r doc-requirements.in
gevent==21.1.2
# via sagemaker-training
greenlet==1.0.0
# via gevent
grpcio==1.36.1
grpcio==1.37.0
# via
# -r doc-requirements.in
# flytekit
Expand All @@ -106,11 +106,11 @@ idna==2.10
# via requests
imagesize==1.2.0
# via sphinx
importlib-metadata==3.9.1
importlib-metadata==4.0.1
# via keyring
inotify_simple==1.2.1
# via sagemaker-training
ipykernel==5.5.0
ipykernel==5.5.3
# via flytekit
ipython-genutils==0.2.0
# via
Expand Down Expand Up @@ -154,7 +154,7 @@ markupsafe==1.1.1
# via jinja2
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow==3.11.0
marshmallow==3.11.1
# via
# dataclasses-json
# marshmallow-enum
Expand All @@ -172,7 +172,7 @@ nbclient==0.5.3
# papermill
nbconvert==6.0.7
# via flytekit
nbformat==5.1.2
nbformat==5.1.3
# via
# nbclient
# nbconvert
Expand All @@ -190,15 +190,15 @@ packaging==20.9
# via
# bleach
# sphinx
pandas==1.2.3
pandas==1.2.4
# via flytekit
pandocfilters==1.4.3
# via nbconvert
papermill==2.3.3
# via flytekit
paramiko==2.7.2
# via sagemaker-training
parso==0.8.1
parso==0.8.2
# via jedi
pathspec==0.8.1
# via
Expand All @@ -210,7 +210,7 @@ pickleshare==0.7.5
# via ipython
prompt-toolkit==3.0.18
# via ipython
protobuf==3.15.6
protobuf==3.15.8
# via
# flyteidl
# flytekit
Expand Down Expand Up @@ -267,7 +267,7 @@ pyzmq==22.0.3
# via jupyter-client
readthedocs-sphinx-search==0.1.0
# via -r doc-requirements.in
regex==2021.3.17
regex==2021.4.4
# via
# black
# docker-image-py
Expand All @@ -283,9 +283,9 @@ retry==0.9.2
# via flytekit
retrying==1.3.3
# via sagemaker-training
s3transfer==0.3.6
s3transfer==0.4.1
# via boto3
sagemaker-training==3.7.4
sagemaker-training==3.9.1
# via flytekit
scantree==0.0.1
# via dirhash
Expand Down Expand Up @@ -314,7 +314,7 @@ sortedcontainers==2.3.0
# via flytekit
soupsieve==2.2.1
# via beautifulsoup4
sphinx-autoapi==1.7.0
sphinx-autoapi==1.8.0
# via -r doc-requirements.in
sphinx-code-include==1.1.1
# via -r doc-requirements.in
Expand All @@ -326,7 +326,7 @@ sphinx-material==0.0.32
# via -r doc-requirements.in
sphinx-prompt==1.4.0
# via -r doc-requirements.in
sphinx==3.5.3
sphinx==3.5.4
# via
# -r doc-requirements.in
# furo
Expand Down Expand Up @@ -368,7 +368,7 @@ tornado==6.1
# via
# ipykernel
# jupyter-client
tqdm==4.59.0
tqdm==4.60.0
# via papermill
traitlets==5.0.5
# via
Expand All @@ -379,7 +379,7 @@ traitlets==5.0.5
# nbclient
# nbconvert
# nbformat
typed-ast==1.4.2
typed-ast==1.4.3
# via black
typing-extensions==3.7.4.3
# via
Expand Down Expand Up @@ -414,7 +414,7 @@ zipp==3.4.1
# via importlib-metadata
zope.event==4.5.0
# via gevent
zope.interface==5.3.0
zope.interface==5.4.0
# via gevent

# The following packages are considered to be unsafe in a requirements file:
Expand Down
1 change: 1 addition & 0 deletions plugins/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"flytekitplugins-awssagemaker": "awssagemaker",
"flytekitplugins-kftensorflow": "kftensorflow",
"flytekitplugins-pandera": "pandera",
"flytekitplugins-sqlalchemy": "sqlalchemy",
}


Expand Down
1 change: 1 addition & 0 deletions plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .task import SQLAlchemyConfig, SQLAlchemyTask
85 changes: 85 additions & 0 deletions plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import typing
from dataclasses import dataclass

import pandas as pd
from sqlalchemy import create_engine

from flytekit import current_context, kwtypes
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.python_function_task import PythonInstanceTask
from flytekit.models.security import Secret
from flytekit.types.schema import FlyteSchema


@dataclass
class SQLAlchemyConfig(object):
"""
Use this configuration to configure task. String should be standard
sqlalchemy connector format
(https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls).
Database can be found:
- within the container
- or from a publicly accessible source
Args:
uri: default sqlalchemy connector
connect_args: sqlalchemy kwarg overrides -- ex: host
secret_connect_args: flyte secrets loaded into sqlalchemy connect args
-- ex: {"password": {"name": SECRET_NAME, "group": SECRET_GROUP}}
"""

uri: str
connect_args: typing.Optional[typing.Dict[str, typing.Any]] = None
secret_connect_args: typing.Optional[typing.Dict[str, Secret]] = None


class SQLAlchemyTask(PythonInstanceTask[SQLAlchemyConfig], SQLTask[SQLAlchemyConfig]):
"""
Makes it possible to run client side SQLAlchemy queries that optionally return a FlyteSchema object
TODO: How should we use pre-built containers for running portable tasks like this. Should this always be a
referenced task type?
"""

_SQLALCHEMY_TASK_TYPE = "sqlalchemy"

def __init__(
self,
name: str,
query_template: str,
task_config: SQLAlchemyConfig,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None,
**kwargs,
):
output_schema = output_schema_type if output_schema_type else FlyteSchema
outputs = kwtypes(results=output_schema)
self._uri = task_config.uri
self._connect_args = task_config.connect_args or {}
self._secret_connect_args = task_config.secret_connect_args

super().__init__(
name=name,
task_config=task_config,
task_type=self._SQLALCHEMY_TASK_TYPE,
query_template=query_template,
inputs=inputs,
outputs=outputs,
**kwargs,
)

@property
def output_columns(self) -> typing.Optional[typing.List[str]]:
c = self.python_interface.outputs["results"].column_names()
return c if c else None

def execute(self, **kwargs) -> typing.Any:
if self._secret_connect_args is not None:
for key, secret in self._secret_connect_args.items():
value = current_context().secrets.get(secret.group, secret.key)
self._connect_args[key] = value
engine = create_engine(self._uri, connect_args=self._connect_args, echo=False)
print(f"Connecting to db {self._uri}")
with engine.begin() as connection:
df = pd.read_sql_query(self.get_query(**kwargs), connection)
return df
34 changes: 34 additions & 0 deletions plugins/sqlalchemy/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from setuptools import setup

PLUGIN_NAME = "sqlalchemy"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=0.17.0,<1.0.0", "sqlalchemy>=1.4.7"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="dolthub",
author_email="[email protected]",
description="SQLAlchemy plugin for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
Empty file.
31 changes: 31 additions & 0 deletions plugins/tests/sqlalchemy/test_sql_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from collections import OrderedDict

from flytekit.common.translator import get_serializable
from flytekit.core import context_manager
from flytekit.core.context_manager import Image, ImageConfig
from plugins.tests.sqlalchemy.test_task import tk as not_tk


def test_sql_lhs():
assert not_tk.lhs == "tk"


def test_sql_command():
default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk)
assert srz_t.container.args[-7:] == [
"--resolver",
"flytekit.core.python_auto_container.default_task_resolver",
"--",
"task-module",
"plugins.tests.sqlalchemy.test_task",
"task-name",
"tk",
]
Loading

0 comments on commit 0edc7dc

Please sign in to comment.