Skip to content

Commit

Permalink
Accept actual type instead of string representation in Argument class (
Browse files Browse the repository at this point in the history
…#761)

Needed some refactoring while working on #751.

Submitting this as a separate PR for easy reviewing.
  • Loading branch information
RobbeSneyders authored Jan 8, 2024
1 parent b7962a4 commit 8e82844
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _add_and_parse_args(cls, operation_spec: OperationSpec):

parser.add_argument(
f"--{arg.name}",
type=arg.python_type, # type: ignore
type=arg.parser,
required=input_required,
default=default,
help=arg.description,
Expand Down
51 changes: 20 additions & 31 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import json
import pkgutil
import pydoc
import re
import types
import typing as t
Expand Down Expand Up @@ -33,36 +34,24 @@ class Argument:
"""

name: str
description: str
type: str
default: t.Any = None
type: t.Type
description: t.Optional[str] = None
default: t.Optional[t.Any] = None
optional: t.Optional[bool] = False

def __post_init__(self):
self.default = None if self.default == "None" else self.default

@property
def python_type(self) -> t.Any:
lookup = {
"str": str,
"int": int,
"float": float,
"bool": bool,
"dict": json.loads,
"list": json.loads,
}
map_fn = lookup[self.type]
return lambda value: map_fn(value) if value != "None" else None # type: ignore
self.parser = json.loads if self.type in [dict, list] else self.type

@property
def kubeflow_type(self) -> str:
lookup = {
"str": "STRING",
"int": "NUMBER_INTEGER",
"float": "NUMBER_DOUBLE",
"bool": "BOOLEAN",
"dict": "STRUCT",
"list": "LIST",
str: "STRING",
int: "NUMBER_INTEGER",
float: "NUMBER_DOUBLE",
bool: "BOOLEAN",
dict: "STRUCT",
list: "LIST",
}
return lookup[self.type]

Expand Down Expand Up @@ -208,7 +197,7 @@ def args(self) -> t.Mapping[str, Argument]:
name: Argument(
name=name,
description=arg_info["description"],
type=arg_info["type"],
type=pydoc.locate(arg_info["type"]), # type: ignore
default=arg_info["default"] if "default" in arg_info else None,
optional=arg_info.get("default") == "None",
)
Expand All @@ -228,48 +217,48 @@ def default_arguments(self) -> t.Dict[str, Argument]:
"input_manifest_path": Argument(
name="input_manifest_path",
description="Path to the input manifest",
type="str",
type=str,
optional=True,
),
"operation_spec": Argument(
name="operation_spec",
description="The operation specification as a dictionary",
type="str",
type=str,
),
"input_partition_rows": Argument(
name="input_partition_rows",
description="The number of rows to load per partition. \
Set to override the automatic partitioning",
type="int",
type=int,
optional=True,
),
"cache": Argument(
name="cache",
description="Set to False to disable caching, True by default.",
type="bool",
type=bool,
default=True,
),
"cluster_type": Argument(
name="cluster_type",
description="The cluster type to use for the execution",
type="str",
type=str,
default="default",
),
"client_kwargs": Argument(
name="client_kwargs",
description="Keyword arguments to pass to the Dask client",
type="dict",
type=dict,
default={},
),
"metadata": Argument(
name="metadata",
description="Metadata arguments containing the run id and base path",
type="str",
type=str,
),
"output_manifest_path": Argument(
name="output_manifest_path",
description="Path to the output manifest",
type="str",
type=str,
),
}

Expand Down

0 comments on commit 8e82844

Please sign in to comment.