From db9644f7fefaaa910aada21c927e4d1ef7ac8b10 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 20 Aug 2020 11:49:26 -0700 Subject: [PATCH 1/3] full execution inputs and outputs data --- flytekit/engines/flyte/engine.py | 66 +++++-- flytekit/models/execution.py | 30 ++- setup.py | 2 +- .../unit/engines/flyte/test_engine.py | 179 +++++++++++++++++- 4 files changed, 251 insertions(+), 26 deletions(-) diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index ba68fd7de5..458eca15c8 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -433,11 +433,16 @@ def get_inputs(self): :rtype: flytekit.models.literals.LiteralMap """ client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - url_blob = client.get_execution_data(self.sdk_workflow_execution.id) - if url_blob.inputs.bytes > 0: + execution_data = client.get_execution_data(self.sdk_workflow_execution.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_inputs.literals): + return execution_data.full_inputs + + if execution_data.inputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as t: tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(url_blob.inputs.url, tmp_name) + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) return _literals.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) @@ -448,11 +453,16 @@ def get_outputs(self): :rtype: flytekit.models.literals.LiteralMap """ client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - url_blob = client.get_execution_data(self.sdk_workflow_execution.id) - if url_blob.outputs.bytes > 0: + execution_data = client.get_execution_data(self.sdk_workflow_execution.id) + + # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_outputs.literals): + return execution_data.full_outputs + + if execution_data.outputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as t: tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(url_blob.outputs.url, tmp_name) + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) return _literals.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) @@ -486,11 +496,16 @@ def get_inputs(self): :rtype: flytekit.models.literals.LiteralMap """ client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - url_blob = client.get_node_execution_data(self.sdk_node_execution.id) - if url_blob.inputs.bytes > 0: + execution_data = client.get_node_execution_data(self.sdk_node_execution.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_inputs.literals): + return execution_data.full_inputs + + if execution_data.inputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as t: tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(url_blob.inputs.url, tmp_name) + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) return _literals.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) @@ -501,11 +516,16 @@ def get_outputs(self): :rtype: flytekit.models.literals.LiteralMap """ client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - url_blob = client.get_node_execution_data(self.sdk_node_execution.id) - if url_blob.outputs.bytes > 0: + execution_data = client.get_node_execution_data(self.sdk_node_execution.id) + + # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_outputs.literals): + return execution_data.full_outputs + + if execution_data.outputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as t: tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(url_blob.outputs.url, tmp_name) + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) return _literals.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) @@ -525,11 +545,16 @@ def get_inputs(self): :rtype: flytekit.models.literals.LiteralMap """ client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - url_blob = client.get_task_execution_data(self.sdk_task_execution.id) - if url_blob.inputs.bytes > 0: + execution_data = client.get_task_execution_data(self.sdk_task_execution.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_inputs.literals): + return execution_data.full_inputs + + if execution_data.inputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as t: tmp_name = _os.path.join(t.name, "inputs.pb") - _data_proxy.Data.get_data(url_blob.inputs.url, tmp_name) + _data_proxy.Data.get_data(execution_data.inputs.url, tmp_name) return _literals.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) @@ -540,11 +565,16 @@ def get_outputs(self): :rtype: flytekit.models.literals.LiteralMap """ client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client - url_blob = client.get_task_execution_data(self.sdk_task_execution.id) - if url_blob.outputs.bytes > 0: + execution_data = client.get_task_execution_data(self.sdk_task_execution.id) + + # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. + if bool(execution_data.full_outputs.literals): + return execution_data.full_outputs + + if execution_data.outputs.bytes > 0: with _common_utils.AutoDeletingTempDir() as t: tmp_name = _os.path.join(t.name, "outputs.pb") - _data_proxy.Data.get_data(url_blob.outputs.url, tmp_name) + _data_proxy.Data.get_data(execution_data.outputs.url, tmp_name) return _literals.LiteralMap.from_flyte_idl( _common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name) ) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 2f07d37b28..d7163fcc8f 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -6,6 +6,7 @@ import pytz as _pytz from flytekit.models import common as _common_models +from flytekit.models import literals as _literals_models from flytekit.models.core import execution as _core_execution from flytekit.models.core import identifier as _identifier @@ -382,13 +383,17 @@ class _CommonDataResponse(_common_models.FlyteIdlEntity): superclass to reduce code duplication until things diverge in the future. """ - def __init__(self, inputs, outputs): + def __init__(self, inputs, outputs, full_inputs, full_outputs): """ :param _common_models.UrlBlob inputs: :param _common_models.UrlBlob outputs: + :param _literals_pb2.LiteralMap full_inputs: + :param _literals_pb2.LiteralMap full_outputs: """ self._inputs = inputs self._outputs = outputs + self._full_inputs = full_inputs + self._full_outputs = full_outputs @property def inputs(self): @@ -404,6 +409,20 @@ def outputs(self): """ return self._outputs + @property + def full_inputs(self): + """ + :rtype: _literals_pb2.LiteralMap + """ + return self._full_inputs + + @property + def full_outputs(self): + """ + :rtype: _literals_pb2.LiteralMap + """ + return self._full_outputs + class WorkflowExecutionGetDataResponse(_CommonDataResponse): @classmethod @@ -415,6 +434,8 @@ def from_flyte_idl(cls, pb2_object): return cls( inputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.inputs), outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs), + full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs), + full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs), ) def to_flyte_idl(self): @@ -423,6 +444,7 @@ def to_flyte_idl(self): """ return _execution_pb2.WorkflowExecutionGetDataResponse( inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), + full_inputs=self.full_inputs.to_flyte_idl(), full_outputs=self.full_outputs.to_flyte_idl(), ) @@ -436,6 +458,8 @@ def from_flyte_idl(cls, pb2_object): return cls( inputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.inputs), outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs), + full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs), + full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs), ) def to_flyte_idl(self): @@ -444,6 +468,7 @@ def to_flyte_idl(self): """ return _task_execution_pb2.TaskExecutionGetDataResponse( inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), + full_inputs=self.full_inputs.to_flyte_idl(), full_outputs=self.full_outputs.to_flyte_idl(), ) @@ -457,6 +482,8 @@ def from_flyte_idl(cls, pb2_object): return cls( inputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.inputs), outputs=_common_models.UrlBlob.from_flyte_idl(pb2_object.outputs), + full_inputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_inputs), + full_outputs=_literals_models.LiteralMap.from_flyte_idl(pb2_object.full_outputs), ) def to_flyte_idl(self): @@ -465,4 +492,5 @@ def to_flyte_idl(self): """ return _node_execution_pb2.NodeExecutionGetDataResponse( inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), + full_inputs=self.full_inputs.to_flyte_idl(), full_outputs=self.full_outputs.to_flyte_idl(), ) diff --git a/setup.py b/setup.py index cb19deaa8d..be2200642f 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ ] }, install_requires=[ - "flyteidl>=0.18.1,<1.0.0", + "flyteidl>=0.18.2,<1.0.0", "click>=6.6,<8.0", "croniter>=0.3.20,<4.0.0", "deprecated>=1.0,<2.0", diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py index e40285e82b..8cdeb46e8c 100644 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ b/tests/flytekit/unit/engines/flyte/test_engine.py @@ -25,6 +25,7 @@ _OUTPUT_MAP = literals.LiteralMap( {"b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2)))} ) +_EMPTY_LITERAL_MAP = literals.LiteralMap(literals={}) @pytest.fixture(scope="function", autouse=True) @@ -251,12 +252,33 @@ def test_fetch_active_launch_plan(mock_client_factory): mock_client.get_active_launch_plan.assert_called_once_with(_common_models.NamedEntityIdentifier("p", "d", "n")) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) +def test_get_full_execution_inputs(mock_client_factory): + mock_client = MagicMock() + mock_client.get_execution_data = MagicMock( + return_value=_execution_models.WorkflowExecutionGetDataResponse( + None, None, _INPUT_MAP, _OUTPUT_MAP, + ) + ) + mock_client_factory.return_value = mock_client + + m = MagicMock() + type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + + inputs = engine.FlyteWorkflowExecution(m).get_inputs() + assert len(inputs.literals) == 1 + assert inputs.literals["a"].scalar.primitive.integer == 1 + mock_client.get_execution_data.assert_called_once_with( + identifier.WorkflowExecutionIdentifier("project", "domain", "name") + ) + + @patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( return_value=_execution_models.WorkflowExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP ) ) mock_client_factory.return_value = mock_client @@ -272,12 +294,33 @@ def test_get_execution_inputs(mock_client_factory, execution_data_locations): ) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) +def test_get_full_execution_outputs(mock_client_factory): + mock_client = MagicMock() + mock_client.get_execution_data = MagicMock( + return_value=_execution_models.WorkflowExecutionGetDataResponse( + None, None, _INPUT_MAP, _OUTPUT_MAP + ) + ) + mock_client_factory.return_value = mock_client + + m = MagicMock() + type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + + outputs = engine.FlyteWorkflowExecution(m).get_outputs() + assert len(outputs.literals) == 1 + assert outputs.literals["b"].scalar.primitive.integer == 2 + mock_client.get_execution_data.assert_called_once_with( + identifier.WorkflowExecutionIdentifier("project", "domain", "name") + ) + + @patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( return_value=_execution_models.WorkflowExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP ) ) mock_client_factory.return_value = mock_client @@ -293,12 +336,39 @@ def test_get_execution_outputs(mock_client_factory, execution_data_locations): ) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) +def test_get_full_node_execution_inputs(mock_client_factory): + mock_client = MagicMock() + mock_client.get_node_execution_data = MagicMock( + return_value=_execution_models.NodeExecutionGetDataResponse( + None, None, _INPUT_MAP, _OUTPUT_MAP, + ) + ) + mock_client_factory.return_value = mock_client + + m = MagicMock() + type(m).id = PropertyMock( + return_value=identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ) + ) + + inputs = engine.FlyteNodeExecution(m).get_inputs() + assert len(inputs.literals) == 1 + assert inputs.literals["a"].scalar.primitive.integer == 1 + mock_client.get_node_execution_data.assert_called_once_with( + identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ) + ) + + @patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_node_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( return_value=_execution_models.NodeExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP ) ) mock_client_factory.return_value = mock_client @@ -320,12 +390,39 @@ def test_get_node_execution_inputs(mock_client_factory, execution_data_locations ) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) +def test_get_full_node_execution_outputs(mock_client_factory): + mock_client = MagicMock() + mock_client.get_node_execution_data = MagicMock( + return_value=_execution_models.NodeExecutionGetDataResponse( + None, None, _INPUT_MAP, _OUTPUT_MAP + ) + ) + mock_client_factory.return_value = mock_client + + m = MagicMock() + type(m).id = PropertyMock( + return_value=identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ) + ) + + outputs = engine.FlyteNodeExecution(m).get_outputs() + assert len(outputs.literals) == 1 + assert outputs.literals["b"].scalar.primitive.integer == 2 + mock_client.get_node_execution_data.assert_called_once_with( + identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ) + ) + + @patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_node_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( return_value=_execution_models.NodeExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP ) ) mock_client_factory.return_value = mock_client @@ -347,12 +444,47 @@ def test_get_node_execution_outputs(mock_client_factory, execution_data_location ) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) +def test_get_full_task_execution_inputs(mock_client_factory): + mock_client = MagicMock() + mock_client.get_task_execution_data = MagicMock( + return_value=_execution_models.TaskExecutionGetDataResponse( + None, None, _INPUT_MAP, _OUTPUT_MAP + ) + ) + mock_client_factory.return_value = mock_client + + m = MagicMock() + type(m).id = PropertyMock( + return_value=identifier.TaskExecutionIdentifier( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ), + 0, + ) + ) + + inputs = engine.FlyteTaskExecution(m).get_inputs() + assert len(inputs.literals) == 1 + assert inputs.literals["a"].scalar.primitive.integer == 1 + mock_client.get_task_execution_data.assert_called_once_with( + identifier.TaskExecutionIdentifier( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ), + 0, + ) + ) + + @patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_task_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP ) ) mock_client_factory.return_value = mock_client @@ -382,12 +514,47 @@ def test_get_task_execution_inputs(mock_client_factory, execution_data_locations ) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) +def test_get_full_task_execution_outputs(mock_client_factory): + mock_client = MagicMock() + mock_client.get_task_execution_data = MagicMock( + return_value=_execution_models.TaskExecutionGetDataResponse( + None, None, _INPUT_MAP, _OUTPUT_MAP + ) + ) + mock_client_factory.return_value = mock_client + + m = MagicMock() + type(m).id = PropertyMock( + return_value=identifier.TaskExecutionIdentifier( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ), + 0, + ) + ) + + outputs = engine.FlyteTaskExecution(m).get_outputs() + assert len(outputs.literals) == 1 + assert outputs.literals["b"].scalar.primitive.integer == 2 + mock_client.get_task_execution_data.assert_called_once_with( + identifier.TaskExecutionIdentifier( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.NodeExecutionIdentifier( + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + ), + 0, + ) + ) + + @patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_task_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse( - execution_data_locations[0], execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1], _EMPTY_LITERAL_MAP, _EMPTY_LITERAL_MAP ) ) mock_client_factory.return_value = mock_client From 62448e617bda43512ba6c185e6528a0403a549ba Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 20 Aug 2020 11:58:46 -0700 Subject: [PATCH 2/3] update another test --- tests/flytekit/unit/models/test_execution.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index 62feaf4691..d5170c3619 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -4,10 +4,18 @@ from flytekit.models import common as _common_models from flytekit.models import execution as _execution +from flytekit.models import literals as _literals from flytekit.models.core import execution as _core_exec from flytekit.models.core import identifier as _identifier from tests.flytekit.common import parameterizers as _parameterizers +_INPUT_MAP = _literals.LiteralMap( + {"a": _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=1)))} +) +_OUTPUT_MAP = _literals.LiteralMap( + {"b": _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=2)))} +) + def test_execution_metadata(): obj = _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1) @@ -104,28 +112,34 @@ def test_execution_spec(literal_value_pair): def test_workflow_execution_data_response(): input_blob = _common_models.UrlBlob("in", 1) output_blob = _common_models.UrlBlob("out", 2) - obj = _execution.WorkflowExecutionGetDataResponse(input_blob, output_blob) + obj = _execution.WorkflowExecutionGetDataResponse(input_blob, output_blob, _INPUT_MAP, _OUTPUT_MAP) obj2 = _execution.WorkflowExecutionGetDataResponse.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.inputs == input_blob assert obj2.outputs == output_blob + assert obj2.full_inputs == _INPUT_MAP + assert obj2.full_outputs == _OUTPUT_MAP def test_node_execution_data_response(): input_blob = _common_models.UrlBlob("in", 1) output_blob = _common_models.UrlBlob("out", 2) - obj = _execution.NodeExecutionGetDataResponse(input_blob, output_blob) + obj = _execution.NodeExecutionGetDataResponse(input_blob, output_blob, _INPUT_MAP, _OUTPUT_MAP) obj2 = _execution.NodeExecutionGetDataResponse.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.inputs == input_blob assert obj2.outputs == output_blob + assert obj2.full_inputs == _INPUT_MAP + assert obj2.full_outputs == _OUTPUT_MAP def test_task_execution_data_response(): input_blob = _common_models.UrlBlob("in", 1) output_blob = _common_models.UrlBlob("out", 2) - obj = _execution.TaskExecutionGetDataResponse(input_blob, output_blob) + obj = _execution.TaskExecutionGetDataResponse(input_blob, output_blob, _INPUT_MAP, _OUTPUT_MAP) obj2 = _execution.TaskExecutionGetDataResponse.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.inputs == input_blob assert obj2.outputs == output_blob + assert obj2.full_inputs == _INPUT_MAP + assert obj2.full_outputs == _OUTPUT_MAP From 1764a8dacd9a31602f82fde9ee91b0035ec6f09e Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 20 Aug 2020 13:44:45 -0700 Subject: [PATCH 3/3] fmt --- flytekit/models/execution.py | 18 +++++++++----- .../unit/engines/flyte/test_engine.py | 24 +++++-------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index d7163fcc8f..7fc6b8d9ec 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -443,8 +443,10 @@ def to_flyte_idl(self): :rtype: _execution_pb2.WorkflowExecutionGetDataResponse """ return _execution_pb2.WorkflowExecutionGetDataResponse( - inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), - full_inputs=self.full_inputs.to_flyte_idl(), full_outputs=self.full_outputs.to_flyte_idl(), + inputs=self.inputs.to_flyte_idl(), + outputs=self.outputs.to_flyte_idl(), + full_inputs=self.full_inputs.to_flyte_idl(), + full_outputs=self.full_outputs.to_flyte_idl(), ) @@ -467,8 +469,10 @@ def to_flyte_idl(self): :rtype: _task_execution_pb2.TaskExecutionGetDataResponse """ return _task_execution_pb2.TaskExecutionGetDataResponse( - inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), - full_inputs=self.full_inputs.to_flyte_idl(), full_outputs=self.full_outputs.to_flyte_idl(), + inputs=self.inputs.to_flyte_idl(), + outputs=self.outputs.to_flyte_idl(), + full_inputs=self.full_inputs.to_flyte_idl(), + full_outputs=self.full_outputs.to_flyte_idl(), ) @@ -491,6 +495,8 @@ def to_flyte_idl(self): :rtype: _node_execution_pb2.NodeExecutionGetDataResponse """ return _node_execution_pb2.NodeExecutionGetDataResponse( - inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), - full_inputs=self.full_inputs.to_flyte_idl(), full_outputs=self.full_outputs.to_flyte_idl(), + inputs=self.inputs.to_flyte_idl(), + outputs=self.outputs.to_flyte_idl(), + full_inputs=self.full_inputs.to_flyte_idl(), + full_outputs=self.full_outputs.to_flyte_idl(), ) diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py index 8cdeb46e8c..b9fb0c25b2 100644 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ b/tests/flytekit/unit/engines/flyte/test_engine.py @@ -256,9 +256,7 @@ def test_fetch_active_launch_plan(mock_client_factory): def test_get_full_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse( - None, None, _INPUT_MAP, _OUTPUT_MAP, - ) + return_value=_execution_models.WorkflowExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP,) ) mock_client_factory.return_value = mock_client @@ -298,9 +296,7 @@ def test_get_execution_inputs(mock_client_factory, execution_data_locations): def test_get_full_execution_outputs(mock_client_factory): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse( - None, None, _INPUT_MAP, _OUTPUT_MAP - ) + return_value=_execution_models.WorkflowExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client @@ -340,9 +336,7 @@ def test_get_execution_outputs(mock_client_factory, execution_data_locations): def test_get_full_node_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse( - None, None, _INPUT_MAP, _OUTPUT_MAP, - ) + return_value=_execution_models.NodeExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP,) ) mock_client_factory.return_value = mock_client @@ -394,9 +388,7 @@ def test_get_node_execution_inputs(mock_client_factory, execution_data_locations def test_get_full_node_execution_outputs(mock_client_factory): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse( - None, None, _INPUT_MAP, _OUTPUT_MAP - ) + return_value=_execution_models.NodeExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client @@ -448,9 +440,7 @@ def test_get_node_execution_outputs(mock_client_factory, execution_data_location def test_get_full_task_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( - return_value=_execution_models.TaskExecutionGetDataResponse( - None, None, _INPUT_MAP, _OUTPUT_MAP - ) + return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client @@ -518,9 +508,7 @@ def test_get_task_execution_inputs(mock_client_factory, execution_data_locations def test_get_full_task_execution_outputs(mock_client_factory): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( - return_value=_execution_models.TaskExecutionGetDataResponse( - None, None, _INPUT_MAP, _OUTPUT_MAP - ) + return_value=_execution_models.TaskExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP) ) mock_client_factory.return_value = mock_client