Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic Spark Integration #101

Merged
merged 7 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions flytekit/common/tasks/generic_spark_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import absolute_import

try:
from inspect import getfullargspec as _getargspec
except ImportError:
from inspect import getargspec as _getargspec

from flytekit import __version__
import sys as _sys
import six as _six
from flytekit.common.tasks import task as _base_tasks
from flytekit.models import literals as _literal_models, task as _task_models
from google.protobuf.json_format import MessageToDict as _MessageToDict
from flytekit.common import interface as _interface
from flytekit.models import interface as _interface_model
from flytekit.configuration import internal as _internal_config

class SdkGenericSparkTask( _base_tasks.SdkTask):
"""
This class includes the additional logic for building a task that executes as a Spark Job.

"""
def __init__(
self,
task_type,
discovery_version,
retries,
interruptible,
task_inputs,
deprecated,
discoverable,
timeout,
spark_type,
main_class,
main_application_file,
spark_conf,
hadoop_conf,
environment,
):
"""
:param Text task_type: string describing the task type
:param Text discovery_version: string describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param bool interruptible: Whether or not task is interruptible
:param Text deprecated:
:param bool discoverable:
:param datetime.timedelta timeout:
:param Text spark_type: Type of Spark Job: Scala/Java
:param Text main_class: Main class to execute for Scala/Java jobs
:param Text main_application_file: Main application file
:param dict[Text,Text] spark_conf:
:param dict[Text,Text] hadoop_conf:
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
"""

spark_job = _task_models.SparkJob(
spark_conf=spark_conf,
hadoop_conf=hadoop_conf,
type = spark_type,
application_file=main_application_file,
main_class=main_class,
executor_path=_sys.executable,
).to_flyte_idl()

# No output support
input_variables = {k: _interface_model.Variable(v.to_flyte_literal_type(), k) for k, v in _six.iteritems(task_inputs)}

super(SdkGenericSparkTask, self).__init__(
task_type,
_task_models.TaskMetadata(
discoverable,
_task_models.RuntimeMetadata(
_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
__version__,
'spark'
),
timeout,
_literal_models.RetryStrategy(retries),
interruptible,
discovery_version,
deprecated
),
_interface.TypedInterface(input_variables, {}),
_MessageToDict(spark_job),
container=self._get_container_definition(
task_inputs= task_inputs,
environment=environment
)
)

def _get_container_definition(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is where you will have to check the types and raise an alert. You can make this as a common method for now, when we add support for all types, we can remove this check

Copy link
Contributor Author

@akhurana001 akhurana001 Apr 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added it in the validate_inputs above

self,
task_inputs=None,
environment=None,
):
"""
:rtype: Container
"""

args = []
for k, v in _six.iteritems(task_inputs):
args.append("--{}".format(k))
args.append("{{{{.Inputs.{}}}}}".format(k))

return _task_models.Container(
image= _internal_config.IMAGE.get(),
command=[],
args=args,
resources=_task_models.Resources([], []),
env=environment,
config={}
)
2 changes: 2 additions & 0 deletions flytekit/common/tasks/spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
hadoop_conf=hadoop_conf,
application_file="local://" + spark_exec_path,
executor_path=_sys.executable,
main_class="",
matthewphsmith marked this conversation as resolved.
Show resolved Hide resolved
type="PYTHON",
).to_flyte_idl()
super(SdkSparkTask, self).__init__(
task_function,
Expand Down
44 changes: 42 additions & 2 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,10 +541,9 @@ def from_flyte_idl(cls, pb2_object):
template=TaskTemplate.from_flyte_idl(pb2_object.template)
)


class SparkJob(_common.FlyteIdlEntity):

def __init__(self, application_file, spark_conf, hadoop_conf, executor_path):
def __init__(self, type, application_file, main_class, spark_conf, hadoop_conf, executor_path):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.

Expand All @@ -553,10 +552,28 @@ def __init__(self, application_file, spark_conf, hadoop_conf, executor_path):
:param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job.
"""
self._application_file = application_file
self._type = type
self._main_class = main_class
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf

@property
def main_class(self):
"""
The main class to execute
:rtype: Text
"""
return self._main_class

@property
def type(self):
"""
Spark Job Type
:rtype: Text
"""
return self._type

@property
def application_file(self):
"""
Expand Down Expand Up @@ -593,8 +610,20 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
"""

# Default to Python
application_type = _spark_task.SparkApplication.PYTHON
if self.type == "SCALA":
application_type = _spark_task.SparkApplication.SCALA
elif self.type == "JAVA":
application_type = _spark_task.SparkApplication.JAVA
elif self.type == "R":
application_type = _spark_task.SparkApplication.R

return _spark_task.SparkJob(
applicationType=application_type,
mainApplicationFile=self.application_file,
mainClass=self.main_class,
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
Expand All @@ -606,9 +635,20 @@ def from_flyte_idl(cls, pb2_object):
:param flyteidl.plugins.spark_pb2.SparkJob pb2_object:
:rtype: SparkJob
"""
# Default to Python
type = "PYTHON"
if pb2_object.applicationType == _spark_task.SparkApplication.SCALA:
type = "SCALA"
elif pb2_object.applicationType == _spark_task.SparkApplication.JAVA:
type = "JAVA"
elif pb2_object.applicationType == _spark_task.SparkApplication.R:
type = "R"

return cls(
type=type,
spark_conf=pb2_object.sparkConf,
application_file=pb2_object.mainApplicationFile,
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
)
Expand Down
26 changes: 24 additions & 2 deletions flytekit/sdk/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flytekit.common import constants as _common_constants
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \
spark_task as _sdk_spark_tasks, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
from flytekit.common.tasks import task as _task
from flytekit.common.types import helpers as _type_helpers
from flytekit.models import interface as _interface_model
Expand Down Expand Up @@ -406,12 +406,16 @@ def spark_task(
cache_version='',
retries=0,
interruptible=None,
inputs=None,
Copy link
Contributor

@kumare3 kumare3 Apr 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this handleded differently as compared to @inputs decorator? If so we should document that correctly, as this is very confusing.

ALSO at the moment I think only a few input types are supported like primitives and blobs right?
If so can we verify this statically and raise an exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to use the inputs decorator. Added the check for primitives for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't mix the decorator with the generic usage. The decorator patterns are part of the SDK's basic python programming model so the assumption is that they will always be python. It's easy enough to expose a helper function or class for Scala spark jobs elsewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if you want to make available through this file, that's fine too. I just think mixing a decorator and common object definitions will eventually cause a problem)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and continuing Ketan's thought, I would also just avoid adding the input arg that way all together. Especially since as your test below shows, you use the input decorator as a helper function--which is a bit confusing. Better to use a new interface and just a normal {'a': Types.Integer} style annotation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally wasn't using the decorator but changed after Ketan's comment to be in-sync with how we do this for Presto as well : https://github.com/lyft/flytekit/blob/master/tests/flytekit/common/workflows/presto.py#L11

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to stop using the existing spark_task decorator. I do think we should make the generic_spark_task available from this file as well. Added it separately as a helper function.

deprecated='',
cache=False,
timeout=None,
spark_conf=None,
hadoop_conf=None,
environment=None,
spark_type=None,
main_class=None,
main_application_file=None,
cls=None
):
"""
Expand Down Expand Up @@ -485,7 +489,25 @@ def wrapper(fn):
if _task_function:
return wrapper(_task_function)
else:
return wrapper
if spark_type is None or spark_type == "PYTHON":
return wrapper
else:
return _sdk_generic_spark_task.SdkGenericSparkTask(
task_type=_common_constants.SdkTaskType.SPARK_TASK,
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
spark_type = spark_type,
task_inputs= inputs or {},
main_class = main_class or "",
main_application_file = main_application_file or "",
spark_conf=spark_conf or {},
hadoop_conf=hadoop_conf or {},
environment=environment or {},
)


def qubole_spark_task(*args, **kwargs):
Expand Down
40 changes: 40 additions & 0 deletions tests/flytekit/common/workflows/scala_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from flytekit.sdk.tasks import spark_task, inputs, python_task
from flytekit.sdk.types import Types
from flytekit.sdk.workflow import workflow_class, Input


scala_spark = spark_task(spark_type="SCALA",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do one of 2 things,
Make Spark_type an enum, which it seems it is already in the proto def
OR
we should make special task wrappers like scala_spark etc which fixes the type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made an enum .

inputs={"partitions": Types.Integer},
main_class="org.apache.spark.examples.SparkPi",
main_application_file="local:///opt/spark/examples/jars/spark-examples.jar",
spark_conf={
'spark.driver.memory': "1000M",
'spark.executor.memory': "1000M",
'spark.executor.cores': '1',
'spark.executor.instances': '2',
},
cache_version='1'
)


@inputs(date_triggered=Types.Datetime)
@python_task(cache_version='1')
def print_every_time(workflow_parameters, date_triggered):
print("My input : {}".format(date_triggered))


@workflow_class
class SparkTasksWorkflow(object):
triggered_date = Input(Types.Datetime)
partitions = Input(Types.Integer)
sparkTask = scala_spark(partitions=partitions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no camel case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

print_always = print_every_time(
date_triggered=triggered_date)


if __name__ == '__main__':
print(scala_spark)