-
Notifications
You must be signed in to change notification settings - Fork 302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generic Spark Integration #101
Changes from all commits
ac991a3
bb43cac
a3f497f
e465e95
ed010a1
f21e412
beca629
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from __future__ import absolute_import | ||
import flytekit.plugins | ||
|
||
__version__ = '0.7.0b3' | ||
__version__ = '0.7.0b4' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from __future__ import absolute_import | ||
|
||
try: | ||
from inspect import getfullargspec as _getargspec | ||
except ImportError: | ||
from inspect import getargspec as _getargspec | ||
|
||
from flytekit import __version__ | ||
import sys as _sys | ||
import six as _six | ||
from flytekit.common.tasks import task as _base_tasks | ||
from flytekit.common.types import helpers as _helpers, primitives as _primitives | ||
|
||
from flytekit.models import literals as _literal_models, task as _task_models | ||
from google.protobuf.json_format import MessageToDict as _MessageToDict | ||
from flytekit.common import interface as _interface | ||
from flytekit.common.exceptions import user as _user_exceptions | ||
from flytekit.common.exceptions import scopes as _exception_scopes | ||
|
||
from flytekit.configuration import internal as _internal_config | ||
|
||
input_types_supported = { _primitives.Integer, | ||
_primitives.Boolean, | ||
_primitives.Float, | ||
_primitives.String, | ||
_primitives.Datetime, | ||
_primitives.Timedelta, | ||
} | ||
|
||
|
||
class SdkGenericSparkTask( _base_tasks.SdkTask): | ||
""" | ||
This class includes the additional logic for building a task that executes as a Spark Job. | ||
|
||
""" | ||
def __init__( | ||
self, | ||
task_type, | ||
discovery_version, | ||
retries, | ||
interruptible, | ||
task_inputs, | ||
deprecated, | ||
discoverable, | ||
timeout, | ||
spark_type, | ||
main_class, | ||
main_application_file, | ||
spark_conf, | ||
hadoop_conf, | ||
environment, | ||
): | ||
""" | ||
:param Text task_type: string describing the task type | ||
:param Text discovery_version: string describing the version for task discovery purposes | ||
:param int retries: Number of retries to attempt | ||
:param bool interruptible: Whether or not task is interruptible | ||
:param Text deprecated: | ||
:param bool discoverable: | ||
:param datetime.timedelta timeout: | ||
:param Text spark_type: Type of Spark Job: Scala/Java | ||
:param Text main_class: Main class to execute for Scala/Java jobs | ||
:param Text main_application_file: Main application file | ||
:param dict[Text,Text] spark_conf: | ||
:param dict[Text,Text] hadoop_conf: | ||
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task. | ||
""" | ||
|
||
spark_job = _task_models.SparkJob( | ||
spark_conf=spark_conf, | ||
hadoop_conf=hadoop_conf, | ||
spark_type = spark_type, | ||
application_file=main_application_file, | ||
main_class=main_class, | ||
executor_path=_sys.executable, | ||
).to_flyte_idl() | ||
|
||
super(SdkGenericSparkTask, self).__init__( | ||
task_type, | ||
_task_models.TaskMetadata( | ||
discoverable, | ||
_task_models.RuntimeMetadata( | ||
_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, | ||
__version__, | ||
'spark' | ||
), | ||
timeout, | ||
_literal_models.RetryStrategy(retries), | ||
interruptible, | ||
discovery_version, | ||
deprecated | ||
), | ||
_interface.TypedInterface({}, {}), | ||
_MessageToDict(spark_job), | ||
) | ||
|
||
# Add Inputs | ||
if task_inputs is not None: | ||
task_inputs(self) | ||
|
||
# Container after the Inputs have been updated. | ||
self._container = self._get_container_definition( | ||
environment=environment | ||
) | ||
|
||
def _validate_inputs(self, inputs): | ||
""" | ||
:param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate | ||
:raises: flytekit.common.exceptions.user.FlyteValidationException | ||
""" | ||
for k, v in _six.iteritems(inputs): | ||
sdk_type =_helpers.get_sdk_type_from_literal_type(v.type) | ||
if sdk_type not in input_types_supported: | ||
raise _user_exceptions.FlyteValidationException( | ||
"Input Type '{}' not supported. Only Primitives are supported for Scala/Java Spark.".format(sdk_type) | ||
) | ||
super(SdkGenericSparkTask, self)._validate_inputs(inputs) | ||
|
||
@_exception_scopes.system_entry_point | ||
def add_inputs(self, inputs): | ||
""" | ||
Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given | ||
name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in | ||
the wrapped function. | ||
:param dict[Text, flytekit.models.interface.Variable] inputs: names and variables | ||
""" | ||
self._validate_inputs(inputs) | ||
self.interface.inputs.update(inputs) | ||
|
||
def _get_container_definition( | ||
self, | ||
environment=None, | ||
): | ||
""" | ||
:rtype: Container | ||
""" | ||
|
||
args = [] | ||
for k, v in _six.iteritems(self.interface.inputs): | ||
args.append("--{}".format(k)) | ||
args.append("{{{{.Inputs.{}}}}}".format(k)) | ||
|
||
return _task_models.Container( | ||
image=_internal_config.IMAGE.get(), | ||
command=[], | ||
args=args, | ||
resources=_task_models.Resources([], []), | ||
env=environment, | ||
config={} | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import enum | ||
|
||
|
||
class SparkType(enum.Enum): | ||
PYTHON = 1 | ||
SCALA = 2 | ||
JAVA = 3 | ||
R = 4 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,10 @@ | |
from flytekit.common import constants as _common_constants | ||
from flytekit.common.exceptions import user as _user_exceptions | ||
from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \ | ||
spark_task as _sdk_spark_tasks, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks | ||
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks | ||
from flytekit.common.tasks import task as _task | ||
from flytekit.common.types import helpers as _type_helpers | ||
from flytekit.sdk.spark_types import SparkType as _spark_type | ||
from flytekit.models import interface as _interface_model | ||
|
||
|
||
|
@@ -474,6 +475,7 @@ def wrapper(fn): | |
discovery_version=cache_version, | ||
retries=retries, | ||
interruptible=interruptible, | ||
spark_type= _spark_type.PYTHON, | ||
deprecated=deprecated, | ||
discoverable=cache, | ||
timeout=timeout or _datetime.timedelta(seconds=0), | ||
|
@@ -488,6 +490,45 @@ def wrapper(fn): | |
return wrapper | ||
|
||
|
||
def generic_spark_task( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think we need this method right? Unless you are defaulting somethings. (i see timeout) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we are defaulting. In-addition, I also want users to be able to look-up all task_types supported from a single place. |
||
spark_type, | ||
main_class, | ||
main_application_file, | ||
cache_version='', | ||
retries=0, | ||
interruptible=None, | ||
inputs=None, | ||
deprecated='', | ||
cache=False, | ||
timeout=None, | ||
spark_conf=None, | ||
hadoop_conf=None, | ||
environment=None, | ||
): | ||
""" | ||
Create a generic spark task. This task will connect to a Spark cluster, configure the environment, | ||
and then execute the mainClass code as the Spark driver program. | ||
|
||
""" | ||
|
||
return _sdk_generic_spark_task.SdkGenericSparkTask( | ||
task_type=_common_constants.SdkTaskType.SPARK_TASK, | ||
discovery_version=cache_version, | ||
retries=retries, | ||
interruptible=interruptible, | ||
deprecated=deprecated, | ||
discoverable=cache, | ||
timeout=timeout or _datetime.timedelta(seconds=0), | ||
spark_type = spark_type, | ||
task_inputs= inputs, | ||
main_class = main_class or "", | ||
main_application_file = main_application_file or "", | ||
spark_conf=spark_conf or {}, | ||
hadoop_conf=hadoop_conf or {}, | ||
environment=environment or {}, | ||
) | ||
|
||
|
||
def qubole_spark_task(*args, **kwargs): | ||
""" | ||
:rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from flytekit.sdk.tasks import generic_spark_task, inputs, python_task | ||
from flytekit.sdk.types import Types | ||
from flytekit.sdk.spark_types import SparkType | ||
from flytekit.sdk.workflow import workflow_class, Input | ||
|
||
|
||
scala_spark = generic_spark_task( | ||
spark_type=SparkType.SCALA, | ||
inputs=inputs(partitions=Types.Integer), | ||
main_class="org.apache.spark.examples.SparkPi", | ||
main_application_file="local:///opt/spark/examples/jars/spark-examples.jar", | ||
spark_conf={ | ||
'spark.driver.memory': "1000M", | ||
'spark.executor.memory': "1000M", | ||
'spark.executor.cores': '1', | ||
'spark.executor.instances': '2', | ||
}, | ||
cache_version='1' | ||
) | ||
|
||
|
||
@inputs(date_triggered=Types.Datetime) | ||
@python_task(cache_version='1') | ||
def print_every_time(workflow_parameters, date_triggered): | ||
print("My input : {}".format(date_triggered)) | ||
|
||
|
||
@workflow_class | ||
class SparkTasksWorkflow(object): | ||
triggered_date = Input(Types.Datetime) | ||
partitions = Input(Types.Integer) | ||
spark_task = scala_spark(partitions=partitions) | ||
print_always = print_every_time( | ||
date_triggered=triggered_date) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is where you will have to check the types and raise an alert. You can make this as a common method for now, when we add support for all types, we can remove this check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, added it in the validate_inputs above