Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic Spark Integration #101

Merged
merged 7 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.7.0b3'
__version__ = '0.7.0b4'
150 changes: 150 additions & 0 deletions flytekit/common/tasks/generic_spark_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import absolute_import

try:
from inspect import getfullargspec as _getargspec
except ImportError:
from inspect import getargspec as _getargspec

from flytekit import __version__
import sys as _sys
import six as _six
from flytekit.common.tasks import task as _base_tasks
from flytekit.common.types import helpers as _helpers, primitives as _primitives

from flytekit.models import literals as _literal_models, task as _task_models
from google.protobuf.json_format import MessageToDict as _MessageToDict
from flytekit.common import interface as _interface
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.exceptions import scopes as _exception_scopes

from flytekit.configuration import internal as _internal_config

input_types_supported = { _primitives.Integer,
_primitives.Boolean,
_primitives.Float,
_primitives.String,
_primitives.Datetime,
_primitives.Timedelta,
}


class SdkGenericSparkTask( _base_tasks.SdkTask):
"""
This class includes the additional logic for building a task that executes as a Spark Job.

"""
def __init__(
self,
task_type,
discovery_version,
retries,
interruptible,
task_inputs,
deprecated,
discoverable,
timeout,
spark_type,
main_class,
main_application_file,
spark_conf,
hadoop_conf,
environment,
):
"""
:param Text task_type: string describing the task type
:param Text discovery_version: string describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param bool interruptible: Whether or not task is interruptible
:param Text deprecated:
:param bool discoverable:
:param datetime.timedelta timeout:
:param Text spark_type: Type of Spark Job: Scala/Java
:param Text main_class: Main class to execute for Scala/Java jobs
:param Text main_application_file: Main application file
:param dict[Text,Text] spark_conf:
:param dict[Text,Text] hadoop_conf:
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
"""

spark_job = _task_models.SparkJob(
spark_conf=spark_conf,
hadoop_conf=hadoop_conf,
spark_type = spark_type,
application_file=main_application_file,
main_class=main_class,
executor_path=_sys.executable,
).to_flyte_idl()

super(SdkGenericSparkTask, self).__init__(
task_type,
_task_models.TaskMetadata(
discoverable,
_task_models.RuntimeMetadata(
_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
__version__,
'spark'
),
timeout,
_literal_models.RetryStrategy(retries),
interruptible,
discovery_version,
deprecated
),
_interface.TypedInterface({}, {}),
_MessageToDict(spark_job),
)

# Add Inputs
if task_inputs is not None:
task_inputs(self)

# Container after the Inputs have been updated.
self._container = self._get_container_definition(
environment=environment
)

def _validate_inputs(self, inputs):
"""
:param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate
:raises: flytekit.common.exceptions.user.FlyteValidationException
"""
for k, v in _six.iteritems(inputs):
sdk_type =_helpers.get_sdk_type_from_literal_type(v.type)
if sdk_type not in input_types_supported:
raise _user_exceptions.FlyteValidationException(
"Input Type '{}' not supported. Only Primitives are supported for Scala/Java Spark.".format(sdk_type)
)
super(SdkGenericSparkTask, self)._validate_inputs(inputs)

@_exception_scopes.system_entry_point
def add_inputs(self, inputs):
"""
Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given
name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in
the wrapped function.
:param dict[Text, flytekit.models.interface.Variable] inputs: names and variables
"""
self._validate_inputs(inputs)
self.interface.inputs.update(inputs)

def _get_container_definition(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is where you will have to check the types and raise an alert. You can make this as a common method for now, when we add support for all types, we can remove this check

Copy link
Contributor Author

@akhurana001 akhurana001 Apr 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, added it in the validate_inputs above

self,
environment=None,
):
"""
:rtype: Container
"""

args = []
for k, v in _six.iteritems(self.interface.inputs):
args.append("--{}".format(k))
args.append("{{{{.Inputs.{}}}}}".format(k))

return _task_models.Container(
image=_internal_config.IMAGE.get(),
command=[],
args=args,
resources=_task_models.Resources([], []),
env=environment,
config={}
)
3 changes: 3 additions & 0 deletions flytekit/common/tasks/spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
deprecated,
discoverable,
timeout,
spark_type,
spark_conf,
hadoop_conf,
environment,
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(
hadoop_conf=hadoop_conf,
application_file="local://" + spark_exec_path,
executor_path=_sys.executable,
main_class="",
matthewphsmith marked this conversation as resolved.
Show resolved Hide resolved
spark_type=spark_type,
).to_flyte_idl()
super(SdkSparkTask, self).__init__(
task_function,
Expand Down
48 changes: 46 additions & 2 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from flyteidl.plugins import spark_pb2 as _spark_task
from flytekit.plugins import flyteidl as _lazy_flyteidl
from google.protobuf import json_format as _json_format, struct_pb2 as _struct

from flytekit.sdk.spark_types import SparkType as _spark_type
from flytekit.models import common as _common, literals as _literals, interface as _interface
from flytekit.models.core import identifier as _identifier
from flytekit.common.exceptions import user as _user_exceptions


class Resources(_common.FlyteIdlEntity):
Expand Down Expand Up @@ -544,7 +545,7 @@ def from_flyte_idl(cls, pb2_object):

class SparkJob(_common.FlyteIdlEntity):

def __init__(self, application_file, spark_conf, hadoop_conf, executor_path):
def __init__(self, spark_type, application_file, main_class, spark_conf, hadoop_conf, executor_path):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.

Expand All @@ -553,10 +554,28 @@ def __init__(self, application_file, spark_conf, hadoop_conf, executor_path):
:param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job.
"""
self._application_file = application_file
self._spark_type = spark_type
self._main_class = main_class
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf

@property
def main_class(self):
"""
The main class to execute
:rtype: Text
"""
return self._main_class

@property
def spark_type(self):
"""
Spark Job Type
:rtype: Text
"""
return self._spark_type

@property
def application_file(self):
"""
Expand Down Expand Up @@ -593,8 +612,22 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
"""

if self.spark_type == _spark_type.PYTHON:
application_type = _spark_task.SparkApplication.PYTHON
elif self.spark_type == _spark_type.JAVA:
application_type = _spark_task.SparkApplication.JAVA
elif self.spark_type == _spark_type.SCALA:
application_type = _spark_task.SparkApplication.SCALA
elif self.spark_type == _spark_type.R:
application_type = _spark_task.SparkApplication.R
else:
raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified")

return _spark_task.SparkJob(
applicationType=application_type,
mainApplicationFile=self.application_file,
mainClass=self.main_class,
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
Expand All @@ -606,9 +639,20 @@ def from_flyte_idl(cls, pb2_object):
:param flyteidl.plugins.spark_pb2.SparkJob pb2_object:
:rtype: SparkJob
"""

application_type = _spark_type.PYTHON
if pb2_object.type == _spark_task.SparkApplication.JAVA:
application_type = _spark_type.JAVA
elif pb2_object.type == _spark_task.SparkApplication.SCALA:
application_type = _spark_type.SCALA
elif pb2_object.type == _spark_task.SparkApplication.R:
application_type = _spark_type.R

return cls(
type= application_type,
spark_conf=pb2_object.sparkConf,
application_file=pb2_object.mainApplicationFile,
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
)
Expand Down
8 changes: 8 additions & 0 deletions flytekit/sdk/spark_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import enum


class SparkType(enum.Enum):
PYTHON = 1
SCALA = 2
JAVA = 3
R = 4
43 changes: 42 additions & 1 deletion flytekit/sdk/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from flytekit.common import constants as _common_constants
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \
spark_task as _sdk_spark_tasks, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
from flytekit.common.tasks import task as _task
from flytekit.common.types import helpers as _type_helpers
from flytekit.sdk.spark_types import SparkType as _spark_type
from flytekit.models import interface as _interface_model


Expand Down Expand Up @@ -474,6 +475,7 @@ def wrapper(fn):
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
spark_type= _spark_type.PYTHON,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
Expand All @@ -488,6 +490,45 @@ def wrapper(fn):
return wrapper


def generic_spark_task(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need this method right?
We can directly use - SdkGenericSparkTask?

Unless you are defaulting somethings. (i see timeout)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we are defaulting. In-addition, I also want users to be able to look-up all task_types supported from a single place.

spark_type,
main_class,
main_application_file,
cache_version='',
retries=0,
interruptible=None,
inputs=None,
deprecated='',
cache=False,
timeout=None,
spark_conf=None,
hadoop_conf=None,
environment=None,
):
"""
Create a generic spark task. This task will connect to a Spark cluster, configure the environment,
and then execute the mainClass code as the Spark driver program.

"""

return _sdk_generic_spark_task.SdkGenericSparkTask(
task_type=_common_constants.SdkTaskType.SPARK_TASK,
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
spark_type = spark_type,
task_inputs= inputs,
main_class = main_class or "",
main_application_file = main_application_file or "",
spark_conf=spark_conf or {},
hadoop_conf=hadoop_conf or {},
environment=environment or {},
)


def qubole_spark_task(*args, **kwargs):
"""
:rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask
Expand Down
38 changes: 38 additions & 0 deletions tests/flytekit/common/workflows/scala_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from flytekit.sdk.tasks import generic_spark_task, inputs, python_task
from flytekit.sdk.types import Types
from flytekit.sdk.spark_types import SparkType
from flytekit.sdk.workflow import workflow_class, Input


scala_spark = generic_spark_task(
spark_type=SparkType.SCALA,
inputs=inputs(partitions=Types.Integer),
main_class="org.apache.spark.examples.SparkPi",
main_application_file="local:///opt/spark/examples/jars/spark-examples.jar",
spark_conf={
'spark.driver.memory': "1000M",
'spark.executor.memory': "1000M",
'spark.executor.cores': '1',
'spark.executor.instances': '2',
},
cache_version='1'
)


@inputs(date_triggered=Types.Datetime)
@python_task(cache_version='1')
def print_every_time(workflow_parameters, date_triggered):
print("My input : {}".format(date_triggered))


@workflow_class
class SparkTasksWorkflow(object):
triggered_date = Input(Types.Datetime)
partitions = Input(Types.Integer)
spark_task = scala_spark(partitions=partitions)
print_always = print_every_time(
date_triggered=triggered_date)