Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove singleton from structured dataset transformer engine #848

Merged
merged 11 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ jobs:
python -m pip install --upgrade pip==21.2.4 setuptools wheel
make setup${{ matrix.spark-version-suffix }}
pip freeze
- name: Test FlyteSchema compatibility
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can revert this, either way, doesn't really matter. i think i like not running the compat stuff with coverage.

run: |
FLYTE_SDK_USE_STRUCTURED_DATASET=FALSE python -m pytest tests/flytekit_compatibility
- name: Test with coverage
run: |
coverage run -m pytest tests/flytekit_compatibility
FLYTE_SDK_USE_STRUCTURED_DATASET=TRUE coverage run -m pytest tests/flytekit/unit
- name: Integration Tests with coverage
# https://github.com/actions/runner/issues/241#issuecomment-577360161
Expand Down
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)

Expand Down
10 changes: 5 additions & 5 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
FLYTE_DATASET_TRANSFORMER,
LOCAL,
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetTransformerEngine,
)

T = TypeVar("T")
Expand Down Expand Up @@ -107,7 +107,7 @@ def decode(


for protocol in [LOCAL, S3]: # Should we add GCS
FLYTE_DATASET_TRANSFORMER.register_handler(PandasToParquetEncodingHandler(protocol), default_for_type=True)
FLYTE_DATASET_TRANSFORMER.register_handler(ParquetToPandasDecodingHandler(protocol), default_for_type=True)
FLYTE_DATASET_TRANSFORMER.register_handler(ArrowToParquetEncodingHandler(protocol), default_for_type=True)
FLYTE_DATASET_TRANSFORMER.register_handler(ParquetToArrowDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=True)
10 changes: 5 additions & 5 deletions flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from flytekit.types.structured.structured_dataset import (
BIGQUERY,
DF,
FLYTE_DATASET_TRANSFORMER,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
)


Expand Down Expand Up @@ -110,7 +110,7 @@ def decode(
return pa.Table.from_pandas(_read_from_bq(flyte_value))


FLYTE_DATASET_TRANSFORMER.register_handler(PandasToBQEncodingHandlers(), default_for_type=False)
FLYTE_DATASET_TRANSFORMER.register_handler(BQToPandasDecodingHandler(), default_for_type=False)
FLYTE_DATASET_TRANSFORMER.register_handler(ArrowToBQEncodingHandlers(), default_for_type=False)
FLYTE_DATASET_TRANSFORMER.register_handler(BQToArrowDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler(), default_for_type=False)
47 changes: 26 additions & 21 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from flytekit.configuration.sdk import USE_STRUCTURED_DATASET
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeTransformer
from flytekit.extend import TypeEngine
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.loggers import logger
from flytekit.models import literals
from flytekit.models import types as type_models
Expand Down Expand Up @@ -106,15 +105,15 @@ def all(self) -> DF:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
ctx = FlyteContextManager.current_context()
return FLYTE_DATASET_TRANSFORMER.open_as(
return flyte_dataset_transformer.open_as(
ctx, self.literal, self._dataframe_type, updated_metadata=self.metadata
)

def iter(self) -> Generator[DF, None, None]:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
ctx = FlyteContextManager.current_context()
return FLYTE_DATASET_TRANSFORMER.iter_as(
return flyte_dataset_transformer.iter_as(
ctx, self.literal, self._dataframe_type, updated_metadata=self.metadata
)

Expand Down Expand Up @@ -170,7 +169,7 @@ class StructuredDatasetEncoder(ABC):
def __init__(self, python_type: Type[T], protocol: str, supported_format: Optional[str] = None):
"""
Extend this abstract class, implement the encode function, and register your concrete class with the
FLYTE_DATASET_TRANSFORMER defined at this module level in order for the core flytekit type engine to handle
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the encoding interface, meaning it is used when there is a Python value that the
flytekit type engine is trying to convert into a Flyte Literal. For the other way, see
the StructuredDatasetEncoder
Expand Down Expand Up @@ -230,7 +229,7 @@ class StructuredDatasetDecoder(ABC):
def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optional[str] = None):
"""
Extend this abstract class, implement the decode function, and register your concrete class with the
FLYTE_DATASET_TRANSFORMER defined at this module level in order for the core flytekit type engine to handle
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the decoder interface, meaning it is used when there is a Flyte Literal value,
and we have to get a Python value out of it. For the other way, see the StructuredDatasetEncoder

Expand Down Expand Up @@ -337,7 +336,8 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):

Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder]

def _finder(self, handler_map, df_type: Type, protocol: str, format: str):
@staticmethod
def _finder(handler_map, df_type: Type, protocol: str, format: str):
try:
return handler_map[df_type][protocol][format]
except KeyError:
Expand All @@ -352,18 +352,21 @@ def _finder(self, handler_map, df_type: Type, protocol: str, format: str):
...
raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}")

def get_encoder(self, df_type: Type, protocol: str, format: str):
return self._finder(self.ENCODERS, df_type, protocol, format)
@classmethod
def get_encoder(cls, df_type: Type, protocol: str, format: str):
return cls._finder(StructuredDatasetTransformerEngine.ENCODERS, df_type, protocol, format)

def get_decoder(self, df_type: Type, protocol: str, format: str):
return self._finder(self.DECODERS, df_type, protocol, format)
@classmethod
def get_decoder(cls, df_type: Type, protocol: str, format: str):
return cls._finder(StructuredDatasetTransformerEngine.DECODERS, df_type, protocol, format)

def _handler_finder(self, h: Handlers) -> Dict[str, Handlers]:
@classmethod
def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]:
# Maybe think about default dict in the future, but is typing as nice?
if isinstance(h, StructuredDatasetEncoder):
top_level = self.ENCODERS
top_level = cls.ENCODERS
elif isinstance(h, StructuredDatasetDecoder):
top_level = self.DECODERS
top_level = cls.DECODERS
else:
raise TypeError(f"We don't support this type of handler {h}")
if h.python_type not in top_level:
Expand All @@ -376,7 +379,8 @@ def __init__(self):
super().__init__("StructuredDataset Transformer", StructuredDataset)
self._type_assertions_enabled = False

def register_handler(self, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False):
@classmethod
def register(cls, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False):
"""
Call this with any handler to register it with this dataframe meta-transformer

Expand All @@ -386,21 +390,22 @@ def register_handler(self, h: Handlers, default_for_type: Optional[bool] = True,
logger.info(f"Structured datasets not enabled, not registering handler {h}")
return

lowest_level = self._handler_finder(h)
lowest_level = cls._handler_finder(h)
if h.supported_format in lowest_level and override is False:
raise ValueError(f"Already registered a handler for {(h.python_type, h.protocol, h.supported_format)}")
lowest_level[h.supported_format] = h
logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {h.protocol}, fmt {h.supported_format}")

if default_for_type:
# TODO: Add logging, think about better ux, maybe default False and warn if doesn't exist.
self.DEFAULT_FORMATS[h.python_type] = h.supported_format
self.DEFAULT_PROTOCOLS[h.python_type] = h.protocol
cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
cls.DEFAULT_PROTOCOLS[h.python_type] = h.protocol

# Register with the type engine as well
# The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as
# long as the older Pandas/FlyteSchema transformer do not also specify the override
TypeEngine.register_additional_type(self, h.python_type, override=True)
engine = StructuredDatasetTransformerEngine()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is main change - every time we register a new handler, we create a new instance of this class. not the best, but okay. we can also cache it as a global variable (which already exists at the bottom of this .py file)... either way.

TypeEngine.register_additional_type(engine, h.python_type, override=True)

def assert_type(self, t: Type[StructuredDataset], v: typing.Any):
return
Expand Down Expand Up @@ -724,7 +729,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:

if USE_STRUCTURED_DATASET.get():
logger.debug("Structured dataset module load... using structured datasets!")
FLYTE_DATASET_TRANSFORMER = StructuredDatasetTransformerEngine()
TypeEngine.register(FLYTE_DATASET_TRANSFORMER)
flyte_dataset_transformer = StructuredDatasetTransformerEngine()
TypeEngine.register(flyte_dataset_transformer)
else:
logger.debug("Structured dataset module load... not using structured datasets")
3 changes: 2 additions & 1 deletion plugins/flytekit-papermill/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
flytekitplugins-spark>=0.30.0b4
git+https://github.com/flyteorg/flytekit@add-sd-make-class-methods#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
# vcs+protocol://repo_url/#egg=pkg&subdirectory=flyte
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

13 changes: 7 additions & 6 deletions plugins/flytekit-papermill/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ docstring-parser==0.13
# via flytekit
flyteidl==0.21.23
# via flytekit
flytekit==0.26.0
flytekit==0.30.0
# via flytekitplugins-spark
flytekitplugins-spark==0.30.0b4
flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@add-sd-make-class-methods#subdirectory=plugins/flytekit-spark
# via -r dev-requirements.in
grpcio==1.43.0
# via flytekit
Expand Down Expand Up @@ -87,11 +87,11 @@ protobuf==3.19.3
# flytekit
py==1.11.0
# via retry
py4j==0.10.9.2
py4j==0.10.9.3
# via pyspark
pyarrow==6.0.1
# via flytekit
pyspark==3.2.0
pyspark==3.2.1
# via flytekitplugins-spark
python-dateutil==2.8.1
# via
Expand Down Expand Up @@ -123,7 +123,6 @@ retry==0.9.2
six==1.16.0
# via
# cookiecutter
# flytekit
# grpcio
# python-dateutil
# responses
Expand All @@ -134,7 +133,9 @@ statsd==3.3.0
text-unidecode==1.3
# via python-slugify
typing-extensions==4.0.1
# via typing-inspect
# via
# flytekit
# typing-inspect
typing-inspect==0.7.1
# via dataclasses-json
urllib3==1.26.8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
FLYTE_DATASET_TRANSFORMER,
PARQUET,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetTransformerEngine,
)


Expand Down Expand Up @@ -45,5 +45,5 @@ def decode(


for protocol in ["/", "s3"]:
FLYTE_DATASET_TRANSFORMER.register_handler(SparkToParquetEncodingHandler(protocol), default_for_type=True)
FLYTE_DATASET_TRANSFORMER.register_handler(ParquetToSparkDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=True)
Loading