Skip to content

Commit

Permalink
rename register_handler to register
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Feb 8, 2022
1 parent bd3be41 commit d201095
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
8 changes: 4 additions & 4 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def decode(


for protocol in [LOCAL, S3]: # Should we add GCS
StructuredDatasetTransformerEngine.register_handler(PandasToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register_handler(ParquetToPandasDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register_handler(ArrowToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register_handler(ParquetToArrowDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=True)
8 changes: 4 additions & 4 deletions flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def decode(
return pa.Table.from_pandas(_read_from_bq(flyte_value))


StructuredDatasetTransformerEngine.register_handler(PandasToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register_handler(BQToPandasDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register_handler(ArrowToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register_handler(BQToArrowDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler(), default_for_type=False)
2 changes: 1 addition & 1 deletion flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def __init__(self):
self._type_assertions_enabled = False

@classmethod
def register_handler(cls, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False):
def register(cls, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False):
"""
Call this with any handler to register it with this dataframe meta-transformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def decode(


for protocol in ["/", "s3"]:
StructuredDatasetTransformerEngine.register_handler(SparkToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register_handler(ParquetToSparkDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=True)
14 changes: 7 additions & 7 deletions tests/flytekit/unit/core/test_structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ def __init__(self, protocol):
def encode(self):
...

StructuredDatasetTransformerEngine.register_handler(TempEncoder("gs"), default_for_type=False)
StructuredDatasetTransformerEngine.register(TempEncoder("gs"), default_for_type=False)
with pytest.raises(ValueError):
StructuredDatasetTransformerEngine.register_handler(TempEncoder("gs://"), default_for_type=False)
StructuredDatasetTransformerEngine.register(TempEncoder("gs://"), default_for_type=False)

class TempEncoder:
pass

with pytest.raises(TypeError, match="We don't support this type of handler"):
StructuredDatasetTransformerEngine.register_handler(TempEncoder, default_for_type=False)
StructuredDatasetTransformerEngine.register(TempEncoder, default_for_type=False)


def test_to_literal():
Expand Down Expand Up @@ -186,7 +186,7 @@ def encode(
) -> literals.StructuredDataset:
return literals.StructuredDataset(uri="")

StructuredDatasetTransformerEngine.register_handler(TempEncoder("myavro"), default_for_type=True)
StructuredDatasetTransformerEngine.register(TempEncoder("myavro"), default_for_type=True)
lt = TypeEngine.to_literal_type(MyDF)
assert lt.structured_dataset_type.format == "myavro"

Expand All @@ -199,7 +199,7 @@ def encode(

# Test that looking up encoders/decoders falls back to the "" encoder/decoder
empty_format_temp_encoder = TempEncoder("")
StructuredDatasetTransformerEngine.register_handler(empty_format_temp_encoder, default_for_type=False)
StructuredDatasetTransformerEngine.register(empty_format_temp_encoder, default_for_type=False)

res = StructuredDatasetTransformerEngine.get_encoder(MyDF, "tmpfs", "rando")
assert res is empty_format_temp_encoder
Expand All @@ -224,7 +224,7 @@ def decode(
) -> typing.Union[typing.Generator[pd.DataFrame, None, None]]:
yield pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

StructuredDatasetTransformerEngine.register_handler(
StructuredDatasetTransformerEngine.register(
MockPandasDecodingHandlers(pd.DataFrame, "tmpfs"), default_for_type=False
)
sd = StructuredDataset()
Expand All @@ -244,7 +244,7 @@ def decode(
) -> pd.DataFrame:
pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

StructuredDatasetTransformerEngine.register_handler(
StructuredDatasetTransformerEngine.register(
MockPandasDecodingHandlers(pd.DataFrame, "tmpfs"), default_for_type=False, override=True
)
sd = StructuredDataset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def decode(
return pd_df


StructuredDatasetTransformerEngine.register_handler(MockBQEncodingHandlers(pd.DataFrame, BIGQUERY), False, True)
StructuredDatasetTransformerEngine.register_handler(MockBQDecodingHandlers(pd.DataFrame, BIGQUERY), False, True)
StructuredDatasetTransformerEngine.register(MockBQEncodingHandlers(pd.DataFrame, BIGQUERY), False, True)
StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(pd.DataFrame, BIGQUERY), False, True)


class NumpyEncodingHandlers(StructuredDatasetEncoder):
Expand Down Expand Up @@ -95,8 +95,8 @@ def decode(


for protocol in [LOCAL, S3]:
StructuredDatasetTransformerEngine.register_handler(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET))
StructuredDatasetTransformerEngine.register_handler(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET))


@task
Expand Down

0 comments on commit d201095

Please sign in to comment.