diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py index 0cb2c8d25c..9e58d8b1d5 100644 --- a/flytekit/core/array_node.py +++ b/flytekit/core/array_node.py @@ -218,8 +218,8 @@ def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode: return self._execution_mode @property - def is_original_sub_node_interface(self) -> bool: - return True + def sub_node_interface_status(self) -> _core_workflow.ArrayNode.SubNodeInterfaceStatus: + return _core_workflow.ArrayNode.SUB_NODE_INTERFACE_ORIGINAL def __call__(self, *args, **kwargs): if not self._bindings: diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 1e2af4b3be..3074dc99ae 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -167,8 +167,8 @@ def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode: return self._execution_mode @property - def is_original_sub_node_interface(self) -> bool: - return False + def sub_node_interface_status(self) -> _core_workflow.ArrayNode.SubNodeInterfaceStatus: + return _core_workflow.ArrayNode.SUB_NODE_INTERFACE_LIST def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]: return self.python_function_task.get_extended_resources(settings) diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index cc1ca2694f..c61841ad24 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -388,7 +388,7 @@ def __init__( min_successes=None, min_success_ratio=None, execution_mode=None, - is_original_sub_node_interface=False, + sub_node_interface_status=False, ) -> None: """ TODO: docstring @@ -399,7 +399,7 @@ def __init__( self._min_successes = min_successes self._min_success_ratio = min_success_ratio self._execution_mode = execution_mode - self._is_original_sub_node_interface = is_original_sub_node_interface + self._sub_node_interface_status = sub_node_interface_status @property def node(self) -> "Node": @@ -412,7 +412,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode: min_successes=self._min_successes, min_success_ratio=self._min_success_ratio, execution_mode=self._execution_mode, - is_original_sub_node_interface=self._is_original_sub_node_interface, + sub_node_interface_status=self._sub_node_interface_status, ) @classmethod diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b156404848..cdcdfaca1b 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -595,7 +595,7 @@ def get_serializable_array_node( min_successes=array_node.min_successes, min_success_ratio=array_node.min_success_ratio, execution_mode=array_node.execution_mode, - is_original_sub_node_interface=array_node.is_original_sub_node_interface, + sub_node_interface_status=array_node.sub_node_interface_status, ) @@ -630,7 +630,7 @@ def get_serializable_array_node_map_task( min_successes=entity.min_successes, min_success_ratio=entity.min_success_ratio, execution_mode=entity.execution_mode, - is_original_sub_node_interface=entity.is_original_sub_node_interface, + sub_node_interface_status=entity.sub_node_interface_status, ) diff --git a/tests/flytekit/unit/core/test_array_node.py b/tests/flytekit/unit/core/test_array_node.py index 719e067b00..71d22789d0 100644 --- a/tests/flytekit/unit/core/test_array_node.py +++ b/tests/flytekit/unit/core/test_array_node.py @@ -2,6 +2,7 @@ from collections import OrderedDict import pytest +from flyteidl.core import workflow_pb2 as _core_workflow from flytekit import LaunchPlan, task, workflow from flytekit.core.context_manager import FlyteContextManager @@ -95,6 +96,10 @@ def test_lp_serialization(target, overrides_metadata, serialization_settings): assert len(wf_spec.template.nodes) == 1 parent_node = wf_spec.template.nodes[0] + assert parent_node.array_node._min_success_ratio == 0.9 + assert parent_node.array_node._parallelism == 10 + assert parent_node.array_node._sub_node_interface_status == _core_workflow.ArrayNode.SUB_NODE_INTERFACE_ORIGINAL + assert parent_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE assert parent_node.inputs[0].var == "a" assert len(parent_node.inputs[0].binding.collection.bindings) == 3 for binding in parent_node.inputs[0].binding.collection.bindings: diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index f025929a13..372b089e62 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -407,7 +407,7 @@ def wf1(x: typing.List[int]): assert array_node.metadata.interruptible assert array_node.array_node._min_success_ratio == 0.9 assert array_node.array_node._parallelism == 10 - assert not array_node.array_node._is_original_sub_node_interface + assert array_node.array_node._sub_node_interface_status == _core_workflow.ArrayNode.SUB_NODE_INTERFACE_LIST assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2