diff --git a/flytekit/models/presto.py b/flytekit/models/presto.py index 2f5f998153..74387d56f3 100644 --- a/flytekit/models/presto.py +++ b/flytekit/models/presto.py @@ -6,7 +6,7 @@ class PrestoQuery(_common.FlyteIdlEntity): - def __init__(self, routing_group, catalog, schema, statement): + def __init__(self, routing_group=None, catalog=None, schema=None, statement=None): """ Initializes a new PrestoQuery. diff --git a/plugins/flytekit-athena/flytekitplugins/athena/__init__.py b/plugins/flytekit-athena/flytekitplugins/athena/__init__.py new file mode 100644 index 0000000000..45b94aa8fe --- /dev/null +++ b/plugins/flytekit-athena/flytekitplugins/athena/__init__.py @@ -0,0 +1 @@ +from .task import AthenaConfig, AthenaTask diff --git a/plugins/flytekit-athena/flytekitplugins/athena/task.py b/plugins/flytekit-athena/flytekitplugins/athena/task.py new file mode 100644 index 0000000000..5399483d61 --- /dev/null +++ b/plugins/flytekit-athena/flytekitplugins/athena/task.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Type + +from google.protobuf.json_format import MessageToDict + +from flytekit.extend import SerializationSettings, SQLTask +from flytekit.models.presto import PrestoQuery +from flytekit.types.schema import FlyteSchema + + +@dataclass +class AthenaConfig(object): + """ + AthenaConfig should be used to configure a Athena Task. + """ + + # The database to query against + database: Optional[str] = None + # The optional workgroup to separate query execution. + workgroup: Optional[str] = None + # The catalog to set for the given Presto query + catalog: Optional[str] = None + + +class AthenaTask(SQLTask[AthenaConfig]): + """ + This is the simplest form of a Athena Task, that can be used even for tasks that do not produce any output. + """ + + # This task is executed using the presto handler in the backend. + _TASK_TYPE = "presto" + + def __init__( + self, + name: str, + query_template: str, + task_config: Optional[AthenaConfig] = None, + inputs: Optional[Dict[str, Type]] = None, + output_schema_type: Optional[Type[FlyteSchema]] = None, + **kwargs, + ): + """ + Args: + name: Name of this task, should be unique in the project + config: Type AthenaConfig object + inputs: Name and type of inputs specified as an ordered dictionary + query_template: The actual query to run. We use Flyte's Golang templating format for Query templating. + Refer to the templating documentation + output_schema_type: If some data is produced by this query, then you can specify the output schema type + **kwargs: All other args required by Parent type - SQLTask + """ + outputs = None + if output_schema_type is not None: + outputs = { + "results": output_schema_type, + } + if task_config is None: + task_config = AthenaConfig() + super().__init__( + name=name, + task_config=task_config, + query_template=query_template, + inputs=inputs, + outputs=outputs, + task_type=self._TASK_TYPE, + **kwargs, + ) + self._output_schema_type = output_schema_type + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + # This task is executed using the presto handler in the backend. + job = PrestoQuery( + statement=self.query_template, + schema=self.task_config.database, + routing_group=self.task_config.workgroup, + catalog=self.task_config.catalog, + ) + return MessageToDict(job.to_flyte_idl()) diff --git a/plugins/flytekit-athena/setup.py b/plugins/flytekit-athena/setup.py new file mode 100644 index 0000000000..a179663fe8 --- /dev/null +++ b/plugins/flytekit-athena/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup + +PLUGIN_NAME = "athena" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=0.19.0,<1.0.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the Athena plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/setup.py b/plugins/setup.py index fddf7156a2..2d61e39ac0 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -17,6 +17,7 @@ "flytekitplugins-pandera": "pandera", "flytekitplugins-dolt": "flytekit-dolt", "flytekitplugins-sqlalchemy": "flytekitplugins-sqlalchemy", + "flytekitplugins-athena": "flytekit-athena", } diff --git a/plugins/tests/athena/__init__.py b/plugins/tests/athena/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/tests/athena/test_athena.py b/plugins/tests/athena/test_athena.py new file mode 100644 index 0000000000..9fd8c60762 --- /dev/null +++ b/plugins/tests/athena/test_athena.py @@ -0,0 +1,73 @@ +from collections import OrderedDict + +import pytest +from flytekitplugins.athena import AthenaConfig, AthenaTask + +from flytekit import kwtypes, workflow +from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable +from flytekit.types.schema import FlyteSchema + + +def test_serialization(): + athena_task = AthenaTask( + name="flytekit.demo.athena_task.query", + inputs=kwtypes(ds=str), + task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"), + query_template=""" + insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet + select * + from blah + where ds = '{{ .Inputs.ds }}' + """, + # the schema literal's backend uri will be equal to the value of .raw_output_data + output_schema_type=FlyteSchema, + ) + + @workflow + def my_wf(ds: str) -> FlyteSchema: + return athena_task(ds=ds) + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, + ) + task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task) + assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"] + assert "insert overwrite directory" in task_spec.template.custom["statement"] + assert "mnist" == task_spec.template.custom["schema"] + assert "my_catalog" == task_spec.template.custom["catalog"] + assert "my_wg" == task_spec.template.custom["routingGroup"] + assert len(task_spec.template.interface.inputs) == 1 + assert len(task_spec.template.interface.outputs) == 1 + + admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None + assert admin_workflow_spec.template.outputs[0].var == "o0" + assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" + assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results" + + +def test_local_exec(): + athena_task = AthenaTask( + name="flytekit.demo.athena_task.query2", + inputs=kwtypes(ds=str), + query_template=""" + insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet + select * + from blah + where ds = '{{ .Inputs.ds }}' + """, + # the schema literal's backend uri will be equal to the value of .raw_output_data + output_schema_type=FlyteSchema, + ) + + assert len(athena_task.interface.inputs) == 1 + assert len(athena_task.interface.outputs) == 1 + + # will not run locally + with pytest.raises(Exception): + athena_task()