Skip to content

Commit

Permalink
add requests and limits parameter to ContainerTask (#438)
Browse files Browse the repository at this point in the history
* add requests and limits parameter

Signed-off-by: Miguel Toledo <[email protected]>

* fix typo

Signed-off-by: Miguel Toledo <[email protected]>

* signoff

Signed-off-by: Miguel Toledo <[email protected]>
  • Loading branch information
migueltol22 authored Mar 29, 2021
1 parent 12cae6a commit 36536b8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
14 changes: 14 additions & 0 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.context_manager import SerializationSettings
from flytekit.core.interface import Interface
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.models import task as _task_model


Expand All @@ -31,6 +32,8 @@ def __init__(
metadata: Optional[TaskMetadata] = None,
arguments: List[str] = None,
outputs: Dict[str, Type] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
input_data_dir: str = None,
output_data_dir: str = None,
metadata_format: MetadataFormat = MetadataFormat.JSON,
Expand All @@ -52,6 +55,13 @@ def __init__(
self._output_data_dir = output_data_dir
self._md_format = metadata_format
self._io_strategy = io_strategy
self._resources = ResourceSpec(
requests=requests if requests else Resources(), limits=limits if limits else Resources()
)

@property
def resources(self) -> ResourceSpec:
return self._resources

def execute(self, **kwargs) -> Any:
print(kwargs)
Expand All @@ -78,4 +88,8 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
io_strategy=self._io_strategy.value if self._io_strategy else None,
),
environment=env,
cpu_request=self.resources.requests.cpu,
cpu_limit=self.resources.limits.cpu,
memory_request=self.resources.requests.mem,
memory_limit=self.resources.limits.mem,
)
3 changes: 2 additions & 1 deletion tests/flytekit/unit/common_tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing
from collections import OrderedDict

from flytekit import ContainerTask
from flytekit import ContainerTask, Resources
from flytekit.common.translator import get_serializable
from flytekit.core import context_manager
from flytekit.core.base_task import kwtypes
Expand Down Expand Up @@ -92,6 +92,7 @@ def t1(a: int) -> (int, str):
output_data_dir="/tmp",
command=["cat"],
arguments=["/tmp/a"],
requests=Resources(mem="400Mi", cpu="1"),
)

sdk_task = get_serializable(OrderedDict(), serialization_settings, t2, fast=True)
Expand Down

0 comments on commit 36536b8

Please sign in to comment.