Skip to content

Commit

Permalink
Allow struct/dataclass types to be used as default arguments (#1024)
Browse files Browse the repository at this point in the history
* Allow struct/dataclass types to be used as default arguments

Signed-off-by: Roberto Ruiz <[email protected]>

* Add test case

Signed-off-by: Eduardo Apolinario <[email protected]>

* Lint

Signed-off-by: Eduardo Apolinario <[email protected]>

Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
RobertoRRW and eapolinario authored May 27, 2022
1 parent 065b899 commit 5c1395a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
5 changes: 4 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def convert_to_literal(

if self._literal_type.simple or self._literal_type.enum_type:
if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT:
o = cast(DataClassJsonMixin, self._python_type).from_json(value)
if type(value) != self._python_type:
o = cast(DataClassJsonMixin, self._python_type).from_json(value)
else:
o = value
return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type)
return Literal(scalar=self._converter.convert(value, self._python_type))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass

from dataclasses_json import dataclass_json

from flytekit import task, workflow


@dataclass_json
@dataclass
class DataclassA:
a: str
b: int


@task
def t(dca: DataclassA):
print(dca)


@workflow
def wf(dca: DataclassA = DataclassA("hello", 42)):
t(dca=dca)
16 changes: 16 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,19 @@ def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch):
)
assert result.stdout.strip() == "wow"
assert result.exit_code == 0


def test_dataclasses_default_arguments():
runner = CliRunner()
dir_name = os.path.dirname(os.path.realpath(__file__))
result = runner.invoke(
pyflyte.main,
[
"run",
os.path.join(dir_name, "dataclasses_default_arguments", "wf.py"),
"wf",
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0

0 comments on commit 5c1395a

Please sign in to comment.