diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 397fff3aa8..ba82ecb969 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -266,7 +266,10 @@ def convert_to_literal( if self._literal_type.simple or self._literal_type.enum_type: if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: - o = cast(DataClassJsonMixin, self._python_type).from_json(value) + if type(value) != self._python_type: + o = cast(DataClassJsonMixin, self._python_type).from_json(value) + else: + o = value return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) return Literal(scalar=self._converter.convert(value, self._python_type)) diff --git a/tests/flytekit/unit/cli/pyflyte/dataclasses_default_arguments/wf.py b/tests/flytekit/unit/cli/pyflyte/dataclasses_default_arguments/wf.py new file mode 100644 index 0000000000..d9ba207cf2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/dataclasses_default_arguments/wf.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +from dataclasses_json import dataclass_json + +from flytekit import task, workflow + + +@dataclass_json +@dataclass +class DataclassA: + a: str + b: int + + +@task +def t(dca: DataclassA): + print(dca) + + +@workflow +def wf(dca: DataclassA = DataclassA("hello", 42)): + t(dca=dca) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 797e2e00d5..9d97e6524d 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -94,3 +94,19 @@ def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch): ) assert result.stdout.strip() == "wow" assert result.exit_code == 0 + + +def test_dataclasses_default_arguments(): + runner = CliRunner() + dir_name = os.path.dirname(os.path.realpath(__file__)) + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(dir_name, "dataclasses_default_arguments", "wf.py"), + "wf", + ], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0