-
Notifications
You must be signed in to change notification settings - Fork 300
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Generic Spark Integration * Image * PR comments * PR comments * Separate task-type for generic-spark * PR comments
- Loading branch information
1 parent
24ceff1
commit 3d6437b
Showing
7 changed files
with
288 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from __future__ import absolute_import | ||
import flytekit.plugins | ||
|
||
__version__ = '0.7.0b3' | ||
__version__ = '0.7.0b4' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from __future__ import absolute_import | ||
|
||
try: | ||
from inspect import getfullargspec as _getargspec | ||
except ImportError: | ||
from inspect import getargspec as _getargspec | ||
|
||
from flytekit import __version__ | ||
import sys as _sys | ||
import six as _six | ||
from flytekit.common.tasks import task as _base_tasks | ||
from flytekit.common.types import helpers as _helpers, primitives as _primitives | ||
|
||
from flytekit.models import literals as _literal_models, task as _task_models | ||
from google.protobuf.json_format import MessageToDict as _MessageToDict | ||
from flytekit.common import interface as _interface | ||
from flytekit.common.exceptions import user as _user_exceptions | ||
from flytekit.common.exceptions import scopes as _exception_scopes | ||
|
||
from flytekit.configuration import internal as _internal_config | ||
|
||
input_types_supported = { _primitives.Integer, | ||
_primitives.Boolean, | ||
_primitives.Float, | ||
_primitives.String, | ||
_primitives.Datetime, | ||
_primitives.Timedelta, | ||
} | ||
|
||
|
||
class SdkGenericSparkTask( _base_tasks.SdkTask): | ||
""" | ||
This class includes the additional logic for building a task that executes as a Spark Job. | ||
""" | ||
def __init__( | ||
self, | ||
task_type, | ||
discovery_version, | ||
retries, | ||
interruptible, | ||
task_inputs, | ||
deprecated, | ||
discoverable, | ||
timeout, | ||
spark_type, | ||
main_class, | ||
main_application_file, | ||
spark_conf, | ||
hadoop_conf, | ||
environment, | ||
): | ||
""" | ||
:param Text task_type: string describing the task type | ||
:param Text discovery_version: string describing the version for task discovery purposes | ||
:param int retries: Number of retries to attempt | ||
:param bool interruptible: Whether or not task is interruptible | ||
:param Text deprecated: | ||
:param bool discoverable: | ||
:param datetime.timedelta timeout: | ||
:param Text spark_type: Type of Spark Job: Scala/Java | ||
:param Text main_class: Main class to execute for Scala/Java jobs | ||
:param Text main_application_file: Main application file | ||
:param dict[Text,Text] spark_conf: | ||
:param dict[Text,Text] hadoop_conf: | ||
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task. | ||
""" | ||
|
||
spark_job = _task_models.SparkJob( | ||
spark_conf=spark_conf, | ||
hadoop_conf=hadoop_conf, | ||
spark_type = spark_type, | ||
application_file=main_application_file, | ||
main_class=main_class, | ||
executor_path=_sys.executable, | ||
).to_flyte_idl() | ||
|
||
super(SdkGenericSparkTask, self).__init__( | ||
task_type, | ||
_task_models.TaskMetadata( | ||
discoverable, | ||
_task_models.RuntimeMetadata( | ||
_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, | ||
__version__, | ||
'spark' | ||
), | ||
timeout, | ||
_literal_models.RetryStrategy(retries), | ||
interruptible, | ||
discovery_version, | ||
deprecated | ||
), | ||
_interface.TypedInterface({}, {}), | ||
_MessageToDict(spark_job), | ||
) | ||
|
||
# Add Inputs | ||
if task_inputs is not None: | ||
task_inputs(self) | ||
|
||
# Container after the Inputs have been updated. | ||
self._container = self._get_container_definition( | ||
environment=environment | ||
) | ||
|
||
def _validate_inputs(self, inputs): | ||
""" | ||
:param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate | ||
:raises: flytekit.common.exceptions.user.FlyteValidationException | ||
""" | ||
for k, v in _six.iteritems(inputs): | ||
sdk_type =_helpers.get_sdk_type_from_literal_type(v.type) | ||
if sdk_type not in input_types_supported: | ||
raise _user_exceptions.FlyteValidationException( | ||
"Input Type '{}' not supported. Only Primitives are supported for Scala/Java Spark.".format(sdk_type) | ||
) | ||
super(SdkGenericSparkTask, self)._validate_inputs(inputs) | ||
|
||
@_exception_scopes.system_entry_point | ||
def add_inputs(self, inputs): | ||
""" | ||
Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given | ||
name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in | ||
the wrapped function. | ||
:param dict[Text, flytekit.models.interface.Variable] inputs: names and variables | ||
""" | ||
self._validate_inputs(inputs) | ||
self.interface.inputs.update(inputs) | ||
|
||
def _get_container_definition( | ||
self, | ||
environment=None, | ||
): | ||
""" | ||
:rtype: Container | ||
""" | ||
|
||
args = [] | ||
for k, v in _six.iteritems(self.interface.inputs): | ||
args.append("--{}".format(k)) | ||
args.append("{{{{.Inputs.{}}}}}".format(k)) | ||
|
||
return _task_models.Container( | ||
image=_internal_config.IMAGE.get(), | ||
command=[], | ||
args=args, | ||
resources=_task_models.Resources([], []), | ||
env=environment, | ||
config={} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import enum | ||
|
||
|
||
class SparkType(enum.Enum): | ||
PYTHON = 1 | ||
SCALA = 2 | ||
JAVA = 3 | ||
R = 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from flytekit.sdk.tasks import generic_spark_task, inputs, python_task | ||
from flytekit.sdk.types import Types | ||
from flytekit.sdk.spark_types import SparkType | ||
from flytekit.sdk.workflow import workflow_class, Input | ||
|
||
|
||
scala_spark = generic_spark_task( | ||
spark_type=SparkType.SCALA, | ||
inputs=inputs(partitions=Types.Integer), | ||
main_class="org.apache.spark.examples.SparkPi", | ||
main_application_file="local:///opt/spark/examples/jars/spark-examples.jar", | ||
spark_conf={ | ||
'spark.driver.memory': "1000M", | ||
'spark.executor.memory': "1000M", | ||
'spark.executor.cores': '1', | ||
'spark.executor.instances': '2', | ||
}, | ||
cache_version='1' | ||
) | ||
|
||
|
||
@inputs(date_triggered=Types.Datetime) | ||
@python_task(cache_version='1') | ||
def print_every_time(workflow_parameters, date_triggered): | ||
print("My input : {}".format(date_triggered)) | ||
|
||
|
||
@workflow_class | ||
class SparkTasksWorkflow(object): | ||
triggered_date = Input(Types.Datetime) | ||
partitions = Input(Types.Integer) | ||
spark_task = scala_spark(partitions=partitions) | ||
print_always = print_every_time( | ||
date_triggered=triggered_date) |