Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add athena plugin #504

Merged
merged 11 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/models/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class PrestoQuery(_common.FlyteIdlEntity):
def __init__(self, routing_group, catalog, schema, statement):
def __init__(self, routing_group=None, catalog=None, schema=None, statement=None):
"""
Initializes a new PrestoQuery.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .task import AthenaConfig, AthenaTask
78 changes: 78 additions & 0 deletions plugins/flytekit-athena/flytekitplugins/athena/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from google.protobuf.json_format import MessageToDict

from flytekit.extend import SerializationSettings, SQLTask
from flytekit.models.presto import PrestoQuery
from flytekit.types.schema import FlyteSchema


@dataclass
class AthenaConfig(object):
"""
AthenaConfig should be used to configure a Athena Task.
"""

# The database to query against
database: Optional[str] = None
# The optional workgroup to separate query execution.
workgroup: Optional[str] = None
# The catalog to set for the given Presto query
catalog: Optional[str] = None


class AthenaTask(SQLTask[AthenaConfig]):
"""
This is the simplest form of a Athena Task, that can be used even for tasks that do not produce any output.
"""

# This task is executed using the presto handler in the backend.
_TASK_TYPE = "presto"

def __init__(
self,
name: str,
query_template: str,
task_config: Optional[AthenaConfig] = None,
inputs: Optional[Dict[str, Type]] = None,
output_schema_type: Optional[Type[FlyteSchema]] = None,
**kwargs,
):
"""
Args:
name: Name of this task, should be unique in the project
config: Type AthenaConfig object
inputs: Name and type of inputs specified as an ordered dictionary
query_template: The actual query to run. We use Flyte's Golang templating format for Query templating.
Refer to the templating documentation
output_schema_type: If some data is produced by this query, then you can specify the output schema type
**kwargs: All other args required by Parent type - SQLTask
"""
outputs = None
if output_schema_type is not None:
outputs = {
"results": output_schema_type,
}
if task_config is None:
task_config = AthenaConfig()
super().__init__(
name=name,
task_config=task_config,
query_template=query_template,
inputs=inputs,
outputs=outputs,
task_type=self._TASK_TYPE,
**kwargs,
)
self._output_schema_type = output_schema_type

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
# This task is executed using the presto handler in the backend.
job = PrestoQuery(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is these settings @wild-endeavor. These are not standard, but I guess ok

statement=self.query_template,
schema=self.task_config.database,
routing_group=self.task_config.workgroup,
catalog=self.task_config.catalog,
)
return MessageToDict(job.to_flyte_idl())
34 changes: 34 additions & 0 deletions plugins/flytekit-athena/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from setuptools import setup

PLUGIN_NAME = "athena"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=0.19.0,<1.0.0"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="This package holds the Athena plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
1 change: 1 addition & 0 deletions plugins/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"flytekitplugins-pandera": "pandera",
"flytekitplugins-dolt": "flytekit-dolt",
"flytekitplugins-sqlalchemy": "flytekitplugins-sqlalchemy",
"flytekitplugins-athena": "flytekit-athena",
}


Expand Down
Empty file.
73 changes: 73 additions & 0 deletions plugins/tests/athena/test_athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from collections import OrderedDict

import pytest
from flytekitplugins.athena import AthenaConfig, AthenaTask

from flytekit import kwtypes, workflow
from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable
from flytekit.types.schema import FlyteSchema


def test_serialization():
athena_task = AthenaTask(
name="flytekit.demo.athena_task.query",
inputs=kwtypes(ds=str),
task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"),
query_template="""
insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
select *
from blah
where ds = '{{ .Inputs.ds }}'
""",
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

@workflow
def my_wf(ds: str) -> FlyteSchema:
return athena_task(ds=ds)

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="proj",
domain="dom",
version="123",
image_config=ImageConfig(default_image=default_img, images=[default_img]),
env={},
)
task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task)
assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
assert "insert overwrite directory" in task_spec.template.custom["statement"]
assert "mnist" == task_spec.template.custom["schema"]
assert "my_catalog" == task_spec.template.custom["catalog"]
assert "my_wg" == task_spec.template.custom["routingGroup"]
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1

admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None
assert admin_workflow_spec.template.outputs[0].var == "o0"
assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"


def test_local_exec():
athena_task = AthenaTask(
name="flytekit.demo.athena_task.query2",
inputs=kwtypes(ds=str),
query_template="""
insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
select *
from blah
where ds = '{{ .Inputs.ds }}'
""",
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

assert len(athena_task.interface.inputs) == 1
assert len(athena_task.interface.outputs) == 1

# will not run locally
with pytest.raises(Exception):
athena_task()