Skip to content

Commit

Permalink
Add athena plugin (#504)
Browse files Browse the repository at this point in the history
Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
Katrina Rogan authored and EngHabu committed Jun 25, 2021
1 parent 7a13759 commit 675fb1d
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flytekit/models/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class PrestoQuery(_common.FlyteIdlEntity):
def __init__(self, routing_group, catalog, schema, statement):
def __init__(self, routing_group=None, catalog=None, schema=None, statement=None):
"""
Initializes a new PrestoQuery.
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-athena/flytekitplugins/athena/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .task import AthenaConfig, AthenaTask
78 changes: 78 additions & 0 deletions plugins/flytekit-athena/flytekitplugins/athena/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from google.protobuf.json_format import MessageToDict

from flytekit.extend import SerializationSettings, SQLTask
from flytekit.models.presto import PrestoQuery
from flytekit.types.schema import FlyteSchema


@dataclass
class AthenaConfig(object):
"""
AthenaConfig should be used to configure a Athena Task.
"""

# The database to query against
database: Optional[str] = None
# The optional workgroup to separate query execution.
workgroup: Optional[str] = None
# The catalog to set for the given Presto query
catalog: Optional[str] = None


class AthenaTask(SQLTask[AthenaConfig]):
"""
This is the simplest form of a Athena Task, that can be used even for tasks that do not produce any output.
"""

# This task is executed using the presto handler in the backend.
_TASK_TYPE = "presto"

def __init__(
self,
name: str,
query_template: str,
task_config: Optional[AthenaConfig] = None,
inputs: Optional[Dict[str, Type]] = None,
output_schema_type: Optional[Type[FlyteSchema]] = None,
**kwargs,
):
"""
Args:
name: Name of this task, should be unique in the project
config: Type AthenaConfig object
inputs: Name and type of inputs specified as an ordered dictionary
query_template: The actual query to run. We use Flyte's Golang templating format for Query templating.
Refer to the templating documentation
output_schema_type: If some data is produced by this query, then you can specify the output schema type
**kwargs: All other args required by Parent type - SQLTask
"""
outputs = None
if output_schema_type is not None:
outputs = {
"results": output_schema_type,
}
if task_config is None:
task_config = AthenaConfig()
super().__init__(
name=name,
task_config=task_config,
query_template=query_template,
inputs=inputs,
outputs=outputs,
task_type=self._TASK_TYPE,
**kwargs,
)
self._output_schema_type = output_schema_type

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
# This task is executed using the presto handler in the backend.
job = PrestoQuery(
statement=self.query_template,
schema=self.task_config.database,
routing_group=self.task_config.workgroup,
catalog=self.task_config.catalog,
)
return MessageToDict(job.to_flyte_idl())
34 changes: 34 additions & 0 deletions plugins/flytekit-athena/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from setuptools import setup

PLUGIN_NAME = "athena"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=0.19.0,<1.0.0"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="This package holds the Athena plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
1 change: 1 addition & 0 deletions plugins/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"flytekitplugins-pandera": "pandera",
"flytekitplugins-dolt": "flytekit-dolt",
"flytekitplugins-sqlalchemy": "flytekitplugins-sqlalchemy",
"flytekitplugins-athena": "flytekit-athena",
}


Expand Down
Empty file.
73 changes: 73 additions & 0 deletions plugins/tests/athena/test_athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from collections import OrderedDict

import pytest
from flytekitplugins.athena import AthenaConfig, AthenaTask

from flytekit import kwtypes, workflow
from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable
from flytekit.types.schema import FlyteSchema


def test_serialization():
athena_task = AthenaTask(
name="flytekit.demo.athena_task.query",
inputs=kwtypes(ds=str),
task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"),
query_template="""
insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
select *
from blah
where ds = '{{ .Inputs.ds }}'
""",
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

@workflow
def my_wf(ds: str) -> FlyteSchema:
return athena_task(ds=ds)

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="proj",
domain="dom",
version="123",
image_config=ImageConfig(default_image=default_img, images=[default_img]),
env={},
)
task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task)
assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
assert "insert overwrite directory" in task_spec.template.custom["statement"]
assert "mnist" == task_spec.template.custom["schema"]
assert "my_catalog" == task_spec.template.custom["catalog"]
assert "my_wg" == task_spec.template.custom["routingGroup"]
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1

admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None
assert admin_workflow_spec.template.outputs[0].var == "o0"
assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"


def test_local_exec():
athena_task = AthenaTask(
name="flytekit.demo.athena_task.query2",
inputs=kwtypes(ds=str),
query_template="""
insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
select *
from blah
where ds = '{{ .Inputs.ds }}'
""",
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

assert len(athena_task.interface.inputs) == 1
assert len(athena_task.interface.outputs) == 1

# will not run locally
with pytest.raises(Exception):
athena_task()

0 comments on commit 675fb1d

Please sign in to comment.