Skip to content

Commit

Permalink
Fix mypy errors in flytekit/types (#757)
Browse files Browse the repository at this point in the history
Signed-off-by: Lisa <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
aeioulisa authored and eapolinario committed Jan 28, 2022
1 parent a9d99eb commit 9550870
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 36 deletions.
3 changes: 2 additions & 1 deletion flytekit/types/directory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

# The following section provides some predefined aliases for commonly used FlyteDirectory formats.

TensorboardLogs = FlyteDirectory[typing.TypeVar("tensorboard")]
tensorboard = typing.TypeVar("tensorboard")
TensorboardLogs = FlyteDirectory[tensorboard]
"""
This type can be used to denote that the output is a folder that contains logs that can be loaded in tensorboard.
this is usually the SummaryWriter output in pytorch or Keras callbacks which record the history readable by
Expand Down
6 changes: 3 additions & 3 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __fspath__(self):
def extension(cls) -> str:
return ""

def __class_getitem__(cls, item: typing.Type) -> typing.Type[FlyteDirectory]:
def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]:
if item is None:
return cls
item_string = str(item)
Expand Down Expand Up @@ -290,7 +290,7 @@ def _downloader():

expected_format = self.get_format(expected_python_type)

fd = FlyteDirectory[expected_format](local_folder, _downloader)
fd = FlyteDirectory.__class_getitem__(expected_format)(local_folder, _downloader)
fd._remote_source = uri

return fd
Expand All @@ -300,7 +300,7 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirec
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART
):
return FlyteDirectory[typing.TypeVar(literal_type.blob.format)]
return FlyteDirectory.__class_getitem__(literal_type.blob.format)
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


Expand Down
30 changes: 20 additions & 10 deletions flytekit/types/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,59 +28,69 @@
# This makes their usage extremely simple for the users. Please keep the list sorted.


HDF5EncodedFile = FlyteFile[typing.TypeVar("hdf5")]
hdf5 = typing.TypeVar("hdf5")
HDF5EncodedFile = FlyteFile[hdf5]
"""
This can be used to denote that the returned file is of type hdf5 and can be received by other tasks that
accept an hdf5 format. This is usually useful for serializing Tensorflow models
"""

HTMLPage = FlyteFile[typing.TypeVar("html")]
html = typing.TypeVar("html")
HTMLPage = FlyteFile[html]
"""
Can be used to receive or return an PNGImage. The underlying type is a FlyteFile, type. This is just a
decoration and useful for attaching content type information with the file and automatically documenting code.
"""

JoblibSerializedFile = FlyteFile[typing.TypeVar("joblib")]
joblib = typing.TypeVar("joblib")
JoblibSerializedFile = FlyteFile[joblib]
"""
This File represents a file that was serialized using `joblib.dump` method can be loaded back using `joblib.load`
"""

JPEGImageFile = FlyteFile[typing.TypeVar("jpeg")]
jpeg = typing.TypeVar("jpeg")
JPEGImageFile = FlyteFile[jpeg]
"""
Can be used to receive or return an JPEGImage. The underlying type is a FlyteFile, type. This is just a
decoration and useful for attaching content type information with the file and automatically documenting code.
"""

PDFFile = FlyteFile[typing.TypeVar("pdf")]
pdf = typing.TypeVar("pdf")
PDFFile = FlyteFile[pdf]
"""
Can be used to receive or return an PDFFile. The underlying type is a FlyteFile, type. This is just a
decoration and useful for attaching content type information with the file and automatically documenting code.
"""

PNGImageFile = FlyteFile[typing.TypeVar("png")]
png = typing.TypeVar("png")
PNGImageFile = FlyteFile[png]
"""
Can be used to receive or return an PNGImage. The underlying type is a FlyteFile, type. This is just a
decoration and useful for attaching content type information with the file and automatically documenting code.
"""

PythonPickledFile = FlyteFile[typing.TypeVar("python-pickle")]
python_pickle = typing.TypeVar("python_pickle")
PythonPickledFile = FlyteFile[python_pickle]
"""
This type can be used when a serialized python pickled object is returned and shared between tasks. This only
adds metadata to the file in Flyte, but does not really carry any object information
"""

PythonNotebook = FlyteFile[typing.TypeVar("ipynb")]
ipynb = typing.TypeVar("ipynb")
PythonNotebook = FlyteFile[ipynb]
"""
This type is used to identify a python notebook file
"""

SVGImageFile = FlyteFile[typing.TypeVar("svg")]
svg = typing.TypeVar("svg")
SVGImageFile = FlyteFile[svg]
"""
Can be used to receive or return an SVGImage. The underlying type is a FlyteFile, type. This is just a
decoration and useful for attaching content type information with the file and automatically documenting code.
"""

CSVFile = FlyteFile[typing.TypeVar("csv")]
csv = typing.TypeVar("csv")
CSVFile = FlyteFile[csv]
"""
Can be used to receive or return a CSVFile. The underlying type is a FlyteFile, type. This is just a
decoration and useful for attaching content type information with the file and automatically documenting code.
Expand Down
10 changes: 5 additions & 5 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def t2() -> flytekit_typing.FlyteFile["csv"]:
def extension(cls) -> str:
return ""

def __class_getitem__(cls, item: typing.Type) -> typing.Type[FlyteFile]:
def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]:
if item is None:
return cls
item_string = str(item)
Expand Down Expand Up @@ -220,10 +220,10 @@ def __init__(self):
super().__init__(name="FlyteFilePath", t=FlyteFile)

@staticmethod
def get_format(t: typing.Union[typing.Type[FlyteFile]]) -> str:
def get_format(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
if t is os.PathLike:
return ""
return t.extension()
return typing.cast(FlyteFile, t).extension()

def _blob_type(self, format: str) -> BlobType:
return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)
Expand Down Expand Up @@ -342,14 +342,14 @@ def _downloader():
return ctx.file_access.get_data(uri, local_path, is_multipart=False)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile[expected_format](local_path, _downloader)
ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader)
ff._remote_source = uri

return ff

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
return FlyteFile[typing.TypeVar(literal_type.blob.format)]
return FlyteFile.__class_getitem__(literal_type.blob.format)

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")

Expand Down
6 changes: 3 additions & 3 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class FlytePickle(typing.Generic[T]):
"""

@classmethod
def python_type(cls) -> None:
return None
def python_type(cls) -> typing.Type:
return type(None)

def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]:
def __class_getitem__(cls, python_type: typing.Type) -> typing.Type:
if python_type is None:
return cls

Expand Down
29 changes: 17 additions & 12 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@
from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Type

import numpy as _np
from dataclasses_json import config, dataclass_json
from marshmallow import fields

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import T, TypeEngine, TypeTransformer
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from flytekit.plugins import pandas

T = typing.TypeVar("T")


class SchemaFormat(Enum):
"""
Expand All @@ -37,7 +40,7 @@ class SchemaOpenMode(Enum):
WRITE = "w"


def generate_ordered_files(directory: os.PathLike, n: int) -> str:
def generate_ordered_files(directory: os.PathLike, n: int) -> typing.Generator[str, None, None]:
for i in range(n):
yield os.path.join(directory, f"{i:05}")

Expand Down Expand Up @@ -73,12 +76,12 @@ def all(self, **kwargs) -> T:


class SchemaWriter(typing.Generic[T]):
def __init__(self, to_path: str, cols: typing.Dict[str, type], fmt: SchemaFormat):
def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
self._to_path = to_path
self._fmt = fmt
self._columns = cols
# TODO This should be change to send a stop instead of hardcoded to 1024
self._file_name_gen = generate_ordered_files(self._to_path, 1024)
self._file_name_gen = generate_ordered_files(Path(self._to_path), 1024)

@property
def to_path(self) -> str:
Expand Down Expand Up @@ -107,14 +110,14 @@ def iter(self, **kwargs) -> typing.Generator[T, None, None]:
with os.scandir(self._from_path) as it:
for entry in it:
if not entry.name.startswith(".") and entry.is_file():
yield self._read(entry.path, **kwargs)
yield self._read(Path(entry.path), **kwargs)

def all(self, **kwargs) -> T:
files = []
files: typing.List[os.PathLike] = []
with os.scandir(self._from_path) as it:
for entry in it:
if not entry.name.startswith(".") and entry.is_file():
files.append(entry.path)
files.append(Path(entry.path))

return self._read(*files, **kwargs)

Expand Down Expand Up @@ -279,13 +282,15 @@ def open(
if not h.handles_remote_io:
# The Schema Handler does not manage its own IO, and this it will expect the files are on local file-system
if self._supported_mode == SchemaOpenMode.READ and not self._downloaded:
if self._downloader is None:
raise AssertionError("downloader cannot be None in read mode!")
# Only for readable objects if they are not downloaded already, we should download them
# Write objects should already have everything written to
self._downloader(self.remote_path, self.local_path)
self._downloaded = True
if mode == SchemaOpenMode.WRITE:
return h.writer(self.local_path, self.columns(), self.format())
return h.reader(self.local_path, self.columns(), self.format())
return h.writer(typing.cast(str, self.local_path), self.columns(), self.format())
return h.reader(typing.cast(str, self.local_path), self.columns(), self.format())

# Remote IO is handled. So we will just pass the remote reference to the object
if mode == SchemaOpenMode.WRITE:
Expand Down Expand Up @@ -384,10 +389,10 @@ def downloader(x, y):
supported_mode=SchemaOpenMode.READ,
)

def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
def guess_python_type(self, literal_type: LiteralType) -> Type[FlyteSchema]:
if not literal_type.schema:
raise ValueError(f"Cannot reverse {literal_type}")
columns: dict[Type] = {}
columns: typing.Dict[str, Type] = {}
for literal_column in literal_type.schema.columns:
if literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.INTEGER:
columns[literal_column.name] = int
Expand All @@ -403,7 +408,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
columns[literal_column.name] = bool
else:
raise ValueError(f"Unknown schema column type {literal_column}")
return FlyteSchema[columns]
return FlyteSchema.__class_getitem__(columns)


TypeEngine.register(FlyteSchemaTransformer())
4 changes: 2 additions & 2 deletions flytekit/types/schema/types_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class ParquetIO(object):
PARQUET_ENGINE = "pyarrow"

def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs) -> pandas.DataFrame:
def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame:
return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs)

def read(self, *files: os.PathLike, columns: typing.List[str] = None, **kwargs) -> pandas.DataFrame:
Expand Down Expand Up @@ -59,7 +59,7 @@ def write(
class FastParquetIO(ParquetIO):
PARQUET_ENGINE = "fastparquet"

def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs) -> pandas.DataFrame:
def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame:
from fastparquet import ParquetFile as _ParquetFile
from fastparquet import thrift_structures as _ts

Expand Down

0 comments on commit 9550870

Please sign in to comment.