Skip to content

Commit

Permalink
Add node resource overrides (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan authored Jun 25, 2021
1 parent 7667a15 commit 72f342f
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 96 deletions.
50 changes: 25 additions & 25 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# make dev-requirements.txt
Expand Down Expand Up @@ -75,6 +75,8 @@ dirhash==0.2.1
# flytekit
distro==1.5.0
# via docker-compose
docker[ssh]==5.0.0
# via docker-compose
docker-compose==1.29.2
# via
# pytest-docker
Expand All @@ -83,26 +85,24 @@ docker-image-py==0.1.10
# via
# -c requirements.txt
# flytekit
docker[ssh]==5.0.0
# via docker-compose
dockerpty==0.4.1
# via docker-compose
docopt==0.6.2
# via docker-compose
flake8-black==0.2.1
# via -r dev-requirements.in
flake8-isort==4.0.0
# via -r dev-requirements.in
flake8==3.9.2
# via
# -r dev-requirements.in
# flake8-black
# flake8-isort
flyteidl==0.19.2
flake8-black==0.2.1
# via -r dev-requirements.in
flake8-isort==4.0.0
# via -r dev-requirements.in
flyteidl==0.19.5
# via
# -c requirements.txt
# flytekit
grpcio==1.38.0
grpcio==1.38.1
# via
# -c requirements.txt
# flytekit
Expand All @@ -116,7 +116,7 @@ importlib-metadata==4.5.0
# keyring
iniconfig==1.1.1
# via pytest
isort==5.8.0
isort==5.9.1
# via
# -r dev-requirements.in
# flake8-isort
Expand All @@ -136,6 +136,12 @@ markupsafe==2.0.1
# via
# -c requirements.txt
# jinja2
marshmallow==3.12.1
# via
# -c requirements.txt
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via
# -c requirements.txt
Expand All @@ -144,29 +150,23 @@ marshmallow-jsonschema==0.12.0
# via
# -c requirements.txt
# flytekit
marshmallow==3.12.1
# via
# -c requirements.txt
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
mccabe==0.6.1
# via flake8
mock==4.0.3
# via -r dev-requirements.in
mypy==0.910
# via -r dev-requirements.in
mypy-extensions==0.4.3
# via
# -c requirements.txt
# black
# mypy
# typing-inspect
mypy==0.902
# via -r dev-requirements.in
natsort==7.1.1
# via
# -c requirements.txt
# flytekit
numpy==1.20.3
numpy==1.21.0
# via
# -c requirements.txt
# pandas
Expand All @@ -175,7 +175,7 @@ packaging==20.9
# via
# -c requirements.txt
# pytest
pandas==1.2.4
pandas==1.2.5
# via
# -c requirements.txt
# flytekit
Expand Down Expand Up @@ -224,22 +224,22 @@ pyrsistent==0.17.3
# via
# -c requirements.txt
# jsonschema
pytest-docker==0.10.3
# via pytest-flyte
git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte
# via -r dev-requirements.in
pytest==6.2.4
# via
# -r dev-requirements.in
# pytest-docker
# pytest-flyte
pytest-docker==0.10.3
# via pytest-flyte
git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte
# via -r dev-requirements.in
python-dateutil==2.8.1
# via
# -c requirements.txt
# croniter
# flytekit
# pandas
python-dotenv==0.17.1
python-dotenv==0.18.0
# via docker-compose
python-json-logger==2.0.1
# via
Expand Down
60 changes: 30 additions & 30 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# make doc-requirements.txt
Expand All @@ -16,7 +16,7 @@ appnope==0.1.2
# via
# ipykernel
# ipython
astroid==2.5.8
astroid==2.6.0
# via sphinx-autoapi
async-generator==1.10
# via nbclient
Expand All @@ -39,9 +39,9 @@ black==21.6b0
# via papermill
bleach==3.3.0
# via nbconvert
boto3==1.17.94
boto3==1.17.100
# via sagemaker-training
botocore==1.20.94
botocore==1.20.100
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -88,15 +88,15 @@ entrypoints==0.3
# via
# nbconvert
# papermill
flyteidl==0.19.2
flyteidl==0.19.5
# via flytekit
git+git://github.com/flyteorg/furo@main
# via -r doc-requirements.in
gevent==21.1.2
# via sagemaker-training
greenlet==1.1.0
# via gevent
grpcio==1.38.0
grpcio==1.38.1
# via
# -r doc-requirements.in
# flytekit
Expand All @@ -112,12 +112,12 @@ inotify_simple==1.2.1
# via sagemaker-training
ipykernel==5.5.5
# via flytekit
ipython==7.24.1
# via ipykernel
ipython-genutils==0.2.0
# via
# nbformat
# traitlets
ipython==7.24.1
# via ipykernel
jedi==0.18.0
# via ipython
jinja2==3.0.1
Expand Down Expand Up @@ -152,15 +152,15 @@ lxml==4.6.3
# via sphinx-material
markupsafe==2.0.1
# via jinja2
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.12.0
# via flytekit
marshmallow==3.12.1
# via
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.12.0
# via flytekit
matplotlib-inline==0.1.2
# via ipython
mistune==0.8.4
Expand All @@ -175,7 +175,7 @@ nbclient==0.5.3
# via
# nbconvert
# papermill
nbconvert==6.0.7
nbconvert==6.1.0
# via flytekit
nbformat==5.1.3
# via
Expand All @@ -184,7 +184,7 @@ nbformat==5.1.3
# papermill
nest-asyncio==1.5.1
# via nbclient
numpy==1.20.3
numpy==1.21.0
# via
# flytekit
# pandas
Expand All @@ -195,7 +195,7 @@ packaging==20.9
# via
# bleach
# sphinx
pandas==1.2.4
pandas==1.2.5
# via flytekit
pandocfilters==1.4.3
# via nbconvert
Expand All @@ -213,7 +213,7 @@ pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
prompt-toolkit==3.0.18
prompt-toolkit==3.0.19
# via ipython
protobuf==3.17.3
# via
Expand All @@ -225,10 +225,10 @@ psutil==5.8.0
# via sagemaker-training
ptyprocess==0.7.0
# via pexpect
py4j==0.10.9
# via pyspark
py==1.10.0
# via retry
py4j==0.10.9
# via pyspark
pyarrow==3.0.0
# via flytekit
pycparser==2.20
Expand Down Expand Up @@ -296,7 +296,7 @@ sagemaker-training==3.9.2
# via flytekit
scantree==0.0.1
# via dirhash
scipy==1.6.3
scipy==1.7.0
# via sagemaker-training
six==1.16.0
# via
Expand All @@ -321,6 +321,17 @@ sortedcontainers==2.4.0
# via flytekit
soupsieve==2.2.1
# via beautifulsoup4
sphinx==3.5.4
# via
# -r doc-requirements.in
# furo
# sphinx-autoapi
# sphinx-code-include
# sphinx-copybutton
# sphinx-fontawesome
# sphinx-gallery
# sphinx-material
# sphinx-prompt
sphinx-autoapi==1.8.1
# via -r doc-requirements.in
sphinx-code-include==1.1.1
Expand All @@ -335,17 +346,6 @@ sphinx-material==0.0.32
# via -r doc-requirements.in
sphinx-prompt==1.4.0
# via -r doc-requirements.in
sphinx==3.5.4
# via
# -r doc-requirements.in
# furo
# sphinx-autoapi
# sphinx-code-include
# sphinx-copybutton
# sphinx-fontawesome
# sphinx-gallery
# sphinx-material
# sphinx-prompt
sphinxcontrib-applehelp==1.0.2
# via sphinx
sphinxcontrib-devhelp==1.0.2
Expand Down
5 changes: 4 additions & 1 deletion flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flytekit.models.core import workflow as _core_wf
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.workflow import BranchNode as BranchNodeModel
from flytekit.models.core.workflow import TaskNodeOverrides

FlyteLocalEntity = Union[
PythonTask,
Expand Down Expand Up @@ -272,7 +273,9 @@ def get_serializable_node(
inputs=entity.bindings,
upstream_node_ids=[n.id for n in upstream_sdk_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(reference_id=task_spec.template.id),
task_node=workflow_model.TaskNode(
reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)
),
)
if entity._aliases:
node_model._output_aliases = entity._aliases
Expand Down
33 changes: 33 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import typing
from typing import Any, List

from flytekit.common.utils import _dnsify
from flytekit.core.resources import Resources
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.task import Resources as _resources_model


class Node(object):
Expand All @@ -30,6 +33,7 @@ def __init__(
self._flyte_entity = flyte_entity
self._aliases: _workflow_model.Alias = None
self._outputs = None
self._resources: typing.Optional[_resources_model] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -81,4 +85,33 @@ def with_overrides(self, *args, **kwargs):
self._aliases = []
for k, v in alias_dict.items():
self._aliases.append(_workflow_model.Alias(var=k, alias=v))
if "requests" in kwargs or "limits" in kwargs:
requests = _convert_resource_overrides(kwargs["requests"], "requests")
limits = _convert_resource_overrides(kwargs["limits"], "limits")
self._resources = _resources_model(requests=requests, limits=limits)
return self


def _convert_resource_overrides(
resources: typing.Optional[Resources], resource_name: str
) -> [_resources_model.ResourceEntry]:
if resources is None:
return []
if not isinstance(resources, Resources):
raise AssertionError(f"{resource_name} should be specified as flytekit.Resources")
resource_entries = []
if resources.cpu is not None:
resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu))

if resources.mem is not None:
resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.MEMORY, resources.mem))

if resources.gpu is not None:
resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.GPU, resources.gpu))

if resources.storage is not None:
resource_entries.append(
_resources_model.ResourceEntry(_resources_model.ResourceName.STORAGE, resources.storage)
)

return resource_entries
Loading

0 comments on commit 72f342f

Please sign in to comment.