From bbf0a25c031daf57e4a6770fbddd2ba8b4865c50 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Wed, 28 Apr 2021 17:44:15 -0700 Subject: [PATCH] Initial DoltTable implementation Signed-off-by: Max Hoffman --- dev-requirements.txt | 8 +- doc-requirements.txt | 32 ++--- plugins/dolt/flytekitplugins/dolt/__init__.py | 1 + plugins/dolt/flytekitplugins/dolt/schema.py | 105 ++++++++++++++ plugins/dolt/scripts/flytekit_install_dolt.sh | 12 ++ plugins/dolt/setup.py | 35 +++++ plugins/setup.py | 1 + plugins/tests/dolt/__init__.py | 0 plugins/tests/dolt/test_wf.py | 128 ++++++++++++++++++ requirements-spark2.txt | 24 ++-- requirements.txt | 24 ++-- 11 files changed, 318 insertions(+), 52 deletions(-) create mode 100644 plugins/dolt/flytekitplugins/dolt/__init__.py create mode 100644 plugins/dolt/flytekitplugins/dolt/schema.py create mode 100644 plugins/dolt/scripts/flytekit_install_dolt.sh create mode 100644 plugins/dolt/setup.py create mode 100644 plugins/tests/dolt/__init__.py create mode 100644 plugins/tests/dolt/test_wf.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 873831d1c2..4aaf5e3fcc 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,7 +12,7 @@ attrs==20.3.0 # via # -c requirements.txt # pytest -black==20.8b1 +black==21.4b2 # via # -c requirements.txt # -r dev-requirements.in @@ -86,12 +86,8 @@ toml==0.10.2 # coverage # pytest typed-ast==1.4.3 - # via - # -c requirements.txt - # black - # mypy + # via mypy typing-extensions==3.7.4.3 # via # -c requirements.txt - # black # mypy diff --git a/doc-requirements.txt b/doc-requirements.txt index 07231a1bd6..d3d99642b4 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -16,7 +16,7 @@ appnope==0.1.2 # via # ipykernel # ipython -astroid==2.5.3 +astroid==2.5.6 # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -24,7 +24,7 @@ attrs==20.3.0 # via # jsonschema # scantree -babel==2.9.0 +babel==2.9.1 # via sphinx backcall==0.2.0 # via ipython @@ -35,13 +35,13 @@ beautifulsoup4==4.9.3 # furo # sphinx-code-include # sphinx-material -black==20.8b1 +black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.55 +boto3==1.17.60 # via sagemaker-training -botocore==1.20.55 +botocore==1.20.60 # via # boto3 # s3transfer @@ -68,7 +68,7 @@ cryptography==3.4.7 # paramiko css-html-js-minify==2.5.5 # via sphinx-material -dataclasses-json==0.5.2 +dataclasses-json==0.5.3 # via flytekit decorator==5.0.7 # via @@ -88,7 +88,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.37 +flyteidl==0.18.39 # via flytekit furo==2021.4.11b34 # via -r doc-requirements.in @@ -277,19 +277,19 @@ requests==2.25.1 # papermill # responses # sphinx -responses==0.13.2 +responses==0.13.3 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.4.1 +s3transfer==0.4.2 # via boto3 -sagemaker-training==3.9.1 +sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.2 +scipy==1.6.3 # via sagemaker-training six==1.15.0 # via @@ -314,13 +314,13 @@ sortedcontainers==2.3.0 # via flytekit soupsieve==2.2.1 # via beautifulsoup4 -sphinx-autoapi==1.8.0 +sphinx-autoapi==1.8.1 # via -r doc-requirements.in sphinx-code-include==1.1.1 # via -r doc-requirements.in sphinx-copybutton==0.3.1 # via -r doc-requirements.in -sphinx-gallery==0.8.2 +sphinx-gallery==0.9.0 # via -r doc-requirements.in sphinx-material==0.0.32 # via -r doc-requirements.in @@ -379,12 +379,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/plugins/dolt/flytekitplugins/dolt/__init__.py b/plugins/dolt/flytekitplugins/dolt/__init__.py new file mode 100644 index 0000000000..fd9379e283 --- /dev/null +++ b/plugins/dolt/flytekitplugins/dolt/__init__.py @@ -0,0 +1 @@ +from .schema import DoltConfig, DoltTable, DoltTableNameTransformer diff --git a/plugins/dolt/flytekitplugins/dolt/schema.py b/plugins/dolt/flytekitplugins/dolt/schema.py new file mode 100644 index 0000000000..01182b806c --- /dev/null +++ b/plugins/dolt/flytekitplugins/dolt/schema.py @@ -0,0 +1,105 @@ +import tempfile +import typing +from dataclasses import dataclass +from typing import Type + +import dolt_integrations.core as dolt_int +import doltcli as dolt +import pandas +from dataclasses_json import dataclass_json +from google.protobuf.struct_pb2 import Struct + +from flytekit import FlyteContext +from flytekit.extend import TypeEngine, TypeTransformer +from flytekit.models import types as _type_models +from flytekit.models.literals import Literal, Scalar +from flytekit.models.types import LiteralType + + +@dataclass_json +@dataclass +class DoltConfig: + db_path: str + tablename: typing.Optional[str] = None + sql: typing.Optional[str] = None + io_args: typing.Optional[dict] = None + branch_conf: typing.Optional[dolt_int.Branch] = None + meta_conf: typing.Optional[dolt_int.Meta] = None + remote_conf: typing.Optional[dolt_int.Remote] = None + + +@dataclass_json +@dataclass +class DoltTable: + config: DoltConfig + data: typing.Optional[pandas.DataFrame] = None + + +class DoltTableNameTransformer(TypeTransformer[DoltTable]): + def __init__(self): + super().__init__(name="DoltTable", t=DoltTable) + + def get_literal_type(self, t: Type[DoltTable]) -> LiteralType: + return LiteralType(simple=_type_models.SimpleType.STRUCT, metadata={}) + + def to_literal( + self, + ctx: FlyteContext, + python_val: DoltTable, + python_type: typing.Type[DoltTable], + expected: LiteralType, + ) -> Literal: + + if not isinstance(python_val, DoltTable): + raise AssertionError(f"Value cannot be converted to a table: {python_val}") + + conf = python_val.config + if python_val.data is not None and python_val.tablename is not None: + db = dolt.Dolt(conf.db_path) + with tempfile.NamedTemporaryFile() as f: + python_val.data.to_csv(f.name, index=False) + dolt_int.save( + db=db, + tablename=conf.tablename, + filename=f.name, + branch_conf=conf.branch_conf, + meta_conf=conf.meta_conf, + remote_conf=conf.remote_conf, + save_args=conf.io_args, + ) + + s = Struct() + s.update(python_val.to_dict()) + return Literal(Scalar(generic=s)) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: typing.Type[DoltTable], + ) -> DoltTable: + + if not (lv and lv.scalar and lv.scalar.generic and lv.scalar.generic["config"]): + return pandas.DataFrame() + + conf = DoltConfig(**lv.scalar.generic["config"]) + db = dolt.Dolt(conf.db_path) + + with tempfile.NamedTemporaryFile() as f: + dolt_int.load( + db=db, + tablename=conf.tablename, + sql=conf.sql, + filename=f.name, + branch_conf=conf.branch_conf, + meta_conf=conf.meta_conf, + remote_conf=conf.remote_conf, + load_args=conf.io_args, + ) + df = pandas.read_csv(f) + lv.data = df + + return lv + + +TypeEngine.register(DoltTableNameTransformer()) diff --git a/plugins/dolt/scripts/flytekit_install_dolt.sh b/plugins/dolt/scripts/flytekit_install_dolt.sh new file mode 100644 index 0000000000..c2f4841789 --- /dev/null +++ b/plugins/dolt/scripts/flytekit_install_dolt.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# Fetches and install Dolt. To be invoked by the Dockerfile + +# echo commands to the terminal output +set -eox pipefail + +# Install Dolt + +apt-get update -y \ + && apt-get install curl \ + && sudo bash -c 'curl -L https://github.com/dolthub/dolt/releases/latest/download/install.sh | sudo bash' diff --git a/plugins/dolt/setup.py b/plugins/dolt/setup.py new file mode 100644 index 0000000000..ddda6f8813 --- /dev/null +++ b/plugins/dolt/setup.py @@ -0,0 +1,35 @@ +from setuptools import setup + +PLUGIN_NAME = "dolt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.3"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="dolthub", + author_email="max@dolthub.com", + description="Dolt plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + scripts=["scripts/flytekit_install_dolt.sh"], +) diff --git a/plugins/setup.py b/plugins/setup.py index abd2d8c7c8..9e54cea7a1 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -16,6 +16,7 @@ "flytekitplugins-kftensorflow": "kftensorflow", "flytekitplugins-pandera": "pandera", "flytekitplugins-sqlalchemy": "sqlalchemy", + "flytekitplugins-dolt": "dolt", } diff --git a/plugins/tests/dolt/__init__.py b/plugins/tests/dolt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/tests/dolt/test_wf.py b/plugins/tests/dolt/test_wf.py new file mode 100644 index 0000000000..c8b41e5bcb --- /dev/null +++ b/plugins/tests/dolt/test_wf.py @@ -0,0 +1,128 @@ +import os +import shutil +import tempfile + +import doltcli as dolt +import pandas +import pytest +from flytekitplugins.dolt.schema import DoltConfig, DoltTable + +from flytekit import task, workflow + + +@pytest.fixture(scope="function") +def doltdb_path(): + d = tempfile.TemporaryDirectory() + try: + db_path = os.path.join(d.name, "foo") + yield db_path + finally: + shutil.rmtree(d.name) + + +@pytest.fixture(scope="function") +def dolt_config(doltdb_path): + yield DoltConfig( + db_path=doltdb_path, + tablename="foo", + ) + + +@pytest.fixture(scope="function") +def db(doltdb_path): + try: + db = dolt.Dolt.init(doltdb_path) + db.sql("create table bar (name text primary key, count bigint)") + db.sql("insert into bar values ('Dilly', 3)") + db.sql("select dolt_commit('-am', 'Initialize bar table')") + yield db + finally: + pass + + +def test_dolt_table_write(db, dolt_config): + @task + def my_dolt(a: int) -> DoltTable: + df = pandas.DataFrame([("Alice", a)], columns=["name", "count"]) + return DoltTable(data=df, config=dolt_config) + + @workflow + def my_wf(a: int) -> DoltTable: + return my_dolt(a=a) + + x = my_wf(a=5) + assert x + assert (x.data == pandas.DataFrame([("Alice", 5)], columns=["name", "count"])).all().all() + + +def test_dolt_table_read(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @workflow + def my_wf(t: DoltTable) -> str: + return my_dolt(t=t) + + dolt_config.tablename = "bar" + x = my_wf(t=DoltTable(config=dolt_config)) + assert x == "Dilly" + + +def test_dolt_table_read_task_config(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @task + def my_table() -> DoltTable: + dolt_config.tablename = "bar" + t = DoltTable(config=dolt_config) + return t + + @workflow + def my_wf() -> str: + t = my_table() + return my_dolt(t=t) + + x = my_wf() + assert x == "Dilly" + + +def test_dolt_table_read_mixed_config(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @task + def my_table(conf: DoltConfig) -> DoltTable: + return DoltTable(config=conf) + + @workflow + def my_wf(conf: DoltConfig) -> str: + t = my_table(conf=conf) + return my_dolt(t=t) + + dolt_config.tablename = "bar" + x = my_wf(conf=dolt_config) + + assert x == "Dilly" + + +def test_dolt_sql_read(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @workflow + def my_wf(t: DoltTable) -> str: + return my_dolt(t=t) + + dolt_config.tablename = None + dolt_config.sql = "select * from bar" + x = my_wf(t=DoltTable(config=dolt_config)) + assert x == "Dilly" diff --git a/requirements-spark2.txt b/requirements-spark2.txt index f4f4e6115c..1a791f19cd 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -24,13 +24,13 @@ backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==20.8b1 +black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.55 +boto3==1.17.60 # via sagemaker-training -botocore==1.20.55 +botocore==1.20.60 # via # boto3 # s3transfer @@ -53,7 +53,7 @@ croniter==1.0.12 # via flytekit cryptography==3.4.7 # via paramiko -dataclasses-json==0.5.2 +dataclasses-json==0.5.3 # via flytekit decorator==5.0.7 # via @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.37 +flyteidl==0.18.39 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -235,19 +235,19 @@ requests==2.25.1 # flytekit # papermill # responses -responses==0.13.2 +responses==0.13.3 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.4.1 +s3transfer==0.4.2 # via boto3 -sagemaker-training==3.9.1 +sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.2 +scipy==1.6.3 # via sagemaker-training six==1.15.0 # via @@ -296,12 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 2595fcfc44..26889d6590 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,13 +24,13 @@ backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==20.8b1 +black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.55 +boto3==1.17.60 # via sagemaker-training -botocore==1.20.55 +botocore==1.20.60 # via # boto3 # s3transfer @@ -53,7 +53,7 @@ croniter==1.0.12 # via flytekit cryptography==3.4.7 # via paramiko -dataclasses-json==0.5.2 +dataclasses-json==0.5.3 # via flytekit decorator==5.0.7 # via @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.37 +flyteidl==0.18.39 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -235,19 +235,19 @@ requests==2.25.1 # flytekit # papermill # responses -responses==0.13.2 +responses==0.13.3 # via flytekit retry==0.9.2 # via flytekit retrying==1.3.3 # via sagemaker-training -s3transfer==0.4.1 +s3transfer==0.4.2 # via boto3 -sagemaker-training==3.9.1 +sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.2 +scipy==1.6.3 # via sagemaker-training six==1.15.0 # via @@ -296,12 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11