Skip to content

Commit

Permalink
Complex dataclass unit tests (#773)
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Dec 14, 2021
1 parent fed03d3 commit ec19076
Showing 1 changed file with 237 additions and 0 deletions.
237 changes: 237 additions & 0 deletions tests/flytekit/unit/core/test_complex_nesting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import os
import tempfile
from dataclasses import dataclass
from typing import List

import pytest
from dataclasses_json import dataclass_json

from flytekit.core.context_manager import ExecutionState, FlyteContextManager, Image, ImageConfig, SerializationSettings
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.type_engine import TypeEngine
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile


@dataclass_json
@dataclass
class MyProxyConfiguration:
# File and directory paths kept as 'str' so Flyte doesn't manage these static resources
splat_data_dir: str
apriori_file: str


@dataclass_json
@dataclass
class MyProxyParameters:
id: str
job_i_step: int


@dataclass_json
@dataclass
class MyAprioriConfiguration:
static_data_dir: FlyteDirectory
external_data_dir: FlyteDirectory


@dataclass_json
@dataclass
class MyInput:
main_product: FlyteFile
apriori_config: MyAprioriConfiguration
proxy_config: MyProxyConfiguration
proxy_params: MyProxyParameters


@pytest.fixture
def folders_and_files_setup():
tmp_dir = tempfile.TemporaryDirectory()
fd, path = tempfile.mkstemp(dir=tmp_dir.name)
tmp_dir_static_data = tempfile.TemporaryDirectory()
tmp_dir_external_data = tempfile.TemporaryDirectory()

try:
with os.fdopen(fd, "w") as tmp:
tmp.write("Hello world")
yield path, tmp_dir_static_data.name, tmp_dir_external_data.name
finally:
tmp_dir.cleanup()
tmp_dir_static_data.cleanup()
tmp_dir_external_data.cleanup()


@pytest.fixture
def two_sample_inputs(folders_and_files_setup):
(file_path, static_data_path, external_data_path) = folders_and_files_setup

main_product = FlyteFile(file_path)
apriori = MyAprioriConfiguration(
static_data_dir=FlyteDirectory(static_data_path),
external_data_dir=FlyteDirectory(external_data_path),
)
proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

my_input = MyInput(
main_product=main_product,
apriori_config=apriori,
proxy_config=proxy_c,
proxy_params=proxy_p,
)

my_input_2 = MyInput(
main_product=main_product,
apriori_config=apriori,
proxy_config=proxy_c,
proxy_params=proxy_p,
)

yield my_input, my_input_2


def test_dataclass_complex_transform(two_sample_inputs):
my_input = two_sample_inputs[0]
my_input_2 = two_sample_inputs[1]

ctx = FlyteContextManager.current_context()
literal_type = TypeEngine.to_literal_type(MyInput)
first_literal = TypeEngine.to_literal(ctx, my_input, MyInput, literal_type)
assert first_literal.scalar.generic["apriori_config"] is not None

converted_back_1 = TypeEngine.to_python_value(ctx, first_literal, MyInput)
assert converted_back_1.apriori_config is not None

second_literal = TypeEngine.to_literal(ctx, converted_back_1, MyInput, literal_type)
assert second_literal.scalar.generic["apriori_config"] is not None

converted_back_2 = TypeEngine.to_python_value(ctx, second_literal, MyInput)
assert converted_back_2.apriori_config is not None

input_list = [my_input, my_input_2]
input_list_type = TypeEngine.to_literal_type(List[MyInput])
literal_list = TypeEngine.to_literal(ctx, input_list, List[MyInput], input_list_type)
assert literal_list.collection.literals[0].scalar.generic["apriori_config"] is not None
assert literal_list.collection.literals[1].scalar.generic["apriori_config"] is not None


def test_two(two_sample_inputs):
my_input = two_sample_inputs[0]
my_input_2 = two_sample_inputs[1]

@dynamic
def dt1(a: List[MyInput]) -> List[FlyteFile]:
x = []
for aa in a:
x.append(aa.main_product)
return x

with FlyteContextManager.with_context(
FlyteContextManager.current_context().with_serialization_settings(
SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
)
) as ctx:
with FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
additional_context={
"dynamic_addl_distro": "s3://my-s3-bucket/fast/123",
"dynamic_dest_dir": "/User/flyte/workflows",
},
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(
ctx, d={"a": [my_input, my_input_2]}, guessed_python_types={"a": List[MyInput]}
)
dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2


def test_str_input(folders_and_files_setup):
proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

# Intentionally passing in the wrong type
my_input = MyInput(
main_product=folders_and_files_setup[0], # noqa
apriori_config=MyAprioriConfiguration(
static_data_dir=FlyteDirectory("gs://my-bucket/one"),
external_data_dir=FlyteDirectory("gs://my-bucket/two"),
),
proxy_config=proxy_c,
proxy_params=proxy_p,
)
ctx = FlyteContextManager.current_context()
literal_type = TypeEngine.to_literal_type(MyInput)
first_literal = TypeEngine.to_literal(ctx, my_input, MyInput, literal_type)
assert first_literal.scalar.generic is not None


def test_dc_dyn_directory(folders_and_files_setup):
proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

my_input_gcs = MyInput(
main_product=FlyteFile(folders_and_files_setup[0]),
apriori_config=MyAprioriConfiguration(
static_data_dir=FlyteDirectory("gs://my-bucket/one"),
external_data_dir=FlyteDirectory("gs://my-bucket/two"),
),
proxy_config=proxy_c,
proxy_params=proxy_p,
)

my_input_gcs_2 = MyInput(
main_product=FlyteFile(folders_and_files_setup[0]),
apriori_config=MyAprioriConfiguration(
static_data_dir=FlyteDirectory("gs://my-bucket/three"),
external_data_dir=FlyteDirectory("gs://my-bucket/four"),
),
proxy_config=proxy_c,
proxy_params=proxy_p,
)

@dynamic
def dt1(a: List[MyInput]) -> List[FlyteDirectory]:
x = []
for aa in a:
x.append(aa.apriori_config.external_data_dir)

return x

with FlyteContextManager.with_context(
FlyteContextManager.current_context().with_serialization_settings(
SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
)
) as ctx:
with FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
additional_context={
"dynamic_addl_distro": "s3://my-s3-bucket/fast/123",
"dynamic_dest_dir": "/User/flyte/workflows",
},
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(
ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, guessed_python_types={"a": List[MyInput]}
)
dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two"
assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"

0 comments on commit ec19076

Please sign in to comment.