Skip to content

Commit

Permalink
Schema overhaul (#785)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
pingsutw authored Jan 13, 2022
1 parent ad3d64f commit 35a5724
Show file tree
Hide file tree
Showing 33 changed files with 1,934 additions and 35 deletions.
12 changes: 6 additions & 6 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bcrypt==3.2.0
# via
# -c requirements.txt
# paramiko
# secretstorage
binaryornot==0.4.4
# via
# -c requirements.txt
Expand Down Expand Up @@ -77,7 +78,6 @@ cryptography==36.0.1
# via
# -c requirements.txt
# paramiko
# secretstorage
dataclasses-json==0.5.6
# via
# -c requirements.txt
Expand Down Expand Up @@ -118,7 +118,7 @@ docstring-parser==0.13
# flytekit
filelock==3.4.0
# via virtualenv
flyteidl==0.21.13
flyteidl==0.21.17
# via
# -c requirements.txt
# flytekit
Expand Down Expand Up @@ -181,6 +181,10 @@ marshmallow-jsonschema==0.13.0
# via
# -c requirements.txt
# flytekit
secretstorage==3.3.1
# via
# -c requirements.txt
# keyring
mock==4.0.3
# via -r dev-requirements.in
mypy==0.930
Expand Down Expand Up @@ -315,10 +319,6 @@ retry==0.9.2
# via
# -c requirements.txt
# flytekit
secretstorage==3.3.1
# via
# -c requirements.txt
# keyring
six==1.16.0
# via
# -c requirements.txt
Expand Down
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types import directory, file, schema
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetType

__version__ = "0.0.0+develop"

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def find_plugin(cls, path: str) -> typing.Type[DataPersistence]:
Returns a plugin for the given protocol, else raise a TypeError
"""
for k, p in cls._PLUGINS.items():
if path.startswith(k):
if path.startswith(k) or path.startswith(k.replace("://", "")):
return p
raise TypeError(f"No plugin found for matching protocol of path {path}")

Expand Down
6 changes: 5 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,11 @@ def transform_function_to_interface(fn: Callable, docstring: Optional[Docstring]
For now the fancy object, maybe in the future a dumb object.
"""
type_hints = typing.get_type_hints(fn)
try:
# include_extras can only be used in python >= 3.9
type_hints = typing.get_type_hints(fn, include_extras=True)
except TypeError:
type_hints = typing.get_type_hints(fn)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)

Expand Down
41 changes: 40 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Type, cast

try:
from typing import Annotated, get_args, get_origin
except ImportError:
from typing_extensions import Annotated, get_origin, get_args

from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
from google.protobuf import reflection as _proto_reflection
Expand Down Expand Up @@ -37,8 +42,9 @@
Primitive,
Scalar,
Schema,
StructuredDatasetMetadata,
)
from flytekit.models.types import LiteralType, SimpleType
from flytekit.models.types import LiteralType, SimpleType, StructuredDatasetType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
Expand Down Expand Up @@ -275,6 +281,7 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
from flytekit.types.directory.types import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.schema.types import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset

for f in dataclasses.fields(python_type):
v = python_val.__getattribute__(f.name)
Expand All @@ -283,6 +290,7 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
issubclass(field_type, FlyteSchema)
or issubclass(field_type, FlyteFile)
or issubclass(field_type, FlyteDirectory)
or issubclass(field_type, StructuredDataset)
):
lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None)
# dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
Expand All @@ -295,6 +303,13 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
# as determined by the transformer.
if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory):
python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri))
elif issubclass(field_type, StructuredDataset):
python_val.__setattr__(
f.name,
field_type(
uri=lv.scalar.structured_dataset.uri,
),
)

elif dataclasses.is_dataclass(field_type):
self._serialize_flyte_type(v, field_type)
Expand All @@ -303,6 +318,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

if not dataclasses.is_dataclass(expected_python_type):
return python_val
Expand Down Expand Up @@ -348,6 +364,21 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
),
expected_python_type,
)
elif issubclass(expected_python_type, StructuredDataset):
return StructuredDatasetTransformerEngine().to_python_value(
FlyteContext.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=python_val.file_format)
),
uri=python_val.uri,
)
)
),
expected_python_type,
)
else:
for f in dataclasses.fields(expected_python_type):
value = python_val.__getattribute__(f.name)
Expand Down Expand Up @@ -492,6 +523,11 @@ def register_restricted_type(
cls._RESTRICTED_TYPES.append(type)
cls.register(RestrictedTypeTransformer(name, type))

@classmethod
def register_additional_type(cls, transformer: TypeTransformer, additional_type: Type, override=False):
if additional_type not in cls._REGISTRY or override:
cls._REGISTRY[additional_type] = transformer

@classmethod
def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
"""
Expand All @@ -517,6 +553,9 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
"""
# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]

if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]

Expand Down
2 changes: 2 additions & 0 deletions flytekit/models/core/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing

from flyteidl.core import types_pb2 as _types_pb2
Expand Down
72 changes: 71 additions & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from datetime import datetime as _datetime

import pytz as _pytz
Expand All @@ -9,6 +10,7 @@
from flytekit.models.core import types as _core_types
from flytekit.models.types import OutputReference as _OutputReference
from flytekit.models.types import SchemaType as _SchemaType
from flytekit.models.types import StructuredDatasetType


class RetryStrategy(_common.FlyteIdlEntity):
Expand Down Expand Up @@ -546,6 +548,54 @@ def from_flyte_idl(cls, pb2_object):
return cls(uri=pb2_object.uri, type=_SchemaType.from_flyte_idl(pb2_object.type))


class StructuredDatasetMetadata(_common.FlyteIdlEntity):
def __init__(self, structured_dataset_type: StructuredDatasetType = None):
self._structured_dataset_type = structured_dataset_type

@property
def structured_dataset_type(self) -> StructuredDatasetType:
return self._structured_dataset_type

def to_flyte_idl(self) -> _literals_pb2.StructuredDatasetMetadata:
return _literals_pb2.StructuredDatasetMetadata(
structured_dataset_type=self.structured_dataset_type.to_flyte_idl()
if self._structured_dataset_type
else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _literals_pb2.StructuredDatasetMetadata) -> "StructuredDatasetMetadata":
return cls(
structured_dataset_type=StructuredDatasetType.from_flyte_idl(pb2_object.structured_dataset_type),
)


class StructuredDataset(_common.FlyteIdlEntity):
def __init__(self, uri: str, metadata: typing.Optional[StructuredDatasetMetadata] = None):
"""
A strongly typed schema that defines the interface of data retrieved from the underlying storage medium.
"""
self._uri = uri
self._metadata = metadata

@property
def uri(self) -> str:
return self._uri

@property
def metadata(self) -> StructuredDatasetMetadata:
return self._metadata

def to_flyte_idl(self) -> _literals_pb2.StructuredDataset:
return _literals_pb2.StructuredDataset(
uri=self.uri, metadata=self.metadata.to_flyte_idl() if self.metadata else None
)

@classmethod
def from_flyte_idl(cls, pb2_object: _literals_pb2.StructuredDataset) -> "StructuredDataset":
return cls(uri=pb2_object.uri, metadata=StructuredDatasetMetadata.from_flyte_idl(pb2_object.metadata))


class LiteralCollection(_common.FlyteIdlEntity):
def __init__(self, literals):
"""
Expand Down Expand Up @@ -615,6 +665,7 @@ def __init__(
none_type: Void = None,
error=None,
generic: Struct = None,
structured_dataset: StructuredDataset = None,
):
"""
Scalar wrapper around Flyte types. Only one can be specified.
Expand All @@ -626,6 +677,7 @@ def __init__(
:param Void none_type:
:param error:
:param google.protobuf.struct_pb2.Struct generic:
:param StructuredDataset structured_dataset:
"""

self._primitive = primitive
Expand All @@ -635,6 +687,7 @@ def __init__(
self._none_type = none_type
self._error = error
self._generic = generic
self._structured_dataset = structured_dataset

@property
def primitive(self):
Expand Down Expand Up @@ -685,13 +738,26 @@ def generic(self):
"""
return self._generic

@property
def structured_dataset(self) -> StructuredDataset:
return self._structured_dataset

@property
def value(self):
"""
Returns whichever value is set
:rtype: T
"""
return self.primitive or self.blob or self.binary or self.schema or self.none_type or self.error
return (
self.primitive
or self.blob
or self.binary
or self.schema
or self.none_type
or self.error
or self.generic
or self.structured_dataset
)

def to_flyte_idl(self):
"""
Expand All @@ -705,6 +771,7 @@ def to_flyte_idl(self):
none_type=self.none_type.to_flyte_idl() if self.none_type is not None else None,
error=self.error if self.error is not None else None,
generic=self.generic,
structured_dataset=self.structured_dataset.to_flyte_idl() if self.structured_dataset is not None else None,
)

@classmethod
Expand All @@ -722,6 +789,9 @@ def from_flyte_idl(cls, pb2_object):
none_type=Void.from_flyte_idl(pb2_object.none_type) if pb2_object.HasField("none_type") else None,
error=pb2_object.error if pb2_object.HasField("error") else None,
generic=pb2_object.generic if pb2_object.HasField("generic") else None,
structured_dataset=StructuredDataset.from_flyte_idl(pb2_object.structured_dataset)
if pb2_object.HasField("structured_dataset")
else None,
)


Expand Down
Loading

0 comments on commit 35a5724

Please sign in to comment.