Skip to content

Commit

Permalink
Generic Spark Integration (#101)
Browse files Browse the repository at this point in the history
* Generic Spark Integration

* Image

* PR comments

* PR comments

* Separate task-type for generic-spark

* PR comments
  • Loading branch information
akhurana001 authored Apr 17, 2020
1 parent 24ceff1 commit 3d6437b
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 4 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.7.0b3'
__version__ = '0.7.0b4'
150 changes: 150 additions & 0 deletions flytekit/common/tasks/generic_spark_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import absolute_import

try:
from inspect import getfullargspec as _getargspec
except ImportError:
from inspect import getargspec as _getargspec

from flytekit import __version__
import sys as _sys
import six as _six
from flytekit.common.tasks import task as _base_tasks
from flytekit.common.types import helpers as _helpers, primitives as _primitives

from flytekit.models import literals as _literal_models, task as _task_models
from google.protobuf.json_format import MessageToDict as _MessageToDict
from flytekit.common import interface as _interface
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.exceptions import scopes as _exception_scopes

from flytekit.configuration import internal as _internal_config

input_types_supported = { _primitives.Integer,
_primitives.Boolean,
_primitives.Float,
_primitives.String,
_primitives.Datetime,
_primitives.Timedelta,
}


class SdkGenericSparkTask( _base_tasks.SdkTask):
"""
This class includes the additional logic for building a task that executes as a Spark Job.
"""
def __init__(
self,
task_type,
discovery_version,
retries,
interruptible,
task_inputs,
deprecated,
discoverable,
timeout,
spark_type,
main_class,
main_application_file,
spark_conf,
hadoop_conf,
environment,
):
"""
:param Text task_type: string describing the task type
:param Text discovery_version: string describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param bool interruptible: Whether or not task is interruptible
:param Text deprecated:
:param bool discoverable:
:param datetime.timedelta timeout:
:param Text spark_type: Type of Spark Job: Scala/Java
:param Text main_class: Main class to execute for Scala/Java jobs
:param Text main_application_file: Main application file
:param dict[Text,Text] spark_conf:
:param dict[Text,Text] hadoop_conf:
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
"""

spark_job = _task_models.SparkJob(
spark_conf=spark_conf,
hadoop_conf=hadoop_conf,
spark_type = spark_type,
application_file=main_application_file,
main_class=main_class,
executor_path=_sys.executable,
).to_flyte_idl()

super(SdkGenericSparkTask, self).__init__(
task_type,
_task_models.TaskMetadata(
discoverable,
_task_models.RuntimeMetadata(
_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
__version__,
'spark'
),
timeout,
_literal_models.RetryStrategy(retries),
interruptible,
discovery_version,
deprecated
),
_interface.TypedInterface({}, {}),
_MessageToDict(spark_job),
)

# Add Inputs
if task_inputs is not None:
task_inputs(self)

# Container after the Inputs have been updated.
self._container = self._get_container_definition(
environment=environment
)

def _validate_inputs(self, inputs):
"""
:param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate
:raises: flytekit.common.exceptions.user.FlyteValidationException
"""
for k, v in _six.iteritems(inputs):
sdk_type =_helpers.get_sdk_type_from_literal_type(v.type)
if sdk_type not in input_types_supported:
raise _user_exceptions.FlyteValidationException(
"Input Type '{}' not supported. Only Primitives are supported for Scala/Java Spark.".format(sdk_type)
)
super(SdkGenericSparkTask, self)._validate_inputs(inputs)

@_exception_scopes.system_entry_point
def add_inputs(self, inputs):
"""
Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given
name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in
the wrapped function.
:param dict[Text, flytekit.models.interface.Variable] inputs: names and variables
"""
self._validate_inputs(inputs)
self.interface.inputs.update(inputs)

def _get_container_definition(
self,
environment=None,
):
"""
:rtype: Container
"""

args = []
for k, v in _six.iteritems(self.interface.inputs):
args.append("--{}".format(k))
args.append("{{{{.Inputs.{}}}}}".format(k))

return _task_models.Container(
image=_internal_config.IMAGE.get(),
command=[],
args=args,
resources=_task_models.Resources([], []),
env=environment,
config={}
)
3 changes: 3 additions & 0 deletions flytekit/common/tasks/spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
deprecated,
discoverable,
timeout,
spark_type,
spark_conf,
hadoop_conf,
environment,
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(
hadoop_conf=hadoop_conf,
application_file="local://" + spark_exec_path,
executor_path=_sys.executable,
main_class="",
spark_type=spark_type,
).to_flyte_idl()
super(SdkSparkTask, self).__init__(
task_function,
Expand Down
48 changes: 46 additions & 2 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from flyteidl.plugins import spark_pb2 as _spark_task
from flytekit.plugins import flyteidl as _lazy_flyteidl
from google.protobuf import json_format as _json_format, struct_pb2 as _struct

from flytekit.sdk.spark_types import SparkType as _spark_type
from flytekit.models import common as _common, literals as _literals, interface as _interface
from flytekit.models.core import identifier as _identifier
from flytekit.common.exceptions import user as _user_exceptions


class Resources(_common.FlyteIdlEntity):
Expand Down Expand Up @@ -544,7 +545,7 @@ def from_flyte_idl(cls, pb2_object):

class SparkJob(_common.FlyteIdlEntity):

def __init__(self, application_file, spark_conf, hadoop_conf, executor_path):
def __init__(self, spark_type, application_file, main_class, spark_conf, hadoop_conf, executor_path):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.
Expand All @@ -553,10 +554,28 @@ def __init__(self, application_file, spark_conf, hadoop_conf, executor_path):
:param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job.
"""
self._application_file = application_file
self._spark_type = spark_type
self._main_class = main_class
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf

@property
def main_class(self):
"""
The main class to execute
:rtype: Text
"""
return self._main_class

@property
def spark_type(self):
"""
Spark Job Type
:rtype: Text
"""
return self._spark_type

@property
def application_file(self):
"""
Expand Down Expand Up @@ -593,8 +612,22 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
"""

if self.spark_type == _spark_type.PYTHON:
application_type = _spark_task.SparkApplication.PYTHON
elif self.spark_type == _spark_type.JAVA:
application_type = _spark_task.SparkApplication.JAVA
elif self.spark_type == _spark_type.SCALA:
application_type = _spark_task.SparkApplication.SCALA
elif self.spark_type == _spark_type.R:
application_type = _spark_task.SparkApplication.R
else:
raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified")

return _spark_task.SparkJob(
applicationType=application_type,
mainApplicationFile=self.application_file,
mainClass=self.main_class,
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
Expand All @@ -606,9 +639,20 @@ def from_flyte_idl(cls, pb2_object):
:param flyteidl.plugins.spark_pb2.SparkJob pb2_object:
:rtype: SparkJob
"""

application_type = _spark_type.PYTHON
if pb2_object.type == _spark_task.SparkApplication.JAVA:
application_type = _spark_type.JAVA
elif pb2_object.type == _spark_task.SparkApplication.SCALA:
application_type = _spark_type.SCALA
elif pb2_object.type == _spark_task.SparkApplication.R:
application_type = _spark_type.R

return cls(
type= application_type,
spark_conf=pb2_object.sparkConf,
application_file=pb2_object.mainApplicationFile,
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
)
Expand Down
8 changes: 8 additions & 0 deletions flytekit/sdk/spark_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import enum


class SparkType(enum.Enum):
PYTHON = 1
SCALA = 2
JAVA = 3
R = 4
43 changes: 42 additions & 1 deletion flytekit/sdk/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from flytekit.common import constants as _common_constants
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \
spark_task as _sdk_spark_tasks, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
from flytekit.common.tasks import task as _task
from flytekit.common.types import helpers as _type_helpers
from flytekit.sdk.spark_types import SparkType as _spark_type
from flytekit.models import interface as _interface_model


Expand Down Expand Up @@ -474,6 +475,7 @@ def wrapper(fn):
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
spark_type= _spark_type.PYTHON,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
Expand All @@ -488,6 +490,45 @@ def wrapper(fn):
return wrapper


def generic_spark_task(
spark_type,
main_class,
main_application_file,
cache_version='',
retries=0,
interruptible=None,
inputs=None,
deprecated='',
cache=False,
timeout=None,
spark_conf=None,
hadoop_conf=None,
environment=None,
):
"""
Create a generic spark task. This task will connect to a Spark cluster, configure the environment,
and then execute the mainClass code as the Spark driver program.
"""

return _sdk_generic_spark_task.SdkGenericSparkTask(
task_type=_common_constants.SdkTaskType.SPARK_TASK,
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
spark_type = spark_type,
task_inputs= inputs,
main_class = main_class or "",
main_application_file = main_application_file or "",
spark_conf=spark_conf or {},
hadoop_conf=hadoop_conf or {},
environment=environment or {},
)


def qubole_spark_task(*args, **kwargs):
"""
:rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask
Expand Down
38 changes: 38 additions & 0 deletions tests/flytekit/common/workflows/scala_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from flytekit.sdk.tasks import generic_spark_task, inputs, python_task
from flytekit.sdk.types import Types
from flytekit.sdk.spark_types import SparkType
from flytekit.sdk.workflow import workflow_class, Input


scala_spark = generic_spark_task(
spark_type=SparkType.SCALA,
inputs=inputs(partitions=Types.Integer),
main_class="org.apache.spark.examples.SparkPi",
main_application_file="local:///opt/spark/examples/jars/spark-examples.jar",
spark_conf={
'spark.driver.memory': "1000M",
'spark.executor.memory': "1000M",
'spark.executor.cores': '1',
'spark.executor.instances': '2',
},
cache_version='1'
)


@inputs(date_triggered=Types.Datetime)
@python_task(cache_version='1')
def print_every_time(workflow_parameters, date_triggered):
print("My input : {}".format(date_triggered))


@workflow_class
class SparkTasksWorkflow(object):
triggered_date = Input(Types.Datetime)
partitions = Input(Types.Integer)
spark_task = scala_spark(partitions=partitions)
print_always = print_every_time(
date_triggered=triggered_date)

0 comments on commit 3d6437b

Please sign in to comment.