Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add athena plugin #504

Merged
merged 11 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/models/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class PrestoQuery(_common.FlyteIdlEntity):
def __init__(self, routing_group, catalog, schema, statement):
def __init__(self, routing_group=None, catalog=None, schema=None, statement=None):
"""
Initializes a new PrestoQuery.
Expand Down
1 change: 1 addition & 0 deletions plugins/athena/flytekitplugins/athena/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .task import AthenaConfig, AthenaTask
68 changes: 68 additions & 0 deletions plugins/athena/flytekitplugins/athena/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from google.protobuf.json_format import MessageToDict

from flytekit.extend import SerializationSettings, SQLTask
from flytekit.models.presto import PrestoQuery
from flytekit.types.schema import FlyteSchema


@dataclass
class AthenaConfig(object):
"""
AthenaConfig should be used to configure a Athena Task. Currently there are no customizable options for the config.
"""

pass


class AthenaTask(SQLTask[AthenaConfig]):
"""
This is the simplest form of a Athena Task, that can be used even for tasks that do not produce any output.
"""

# This task is executed using the presto handler in the backend.
_TASK_TYPE = "presto"

def __init__(
self,
name: str,
query_template: str,
task_config: Optional[AthenaConfig] = None,
inputs: Optional[Dict[str, Type]] = None,
output_schema_type: Optional[Type[FlyteSchema]] = None,
**kwargs,
):
"""
Args:
name: Name of this task, should be unique in the project
config: Type AthenaConfig object
inputs: Name and type of inputs specified as an ordered dictionary
query_template: The actual query to run. We use Flyte's Golang templating format for Query templating.
Refer to the templating documentation
output_schema_type: If some data is produced by this query, then you can specify the output schema type
**kwargs: All other args required by Parent type - SQLTask
"""
outputs = None
if output_schema_type is not None:
outputs = {
"results": output_schema_type,
}
if task_config is None:
task_config = AthenaConfig()
super().__init__(
name=name,
task_config=task_config,
query_template=query_template,
inputs=inputs,
outputs=outputs,
task_type=self._TASK_TYPE,
**kwargs,
)
self._output_schema_type = output_schema_type

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
# This task is executed using the presto handler in the backend.
job = PrestoQuery(statement=self.query_template)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stand corrected, I think it needs somethings set... they are just named differently... I think the minimum is the Database. We used Schema here to map it to...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep you're right! ty

return MessageToDict(job.to_flyte_idl())
34 changes: 34 additions & 0 deletions plugins/athena/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from setuptools import setup

PLUGIN_NAME = "athena"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=0.18.0,<1.0.0"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump


__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="This package holds the Athena plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
Empty file.
69 changes: 69 additions & 0 deletions plugins/tests/athena/test_athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections import OrderedDict

import pytest

from flytekit import kwtypes, workflow
from flytekit.extend import Image, ImageConfig, SerializationSettings, get_serializable
from flytekit.types.schema import FlyteSchema
from plugins.athena.flytekitplugins.athena.task import AthenaTask


def test_serialization():
athena_task = AthenaTask(
name="flytekit.demo.athena_task.query",
inputs=kwtypes(ds=str),
query_template="""
insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
select *
from blah
where ds = '{{ .Inputs.ds }}'
""",
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

@workflow
def my_wf(ds: str) -> FlyteSchema:
return athena_task(ds=ds)

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="proj",
domain="dom",
version="123",
image_config=ImageConfig(default_image=default_img, images=[default_img]),
env={},
)
task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task)
assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
assert "insert overwrite directory" in task_spec.template.custom["statement"]
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1

admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None
assert admin_workflow_spec.template.outputs[0].var == "o0"
assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"


def test_local_exec():
athena_task = AthenaTask(
name="flytekit.demo.athena_task.query2",
inputs=kwtypes(ds=str),
query_template="""
insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
select *
from blah
where ds = '{{ .Inputs.ds }}'
""",
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

assert len(athena_task.interface.inputs) == 1
assert len(athena_task.interface.outputs) == 1

# will not run locally
with pytest.raises(Exception):
athena_task()