diff --git a/dev-requirements.txt b/dev-requirements.txt index b69068da2a..b975007b60 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile +# This file is autogenerated by pip-compile with python 3.8 # To update, run: # # make dev-requirements.txt @@ -75,6 +75,8 @@ dirhash==0.2.1 # flytekit distro==1.5.0 # via docker-compose +docker[ssh]==5.0.0 + # via docker-compose docker-compose==1.29.2 # via # pytest-docker @@ -83,26 +85,24 @@ docker-image-py==0.1.10 # via # -c requirements.txt # flytekit -docker[ssh]==5.0.0 - # via docker-compose dockerpty==0.4.1 # via docker-compose docopt==0.6.2 # via docker-compose -flake8-black==0.2.1 - # via -r dev-requirements.in -flake8-isort==4.0.0 - # via -r dev-requirements.in flake8==3.9.2 # via # -r dev-requirements.in # flake8-black # flake8-isort -flyteidl==0.19.2 +flake8-black==0.2.1 + # via -r dev-requirements.in +flake8-isort==4.0.0 + # via -r dev-requirements.in +flyteidl==0.19.5 # via # -c requirements.txt # flytekit -grpcio==1.38.0 +grpcio==1.38.1 # via # -c requirements.txt # flytekit @@ -116,7 +116,7 @@ importlib-metadata==4.5.0 # keyring iniconfig==1.1.1 # via pytest -isort==5.8.0 +isort==5.9.1 # via # -r dev-requirements.in # flake8-isort @@ -136,6 +136,12 @@ markupsafe==2.0.1 # via # -c requirements.txt # jinja2 +marshmallow==3.12.1 + # via + # -c requirements.txt + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema marshmallow-enum==1.5.1 # via # -c requirements.txt @@ -144,29 +150,23 @@ marshmallow-jsonschema==0.12.0 # via # -c requirements.txt # flytekit -marshmallow==3.12.1 - # via - # -c requirements.txt - # dataclasses-json - # marshmallow-enum - # marshmallow-jsonschema mccabe==0.6.1 # via flake8 mock==4.0.3 # via -r dev-requirements.in +mypy==0.910 + # via -r dev-requirements.in mypy-extensions==0.4.3 # via # -c requirements.txt # black # mypy # typing-inspect -mypy==0.902 - # via -r dev-requirements.in natsort==7.1.1 # via # -c requirements.txt # flytekit -numpy==1.20.3 +numpy==1.21.0 # via # -c requirements.txt # pandas @@ -175,7 +175,7 @@ packaging==20.9 # via # -c requirements.txt # pytest -pandas==1.2.4 +pandas==1.2.5 # via # -c requirements.txt # flytekit @@ -224,22 +224,22 @@ pyrsistent==0.17.3 # via # -c requirements.txt # jsonschema -pytest-docker==0.10.3 - # via pytest-flyte -git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte - # via -r dev-requirements.in pytest==6.2.4 # via # -r dev-requirements.in # pytest-docker # pytest-flyte +pytest-docker==0.10.3 + # via pytest-flyte +git+git://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte + # via -r dev-requirements.in python-dateutil==2.8.1 # via # -c requirements.txt # croniter # flytekit # pandas -python-dotenv==0.17.1 +python-dotenv==0.18.0 # via docker-compose python-json-logger==2.0.1 # via diff --git a/doc-requirements.txt b/doc-requirements.txt index a2cbda0117..3f68e03462 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile +# This file is autogenerated by pip-compile with python 3.8 # To update, run: # # make doc-requirements.txt @@ -16,7 +16,7 @@ appnope==0.1.2 # via # ipykernel # ipython -astroid==2.5.8 +astroid==2.6.0 # via sphinx-autoapi async-generator==1.10 # via nbclient @@ -39,9 +39,9 @@ black==21.6b0 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.94 +boto3==1.17.100 # via sagemaker-training -botocore==1.20.94 +botocore==1.20.100 # via # boto3 # s3transfer @@ -88,7 +88,7 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.19.2 +flyteidl==0.19.5 # via flytekit git+git://github.com/flyteorg/furo@main # via -r doc-requirements.in @@ -96,7 +96,7 @@ gevent==21.1.2 # via sagemaker-training greenlet==1.1.0 # via gevent -grpcio==1.38.0 +grpcio==1.38.1 # via # -r doc-requirements.in # flytekit @@ -112,12 +112,12 @@ inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.5 # via flytekit +ipython==7.24.1 + # via ipykernel ipython-genutils==0.2.0 # via # nbformat # traitlets -ipython==7.24.1 - # via ipykernel jedi==0.18.0 # via ipython jinja2==3.0.1 @@ -152,15 +152,15 @@ lxml==4.6.3 # via sphinx-material markupsafe==2.0.1 # via jinja2 -marshmallow-enum==1.5.1 - # via dataclasses-json -marshmallow-jsonschema==0.12.0 - # via flytekit marshmallow==3.12.1 # via # dataclasses-json # marshmallow-enum # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.12.0 + # via flytekit matplotlib-inline==0.1.2 # via ipython mistune==0.8.4 @@ -175,7 +175,7 @@ nbclient==0.5.3 # via # nbconvert # papermill -nbconvert==6.0.7 +nbconvert==6.1.0 # via flytekit nbformat==5.1.3 # via @@ -184,7 +184,7 @@ nbformat==5.1.3 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.20.3 +numpy==1.21.0 # via # flytekit # pandas @@ -195,7 +195,7 @@ packaging==20.9 # via # bleach # sphinx -pandas==1.2.4 +pandas==1.2.5 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -213,7 +213,7 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -prompt-toolkit==3.0.18 +prompt-toolkit==3.0.19 # via ipython protobuf==3.17.3 # via @@ -225,10 +225,10 @@ psutil==5.8.0 # via sagemaker-training ptyprocess==0.7.0 # via pexpect -py4j==0.10.9 - # via pyspark py==1.10.0 # via retry +py4j==0.10.9 + # via pyspark pyarrow==3.0.0 # via flytekit pycparser==2.20 @@ -296,7 +296,7 @@ sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.3 +scipy==1.7.0 # via sagemaker-training six==1.16.0 # via @@ -321,6 +321,17 @@ sortedcontainers==2.4.0 # via flytekit soupsieve==2.2.1 # via beautifulsoup4 +sphinx==3.5.4 + # via + # -r doc-requirements.in + # furo + # sphinx-autoapi + # sphinx-code-include + # sphinx-copybutton + # sphinx-fontawesome + # sphinx-gallery + # sphinx-material + # sphinx-prompt sphinx-autoapi==1.8.1 # via -r doc-requirements.in sphinx-code-include==1.1.1 @@ -335,17 +346,6 @@ sphinx-material==0.0.32 # via -r doc-requirements.in sphinx-prompt==1.4.0 # via -r doc-requirements.in -sphinx==3.5.4 - # via - # -r doc-requirements.in - # furo - # sphinx-autoapi - # sphinx-code-include - # sphinx-copybutton - # sphinx-fontawesome - # sphinx-gallery - # sphinx-material - # sphinx-prompt sphinxcontrib-applehelp==1.0.2 # via sphinx sphinxcontrib-devhelp==1.0.2 diff --git a/flytekit/common/translator.py b/flytekit/common/translator.py index dea8c9f3bc..73ef4c68bd 100644 --- a/flytekit/common/translator.py +++ b/flytekit/common/translator.py @@ -21,6 +21,7 @@ from flytekit.models.core import workflow as _core_wf from flytekit.models.core import workflow as workflow_model from flytekit.models.core.workflow import BranchNode as BranchNodeModel +from flytekit.models.core.workflow import TaskNodeOverrides FlyteLocalEntity = Union[ PythonTask, @@ -272,7 +273,9 @@ def get_serializable_node( inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], - task_node=workflow_model.TaskNode(reference_id=task_spec.template.id), + task_node=workflow_model.TaskNode( + reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) + ), ) if entity._aliases: node_model._output_aliases = entity._aliases diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 15ce1bca7c..24c6505f08 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -1,10 +1,13 @@ from __future__ import annotations +import typing from typing import Any, List from flytekit.common.utils import _dnsify +from flytekit.core.resources import Resources from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.task import Resources as _resources_model class Node(object): @@ -30,6 +33,7 @@ def __init__( self._flyte_entity = flyte_entity self._aliases: _workflow_model.Alias = None self._outputs = None + self._resources: typing.Optional[_resources_model] = None def runs_before(self, other: Node): """ @@ -81,4 +85,33 @@ def with_overrides(self, *args, **kwargs): self._aliases = [] for k, v in alias_dict.items(): self._aliases.append(_workflow_model.Alias(var=k, alias=v)) + if "requests" in kwargs or "limits" in kwargs: + requests = _convert_resource_overrides(kwargs["requests"], "requests") + limits = _convert_resource_overrides(kwargs["limits"], "limits") + self._resources = _resources_model(requests=requests, limits=limits) return self + + +def _convert_resource_overrides( + resources: typing.Optional[Resources], resource_name: str +) -> [_resources_model.ResourceEntry]: + if resources is None: + return [] + if not isinstance(resources, Resources): + raise AssertionError(f"{resource_name} should be specified as flytekit.Resources") + resource_entries = [] + if resources.cpu is not None: + resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu)) + + if resources.mem is not None: + resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.MEMORY, resources.mem)) + + if resources.gpu is not None: + resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.GPU, resources.gpu)) + + if resources.storage is not None: + resource_entries.append( + _resources_model.ResourceEntry(_resources_model.ResourceName.STORAGE, resources.storage) + ) + + return resource_entries diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 9bec4768a8..8dc3ac9a70 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -1,3 +1,5 @@ +import typing + from flyteidl.core import workflow_pb2 as _core_workflow from flytekit.models import common as _common @@ -7,6 +9,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.models.literals import Binding as _Binding from flytekit.models.literals import RetryStrategy as _RetryStrategy +from flytekit.models.task import Resources class IfBlock(_common.FlyteIdlEntity): @@ -370,16 +373,39 @@ def from_flyte_idl(cls, pb2_object): ) +class TaskNodeOverrides(_common.FlyteIdlEntity): + def __init__(self, resources: typing.Optional[Resources] = None): + self._resources = resources + + @property + def resources(self) -> Resources: + return self._resources + + def to_flyte_idl(self): + return _core_workflow.TaskNodeOverrides( + resources=self.resources.to_flyte_idl() if self.resources is not None else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + resources = Resources.from_flyte_idl(pb2_object.resources) + if bool(resources.requests) or bool(resources.limits): + return cls(resources=resources) + return cls(resources=None) + + class TaskNode(_common.FlyteIdlEntity): - def __init__(self, reference_id): + def __init__(self, reference_id, overrides: typing.Optional[TaskNodeOverrides] = None): """ Refers to the task that the Node is to execute. NB: This is currently a oneof in protobuf, but there's only one option currently. This code should be updated when more options are available. :param flytekit.models.core.identifier.Identifier reference_id: A globally unique identifier for the task. + :param flyteidl.core.workflow_pb2.TaskNodeOverrides """ self._reference_id = reference_id + self._overrides = overrides @property def reference_id(self): @@ -389,11 +415,18 @@ def reference_id(self): """ return self._reference_id + @property + def overrides(self) -> TaskNodeOverrides: + return self._overrides + def to_flyte_idl(self): """ :rtype: flyteidl.core.workflow_pb2.TaskNode """ - return _core_workflow.TaskNode(reference_id=self.reference_id.to_flyte_idl()) + return _core_workflow.TaskNode( + reference_id=self.reference_id.to_flyte_idl(), + overrides=self.overrides.to_flyte_idl() if self.overrides is not None else None, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -401,7 +434,13 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.workflow_pb2.TaskNode pb2_object: :rtype: TaskNode """ - return cls(reference_id=_identifier.Identifier.from_flyte_idl(pb2_object.reference_id)) + overrides = TaskNodeOverrides.from_flyte_idl(pb2_object.overrides) + if overrides.resources is None: + overrides = None + return cls( + reference_id=_identifier.Identifier.from_flyte_idl(pb2_object.reference_id), + overrides=overrides, + ) class WorkflowNode(_common.FlyteIdlEntity): diff --git a/requirements-spark2.txt b/requirements-spark2.txt index f28f4cd5af..9c4a827a96 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile +# This file is autogenerated by pip-compile with python 3.8 # To update, run: # # make requirements-spark2.txt @@ -29,9 +29,9 @@ black==21.6b0 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.94 +boto3==1.17.100 # via sagemaker-training -botocore==1.20.94 +botocore==1.20.100 # via # boto3 # s3transfer @@ -72,13 +72,13 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.19.2 +flyteidl==0.19.5 # via flytekit gevent==21.1.2 # via sagemaker-training greenlet==1.1.0 # via gevent -grpcio==1.38.0 +grpcio==1.38.1 # via flytekit hmsclient==0.1.1 # via flytekit @@ -90,12 +90,12 @@ inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.5 # via flytekit +ipython==7.24.1 + # via ipykernel ipython-genutils==0.2.0 # via # nbformat # traitlets -ipython==7.24.1 - # via ipykernel jedi==0.18.0 # via ipython jinja2==3.0.1 @@ -123,15 +123,15 @@ keyring==23.0.1 # via flytekit markupsafe==2.0.1 # via jinja2 -marshmallow-enum==1.5.1 - # via dataclasses-json -marshmallow-jsonschema==0.12.0 - # via flytekit marshmallow==3.12.1 # via # dataclasses-json # marshmallow-enum # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.12.0 + # via flytekit matplotlib-inline==0.1.2 # via ipython mistune==0.8.4 @@ -146,7 +146,7 @@ nbclient==0.5.3 # via # nbconvert # papermill -nbconvert==6.0.7 +nbconvert==6.1.0 # via flytekit nbformat==5.1.3 # via @@ -155,7 +155,7 @@ nbformat==5.1.3 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.20.3 +numpy==1.21.0 # via # flytekit # pandas @@ -164,7 +164,7 @@ numpy==1.20.3 # scipy packaging==20.9 # via bleach -pandas==1.2.4 +pandas==1.2.5 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -182,7 +182,7 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -prompt-toolkit==3.0.18 +prompt-toolkit==3.0.19 # via ipython protobuf==3.17.3 # via @@ -194,10 +194,10 @@ psutil==5.8.0 # via sagemaker-training ptyprocess==0.7.0 # via pexpect -py4j==0.10.7 - # via pyspark py==1.10.0 # via retry +py4j==0.10.7 + # via pyspark pyarrow==3.0.0 # via flytekit pycparser==2.20 @@ -255,7 +255,7 @@ sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.3 +scipy==1.7.0 # via sagemaker-training six==1.16.0 # via diff --git a/requirements.txt b/requirements.txt index 9a2d2ad643..a210c489bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile +# This file is autogenerated by pip-compile with python 3.8 # To update, run: # # make requirements.txt @@ -29,9 +29,9 @@ black==21.6b0 # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.94 +boto3==1.17.100 # via sagemaker-training -botocore==1.20.94 +botocore==1.20.100 # via # boto3 # s3transfer @@ -72,13 +72,13 @@ entrypoints==0.3 # via # nbconvert # papermill -flyteidl==0.19.2 +flyteidl==0.19.5 # via flytekit gevent==21.1.2 # via sagemaker-training greenlet==1.1.0 # via gevent -grpcio==1.38.0 +grpcio==1.38.1 # via flytekit hmsclient==0.1.1 # via flytekit @@ -90,12 +90,12 @@ inotify_simple==1.2.1 # via sagemaker-training ipykernel==5.5.5 # via flytekit +ipython==7.24.1 + # via ipykernel ipython-genutils==0.2.0 # via # nbformat # traitlets -ipython==7.24.1 - # via ipykernel jedi==0.18.0 # via ipython jinja2==3.0.1 @@ -123,15 +123,15 @@ keyring==23.0.1 # via flytekit markupsafe==2.0.1 # via jinja2 -marshmallow-enum==1.5.1 - # via dataclasses-json -marshmallow-jsonschema==0.12.0 - # via flytekit marshmallow==3.12.1 # via # dataclasses-json # marshmallow-enum # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.12.0 + # via flytekit matplotlib-inline==0.1.2 # via ipython mistune==0.8.4 @@ -146,7 +146,7 @@ nbclient==0.5.3 # via # nbconvert # papermill -nbconvert==6.0.7 +nbconvert==6.1.0 # via flytekit nbformat==5.1.3 # via @@ -155,7 +155,7 @@ nbformat==5.1.3 # papermill nest-asyncio==1.5.1 # via nbclient -numpy==1.20.3 +numpy==1.21.0 # via # flytekit # pandas @@ -164,7 +164,7 @@ numpy==1.20.3 # scipy packaging==20.9 # via bleach -pandas==1.2.4 +pandas==1.2.5 # via flytekit pandocfilters==1.4.3 # via nbconvert @@ -182,7 +182,7 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -prompt-toolkit==3.0.18 +prompt-toolkit==3.0.19 # via ipython protobuf==3.17.3 # via @@ -194,10 +194,10 @@ psutil==5.8.0 # via sagemaker-training ptyprocess==0.7.0 # via pexpect -py4j==0.10.9 - # via pyspark py==1.10.0 # via retry +py4j==0.10.9 + # via pyspark pyarrow==3.0.0 # via flytekit pycparser==2.20 @@ -255,7 +255,7 @@ sagemaker-training==3.9.2 # via flytekit scantree==0.0.1 # via dirhash -scipy==1.6.3 +scipy==1.7.0 # via sagemaker-training six==1.16.0 # via diff --git a/setup.py b/setup.py index 69647a491a..5419fc240d 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ ] }, install_requires=[ - "flyteidl>=0.19.2,<1.0.0", + "flyteidl>=0.19.5,<1.0.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=2.0.0,<4.0.0", diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index af2fa0ebfc..a6e1c45a4f 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -3,6 +3,7 @@ import pytest +from flytekit import Resources, map_task from flytekit.common.exceptions.user import FlyteAssertion from flytekit.common.translator import get_serializable from flytekit.core import context_manager @@ -11,6 +12,7 @@ from flytekit.core.node_creation import create_node from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.models.task import Resources as _resources_models def test_normal_task(): @@ -162,3 +164,37 @@ def my_wf(a: int, b: str) -> (str, typing.List[str], int): return t2_node.o0, subwf_node.o0, subwf_node.o1 my_wf(a=5, b="hello") + + +def test_resource_overrides(): + @task + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: typing.List[str]) -> typing.List[str]: + mappy = map_task(t1) + map_node = create_node(mappy, a=a).with_overrides( + requests=Resources(cpu="1", mem="100"), limits=Resources(cpu="2", mem="200") + ) + return map_node.o0 + + serialization_settings = context_manager.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides is not None + assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [ + _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), + _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"), + ] + + assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [ + _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"), + _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"), + ] diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index 2cc2e51420..de83f66f78 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -6,6 +6,7 @@ from flytekit.models.core import condition as _condition from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow +from flytekit.models.task import Resources _generic_id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version") @@ -225,3 +226,35 @@ def test_branch_node(): bn2 = _workflow.BranchNode.from_flyte_idl(bn.to_flyte_idl()) assert bn == bn2 assert bn.if_else.case.then_node == obj + + +def test_task_node_overrides(): + overrides = _workflow.TaskNodeOverrides( + Resources( + requests=[Resources.ResourceEntry(Resources.ResourceName.CPU, "1")], + limits=[Resources.ResourceEntry(Resources.ResourceName.CPU, "2")], + ) + ) + assert overrides.resources.requests == [Resources.ResourceEntry(Resources.ResourceName.CPU, "1")] + assert overrides.resources.limits == [Resources.ResourceEntry(Resources.ResourceName.CPU, "2")] + + obj = _workflow.TaskNodeOverrides.from_flyte_idl(overrides.to_flyte_idl()) + assert overrides == obj + + +def test_task_node_with_overrides(): + task_node = _workflow.TaskNode( + reference_id=_generic_id, + overrides=_workflow.TaskNodeOverrides( + Resources( + requests=[Resources.ResourceEntry(Resources.ResourceName.CPU, "1")], + limits=[Resources.ResourceEntry(Resources.ResourceName.CPU, "2")], + ) + ), + ) + + assert task_node.overrides.resources.requests == [Resources.ResourceEntry(Resources.ResourceName.CPU, "1")] + assert task_node.overrides.resources.limits == [Resources.ResourceEntry(Resources.ResourceName.CPU, "2")] + + obj = _workflow.TaskNode.from_flyte_idl(task_node.to_flyte_idl()) + assert task_node == obj