diff --git a/flytekit/annotated/python_function_task.py b/flytekit/annotated/python_function_task.py index 48431a8ac3..8a7cf4a088 100644 --- a/flytekit/annotated/python_function_task.py +++ b/flytekit/annotated/python_function_task.py @@ -79,7 +79,7 @@ def __init__( :param task_type: String task type to be associated with this Task :param container_image: String FQN for the image. :param Resources requests: custom resource request settings. - :param Resources requests: custom resource limit settings. + :param Resources limits: custom resource limit settings. """ super().__init__( task_type=task_type, name=name, task_config=task_config, **kwargs, @@ -87,7 +87,7 @@ def __init__( self._container_image = container_image # TODO(katrogan): Implement resource overrides self._resources = ResourceSpec( - requests=requests if requests else Resources(), limits=limits if requests else Resources() + requests=requests if requests else Resources(), limits=limits if limits else Resources() ) self._environment = environment diff --git a/tests/flytekit/unit/annotated/test_type_hints.py b/tests/flytekit/unit/annotated/test_type_hints.py index bf6c7d88af..ac692b1bcc 100644 --- a/tests/flytekit/unit/annotated/test_type_hints.py +++ b/tests/flytekit/unit/annotated/test_type_hints.py @@ -940,6 +940,11 @@ def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) + @task(requests=Resources(cpu="3")) + def t2(a: int) -> str: + a = a + 200 + return "now it's " + str(a) + @workflow def my_wf(a: int) -> str: x = t1(a=a) @@ -964,6 +969,12 @@ def my_wf(a: int) -> str: _resource_models.ResourceEntry(_resource_models.ResourceName.MEMORY, "400M"), ] + sdk_task2 = t2.get_registerable_entity() + assert sdk_task2.container.resources.requests == [ + _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3") + ] + assert sdk_task2.container.resources.limits == [] + def test_wf_explicitly_returning_empty_task(): @task diff --git a/tests/flytekit/unit/taskplugins/pytorch/test_pytorch_task.py b/tests/flytekit/unit/taskplugins/pytorch/test_pytorch_task.py index ab0010bf1e..447309dbe2 100644 --- a/tests/flytekit/unit/taskplugins/pytorch/test_pytorch_task.py +++ b/tests/flytekit/unit/taskplugins/pytorch/test_pytorch_task.py @@ -23,6 +23,6 @@ def my_pytorch_task(x: int, y: str) -> int: ) assert my_pytorch_task.get_custom(reg) == {"workers": 10} - assert my_pytorch_task.resources.limits is None + assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") assert my_pytorch_task.task_type == "pytorch" diff --git a/tests/flytekit/unit/taskplugins/tensorflow/test_tensorflow_task.py b/tests/flytekit/unit/taskplugins/tensorflow/test_tensorflow_task.py index 74b9b77386..d15a37adcd 100644 --- a/tests/flytekit/unit/taskplugins/tensorflow/test_tensorflow_task.py +++ b/tests/flytekit/unit/taskplugins/tensorflow/test_tensorflow_task.py @@ -29,6 +29,6 @@ def my_tensorflow_task(x: int, y: str) -> int: ) assert my_tensorflow_task.get_custom(reg) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1} - assert my_tensorflow_task.resources.limits is None + assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1") assert my_tensorflow_task.task_type == "tensorflow"