Skip to content

Commit

Permalink
[Temporary] tensorflow plugin implementation (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
pingsutw authored Sep 14, 2020
1 parent fd57a48 commit bfa2040
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 0 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ If `@pytorch_task` is to be used, one should install the `pytorch` plugin.
pip install "flytekit[pytorch]"
```

### TensorFlow

If `@tensorflow_task` is to be used, one should install the `tensorflow` plugin.

```bash
pip install flytekit[tensorflow]
```

### Full Installation

To install all or multiple available plugins, one can specify them individually:
Expand Down
1 change: 1 addition & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SdkTaskType(object):
SENSOR_TASK = "sensor-task"
PRESTO_TASK = "presto"
PYTORCH_TASK = "pytorch"
TENSORFLOW_TASK = "tensorflow"
# Raw container task is just a name, it defaults to using the regular container task (like python etc), but sets the data_config in the container
RAW_CONTAINER_TASK = "raw-container"
SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task"
Expand Down
84 changes: 84 additions & 0 deletions flytekit/common/tasks/tensorflow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import absolute_import

try:
from inspect import getfullargspec as _getargspec
except ImportError:
from inspect import getargspec as _getargspec

import six as _six
from flytekit.common import constants as _constants
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.tasks import output as _task_output, sdk_runnable as _sdk_runnable
from flytekit.common.types import helpers as _type_helpers
from flytekit.models import literals as _literal_models, task as _task_models
from google.protobuf.json_format import MessageToDict as _MessageToDict


class SdkRunnableTensorflowContainer(_sdk_runnable.SdkRunnableContainer):

@property
def args(self):
"""
Override args to remove the injection of command prefixes
:rtype: list[Text]
"""
return self._args

class SdkTensorFlowTask(_sdk_runnable.SdkRunnableTask):
def __init__(
self,
task_function,
task_type,
discovery_version,
retries,
interruptible,
deprecated,
discoverable,
timeout,
workers_count,
ps_replicas_count,
chief_replicas_count,
per_replica_storage_request,
per_replica_cpu_request,
per_replica_gpu_request,
per_replica_memory_request,
per_replica_storage_limit,
per_replica_cpu_limit,
per_replica_gpu_limit,
per_replica_memory_limit,
environment
):
tensorflow_job = _task_models.TensorFlowJob(
workers_count=workers_count,
ps_replicas_count=ps_replicas_count,
chief_replicas_count=chief_replicas_count
).to_flyte_idl()
super(SdkTensorFlowTask, self).__init__(
task_function=task_function,
task_type=task_type,
discovery_version=discovery_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
storage_request=per_replica_storage_request,
cpu_request=per_replica_cpu_request,
gpu_request=per_replica_gpu_request,
memory_request=per_replica_memory_request,
storage_limit=per_replica_storage_limit,
cpu_limit=per_replica_cpu_limit,
gpu_limit=per_replica_gpu_limit,
memory_limit=per_replica_memory_limit,
discoverable=discoverable,
timeout=timeout,
environment=environment,
custom=_MessageToDict(tensorflow_job)
)

def _get_container_definition(
self,
**kwargs
):
"""
:rtype: SdkRunnableTensorflowContainer
"""
return super(SdkTensorFlowTask, self)._get_container_definition(cls=SdkRunnableTensorflowContainer, **kwargs)
36 changes: 36 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from flytekit.models import interface as _interface
from flytekit.models import literals as _literals
from flytekit.models.core import identifier as _identifier
from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task
from flytekit.plugins import flyteidl as _lazy_flyteidl
from flytekit.sdk.spark_types import SparkType as _spark_type

Expand Down Expand Up @@ -881,3 +882,38 @@ def to_flyte_idl(self):
@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(workers_count=pb2_object.workers,)


class TensorFlowJob(_common.FlyteIdlEntity):

def __init__(self, workers_count, ps_replicas_count, chief_replicas_count):
self._workers_count = workers_count
self._ps_replicas_count = ps_replicas_count
self._chief_replicas_count = chief_replicas_count

@property
def workers_count(self):
return self._workers_count

@property
def ps_replicas_count(self):
return self._ps_replicas_count

@property
def chief_replicas_count(self):
return self._chief_replicas_count

def to_flyte_idl(self):
return _tensorflow_task.DistributedTensorflowTrainingTask(
workers=self.workers_count,
ps_replicas=self.ps_replicas_count,
chief_replicas=self.chief_replicas_count
)

@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(
workers_count=pb2_object.workers,
ps_replicas_count=pb2_object.ps_replicas,
chief_replicas_count=pb2_object.chief_replicas
)
2 changes: 2 additions & 0 deletions flytekit/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

torch = _lazy_loader.lazy_load_module("torch") # type: _lazy_loader._LazyLoadModule

tensorflow = _lazy_loader.lazy_load_module("tensorflow") # type: types.ModuleType

_lazy_loader.LazyLoadPlugin("spark", ["pyspark>=2.4.0,<3.0.0"], [pyspark])

_lazy_loader.LazyLoadPlugin("spark3", ["pyspark>=3.0.0"], [pyspark])
Expand Down
163 changes: 163 additions & 0 deletions flytekit/sdk/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.common.tasks import generic_spark_task as _sdk_generic_spark_task
from flytekit.common.tasks import hive_task as _sdk_hive_tasks
from flytekit.common.tasks import pytorch_task as _sdk_pytorch_tasks
from flytekit.common.tasks import tensorflow_task as _sdk_tensorflow_tasks
from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic
from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks
from flytekit.common.tasks import sidecar_task as _sdk_sidecar_tasks
Expand Down Expand Up @@ -1336,3 +1337,165 @@ def wrapper(fn):
return wrapper(_task_function)
else:
return wrapper

def tensorflow_task(
_task_function=None,
cache_version='',
retries=0,
interruptible=False,
deprecated='',
cache=False,
timeout=None,
workers_count=1,
ps_replicas_count=None,
chief_replicas_count=None,
per_replica_storage_request="",
per_replica_cpu_request="",
per_replica_gpu_request="",
per_replica_memory_request="",
per_replica_storage_limit="",
per_replica_cpu_limit="",
per_replica_gpu_limit="",
per_replica_memory_limit="",
environment=None,
cls=None
):
"""
Decorator to create a Tensorflow Task definition. This task will submit TFJob (see https://github.com/kubeflow/tf-operator)
defined by the code within the _task_function to k8s cluster.
.. code-block:: python
@inputs(int_list=[Types.Integer])
@outputs(result=Types.Integer
@tensorflow_task(
workers_count=2,
ps_replicas_count=1,
chief_replicas_count=1,
per_replica_cpu_request="500m",
per_replica_memory_request="4Gi",
per_replica_memory_limit="8Gi",
per_replica_gpu_limit="1",
)
def my_tensorflow_job(wf_params, int_list, result):
pass
:param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must
take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword
arguments are allowed for wrapped task functions.
:param Text cache_version: [optional] string representing logical version for discovery. This field should be
updated whenever the underlying algorithm changes.
.. note::
This argument is required to be a non-empty string if `cache` is True.
:param int retries: [optional] integer determining number of times task can be retried on
:py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults
to 0.
.. note::
If retries > 0, the task must be able to recover from any remote state created within the user code. It is
strongly recommended that tasks are written to be idempotent.
:param bool interruptible: [optional] boolean describing if the task is interruptible.
:param Text deprecated: [optional] string that should be provided if this task is deprecated. The string
will be logged as a warning so it should contain information regarding how to update to a newer task.
:param bool cache: [optional] boolean describing if the outputs of this task should be cached and
re-usable.
:param datetime.timedelta timeout: [optional] describes how long the task should be allowed to
run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run
indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout.
:param int workers_count: integer determining the number of worker replicas spawned in the cluster for this job
:param int ps_replicas_count: integer determining the number of parameter server replicas spawned in the cluster for this job
:param int chief_replicas_count: integer determining the number of chief server replicas spawned in the cluster for this job
:param Text per_replica_storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space
for each replica spawned for this job (i.e. both for parameter, chief server and workers). Default is set by platform-level configuration.
.. note::
This is currently not supported by the platform.
:param Text per_replica_cpu_request: [optional] Kubernetes resource string for lower-bound of cores for each replica
spawned for this job (i.e. both for parameter, chief server and workers).
This can be set to a fractional portion of a CPU. Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs for each
replica spawned for this job (i.e. both for parameter, chief server and workers).
Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_memory_request: [optional] Kubernetes resource string for lower-bound of physical memory
necessary for each replica spawned for this job (i.e. both for parameter, chief server and workers). Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space
for each replica spawned for this job (i.e. both for parameter, chief server and workers).
This amount is not guaranteed! If not specified, it is set equal to storage_request.
.. note::
This is currently not supported by the platform.
:param Text per_replica_cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for each replica
spawned for this job (i.e. both for parameter, chief server and workers).
This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified,
it is set equal to cpu_request.
:param Text per_replica_gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs for each
replica spawned for this job (i.e. both for parameter, chief server and workers).
This amount is not guaranteed! If not specified, it is set equal to gpu_request.
:param Text per_replica_memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory
necessary for each replica spawned for this job (i.e. both for parameter, chief server and workers).
This amount is not guaranteed! If not specified, it is set equal to memory_request.
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
:param cls: This can be used to override the task implementation with a user-defined extension. The class
provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to
inject bespoke logic into the base Flyte programming model.
:rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask
"""
def wrapper(fn):
return (cls or _sdk_tensorflow_tasks.SdkTensorFlowTask)(
task_function=fn,
task_type=_common_constants.SdkTaskType.TENSORFLOW_TASK,
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
workers_count=workers_count,
ps_replicas_count=ps_replicas_count,
chief_replicas_count=chief_replicas_count,
per_replica_storage_request=per_replica_storage_request,
per_replica_cpu_request=per_replica_cpu_request,
per_replica_gpu_request=per_replica_gpu_request,
per_replica_memory_request=per_replica_memory_request,
per_replica_storage_limit=per_replica_storage_limit,
per_replica_cpu_limit=per_replica_cpu_limit,
per_replica_gpu_limit=per_replica_gpu_limit,
per_replica_memory_limit=per_replica_memory_limit,
environment=environment or {}
)

if _task_function:
return wrapper(_task_function)
else:
return wrapper
51 changes: 51 additions & 0 deletions tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import absolute_import
from flytekit.sdk.tasks import tensorflow_task, inputs, outputs
from flytekit.sdk.types import Types
from flytekit.common import constants as _common_constants
from flytekit.common.tasks import sdk_runnable as _sdk_runnable, tensorflow_task as _tensorflow_task
from flytekit.models import types as _type_models
from flytekit.models.core import identifier as _identifier
import datetime as _datetime


@inputs(in1=Types.Integer)
@outputs(out1=Types.String)
@tensorflow_task(workers_count=2, ps_replicas_count=1, chief_replicas_count=1)
def simple_tensorflow_task(wf_params, sc, in1, out1):
pass


simple_tensorflow_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version")


def test_simple_tensorflow_task():
assert isinstance(simple_tensorflow_task, _tensorflow_task.SdkTensorFlowTask)
assert isinstance(simple_tensorflow_task, _sdk_runnable.SdkRunnableTask)
assert simple_tensorflow_task.interface.inputs['in1'].description == ''
assert simple_tensorflow_task.interface.inputs['in1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER)
assert simple_tensorflow_task.interface.outputs['out1'].description == ''
assert simple_tensorflow_task.interface.outputs['out1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.STRING)
assert simple_tensorflow_task.type == _common_constants.SdkTaskType.TENSORFLOW_TASK
assert simple_tensorflow_task.task_function_name == 'simple_tensorflow_task'
assert simple_tensorflow_task.task_module == __name__
assert simple_tensorflow_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert simple_tensorflow_task.metadata.deprecated_error_message == ''
assert simple_tensorflow_task.metadata.discoverable is False
assert simple_tensorflow_task.metadata.discovery_version == ''
assert simple_tensorflow_task.metadata.retries.retries == 0
assert len(simple_tensorflow_task.container.resources.limits) == 0
assert len(simple_tensorflow_task.container.resources.requests) == 0
assert simple_tensorflow_task.custom['workers'] == 2
assert simple_tensorflow_task.custom['psReplicas'] == 1
assert simple_tensorflow_task.custom['chiefReplicas'] == 1


# Should strip out the venv component of the args.
assert simple_tensorflow_task._get_container_definition().args[0] == 'pyflyte-execute'

pb2 = simple_tensorflow_task.to_flyte_idl()
assert pb2.custom['workers'] == 2
assert pb2.custom['psReplicas'] == 1
assert pb2.custom['chiefReplicas'] == 1

0 comments on commit bfa2040

Please sign in to comment.