Skip to content

Commit

Permalink
Single out tensorflow tests and fix protobuf type transformer tests
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Dec 8, 2022
1 parent 2cd0a67 commit e1bfe58
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ test: lint unit_test

.PHONY: unit_test
unit_test:
pytest -m "not sandbox_test" tests/flytekit/unit
# Skip tensorflow tests and run them with the necessary env var set so that a working (albeit slower)
# library is used to serialize/deserialize protobufs is used.
pytest -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/tensorflow && \
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest tests/flytekit/unit/extras/tensorflow

requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt
requirements-spark2.txt: requirements-spark2.in install-piptools
Expand Down
1 change: 0 additions & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@
LiteralType
BlobType
"""

import sys
from typing import Generator

Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
from google.protobuf import reflection as _proto_reflection
from google.protobuf import struct_pb2 as _struct
from google.protobuf.json_format import MessageToDict as _MessageToDict
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
Expand Down Expand Up @@ -551,11 +551,11 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
raise ValueError(f"Dataclass transformer cannot reverse {literal_type}")


class ProtobufTransformer(TypeTransformer[_proto_reflection.GeneratedProtocolMessageType]):
class ProtobufTransformer(TypeTransformer[Message]):
PB_FIELD_KEY = "pb_type"

def __init__(self):
super().__init__("Protobuf-Transformer", _proto_reflection.GeneratedProtocolMessageType)
super().__init__("Protobuf-Transformer", Message)

@staticmethod
def tag(expected_python_type: Type[T]) -> str:
Expand Down
3 changes: 3 additions & 0 deletions flytekit/extras/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import tensorflow

_tensorflow_installed = True
except TypeError as e:
logger.warn(f"Unsupported version of tensorflow installed. Error message from protobuf library: {e}")
_tensorflow_installed = False
except (ImportError, OSError):
_tensorflow_installed = False

Expand Down

0 comments on commit e1bfe58

Please sign in to comment.