diff --git a/Makefile b/Makefile index 4b3278bec0..2f71c97ebe 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,10 @@ test: lint unit_test .PHONY: unit_test unit_test: - pytest -m "not sandbox_test" tests/flytekit/unit + # Skip tensorflow tests and run them with the necessary env var set so that a working (albeit slower) + # library is used to serialize/deserialize protobufs is used. + pytest -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/tensorflow && \ + PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest tests/flytekit/unit/extras/tensorflow requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt requirements-spark2.txt: requirements-spark2.in install-piptools diff --git a/flytekit/__init__.py b/flytekit/__init__.py index a18d476291..9a2ed7f180 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -154,7 +154,6 @@ LiteralType BlobType """ - import sys from typing import Generator diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 98969e41b3..35e2a3e39d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,10 +15,10 @@ from dataclasses_json import DataClassJsonMixin, dataclass_json from google.protobuf import json_format as _json_format -from google.protobuf import reflection as _proto_reflection from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict from google.protobuf.json_format import ParseDict as _ParseDict +from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions from marshmallow_jsonschema import JSONSchema @@ -551,11 +551,11 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: raise ValueError(f"Dataclass transformer cannot reverse {literal_type}") -class ProtobufTransformer(TypeTransformer[_proto_reflection.GeneratedProtocolMessageType]): +class ProtobufTransformer(TypeTransformer[Message]): PB_FIELD_KEY = "pb_type" def __init__(self): - super().__init__("Protobuf-Transformer", _proto_reflection.GeneratedProtocolMessageType) + super().__init__("Protobuf-Transformer", Message) @staticmethod def tag(expected_python_type: Type[T]) -> str: diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py index c42b1fd3d5..f51da24dae 100644 --- a/flytekit/extras/tensorflow/__init__.py +++ b/flytekit/extras/tensorflow/__init__.py @@ -18,6 +18,9 @@ import tensorflow _tensorflow_installed = True +except TypeError as e: + logger.warn(f"Unsupported version of tensorflow installed. Error message from protobuf library: {e}") + _tensorflow_installed = False except (ImportError, OSError): _tensorflow_installed = False