diff --git a/dev-requirements.txt b/dev-requirements.txt index 9f96cf8b50..bb83f44571 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -28,8 +28,6 @@ black==21.4b2 # -c requirements.txt # -r dev-requirements.in # flake8-black -cached-property==1.5.2 - # via docker-compose certifi==2020.12.5 # via # -c requirements.txt @@ -100,7 +98,7 @@ flake8==3.9.1 # -r dev-requirements.in # flake8-black # flake8-isort -flyteidl==0.18.40 +flyteidl==0.18.41 # via # -c requirements.txt # flytekit @@ -115,11 +113,7 @@ idna==2.10 importlib-metadata==4.0.1 # via # -c requirements.txt - # flake8 - # jsonschema # keyring - # pluggy - # pytest iniconfig==1.1.1 # via pytest isort==5.8.0 @@ -316,15 +310,10 @@ toml==0.10.2 # coverage # pytest typed-ast==1.4.3 - # via - # -c requirements.txt - # black - # mypy + # via mypy typing-extensions==3.7.4.3 # via # -c requirements.txt - # black - # importlib-metadata # mypy # typing-inspect typing-inspect==0.6.0 diff --git a/doc-requirements.txt b/doc-requirements.txt index 12191d4c38..21b20c0ab0 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -39,9 +39,9 @@ black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.61 +boto3==1.17.62 # via sagemaker-training -botocore==1.20.61 +botocore==1.20.62 # via # boto3 # s3transfer @@ -88,7 +88,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.40 +flyteidl==0.18.41 # via flytekit git+git://github.com/flyteorg/furo@main # via -r doc-requirements.in @@ -107,9 +107,7 @@ idna==2.10 imagesize==1.2.0 # via sphinx importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -252,7 +250,7 @@ python-dateutil==2.8.1 # flytekit # jupyter-client # pandas -python-slugify[unidecode]==4.0.1 +python-slugify[unidecode]==5.0.0 # via sphinx-material pytimeparse==1.1.8 # via flytekit @@ -381,15 +379,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via - # astroid - # black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/plugins/flytekitplugins.dolt/flytekitplugins/dolt/__init__.py b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/__init__.py new file mode 100644 index 0000000000..fd9379e283 --- /dev/null +++ b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/__init__.py @@ -0,0 +1 @@ +from .schema import DoltConfig, DoltTable, DoltTableNameTransformer diff --git a/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py new file mode 100644 index 0000000000..05d20fab07 --- /dev/null +++ b/plugins/flytekitplugins.dolt/flytekitplugins/dolt/schema.py @@ -0,0 +1,105 @@ +import tempfile +import typing +from dataclasses import dataclass +from typing import Type + +import dolt_integrations.core as dolt_int +import doltcli as dolt +import pandas +from dataclasses_json import dataclass_json +from google.protobuf.struct_pb2 import Struct + +from flytekit import FlyteContext +from flytekit.extend import TypeEngine, TypeTransformer +from flytekit.models import types as _type_models +from flytekit.models.literals import Literal, Scalar +from flytekit.models.types import LiteralType + + +@dataclass_json +@dataclass +class DoltConfig: + db_path: str + tablename: typing.Optional[str] = None + sql: typing.Optional[str] = None + io_args: typing.Optional[dict] = None + branch_conf: typing.Optional[dolt_int.Branch] = None + meta_conf: typing.Optional[dolt_int.Meta] = None + remote_conf: typing.Optional[dolt_int.Remote] = None + + +@dataclass_json +@dataclass +class DoltTable: + config: DoltConfig + data: typing.Optional[pandas.DataFrame] = None + + +class DoltTableNameTransformer(TypeTransformer[DoltTable]): + def __init__(self): + super().__init__(name="DoltTable", t=DoltTable) + + def get_literal_type(self, t: Type[DoltTable]) -> LiteralType: + return LiteralType(simple=_type_models.SimpleType.STRUCT, metadata={}) + + def to_literal( + self, + ctx: FlyteContext, + python_val: DoltTable, + python_type: typing.Type[DoltTable], + expected: LiteralType, + ) -> Literal: + + if not isinstance(python_val, DoltTable): + raise AssertionError(f"Value cannot be converted to a table: {python_val}") + + conf = python_val.config + if python_val.data is not None and python_val.config.tablename is not None: + db = dolt.Dolt(conf.db_path) + with tempfile.NamedTemporaryFile() as f: + python_val.data.to_csv(f.name, index=False) + dolt_int.save( + db=db, + tablename=conf.tablename, + filename=f.name, + branch_conf=conf.branch_conf, + meta_conf=conf.meta_conf, + remote_conf=conf.remote_conf, + save_args=conf.io_args, + ) + + s = Struct() + s.update(python_val.to_dict()) + return Literal(Scalar(generic=s)) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: typing.Type[DoltTable], + ) -> DoltTable: + + if not (lv and lv.scalar and lv.scalar.generic and lv.scalar.generic["config"]): + return pandas.DataFrame() + + conf = DoltConfig(**lv.scalar.generic["config"]) + db = dolt.Dolt(conf.db_path) + + with tempfile.NamedTemporaryFile() as f: + dolt_int.load( + db=db, + tablename=conf.tablename, + sql=conf.sql, + filename=f.name, + branch_conf=conf.branch_conf, + meta_conf=conf.meta_conf, + remote_conf=conf.remote_conf, + load_args=conf.io_args, + ) + df = pandas.read_csv(f) + lv.data = df + + return lv + + +TypeEngine.register(DoltTableNameTransformer()) diff --git a/plugins/flytekitplugins.dolt/scripts/flytekit_install_dolt.sh b/plugins/flytekitplugins.dolt/scripts/flytekit_install_dolt.sh new file mode 100644 index 0000000000..c2f4841789 --- /dev/null +++ b/plugins/flytekitplugins.dolt/scripts/flytekit_install_dolt.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# Fetches and install Dolt. To be invoked by the Dockerfile + +# echo commands to the terminal output +set -eox pipefail + +# Install Dolt + +apt-get update -y \ + && apt-get install curl \ + && sudo bash -c 'curl -L https://github.com/dolthub/dolt/releases/latest/download/install.sh | sudo bash' diff --git a/plugins/flytekitplugins.dolt/setup.py b/plugins/flytekitplugins.dolt/setup.py new file mode 100644 index 0000000000..bf479b0a29 --- /dev/null +++ b/plugins/flytekitplugins.dolt/setup.py @@ -0,0 +1,66 @@ +import shlex +import subprocess +import urllib.request + +from setuptools import setup +from setuptools.command.develop import develop + +PLUGIN_NAME = "dolt" + +microlib_name = f"flytekitplugins.{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "dolt_integrations>=0.1.3"] + +__version__ = "0.0.0+develop" + + +class PostDevelopCommand(develop): + """Post-installation for development mode.""" + + def run(self): + develop.run(self) + install, _ = urllib.request.urlretrieve( + "https://github.com/liquidata-inc/dolt/releases/latest/download/install.sh" + ) + subprocess.call(shlex.split(f"chmod +x {install}")) + subprocess.call(shlex.split(f"sudo {install}")) + + pref = "dolt config --global --add" + subprocess.call( + shlex.split(f"{pref} user.email bojack@horseman.com"), + ) + subprocess.call( + shlex.split(f"{pref} user.name 'Bojack Horseman'"), + ) + subprocess.call( + shlex.split(f"{pref} metrics.host eventsapi.awsdev.ld-corp.com"), + ) + subprocess.call(shlex.split(f"{pref} metrics.port 443")) + + +setup( + name=microlib_name, + version=__version__, + author="dolthub", + author_email="max@dolthub.com", + description="Dolt plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + cmdclass=dict(develop=PostDevelopCommand), + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + scripts=["scripts/flytekit_install_dolt.sh"], +) diff --git a/plugins/setup.py b/plugins/setup.py index 8f685c550d..48f302bc5a 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -16,6 +16,7 @@ "flytekitplugins-kftensorflow": "kftensorflow", "flytekitplugins-pandera": "pandera", "flytekitplugins-sqlalchemy": "sqlalchemy", + "flytekitplugins-dolt": "flytekitplugins.dolt", } diff --git a/plugins/tests/dolt/__init__.py b/plugins/tests/dolt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/tests/dolt/test_wf.py b/plugins/tests/dolt/test_wf.py new file mode 100644 index 0000000000..c8b41e5bcb --- /dev/null +++ b/plugins/tests/dolt/test_wf.py @@ -0,0 +1,128 @@ +import os +import shutil +import tempfile + +import doltcli as dolt +import pandas +import pytest +from flytekitplugins.dolt.schema import DoltConfig, DoltTable + +from flytekit import task, workflow + + +@pytest.fixture(scope="function") +def doltdb_path(): + d = tempfile.TemporaryDirectory() + try: + db_path = os.path.join(d.name, "foo") + yield db_path + finally: + shutil.rmtree(d.name) + + +@pytest.fixture(scope="function") +def dolt_config(doltdb_path): + yield DoltConfig( + db_path=doltdb_path, + tablename="foo", + ) + + +@pytest.fixture(scope="function") +def db(doltdb_path): + try: + db = dolt.Dolt.init(doltdb_path) + db.sql("create table bar (name text primary key, count bigint)") + db.sql("insert into bar values ('Dilly', 3)") + db.sql("select dolt_commit('-am', 'Initialize bar table')") + yield db + finally: + pass + + +def test_dolt_table_write(db, dolt_config): + @task + def my_dolt(a: int) -> DoltTable: + df = pandas.DataFrame([("Alice", a)], columns=["name", "count"]) + return DoltTable(data=df, config=dolt_config) + + @workflow + def my_wf(a: int) -> DoltTable: + return my_dolt(a=a) + + x = my_wf(a=5) + assert x + assert (x.data == pandas.DataFrame([("Alice", 5)], columns=["name", "count"])).all().all() + + +def test_dolt_table_read(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @workflow + def my_wf(t: DoltTable) -> str: + return my_dolt(t=t) + + dolt_config.tablename = "bar" + x = my_wf(t=DoltTable(config=dolt_config)) + assert x == "Dilly" + + +def test_dolt_table_read_task_config(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @task + def my_table() -> DoltTable: + dolt_config.tablename = "bar" + t = DoltTable(config=dolt_config) + return t + + @workflow + def my_wf() -> str: + t = my_table() + return my_dolt(t=t) + + x = my_wf() + assert x == "Dilly" + + +def test_dolt_table_read_mixed_config(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @task + def my_table(conf: DoltConfig) -> DoltTable: + return DoltTable(config=conf) + + @workflow + def my_wf(conf: DoltConfig) -> str: + t = my_table(conf=conf) + return my_dolt(t=t) + + dolt_config.tablename = "bar" + x = my_wf(conf=dolt_config) + + assert x == "Dilly" + + +def test_dolt_sql_read(db, dolt_config): + @task + def my_dolt(t: DoltTable) -> str: + df = t.data + return df.name.values[0] + + @workflow + def my_wf(t: DoltTable) -> str: + return my_dolt(t=t) + + dolt_config.tablename = None + dolt_config.sql = "select * from bar" + x = my_wf(t=DoltTable(config=dolt_config)) + assert x == "Dilly" diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 94ff91c6d1..376300bc9c 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -28,9 +28,9 @@ black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.61 +boto3==1.17.62 # via sagemaker-training -botocore==1.20.61 +botocore==1.20.62 # via # boto3 # s3transfer @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.40 +flyteidl==0.18.41 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -84,9 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -298,13 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 6e6be1b068..bd47439a0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,9 +28,9 @@ black==21.4b2 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.61 +boto3==1.17.62 # via sagemaker-training -botocore==1.20.61 +botocore==1.20.62 # via # boto3 # s3transfer @@ -71,7 +71,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.18.40 +flyteidl==0.18.41 # via flytekit gevent==21.1.2 # via sagemaker-training @@ -84,9 +84,7 @@ hmsclient==0.1.1 idna==2.10 # via requests importlib-metadata==4.0.1 - # via - # jsonschema - # keyring + # via keyring inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.3 @@ -298,13 +296,8 @@ traitlets==5.0.5 # nbclient # nbconvert # nbformat -typed-ast==1.4.3 - # via black typing-extensions==3.7.4.3 - # via - # black - # importlib-metadata - # typing-inspect + # via typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11