Skip to content

Commit

Permalink
Add support for env vars to pyflyte run (#1617)
Browse files Browse the repository at this point in the history
* Add support for env vars to pyflyte run

Signed-off-by: Kevin Su <[email protected]>

* bump idl

Signed-off-by: Kevin Su <[email protected]>

* update doc

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored May 15, 2023
1 parent 8ef79e5 commit b410108
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 45 deletions.
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ flask==2.2.3
# via mlflow
flatbuffers==23.1.21
# via tensorflow
flyteidl==1.3.16
flyteidl==1.5.4
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
17 changes: 16 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,25 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
default=False,
help="Whether to dump a code snippet instructing how to load the workflow execution using flyteremote",
),
click.Option(
param_decls=["--overwrite-cache", "overwrite_cache"],
required=False,
is_flag=True,
default=False,
help="Whether to overwrite the cache if it already exists",
),
click.Option(
param_decls=["--envs", "envs"],
required=False,
type=JsonParamType(),
help="Environment variables to set in the container",
),
]


def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]:
"""
Load the workflow of a the script file.
Load the workflow of a script file.
N.B.: it assumes that the file is self-contained, in other words, there are no relative imports.
"""
flyte_ctx_builder = context_manager.FlyteContextManager.current_context().new_builder()
Expand Down Expand Up @@ -670,6 +683,8 @@ def _run(*args, **kwargs):
wait=run_level_params.get("wait_execution"),
options=options,
type_hints=entity.python_interface.inputs,
overwrite_cache=run_level_params.get("overwrite_cache"),
envs=run_level_params.get("envs"),
)

console_url = remote.generate_console_url(execution)
Expand Down
18 changes: 18 additions & 0 deletions flytekit/models/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc as _abc
import json as _json
import re
from typing import Dict

from flyteidl.admin import common_pb2 as _common_pb2
from flyteidl.core import literals_pb2 as _literals_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct

Expand Down Expand Up @@ -485,3 +487,19 @@ def to_flyte_idl(self):
@classmethod
def from_flyte_idl(cls, pb2):
return cls(output_location_prefix=pb2.output_location_prefix)


class Envs(FlyteIdlEntity):
def __init__(self, envs: Dict[str, str]):
self._envs = envs

@property
def envs(self) -> Dict[str, str]:
return self._envs

def to_flyte_idl(self) -> _common_pb2.Envs:
return _common_pb2.Envs(values=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in self.envs.items()])

@classmethod
def from_flyte_idl(cls, pb2: _common_pb2.Envs) -> _common_pb2.Envs:
return cls(envs={kv.key: kv.value for kv in pb2.values})
35 changes: 25 additions & 10 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import typing
from typing import Optional

import flyteidl
import flyteidl.admin.execution_pb2 as _execution_pb2
Expand Down Expand Up @@ -47,10 +48,10 @@ def __init__(
mode: int,
principal: str,
nesting: int,
scheduled_at: typing.Optional[datetime.datetime] = None,
parent_node_execution: typing.Optional[_identifier.NodeExecutionIdentifier] = None,
reference_execution: typing.Optional[_identifier.WorkflowExecutionIdentifier] = None,
system_metadata: typing.Optional[SystemMetadata] = None,
scheduled_at: Optional[datetime.datetime] = None,
parent_node_execution: Optional[_identifier.NodeExecutionIdentifier] = None,
reference_execution: Optional[_identifier.WorkflowExecutionIdentifier] = None,
system_metadata: Optional[SystemMetadata] = None,
):
"""
:param mode: An enum value from ExecutionMetadata.ExecutionMode which specifies how the job started.
Expand Down Expand Up @@ -173,9 +174,10 @@ def __init__(
annotations=None,
auth_role=None,
raw_output_data_config=None,
max_parallelism=None,
security_context: typing.Optional[security.SecurityContext] = None,
overwrite_cache: bool = None,
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
overwrite_cache: Optional[bool] = None,
envs: Optional[_common_models.Envs] = None,
):
"""
:param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute
Expand All @@ -189,7 +191,9 @@ def __init__(
:param max_parallelism int: Controls the maximum number of tasknodes that can be run in parallel for the entire
workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
parallelism/concurrency of MapTasks is independent from this.
:param security_context: Optional security context to use for this execution.
:param overwrite_cache: Optional flag to overwrite the cache for this execution.
:param envs: flytekit.models.common.Envs environment variables to set for this execution.
"""
self._launch_plan = launch_plan
self._metadata = metadata
Expand All @@ -201,7 +205,8 @@ def __init__(
self._raw_output_data_config = raw_output_data_config
self._max_parallelism = max_parallelism
self._security_context = security_context
self.overwrite_cache = overwrite_cache
self._overwrite_cache = overwrite_cache
self._envs = envs

@property
def launch_plan(self):
Expand Down Expand Up @@ -268,6 +273,14 @@ def max_parallelism(self) -> int:
def security_context(self) -> typing.Optional[security.SecurityContext]:
return self._security_context

@property
def overwrite_cache(self) -> Optional[bool]:
return self._overwrite_cache

@property
def envs(self) -> Optional[_common_models.Envs]:
return self._envs

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.execution_pb2.ExecutionSpec
Expand All @@ -276,7 +289,7 @@ def to_flyte_idl(self):
launch_plan=self.launch_plan.to_flyte_idl(),
metadata=self.metadata.to_flyte_idl(),
notifications=self.notifications.to_flyte_idl() if self.notifications else None,
disable_all=self.disable_all,
disable_all=self.disable_all, # type: ignore
labels=self.labels.to_flyte_idl(),
annotations=self.annotations.to_flyte_idl(),
auth_role=self._auth_role.to_flyte_idl() if self.auth_role else None,
Expand All @@ -286,6 +299,7 @@ def to_flyte_idl(self):
max_parallelism=self.max_parallelism,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
overwrite_cache=self.overwrite_cache,
envs=self.envs.to_flyte_idl() if self.envs else None,
)

@classmethod
Expand All @@ -310,6 +324,7 @@ def from_flyte_idl(cls, p):
if p.security_context
else None,
overwrite_cache=p.overwrite_cache,
envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None,
)


Expand Down
Loading

0 comments on commit b410108

Please sign in to comment.