From d77a7739b53186d6c822dbfba29bc74dce66e8aa Mon Sep 17 00:00:00 2001 From: Maximilian Hoffman Date: Thu, 22 Apr 2021 13:13:48 -0700 Subject: [PATCH] Sqlalchemy Task (#445) --- dev-requirements.txt | 8 +- doc-requirements.txt | 46 +++++----- plugins/setup.py | 1 + .../flytekitplugins/sqlalchemy/__init__.py | 1 + .../flytekitplugins/sqlalchemy/task.py | 85 +++++++++++++++++++ plugins/sqlalchemy/setup.py | 34 ++++++++ plugins/tests/sqlalchemy/__init__.py | 0 plugins/tests/sqlalchemy/test_sql_tracker.py | 31 +++++++ plugins/tests/sqlalchemy/test_task.py | 85 +++++++++++++++++++ requirements-spark2.txt | 58 ++++++------- requirements.txt | 38 ++++----- 11 files changed, 310 insertions(+), 77 deletions(-) create mode 100644 plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py create mode 100644 plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py create mode 100644 plugins/sqlalchemy/setup.py create mode 100644 plugins/tests/sqlalchemy/__init__.py create mode 100644 plugins/tests/sqlalchemy/test_sql_tracker.py create mode 100644 plugins/tests/sqlalchemy/test_task.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 5acc7855b8..873831d1c2 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -27,7 +27,7 @@ flake8-black==0.2.1 # via -r dev-requirements.in flake8-isort==4.0.0 # via -r dev-requirements.in -flake8==3.9.0 +flake8==3.9.1 # via # -r dev-requirements.in # flake8-black @@ -71,9 +71,9 @@ pyparsing==2.4.7 # via # -c requirements.txt # packaging -pytest==6.2.2 +pytest==6.2.3 # via -r dev-requirements.in -regex==2021.3.17 +regex==2021.4.4 # via # -c requirements.txt # black @@ -85,7 +85,7 @@ toml==0.10.2 # black # coverage # pytest -typed-ast==1.4.2 +typed-ast==1.4.3 # via # -c requirements.txt # black diff --git a/doc-requirements.txt b/doc-requirements.txt index 84195f76de..07231a1bd6 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -16,7 +16,7 @@ appnope==0.1.2 # via # ipykernel # ipython -astroid==2.5.2 +astroid==2.5.3 # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -39,9 +39,9 @@ black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.40 +boto3==1.17.55 # via sagemaker-training -botocore==1.20.40 +botocore==1.20.55 # via # boto3 # s3transfer @@ -60,7 +60,7 @@ click==7.1.2 # flytekit # hmsclient # papermill -croniter==1.0.10 +croniter==1.0.12 # via flytekit cryptography==3.4.7 # via @@ -70,7 +70,7 @@ css-html-js-minify==2.5.5 # via sphinx-material dataclasses-json==0.5.2 # via flytekit -decorator==4.4.2 +decorator==5.0.7 # via # ipython # retry @@ -88,15 +88,15 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.26 +flyteidl==0.18.37 # via flytekit -furo==2021.3.20b30 +furo==2021.4.11b34 # via -r doc-requirements.in gevent==21.1.2 # via sagemaker-training greenlet==1.0.0 # via gevent -grpcio==1.36.1 +grpcio==1.37.0 # via # -r doc-requirements.in # flytekit @@ -106,11 +106,11 @@ idna==2.10 # via requests imagesize==1.2.0 # via sphinx -importlib-metadata==3.9.1 +importlib-metadata==4.0.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==5.5.0 +ipykernel==5.5.3 # via flytekit ipython-genutils==0.2.0 # via @@ -154,7 +154,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.11.0 +marshmallow==3.11.1 # via # dataclasses-json # marshmallow-enum @@ -172,7 +172,7 @@ nbclient==0.5.3 # papermill nbconvert==6.0.7 # via flytekit -nbformat==5.1.2 +nbformat==5.1.3 # via # nbclient # nbconvert @@ -190,7 +190,7 @@ packaging==20.9 # via # bleach # sphinx -pandas==1.2.3 +pandas==1.2.4 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -198,7 +198,7 @@ papermill==2.3.3 # via flytekit paramiko==2.7.2 # via sagemaker-training -parso==0.8.1 +parso==0.8.2 # via jedi pathspec==0.8.1 # via @@ -210,7 +210,7 @@ pickleshare==0.7.5 # via ipython prompt-toolkit==3.0.18 # via ipython -protobuf==3.15.6 +protobuf==3.15.8 # via # flyteidl # flytekit @@ -267,7 +267,7 @@ pyzmq==22.0.3 # via jupyter-client readthedocs-sphinx-search==0.1.0 # via -r doc-requirements.in -regex==2021.3.17 +regex==2021.4.4 # via # black # docker-image-py @@ -283,9 +283,9 @@ retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.3.6 +s3transfer==0.4.1 # via boto3 -sagemaker-training==3.7.4 +sagemaker-training==3.9.1 # via flytekit scantree==0.0.1 # via dirhash @@ -314,7 +314,7 @@ sortedcontainers==2.3.0 # via flytekit soupsieve==2.2.1 # via beautifulsoup4 -sphinx-autoapi==1.7.0 +sphinx-autoapi==1.8.0 # via -r doc-requirements.in sphinx-code-include==1.1.1 # via -r doc-requirements.in @@ -326,7 +326,7 @@ sphinx-material==0.0.32 # via -r doc-requirements.in sphinx-prompt==1.4.0 # via -r doc-requirements.in -sphinx==3.5.3 +sphinx==3.5.4 # via # -r doc-requirements.in # furo @@ -368,7 +368,7 @@ tornado==6.1 # via # ipykernel # jupyter-client -tqdm==4.59.0 +tqdm==4.60.0 # via papermill traitlets==5.0.5 # via @@ -379,7 +379,7 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.2 +typed-ast==1.4.3 # via black typing-extensions==3.7.4.3 # via @@ -414,7 +414,7 @@ zipp==3.4.1 # via importlib-metadata zope.event==4.5.0 # via gevent -zope.interface==5.3.0 +zope.interface==5.4.0 # via gevent # The following packages are considered to be unsafe in a requirements file: diff --git a/plugins/setup.py b/plugins/setup.py index e03e135be2..abd2d8c7c8 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -15,6 +15,7 @@ "flytekitplugins-awssagemaker": "awssagemaker", "flytekitplugins-kftensorflow": "kftensorflow", "flytekitplugins-pandera": "pandera", + "flytekitplugins-sqlalchemy": "sqlalchemy", } diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py new file mode 100644 index 0000000000..aaf8ade06f --- /dev/null +++ b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/__init__.py @@ -0,0 +1 @@ +from .task import SQLAlchemyConfig, SQLAlchemyTask diff --git a/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py new file mode 100644 index 0000000000..38f62cb8b7 --- /dev/null +++ b/plugins/sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -0,0 +1,85 @@ +import typing +from dataclasses import dataclass + +import pandas as pd +from sqlalchemy import create_engine + +from flytekit import current_context, kwtypes +from flytekit.core.base_sql_task import SQLTask +from flytekit.core.python_function_task import PythonInstanceTask +from flytekit.models.security import Secret +from flytekit.types.schema import FlyteSchema + + +@dataclass +class SQLAlchemyConfig(object): + """ + Use this configuration to configure task. String should be standard + sqlalchemy connector format + (https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls). + Database can be found: + - within the container + - or from a publicly accessible source + + Args: + uri: default sqlalchemy connector + connect_args: sqlalchemy kwarg overrides -- ex: host + secret_connect_args: flyte secrets loaded into sqlalchemy connect args + -- ex: {"password": {"name": SECRET_NAME, "group": SECRET_GROUP}} + """ + + uri: str + connect_args: typing.Optional[typing.Dict[str, typing.Any]] = None + secret_connect_args: typing.Optional[typing.Dict[str, Secret]] = None + + +class SQLAlchemyTask(PythonInstanceTask[SQLAlchemyConfig], SQLTask[SQLAlchemyConfig]): + """ + Makes it possible to run client side SQLAlchemy queries that optionally return a FlyteSchema object + + TODO: How should we use pre-built containers for running portable tasks like this. Should this always be a + referenced task type? + """ + + _SQLALCHEMY_TASK_TYPE = "sqlalchemy" + + def __init__( + self, + name: str, + query_template: str, + task_config: SQLAlchemyConfig, + inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + **kwargs, + ): + output_schema = output_schema_type if output_schema_type else FlyteSchema + outputs = kwtypes(results=output_schema) + self._uri = task_config.uri + self._connect_args = task_config.connect_args or {} + self._secret_connect_args = task_config.secret_connect_args + + super().__init__( + name=name, + task_config=task_config, + task_type=self._SQLALCHEMY_TASK_TYPE, + query_template=query_template, + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + @property + def output_columns(self) -> typing.Optional[typing.List[str]]: + c = self.python_interface.outputs["results"].column_names() + return c if c else None + + def execute(self, **kwargs) -> typing.Any: + if self._secret_connect_args is not None: + for key, secret in self._secret_connect_args.items(): + value = current_context().secrets.get(secret.group, secret.key) + self._connect_args[key] = value + engine = create_engine(self._uri, connect_args=self._connect_args, echo=False) + print(f"Connecting to db {self._uri}") + with engine.begin() as connection: + df = pd.read_sql_query(self.get_query(**kwargs), connection) + return df diff --git a/plugins/sqlalchemy/setup.py b/plugins/sqlalchemy/setup.py new file mode 100644 index 0000000000..e39b139552 --- /dev/null +++ b/plugins/sqlalchemy/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup + +PLUGIN_NAME = "sqlalchemy" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=0.17.0,<1.0.0", "sqlalchemy>=1.4.7"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="dolthub", + author_email="max@dolthub.com", + description="SQLAlchemy plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/tests/sqlalchemy/__init__.py b/plugins/tests/sqlalchemy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/tests/sqlalchemy/test_sql_tracker.py b/plugins/tests/sqlalchemy/test_sql_tracker.py new file mode 100644 index 0000000000..aad37cd254 --- /dev/null +++ b/plugins/tests/sqlalchemy/test_sql_tracker.py @@ -0,0 +1,31 @@ +from collections import OrderedDict + +from flytekit.common.translator import get_serializable +from flytekit.core import context_manager +from flytekit.core.context_manager import Image, ImageConfig +from plugins.tests.sqlalchemy.test_task import tk as not_tk + + +def test_sql_lhs(): + assert not_tk.lhs == "tk" + + +def test_sql_command(): + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = context_manager.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk) + assert srz_t.container.args[-7:] == [ + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "plugins.tests.sqlalchemy.test_task", + "task-name", + "tk", + ] diff --git a/plugins/tests/sqlalchemy/test_task.py b/plugins/tests/sqlalchemy/test_task.py new file mode 100644 index 0000000000..6adeca7519 --- /dev/null +++ b/plugins/tests/sqlalchemy/test_task.py @@ -0,0 +1,85 @@ +import contextlib +import os +import shutil +import sqlite3 +import tempfile + +import pandas +import pytest +from flytekitplugins.sqlalchemy import SQLAlchemyConfig, SQLAlchemyTask + +from flytekit import kwtypes, task, workflow +from flytekit.types.schema import FlyteSchema + +tk = SQLAlchemyTask( + "test", + query_template="select * from tracks", + task_config=SQLAlchemyConfig( + uri="sqlite://", + ), +) + + +@pytest.fixture(scope="function") +def sql_server(): + try: + d = tempfile.TemporaryDirectory() + db_path = os.path.join(d.name, "tracks.db") + with contextlib.closing(sqlite3.connect(db_path)) as con: + con.execute("create table tracks (TrackId bigint, Name text)") + con.execute("insert into tracks values (0, 'Sue'), (1, 'L'), (2, 'M'), (3, 'Ji'), (4, 'Po')") + con.commit() + yield f"sqlite:///{db_path}" + finally: + if os.path.exists(d.name): + shutil.rmtree(d.name) + + +def test_task_static(sql_server): + tk = SQLAlchemyTask( + "test", + query_template="select * from tracks", + task_config=SQLAlchemyConfig( + uri=sql_server, + ), + ) + + assert tk.output_columns is None + + df = tk() + assert df is not None + + +def test_task_schema(sql_server): + sql_task = SQLAlchemyTask( + "test", + query_template="select TrackId, Name from tracks limit {{.inputs.limit}}", + inputs=kwtypes(limit=int), + output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], + task_config=SQLAlchemyConfig( + uri=sql_server, + ), + ) + + assert sql_task.output_columns is not None + df = sql_task(limit=1) + assert df is not None + + +def test_workflow(sql_server): + @task + def my_task(df: pandas.DataFrame) -> int: + return len(df[df.columns[0]]) + + sql_task = SQLAlchemyTask( + "test", + query_template="select * from tracks limit {{.inputs.limit}}", + inputs=kwtypes(limit=int), + task_config=SQLAlchemyConfig(uri=sql_server), + ) + + @workflow + def wf(limit: int) -> int: + return my_task(df=sql_task(limit=limit)) + + assert wf(limit=5) == 5 diff --git a/requirements-spark2.txt b/requirements-spark2.txt index c58e52ca25..f4f4e6115c 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -10,6 +10,10 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black +appnope==0.1.2 + # via + # ipykernel + # ipython async-generator==1.10 # via nbclient attrs==20.3.0 @@ -24,9 +28,9 @@ black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.39 +boto3==1.17.55 # via sagemaker-training -botocore==1.20.39 +botocore==1.20.55 # via # boto3 # s3transfer @@ -45,15 +49,13 @@ click==7.1.2 # flytekit # hmsclient # papermill -croniter==1.0.10 +croniter==1.0.12 # via flytekit cryptography==3.4.7 - # via - # paramiko - # secretstorage + # via paramiko dataclasses-json==0.5.2 # via flytekit -decorator==4.4.2 +decorator==5.0.7 # via # ipython # retry @@ -69,36 +71,32 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.26 +flyteidl==0.18.37 # via flytekit gevent==21.1.2 # via sagemaker-training greenlet==1.0.0 # via gevent -grpcio==1.36.1 +grpcio==1.37.0 # via flytekit hmsclient==0.1.1 # via flytekit idna==2.10 # via requests -importlib-metadata==3.7.3 +importlib-metadata==4.0.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==5.5.0 +ipykernel==5.5.3 # via flytekit ipython-genutils==0.2.0 # via # nbformat # traitlets -ipython==7.21.0 +ipython==7.22.0 # via ipykernel jedi==0.18.0 # via ipython -jeepney==0.6.0 - # via - # keyring - # secretstorage jinja2==2.11.3 # via nbconvert jmespath==0.10.0 @@ -126,7 +124,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.10.0 +marshmallow==3.11.1 # via # dataclasses-json # marshmallow-enum @@ -144,14 +142,14 @@ nbclient==0.5.3 # papermill nbconvert==6.0.7 # via flytekit -nbformat==5.1.2 +nbformat==5.1.3 # via # nbclient # nbconvert # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.20.1 +numpy==1.20.2 # via # flytekit # pandas @@ -160,7 +158,7 @@ numpy==1.20.1 # scipy packaging==20.9 # via bleach -pandas==1.2.3 +pandas==1.2.4 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -168,7 +166,7 @@ papermill==2.3.3 # via flytekit paramiko==2.7.2 # via sagemaker-training -parso==0.8.1 +parso==0.8.2 # via jedi pathspec==0.8.1 # via @@ -180,7 +178,7 @@ pickleshare==0.7.5 # via ipython prompt-toolkit==3.0.18 # via ipython -protobuf==3.15.6 +protobuf==3.15.8 # via # flyteidl # flytekit @@ -228,7 +226,7 @@ pyyaml==5.4.1 # via papermill pyzmq==22.0.3 # via jupyter-client -regex==2021.3.17 +regex==2021.4.4 # via # black # docker-image-py @@ -237,22 +235,20 @@ requests==2.25.1 # flytekit # papermill # responses -responses==0.13.1 +responses==0.13.2 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.3.6 +s3transfer==0.4.1 # via boto3 -sagemaker-training==3.7.3 +sagemaker-training==3.9.1 # via flytekit scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training -secretstorage==3.3.1 - # via keyring six==1.15.0 # via # bcrypt @@ -289,7 +285,7 @@ tornado==6.1 # via # ipykernel # jupyter-client -tqdm==4.59.0 +tqdm==4.60.0 # via papermill traitlets==5.0.5 # via @@ -300,7 +296,7 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.2 +typed-ast==1.4.3 # via black typing-extensions==3.7.4.3 # via @@ -330,7 +326,7 @@ zipp==3.4.1 # via importlib-metadata zope.event==4.5.0 # via gevent -zope.interface==5.3.0 +zope.interface==5.4.0 # via gevent # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements.txt b/requirements.txt index 0473ae8da4..2595fcfc44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,9 +28,9 @@ black==20.8b1 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.40 +boto3==1.17.55 # via sagemaker-training -botocore==1.20.40 +botocore==1.20.55 # via # boto3 # s3transfer @@ -49,13 +49,13 @@ click==7.1.2 # flytekit # hmsclient # papermill -croniter==1.0.10 +croniter==1.0.12 # via flytekit cryptography==3.4.7 # via paramiko dataclasses-json==0.5.2 # via flytekit -decorator==4.4.2 +decorator==5.0.7 # via # ipython # retry @@ -71,23 +71,23 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.26 +flyteidl==0.18.37 # via flytekit gevent==21.1.2 # via sagemaker-training greenlet==1.0.0 # via gevent -grpcio==1.36.1 +grpcio==1.37.0 # via flytekit hmsclient==0.1.1 # via flytekit idna==2.10 # via requests -importlib-metadata==3.9.1 +importlib-metadata==4.0.1 # via keyring inotify_simple==1.2.1 # via sagemaker-training -ipykernel==5.5.0 +ipykernel==5.5.3 # via flytekit ipython-genutils==0.2.0 # via @@ -124,7 +124,7 @@ markupsafe==1.1.1 # via jinja2 marshmallow-enum==1.5.1 # via dataclasses-json -marshmallow==3.11.0 +marshmallow==3.11.1 # via # dataclasses-json # marshmallow-enum @@ -142,7 +142,7 @@ nbclient==0.5.3 # papermill nbconvert==6.0.7 # via flytekit -nbformat==5.1.2 +nbformat==5.1.3 # via # nbclient # nbconvert @@ -158,7 +158,7 @@ numpy==1.20.2 # scipy packaging==20.9 # via bleach -pandas==1.2.3 +pandas==1.2.4 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -166,7 +166,7 @@ papermill==2.3.3 # via flytekit paramiko==2.7.2 # via sagemaker-training -parso==0.8.1 +parso==0.8.2 # via jedi pathspec==0.8.1 # via @@ -178,7 +178,7 @@ pickleshare==0.7.5 # via ipython prompt-toolkit==3.0.18 # via ipython -protobuf==3.15.6 +protobuf==3.15.8 # via # flyteidl # flytekit @@ -226,7 +226,7 @@ pyyaml==5.4.1 # via papermill pyzmq==22.0.3 # via jupyter-client -regex==2021.3.17 +regex==2021.4.4 # via # black # docker-image-py @@ -241,9 +241,9 @@ retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.3.6 +s3transfer==0.4.1 # via boto3 -sagemaker-training==3.7.4 +sagemaker-training==3.9.1 # via flytekit scantree==0.0.1 # via dirhash @@ -285,7 +285,7 @@ tornado==6.1 # via # ipykernel # jupyter-client -tqdm==4.59.0 +tqdm==4.60.0 # via papermill traitlets==5.0.5 # via @@ -296,7 +296,7 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.2 +typed-ast==1.4.3 # via black typing-extensions==3.7.4.3 # via @@ -326,7 +326,7 @@ zipp==3.4.1 # via importlib-metadata zope.event==4.5.0 # via gevent -zope.interface==5.3.0 +zope.interface==5.4.0 # via gevent # The following packages are considered to be unsafe in a requirements file: