Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlyteRemote data context does not get used #993

Merged
merged 12 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,10 @@ class LiteralsResolver(collections.UserDict):
"""

def __init__(
self, literals: typing.Dict[str, Literal], variable_map: Optional[Dict[str, _interface_models.Variable]] = None
self,
literals: typing.Dict[str, Literal],
variable_map: Optional[Dict[str, _interface_models.Variable]] = None,
ctx: Optional[FlyteContext] = None,
):
"""
:param literals: A Python map of strings to Flyte Literal models.
Expand All @@ -1457,6 +1460,7 @@ def __init__(
self._variable_map = variable_map
self._native_values = {}
self._type_hints = {}
self._ctx = ctx

def __str__(self) -> str:
if len(self._literals) == len(self._native_values):
Expand Down Expand Up @@ -1535,7 +1539,7 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any:
raise e
else:
ValueError("as_type argument not supplied and Variable map not specified in LiteralsResolver")
val = TypeEngine.to_python_value(FlyteContext.current_context(), self._literals[attr], as_type)
val = TypeEngine.to_python_value(self._ctx or FlyteContext.current_context(), self._literals[attr], as_type)
self._native_values[attr] = val
return val

Expand Down
33 changes: 16 additions & 17 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
from datetime import datetime, timedelta

from flyteidl.core import literals_pb2 as literals_pb2

from flytekit import Literal
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.core import constants, context_manager, tracker, utils
from flytekit.core import constants, tracker, utils
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_auto_container import PythonAutoContainerTask
Expand Down Expand Up @@ -144,10 +143,12 @@ def __init__(
data_config=config.data_config,
)

# Save the file access object locally, but also make it available for use from the context.
FlyteContextManager.with_context(
FlyteContextManager.current_context().with_file_access(self._file_access).build()
)
# Save the file access object locally, build a context for it and save that as well.
self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build()

@property
def context(self) -> FlyteContext:
return self._ctx

@property
def client(self) -> SynchronousFlyteClient:
Expand Down Expand Up @@ -426,7 +427,7 @@ def _serialize_and_register(
remote_logger.info(f"{entity.name} already exists")
# Let us also create a default launch-plan, ideally the default launchplan should be added
# to the orderedDict, but we do not.
default_lp = LaunchPlan.get_default_launch_plan(FlyteContextManager.current_context(), entity)
default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
lp_entity = get_serializable_launch_plan(
OrderedDict(),
settings or serialization_settings,
Expand Down Expand Up @@ -495,7 +496,7 @@ def register_workflow(
serialization_settings = b.build()
ident = self._serialize_and_register(entity, serialization_settings, version, options)
if default_launch_plan:
default_lp = LaunchPlan.get_default_launch_plan(FlyteContextManager.current_context(), entity)
default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
self.register_launch_plan(
default_lp, version=ident.version, project=ident.project, domain=ident.domain, options=options
)
Expand Down Expand Up @@ -999,12 +1000,11 @@ def execute_local_workflow(
raise ValueError("Need image config since we are registering")
self.register_workflow(entity, ss, version=version, options=options)

ctx = context_manager.FlyteContext.current_context()
try:
flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
remote_logger.info("Try to register default launch plan because it wasn't found in Flyte Admin!")
default_lp = LaunchPlan.get_default_launch_plan(ctx, entity)
default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
self.register_launch_plan(
default_lp,
project=resolved_identifiers.project,
Expand Down Expand Up @@ -1378,13 +1378,12 @@ def _assign_inputs_and_outputs(
interface: TypedInterface,
):
"""Helper for assigning synced inputs and outputs to an execution object."""
with self.remote_context():
input_literal_map = self._get_input_literal_map(execution_data)
execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs)
input_literal_map = self._get_input_literal_map(execution_data)
execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs, self.context)

if execution.is_done and not execution.error:
output_literal_map = self._get_output_literal_map(execution_data)
execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs)
if execution.is_done and not execution.error:
output_literal_map = self._get_output_literal_map(execution_data)
execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs, self.context)
return execution

def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap:
Expand Down
3 changes: 2 additions & 1 deletion tests/flytekit/unit/core/test_literals_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def test_interface():
"my_df": interface_models.Variable(type=df_literal_type, description=""),
}

lr = LiteralsResolver(lm, variable_map=variable_map)
lr = LiteralsResolver(lm, variable_map=variable_map, ctx=ctx)
assert lr._ctx is ctx

with pytest.raises(ValueError):
lr["not"] # noqa
Expand Down