From b257569b31f83feb8e40b6371a72e32cefb63a63 Mon Sep 17 00:00:00 2001
From: Lisa <30621230+aeioulisa@users.noreply.github.com>
Date: Tue, 7 Dec 2021 01:51:09 +0800
Subject: [PATCH] Fix mypy errors in flytekit/types (#757)

Signed-off-by: Lisa <aeioulisa@gmail.com>
Signed-off-by: Kevin Su <pingsutw@apache.org>
---
 flytekit/types/directory/__init__.py  |  3 ++-
 flytekit/types/directory/types.py     |  6 +++---
 flytekit/types/file/__init__.py       | 30 ++++++++++++++++++---------
 flytekit/types/file/file.py           | 10 ++++-----
 flytekit/types/pickle/pickle.py       |  6 +++---
 flytekit/types/schema/types.py        | 29 +++++++++++++++-----------
 flytekit/types/schema/types_pandas.py |  4 ++--
 7 files changed, 52 insertions(+), 36 deletions(-)

diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py
index 4edb5f9205..97d8ab57ce 100644
--- a/flytekit/types/directory/__init__.py
+++ b/flytekit/types/directory/__init__.py
@@ -18,7 +18,8 @@
 
 # The following section provides some predefined aliases for commonly used FlyteDirectory formats.
 
-TensorboardLogs = FlyteDirectory[typing.TypeVar("tensorboard")]
+tensorboard = typing.TypeVar("tensorboard")
+TensorboardLogs = FlyteDirectory[tensorboard]
 """
     This type can be used to denote that the output is a folder that contains logs that can be loaded in tensorboard.
     this is usually the SummaryWriter output in pytorch or Keras callbacks which record the history readable by
diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py
index a0112c1eea..81f4fb0fd0 100644
--- a/flytekit/types/directory/types.py
+++ b/flytekit/types/directory/types.py
@@ -133,7 +133,7 @@ def __fspath__(self):
     def extension(cls) -> str:
         return ""
 
-    def __class_getitem__(cls, item: typing.Type) -> typing.Type[FlyteDirectory]:
+    def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]:
         if item is None:
             return cls
         item_string = str(item)
@@ -290,7 +290,7 @@ def _downloader():
 
         expected_format = self.get_format(expected_python_type)
 
-        fd = FlyteDirectory[expected_format](local_folder, _downloader)
+        fd = FlyteDirectory.__class_getitem__(expected_format)(local_folder, _downloader)
         fd._remote_source = uri
 
         return fd
@@ -300,7 +300,7 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirec
             literal_type.blob is not None
             and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART
         ):
-            return FlyteDirectory[typing.TypeVar(literal_type.blob.format)]
+            return FlyteDirectory.__class_getitem__(literal_type.blob.format)
         raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
 
 
diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py
index 2b65efbcd6..81796fc49e 100644
--- a/flytekit/types/file/__init__.py
+++ b/flytekit/types/file/__init__.py
@@ -28,59 +28,69 @@
 # This makes their usage extremely simple for the users. Please keep the list sorted.
 
 
-HDF5EncodedFile = FlyteFile[typing.TypeVar("hdf5")]
+hdf5 = typing.TypeVar("hdf5")
+HDF5EncodedFile = FlyteFile[hdf5]
 """
     This can be used to denote that the returned file is of type hdf5 and can be received by other tasks that
     accept an hdf5 format. This is usually useful for serializing Tensorflow models
 """
 
-HTMLPage = FlyteFile[typing.TypeVar("html")]
+html = typing.TypeVar("html")
+HTMLPage = FlyteFile[html]
 """
     Can be used to receive or return an PNGImage. The underlying type is a FlyteFile, type. This is just a
     decoration and useful for attaching content type information with the file and automatically documenting code.
 """
 
-JoblibSerializedFile = FlyteFile[typing.TypeVar("joblib")]
+joblib = typing.TypeVar("joblib")
+JoblibSerializedFile = FlyteFile[joblib]
 """
     This File represents a file that was serialized using `joblib.dump` method can be loaded back using `joblib.load`
 """
 
-JPEGImageFile = FlyteFile[typing.TypeVar("jpeg")]
+jpeg = typing.TypeVar("jpeg")
+JPEGImageFile = FlyteFile[jpeg]
 """
     Can be used to receive or return an JPEGImage. The underlying type is a FlyteFile, type. This is just a
     decoration and useful for attaching content type information with the file and automatically documenting code.
 """
 
-PDFFile = FlyteFile[typing.TypeVar("pdf")]
+pdf = typing.TypeVar("pdf")
+PDFFile = FlyteFile[pdf]
 """
     Can be used to receive or return an PDFFile. The underlying type is a FlyteFile, type. This is just a
     decoration and useful for attaching content type information with the file and automatically documenting code.
 """
 
-PNGImageFile = FlyteFile[typing.TypeVar("png")]
+png = typing.TypeVar("png")
+PNGImageFile = FlyteFile[png]
 """
     Can be used to receive or return an PNGImage. The underlying type is a FlyteFile, type. This is just a
     decoration and useful for attaching content type information with the file and automatically documenting code.
 """
 
-PythonPickledFile = FlyteFile[typing.TypeVar("python-pickle")]
+python_pickle = typing.TypeVar("python_pickle")
+PythonPickledFile = FlyteFile[python_pickle]
 """
     This type can be used when a serialized python pickled object is returned and shared between tasks. This only
     adds metadata to the file in Flyte, but does not really carry any object information
 """
 
-PythonNotebook = FlyteFile[typing.TypeVar("ipynb")]
+ipynb = typing.TypeVar("ipynb")
+PythonNotebook = FlyteFile[ipynb]
 """
     This type is used to identify a python notebook file
 """
 
-SVGImageFile = FlyteFile[typing.TypeVar("svg")]
+svg = typing.TypeVar("svg")
+SVGImageFile = FlyteFile[svg]
 """
     Can be used to receive or return an SVGImage. The underlying type is a FlyteFile, type. This is just a
     decoration and useful for attaching content type information with the file and automatically documenting code.
 """
 
-CSVFile = FlyteFile[typing.TypeVar("csv")]
+csv = typing.TypeVar("csv")
+CSVFile = FlyteFile[csv]
 """
     Can be used to receive or return a CSVFile. The underlying type is a FlyteFile, type. This is just a
     decoration and useful for attaching content type information with the file and automatically documenting code.
diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py
index 1fbcee049a..eb64e5943d 100644
--- a/flytekit/types/file/file.py
+++ b/flytekit/types/file/file.py
@@ -137,7 +137,7 @@ def t2() -> flytekit_typing.FlyteFile["csv"]:
     def extension(cls) -> str:
         return ""
 
-    def __class_getitem__(cls, item: typing.Type) -> typing.Type[FlyteFile]:
+    def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]:
         if item is None:
             return cls
         item_string = str(item)
@@ -220,10 +220,10 @@ def __init__(self):
         super().__init__(name="FlyteFilePath", t=FlyteFile)
 
     @staticmethod
-    def get_format(t: typing.Union[typing.Type[FlyteFile]]) -> str:
+    def get_format(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
         if t is os.PathLike:
             return ""
-        return t.extension()
+        return typing.cast(FlyteFile, t).extension()
 
     def _blob_type(self, format: str) -> BlobType:
         return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)
@@ -342,14 +342,14 @@ def _downloader():
             return ctx.file_access.get_data(uri, local_path, is_multipart=False)
 
         expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
-        ff = FlyteFile[expected_format](local_path, _downloader)
+        ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader)
         ff._remote_source = uri
 
         return ff
 
     def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
         if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
-            return FlyteFile[typing.TypeVar(literal_type.blob.format)]
+            return FlyteFile.__class_getitem__(literal_type.blob.format)
 
         raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
 
diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py
index 8251d111bb..9219d3a8b4 100644
--- a/flytekit/types/pickle/pickle.py
+++ b/flytekit/types/pickle/pickle.py
@@ -20,10 +20,10 @@ class FlytePickle(typing.Generic[T]):
     """
 
     @classmethod
-    def python_type(cls) -> None:
-        return None
+    def python_type(cls) -> typing.Type:
+        return type(None)
 
-    def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]:
+    def __class_getitem__(cls, python_type: typing.Type) -> typing.Type:
         if python_type is None:
             return cls
 
diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py
index cb421f98e3..335b6dc482 100644
--- a/flytekit/types/schema/types.py
+++ b/flytekit/types/schema/types.py
@@ -6,6 +6,7 @@
 from abc import abstractmethod
 from dataclasses import dataclass, field
 from enum import Enum
+from pathlib import Path
 from typing import Type
 
 import numpy as _np
@@ -13,11 +14,13 @@
 from marshmallow import fields
 
 from flytekit.core.context_manager import FlyteContext, FlyteContextManager
-from flytekit.core.type_engine import T, TypeEngine, TypeTransformer
+from flytekit.core.type_engine import TypeEngine, TypeTransformer
 from flytekit.models.literals import Literal, Scalar, Schema
 from flytekit.models.types import LiteralType, SchemaType
 from flytekit.plugins import pandas
 
+T = typing.TypeVar("T")
+
 
 class SchemaFormat(Enum):
     """
@@ -37,7 +40,7 @@ class SchemaOpenMode(Enum):
     WRITE = "w"
 
 
-def generate_ordered_files(directory: os.PathLike, n: int) -> str:
+def generate_ordered_files(directory: os.PathLike, n: int) -> typing.Generator[str, None, None]:
     for i in range(n):
         yield os.path.join(directory, f"{i:05}")
 
@@ -73,12 +76,12 @@ def all(self, **kwargs) -> T:
 
 
 class SchemaWriter(typing.Generic[T]):
-    def __init__(self, to_path: str, cols: typing.Dict[str, type], fmt: SchemaFormat):
+    def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
         self._to_path = to_path
         self._fmt = fmt
         self._columns = cols
         # TODO This should be change to send a stop instead of hardcoded to 1024
-        self._file_name_gen = generate_ordered_files(self._to_path, 1024)
+        self._file_name_gen = generate_ordered_files(Path(self._to_path), 1024)
 
     @property
     def to_path(self) -> str:
@@ -107,14 +110,14 @@ def iter(self, **kwargs) -> typing.Generator[T, None, None]:
         with os.scandir(self._from_path) as it:
             for entry in it:
                 if not entry.name.startswith(".") and entry.is_file():
-                    yield self._read(entry.path, **kwargs)
+                    yield self._read(Path(entry.path), **kwargs)
 
     def all(self, **kwargs) -> T:
-        files = []
+        files: typing.List[os.PathLike] = []
         with os.scandir(self._from_path) as it:
             for entry in it:
                 if not entry.name.startswith(".") and entry.is_file():
-                    files.append(entry.path)
+                    files.append(Path(entry.path))
 
         return self._read(*files, **kwargs)
 
@@ -279,13 +282,15 @@ def open(
         if not h.handles_remote_io:
             # The Schema Handler does not manage its own IO, and this it will expect the files are on local file-system
             if self._supported_mode == SchemaOpenMode.READ and not self._downloaded:
+                if self._downloader is None:
+                    raise AssertionError("downloader cannot be None in read mode!")
                 # Only for readable objects if they are not downloaded already, we should download them
                 # Write objects should already have everything written to
                 self._downloader(self.remote_path, self.local_path)
                 self._downloaded = True
             if mode == SchemaOpenMode.WRITE:
-                return h.writer(self.local_path, self.columns(), self.format())
-            return h.reader(self.local_path, self.columns(), self.format())
+                return h.writer(typing.cast(str, self.local_path), self.columns(), self.format())
+            return h.reader(typing.cast(str, self.local_path), self.columns(), self.format())
 
         # Remote IO is handled. So we will just pass the remote reference to the object
         if mode == SchemaOpenMode.WRITE:
@@ -384,10 +389,10 @@ def downloader(x, y):
             supported_mode=SchemaOpenMode.READ,
         )
 
-    def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
+    def guess_python_type(self, literal_type: LiteralType) -> Type[FlyteSchema]:
         if not literal_type.schema:
             raise ValueError(f"Cannot reverse {literal_type}")
-        columns: dict[Type] = {}
+        columns: typing.Dict[str, Type] = {}
         for literal_column in literal_type.schema.columns:
             if literal_column.type == SchemaType.SchemaColumn.SchemaColumnType.INTEGER:
                 columns[literal_column.name] = int
@@ -403,7 +408,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
                 columns[literal_column.name] = bool
             else:
                 raise ValueError(f"Unknown schema column type {literal_column}")
-        return FlyteSchema[columns]
+        return FlyteSchema.__class_getitem__(columns)
 
 
 TypeEngine.register(FlyteSchemaTransformer())
diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py
index 41a5423c08..0edf024b08 100644
--- a/flytekit/types/schema/types_pandas.py
+++ b/flytekit/types/schema/types_pandas.py
@@ -15,7 +15,7 @@
 class ParquetIO(object):
     PARQUET_ENGINE = "pyarrow"
 
-    def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs) -> pandas.DataFrame:
+    def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame:
         return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs)
 
     def read(self, *files: os.PathLike, columns: typing.List[str] = None, **kwargs) -> pandas.DataFrame:
@@ -59,7 +59,7 @@ def write(
 class FastParquetIO(ParquetIO):
     PARQUET_ENGINE = "fastparquet"
 
-    def _read(self, chunk: os.PathLike, columns: typing.List[str], **kwargs) -> pandas.DataFrame:
+    def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame:
         from fastparquet import ParquetFile as _ParquetFile
         from fastparquet import thrift_structures as _ts