Skip to content

Commit

Permalink
Open HashMethod to all types
Browse files Browse the repository at this point in the history
  • Loading branch information
eapolinario committed Sep 22, 2022
1 parent 2ccaed7 commit 6fb9c59
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 25 deletions.
11 changes: 2 additions & 9 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ class TypeTransformer(typing.Generic[T]):
Base transformer type that should be implemented for every python native type that can be handled by flytekit
"""

def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True, hash_overridable: bool = False):
def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
# `hash_overridable` indicates that the literals produced by this type transformer can set their hashes if needed.
# See (link to documentation where this feature is explained).
self._hash_overridable = hash_overridable

@property
def name(self):
Expand All @@ -88,10 +85,6 @@ def type_assertions_enabled(self) -> bool:
"""
return self._type_assertions_enabled

@property
def hash_overridable(self) -> bool:
return self._hash_overridable

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, t):
raise TypeTransformerFailedError(f"Type of Val '{v}' is not an instance of {t}")
Expand Down Expand Up @@ -742,7 +735,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type

# In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
hash = None
if transformer.hash_overridable and get_origin(python_type) is Annotated:
if get_origin(python_type) is Annotated:
# We are now dealing with one of two cases:
# 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using
# the method indicated in the annotation.
Expand Down
1 change: 0 additions & 1 deletion flytekit/types/schema/types_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class PandasDataFrameTransformer(TypeTransformer[pandas.DataFrame]):
def __init__(self):
super().__init__("PandasDataFrame<->GenericSchema", pandas.DataFrame)
self._parquet_engine = ParquetIO()
self._hash_overridable = True

@staticmethod
def _get_schema_type() -> SchemaType:
Expand Down
3 changes: 0 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,6 @@ def __init__(self):
super().__init__("StructuredDataset Transformer", StructuredDataset)
self._type_assertions_enabled = False

# Instances of StructuredDataset opt-in to the ability of being cached.
self._hash_overridable = True

@classmethod
def register_renderer(cls, python_type: Type, renderer: Renderable):
cls.Renderers[python_type] = renderer
Expand Down
16 changes: 8 additions & 8 deletions tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def my_wf(a: int, b: str) -> (int, str):
assert n_cached_task_calls == 2


def test_set_integer_literal_hash_is_not_cached():
def test_set_integer_literal_hash_is_cached():
"""
Test to confirm that the local cache is not set in the case of integers, even if we
Test to confirm that the local cache is set in the case of integers, even if we
return an annotated integer. In order to make this very explicit, we define a constant hash
function, i.e. the same value is returned by it regardless of the input.
"""
Expand All @@ -289,13 +289,13 @@ def wf(a: int) -> int:
assert n_cached_task_calls == 0
assert wf(a=3) == 3
assert n_cached_task_calls == 1
# Confirm that the value is not cached, even though we set a hash function that
# returns a constant value and that the task has only one input.
assert wf(a=2) == 2
assert n_cached_task_calls == 2
# Confirm that the value is cached due to the fact the hash value is constant, regardless
# of the value passed to the cacheable task.
assert wf(a=2) == 3
assert n_cached_task_calls == 1
# Confirm that the cache is hit if we execute the workflow with the same value as previous run.
assert wf(a=2) == 2
assert n_cached_task_calls == 2
assert wf(a=2) == 3
assert n_cached_task_calls == 1


def test_pass_annotated_to_downstream_tasks():
Expand Down
7 changes: 3 additions & 4 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,17 +1254,16 @@ def t1(a: int) -> int:
assert t1(a=3) == 9


def test_literal_hash_int_not_set():
def test_literal_hash_int_can_be_set():
"""
Test to confirm that annotating an integer with `HashMethod` does not force the literal to have its
hash set.
Test to confirm that annotating an integer with `HashMethod` is allowed.
"""
ctx = FlyteContext.current_context()
lv = TypeEngine.to_literal(
ctx, 42, Annotated[int, HashMethod(str)], LiteralType(simple=model_types.SimpleType.INTEGER)
)
assert lv.scalar.primitive.integer == 42
assert lv.hash is None
assert lv.hash == "42"


def test_literal_hash_to_python_value():
Expand Down

0 comments on commit 6fb9c59

Please sign in to comment.