Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node resource overrides #523

Merged
merged 9 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# make dev-requirements.txt
Expand Down Expand Up @@ -75,6 +75,8 @@ dirhash==0.2.1
# flytekit
distro==1.5.0
# via docker-compose
docker[ssh]==5.0.0
# via docker-compose
docker-compose==1.29.2
# via
# pytest-docker
Expand All @@ -83,26 +85,24 @@ docker-image-py==0.1.10
# via
# -c requirements.txt
# flytekit
docker[ssh]==5.0.0
# via docker-compose
dockerpty==0.4.1
# via docker-compose
docopt==0.6.2
# via docker-compose
flake8-black==0.2.1
# via -r dev-requirements.in
flake8-isort==4.0.0
# via -r dev-requirements.in
flake8==3.9.2
# via
# -r dev-requirements.in
# flake8-black
# flake8-isort
flyteidl==0.19.2
flake8-black==0.2.1
# via -r dev-requirements.in
flake8-isort==4.0.0
# via -r dev-requirements.in
flyteidl==0.19.5
# via
# -c requirements.txt
# flytekit
grpcio==1.38.0
grpcio==1.38.1
# via
# -c requirements.txt
# flytekit
Expand All @@ -116,7 +116,7 @@ importlib-metadata==4.5.0
# keyring
iniconfig==1.1.1
# via pytest
isort==5.8.0
isort==5.9.1
# via
# -r dev-requirements.in
# flake8-isort
Expand All @@ -136,6 +136,12 @@ markupsafe==2.0.1
# via
# -c requirements.txt
# jinja2
marshmallow==3.12.1
# via
# -c requirements.txt
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via
# -c requirements.txt
Expand All @@ -144,29 +150,23 @@ marshmallow-jsonschema==0.12.0
# via
# -c requirements.txt
# flytekit
marshmallow==3.12.1
# via
# -c requirements.txt
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
mccabe==0.6.1
# via flake8
mock==4.0.3
# via -r dev-requirements.in
mypy==0.910
# via -r dev-requirements.in
mypy-extensions==0.4.3
# via
# -c requirements.txt
# black
# mypy
# typing-inspect
mypy==0.902
# via -r dev-requirements.in
natsort==7.1.1
# via
# -c requirements.txt
# flytekit
numpy==1.20.3
numpy==1.21.0
# via
# -c requirements.txt
# pandas
Expand All @@ -175,7 +175,7 @@ packaging==20.9
# via
# -c requirements.txt
# pytest
pandas==1.2.4
pandas==1.2.5
# via
# -c requirements.txt
# flytekit
Expand Down Expand Up @@ -224,22 +224,22 @@ pyrsistent==0.17.3
# via
# -c requirements.txt
# jsonschema
pytest-docker==0.10.3
# via pytest-flyte
git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte
# via -r dev-requirements.in
pytest==6.2.4
# via
# -r dev-requirements.in
# pytest-docker
# pytest-flyte
pytest-docker==0.10.3
# via pytest-flyte
git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte
# via -r dev-requirements.in
python-dateutil==2.8.1
# via
# -c requirements.txt
# croniter
# flytekit
# pandas
python-dotenv==0.17.1
python-dotenv==0.18.0
# via docker-compose
python-json-logger==2.0.1
# via
Expand Down
60 changes: 30 additions & 30 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# make doc-requirements.txt
Expand All @@ -16,7 +16,7 @@ appnope==0.1.2
# via
# ipykernel
# ipython
astroid==2.5.8
astroid==2.6.0
# via sphinx-autoapi
async-generator==1.10
# via nbclient
Expand All @@ -39,9 +39,9 @@ black==21.6b0
# via papermill
bleach==3.3.0
# via nbconvert
boto3==1.17.94
boto3==1.17.100
# via sagemaker-training
botocore==1.20.94
botocore==1.20.100
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -88,15 +88,15 @@ entrypoints==0.3
# via
# nbconvert
# papermill
flyteidl==0.19.2
flyteidl==0.19.5
# via flytekit
git+git://github.com/flyteorg/furo@main
# via -r doc-requirements.in
gevent==21.1.2
# via sagemaker-training
greenlet==1.1.0
# via gevent
grpcio==1.38.0
grpcio==1.38.1
# via
# -r doc-requirements.in
# flytekit
Expand All @@ -112,12 +112,12 @@ inotify_simple==1.2.1
# via sagemaker-training
ipykernel==5.5.5
# via flytekit
ipython==7.24.1
# via ipykernel
ipython-genutils==0.2.0
# via
# nbformat
# traitlets
ipython==7.24.1
# via ipykernel
jedi==0.18.0
# via ipython
jinja2==3.0.1
Expand Down Expand Up @@ -152,15 +152,15 @@ lxml==4.6.3
# via sphinx-material
markupsafe==2.0.1
# via jinja2
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.12.0
# via flytekit
marshmallow==3.12.1
# via
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.12.0
# via flytekit
matplotlib-inline==0.1.2
# via ipython
mistune==0.8.4
Expand All @@ -175,7 +175,7 @@ nbclient==0.5.3
# via
# nbconvert
# papermill
nbconvert==6.0.7
nbconvert==6.1.0
# via flytekit
nbformat==5.1.3
# via
Expand All @@ -184,7 +184,7 @@ nbformat==5.1.3
# papermill
nest-asyncio==1.5.1
# via nbclient
numpy==1.20.3
numpy==1.21.0
# via
# flytekit
# pandas
Expand All @@ -195,7 +195,7 @@ packaging==20.9
# via
# bleach
# sphinx
pandas==1.2.4
pandas==1.2.5
# via flytekit
pandocfilters==1.4.3
# via nbconvert
Expand All @@ -213,7 +213,7 @@ pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
prompt-toolkit==3.0.18
prompt-toolkit==3.0.19
# via ipython
protobuf==3.17.3
# via
Expand All @@ -225,10 +225,10 @@ psutil==5.8.0
# via sagemaker-training
ptyprocess==0.7.0
# via pexpect
py4j==0.10.9
# via pyspark
py==1.10.0
# via retry
py4j==0.10.9
# via pyspark
pyarrow==3.0.0
# via flytekit
pycparser==2.20
Expand Down Expand Up @@ -296,7 +296,7 @@ sagemaker-training==3.9.2
# via flytekit
scantree==0.0.1
# via dirhash
scipy==1.6.3
scipy==1.7.0
# via sagemaker-training
six==1.16.0
# via
Expand All @@ -321,6 +321,17 @@ sortedcontainers==2.4.0
# via flytekit
soupsieve==2.2.1
# via beautifulsoup4
sphinx==3.5.4
# via
# -r doc-requirements.in
# furo
# sphinx-autoapi
# sphinx-code-include
# sphinx-copybutton
# sphinx-fontawesome
# sphinx-gallery
# sphinx-material
# sphinx-prompt
sphinx-autoapi==1.8.1
# via -r doc-requirements.in
sphinx-code-include==1.1.1
Expand All @@ -335,17 +346,6 @@ sphinx-material==0.0.32
# via -r doc-requirements.in
sphinx-prompt==1.4.0
# via -r doc-requirements.in
sphinx==3.5.4
# via
# -r doc-requirements.in
# furo
# sphinx-autoapi
# sphinx-code-include
# sphinx-copybutton
# sphinx-fontawesome
# sphinx-gallery
# sphinx-material
# sphinx-prompt
sphinxcontrib-applehelp==1.0.2
# via sphinx
sphinxcontrib-devhelp==1.0.2
Expand Down
5 changes: 4 additions & 1 deletion flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flytekit.models.core import workflow as _core_wf
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.workflow import BranchNode as BranchNodeModel
from flytekit.models.core.workflow import TaskNodeOverrides

FlyteLocalEntity = Union[
PythonTask,
Expand Down Expand Up @@ -272,7 +273,9 @@ def get_serializable_node(
inputs=entity.bindings,
upstream_node_ids=[n.id for n in upstream_sdk_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(reference_id=task_spec.template.id),
task_node=workflow_model.TaskNode(
reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't the Resource models though right? They're the internal dataclass. Don't they need to be translated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see. can we update the type hint on line 36 of node.py then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, thanks for the catch

),
)
if entity._aliases:
node_model._output_aliases = entity._aliases
Expand Down
28 changes: 28 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Any, List

from flytekit.common.utils import _dnsify
from flytekit.core.resources import Resources
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.task import Resources as _resources_model


class Node(object):
Expand All @@ -30,6 +32,7 @@ def __init__(
self._flyte_entity = flyte_entity
self._aliases: _workflow_model.Alias = None
self._outputs = None
self._resources: Resources = None
katrogan marked this conversation as resolved.
Show resolved Hide resolved

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -81,4 +84,29 @@ def with_overrides(self, *args, **kwargs):
self._aliases = []
for k, v in alias_dict.items():
self._aliases.append(_workflow_model.Alias(var=k, alias=v))
if "requests" in kwargs or "limits" in kwargs:
requests = _convert_resource_overrides(kwargs.get("requests", Resources()), "requests")
limits = _convert_resource_overrides(kwargs.get("limits", Resources()), "limits")
self._resources = _resources_model(requests=requests, limits=limits)
return self


def _convert_resource_overrides(resources: Resources, resource_name: str) -> [_resources_model.ResourceEntry]:
if not isinstance(resources, Resources):
raise AssertionError(f"{resource_name} should be specified as flytekit.Resources")
resource_entries = []
if resources.cpu is not None:
resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu))

if resources.mem is not None:
resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.MEMORY, resources.mem))

if resources.gpu is not None:
resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.GPU, resources.gpu))

if resources.storage is not None:
resource_entries.append(
_resources_model.ResourceEntry(_resources_model.ResourceName.STORAGE, resources.storage)
)

return resource_entries
Loading