diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 4edb5f9205d..97d8ab57ceb 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -18,7 +18,8 @@ # The following section provides some predefined aliases for commonly used FlyteDirectory formats. -TensorboardLogs = FlyteDirectory[typing.TypeVar("tensorboard")] +tensorboard = typing.TypeVar("tensorboard") +TensorboardLogs = FlyteDirectory[tensorboard] """ This type can be used to denote that the output is a folder that contains logs that can be loaded in tensorboard. this is usually the SummaryWriter output in pytorch or Keras callbacks which record the history readable by diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index a0112c1eeac..81f4fb0fd04 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -133,7 +133,7 @@ def __fspath__(self): def extension(cls) -> str: return "" - def __class_getitem__(cls, item: typing.Type) -> typing.Type[FlyteDirectory]: + def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]: if item is None: return cls item_string = str(item) @@ -290,7 +290,7 @@ def _downloader(): expected_format = self.get_format(expected_python_type) - fd = FlyteDirectory[expected_format](local_folder, _downloader) + fd = FlyteDirectory.__class_getitem__(expected_format)(local_folder, _downloader) fd._remote_source = uri return fd @@ -300,7 +300,7 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirec literal_type.blob is not None and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART ): - return FlyteDirectory[typing.TypeVar(literal_type.blob.format)] + return FlyteDirectory.__class_getitem__(literal_type.blob.format) raise ValueError(f"Transformer {self} cannot reverse {literal_type}") diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 2b65efbcd6f..81796fc49ed 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -28,59 +28,69 @@ # This makes their usage extremely simple for the users. Please keep the list sorted. -HDF5EncodedFile = FlyteFile[typing.TypeVar("hdf5")] +hdf5 = typing.TypeVar("hdf5") +HDF5EncodedFile = FlyteFile[hdf5] """ This can be used to denote that the returned file is of type hdf5 and can be received by other tasks that accept an hdf5 format. This is usually useful for serializing Tensorflow models """ -HTMLPage = FlyteFile[typing.TypeVar("html")] +html = typing.TypeVar("html") +HTMLPage = FlyteFile[html] """ Can be used to receive or return an PNGImage. The underlying type is a FlyteFile, type. This is just a decoration and useful for attaching content type information with the file and automatically documenting code. """ -JoblibSerializedFile = FlyteFile[typing.TypeVar("joblib")] +joblib = typing.TypeVar("joblib") +JoblibSerializedFile = FlyteFile[joblib] """ This File represents a file that was serialized using `joblib.dump` method can be loaded back using `joblib.load` """ -JPEGImageFile = FlyteFile[typing.TypeVar("jpeg")] +jpeg = typing.TypeVar("jpeg") +JPEGImageFile = FlyteFile[jpeg] """ Can be used to receive or return an JPEGImage. The underlying type is a FlyteFile, type. This is just a decoration and useful for attaching content type information with the file and automatically documenting code. """ -PDFFile = FlyteFile[typing.TypeVar("pdf")] +pdf = typing.TypeVar("pdf") +PDFFile = FlyteFile[pdf] """ Can be used to receive or return an PDFFile. The underlying type is a FlyteFile, type. This is just a decoration and useful for attaching content type information with the file and automatically documenting code. """ -PNGImageFile = FlyteFile[typing.TypeVar("png")] +png = typing.TypeVar("png") +PNGImageFile = FlyteFile[png] """ Can be used to receive or return an PNGImage. The underlying type is a FlyteFile, type. This is just a decoration and useful for attaching content type information with the file and automatically documenting code. """ -PythonPickledFile = FlyteFile[typing.TypeVar("python-pickle")] +python_pickle = typing.TypeVar("python_pickle") +PythonPickledFile = FlyteFile[python_pickle] """ This type can be used when a serialized python pickled object is returned and shared between tasks. This only adds metadata to the file in Flyte, but does not really carry any object information """ -PythonNotebook = FlyteFile[typing.TypeVar("ipynb")] +ipynb = typing.TypeVar("ipynb") +PythonNotebook = FlyteFile[ipynb] """ This type is used to identify a python notebook file """ -SVGImageFile = FlyteFile[typing.TypeVar("svg")] +svg = typing.TypeVar("svg") +SVGImageFile = FlyteFile[svg] """ Can be used to receive or return an SVGImage. The underlying type is a FlyteFile, type. This is just a decoration and useful for attaching content type information with the file and automatically documenting code. """ -CSVFile = FlyteFile[typing.TypeVar("csv")] +csv = typing.TypeVar("csv") +CSVFile = FlyteFile[csv] """ Can be used to receive or return a CSVFile. The underlying type is a FlyteFile, type. This is just a decoration and useful for attaching content type information with the file and automatically documenting code. diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 1fbcee049a0..eb64e5943d1 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -137,7 +137,7 @@ def t2() -> flytekit_typing.FlyteFile["csv"]: def extension(cls) -> str: return "" - def __class_getitem__(cls, item: typing.Type) -> typing.Type[FlyteFile]: + def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: if item is None: return cls item_string = str(item) @@ -220,10 +220,10 @@ def __init__(self): super().__init__(name="FlyteFilePath", t=FlyteFile) @staticmethod - def get_format(t: typing.Union[typing.Type[FlyteFile]]) -> str: + def get_format(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str: if t is os.PathLike: return "" - return t.extension() + return typing.cast(FlyteFile, t).extension() def _blob_type(self, format: str) -> BlobType: return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE) @@ -342,14 +342,14 @@ def _downloader(): return ctx.file_access.get_data(uri, local_path, is_multipart=False) expected_format = FlyteFilePathTransformer.get_format(expected_python_type) - ff = FlyteFile[expected_format](local_path, _downloader) + ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader) ff._remote_source = uri return ff def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: - return FlyteFile[typing.TypeVar(literal_type.blob.format)] + return FlyteFile.__class_getitem__(literal_type.blob.format) raise ValueError(f"Transformer {self} cannot reverse {literal_type}") diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 8251d111bb2..9219d3a8b42 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -20,10 +20,10 @@ class FlytePickle(typing.Generic[T]): """ @classmethod - def python_type(cls) -> None: - return None + def python_type(cls) -> typing.Type: + return type(None) - def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]: + def __class_getitem__(cls, python_type: typing.Type) -> typing.Type: if python_type is None: return cls diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index cb421f98e37..335b6dc482c 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -6,6 +6,7 @@ from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum +from pathlib import Path from typing import Type import numpy as _np @@ -13,11 +14,13 @@ from marshmallow import fields from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import T, TypeEngine, TypeTransformer +from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType from flytekit.plugins import pandas +T = typing.TypeVar("T") + class SchemaFormat(Enum): """ @@ -37,7 +40,7 @@ class SchemaOpenMode(Enum): WRITE = "w" -def generate_ordered_files(directory: os.PathLike, n: int) -> str: +def generate_ordered_files(directory: os.PathLike, n: int) -> typing.Generator[str, None, None]: for i in range(n): yield os.path.join(directory, f"{i:05}") @@ -73,12 +76,12 @@ def all(self, **kwargs) -> T: class SchemaWriter(typing.Generic[T]): - def __init__(self, to_path: str, cols: typing.Dict[str, type], fmt: SchemaFormat): + def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): self._to_path = to_path self._fmt = fmt self._columns = cols # TODO This should be change to send a stop instead of hardcoded to 1024 - self._file_name_gen = generate_ordered_files(self._to_path, 1024) + self._file_name_gen = generate_ordered_files(Path(self._to_path), 1024) @property def to_path(self) -> str: @@ -107,14 +110,14 @@ def iter(self, **kwargs) -> typing.Generator[T, None, None]: with os.scandir(self._from_path) as it: for entry in it: if not entry.name.startswith(".") and entry.is_file(): - yield self._read(entry.path, **kwargs) + yield self._read(Path(entry.path), **kwargs) def all(self, **kwargs) -> T: - files = [] + files: typing.List[os.PathLike] = [] with os.scandir(self._from_path) as it: for entry in it: if not entry.name.startswith(".") and entry.is_file(): - files.append(entry.path) + files.append(Path(entry.path)) return self._read(*files, **kwargs) @@ -279,13 +282,15 @@ def open( if not h.handles_remote_io: # The Schema Handler does not manage its own IO, and this it will expect the files are on local file-system if self._supported_mode == SchemaOpenMode.READ and not self._downloaded: + if self._downloader is None: + raise AssertionError("downloader cannot be None in read mode!") # Only for readable objects if they are not downloaded already, we should download them # Write objects should already have everything written to self._downloader(self.remote_path, self.local_path) self._downloaded = True if mode == SchemaOpenMode.WRITE: - return h.writer(self.local_path, self.columns(), self.format()) - return h.reader(self.local_path, self.columns(), self.format()) + return h.writer(typing.cast(str, self.local_path), self.columns(), self.format()) + return h.reader(typing.cast(str, self.local_path), self.columns(), self.format()) # Remote IO is handled. So we will just pass the remote reference to the object if mode == SchemaOpenMode.WRITE: @@ -384,10 +389,10 @@ def downloader(x, y): supported_mode=SchemaOpenMode.READ, ) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + def guess_python_type(self, literal_type: LiteralType) -> Type[FlyteSchema]: if not literal_type.schema: raise ValueError(f"Cannot reverse {literal_type}") - columns: dict[Type] = {} + columns: typing.Dict[str, Type] = {} for literal_column in literal_type.schema.columns: if literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.INTEGER: columns[literal_column.name] = int @@ -403,7 +408,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: columns[literal_column.name] = bool else: raise ValueError(f"Unknown schema column type {literal_column}") - return FlyteSchema[columns] + return FlyteSchema.__class_getitem__(columns) TypeEngine.register(FlyteSchemaTransformer()) diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index 41a5423c08c..0edf024b083 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -15,7 +15,7 @@ class ParquetIO(object): PARQUET_ENGINE = "pyarrow" - def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs) -> pandas.DataFrame: + def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame: return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs) def read(self, *files: os.PathLike, columns: typing.List[str] = None, **kwargs) -> pandas.DataFrame: @@ -59,7 +59,7 @@ def write( class FastParquetIO(ParquetIO): PARQUET_ENGINE = "fastparquet" - def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs) -> pandas.DataFrame: + def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame: from fastparquet import ParquetFile as _ParquetFile from fastparquet import thrift_structures as _ts