Skip to content

Commit

Permalink
FlyteRemote data context does not get used (#993)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored May 12, 2022
1 parent 0fd7601 commit cc2a4e7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
8 changes: 6 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,10 @@ class LiteralsResolver(collections.UserDict):
"""

def __init__(
self, literals: typing.Dict[str, Literal], variable_map: Optional[Dict[str, _interface_models.Variable]] = None
self,
literals: typing.Dict[str, Literal],
variable_map: Optional[Dict[str, _interface_models.Variable]] = None,
ctx: Optional[FlyteContext] = None,
):
"""
:param literals: A Python map of strings to Flyte Literal models.
Expand All @@ -1457,6 +1460,7 @@ def __init__(
self._variable_map = variable_map
self._native_values = {}
self._type_hints = {}
self._ctx = ctx

def __str__(self) -> str:
if len(self._literals) == len(self._native_values):
Expand Down Expand Up @@ -1535,7 +1539,7 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any:
raise e
else:
ValueError("as_type argument not supplied and Variable map not specified in LiteralsResolver")
val = TypeEngine.to_python_value(FlyteContext.current_context(), self._literals[attr], as_type)
val = TypeEngine.to_python_value(self._ctx or FlyteContext.current_context(), self._literals[attr], as_type)
self._native_values[attr] = val
return val

Expand Down
32 changes: 16 additions & 16 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.core import constants, context_manager, tracker, utils
from flytekit.core import constants, tracker, utils
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_auto_container import PythonAutoContainerTask
Expand Down Expand Up @@ -144,10 +144,12 @@ def __init__(
data_config=config.data_config,
)

# Save the file access object locally, but also make it available for use from the context.
FlyteContextManager.with_context(
FlyteContextManager.current_context().with_file_access(self._file_access).build()
)
# Save the file access object locally, build a context for it and save that as well.
self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build()

@property
def context(self) -> FlyteContext:
return self._ctx

@property
def client(self) -> SynchronousFlyteClient:
Expand Down Expand Up @@ -426,7 +428,7 @@ def _serialize_and_register(
remote_logger.info(f"{entity.name} already exists")
# Let us also create a default launch-plan, ideally the default launchplan should be added
# to the orderedDict, but we do not.
default_lp = LaunchPlan.get_default_launch_plan(FlyteContextManager.current_context(), entity)
default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
lp_entity = get_serializable_launch_plan(
OrderedDict(),
settings or serialization_settings,
Expand Down Expand Up @@ -495,7 +497,7 @@ def register_workflow(
serialization_settings = b.build()
ident = self._serialize_and_register(entity, serialization_settings, version, options)
if default_launch_plan:
default_lp = LaunchPlan.get_default_launch_plan(FlyteContextManager.current_context(), entity)
default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
self.register_launch_plan(
default_lp, version=ident.version, project=ident.project, domain=ident.domain, options=options
)
Expand Down Expand Up @@ -999,12 +1001,11 @@ def execute_local_workflow(
raise ValueError("Need image config since we are registering")
self.register_workflow(entity, ss, version=version, options=options)

ctx = context_manager.FlyteContext.current_context()
try:
flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
remote_logger.info("Try to register default launch plan because it wasn't found in Flyte Admin!")
default_lp = LaunchPlan.get_default_launch_plan(ctx, entity)
default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
self.register_launch_plan(
default_lp,
project=resolved_identifiers.project,
Expand Down Expand Up @@ -1378,13 +1379,12 @@ def _assign_inputs_and_outputs(
interface: TypedInterface,
):
"""Helper for assigning synced inputs and outputs to an execution object."""
with self.remote_context():
input_literal_map = self._get_input_literal_map(execution_data)
execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs)
input_literal_map = self._get_input_literal_map(execution_data)
execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs, self.context)

if execution.is_done and not execution.error:
output_literal_map = self._get_output_literal_map(execution_data)
execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs)
if execution.is_done and not execution.error:
output_literal_map = self._get_output_literal_map(execution_data)
execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs, self.context)
return execution

def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap:
Expand Down
3 changes: 2 additions & 1 deletion tests/flytekit/unit/core/test_literals_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def test_interface():
"my_df": interface_models.Variable(type=df_literal_type, description=""),
}

lr = LiteralsResolver(lm, variable_map=variable_map)
lr = LiteralsResolver(lm, variable_map=variable_map, ctx=ctx)
assert lr._ctx is ctx

with pytest.raises(ValueError):
lr["not"] # noqa
Expand Down

0 comments on commit cc2a4e7

Please sign in to comment.