diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index bf549860586..13d1a6e7316 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -5,6 +5,7 @@ from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.context_manager import SerializationSettings from flytekit.core.interface import Interface +from flytekit.core.resources import Resources, ResourceSpec from flytekit.models import task as _task_model @@ -31,6 +32,8 @@ def __init__( metadata: Optional[TaskMetadata] = None, arguments: List[str] = None, outputs: Dict[str, Type] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, input_data_dir: str = None, output_data_dir: str = None, metadata_format: MetadataFormat = MetadataFormat.JSON, @@ -52,6 +55,13 @@ def __init__( self._output_data_dir = output_data_dir self._md_format = metadata_format self._io_strategy = io_strategy + self._resources = ResourceSpec( + requests=requests if requests else Resources(), limits=limits if limits else Resources() + ) + + @property + def resources(self) -> ResourceSpec: + return self._resources def execute(self, **kwargs) -> Any: print(kwargs) @@ -78,4 +88,8 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe io_strategy=self._io_strategy.value if self._io_strategy else None, ), environment=env, + cpu_request=self.resources.requests.cpu, + cpu_limit=self.resources.limits.cpu, + memory_request=self.resources.requests.mem, + memory_limit=self.resources.limits.mem, ) diff --git a/tests/flytekit/unit/common_tests/test_translator.py b/tests/flytekit/unit/common_tests/test_translator.py index 70927fbfdfc..7dec6bbe8a6 100644 --- a/tests/flytekit/unit/common_tests/test_translator.py +++ b/tests/flytekit/unit/common_tests/test_translator.py @@ -1,7 +1,7 @@ import typing from collections import OrderedDict -from flytekit import ContainerTask +from flytekit import ContainerTask, Resources from flytekit.common.translator import get_serializable from flytekit.core import context_manager from flytekit.core.base_task import kwtypes @@ -92,6 +92,7 @@ def t1(a: int) -> (int, str): output_data_dir="/tmp", command=["cat"], arguments=["/tmp/a"], + requests=Resources(mem="400Mi", cpu="1"), ) sdk_task = get_serializable(OrderedDict(), serialization_settings, t2, fast=True)